206 lines
8.6 KiB
Python
206 lines
8.6 KiB
Python
"""NES background encoder: 4 sub-palettes + per-16x16 attribute + 256-tile CHR.
|
|
|
|
The NES PPU draws a 32x30 grid of 8x8 tiles (pattern table, <=256 unique), each
|
|
tile 2bpp. Colour comes from 4 background sub-palettes (each = a shared universal
|
|
background + 3 colours), one chosen per 16x16 region via the attribute table. So
|
|
the pipeline is: choose the universal bg, cluster the image into 4 sub-palettes,
|
|
assign each region its best one, dither, then vector-quantise the 8x8 tile
|
|
patterns down to 256 CHR tiles.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import numpy as np
|
|
|
|
from ... import dither, palette as c64pal
|
|
from ...convert import base
|
|
from .. import palette as npal
|
|
|
|
W, H = 256, 240
|
|
RCOLS, RROWS = 16, 15 # 16x16-pixel regions (2x2 tiles)
|
|
TCOLS, TROWS = 32, 30 # 8x8 tiles
|
|
NTILES = 256
|
|
|
|
|
|
def _best_colors(pix_lab, plab, bg0, n, k_cand=20):
|
|
"""Best ``n`` NES colours (besides fixed bg0) for a pool of pixels."""
|
|
mean_d = np.sum((pix_lab.mean(0)[None, :] - plab) ** 2, 1)
|
|
cand = [c for c in np.argsort(mean_d)[:k_cand] if c != bg0]
|
|
dist = np.sum((pix_lab[:, None, :] - plab[None, cand, :]) ** 2, 2) # (px, k)
|
|
d_bg = np.sum((pix_lab - plab[bg0]) ** 2, 1) # (px,)
|
|
from itertools import combinations
|
|
best, best_err = None, np.inf
|
|
for combo in combinations(range(len(cand)), n):
|
|
m = np.minimum(d_bg, dist[:, combo].min(1)).sum()
|
|
if m < best_err:
|
|
best_err, best = m, [cand[i] for i in combo]
|
|
return best
|
|
|
|
|
|
def _tile_codebook(patterns, k, iters=8):
|
|
"""k-medoids over 8x8 pen patterns (values 0-3); distance = differing pixels."""
|
|
uniq, counts = np.unique(patterns, axis=0, return_counts=True)
|
|
if len(uniq) <= k:
|
|
code = np.zeros((k, patterns.shape[1]), patterns.dtype)
|
|
code[:len(uniq)] = uniq
|
|
lut = {tuple(p): i for i, p in enumerate(uniq)}
|
|
return code, np.array([lut[tuple(p)] for p in patterns])
|
|
code = uniq[np.argsort(-counts)[:k]].copy()
|
|
labels = np.zeros(len(patterns), np.int64)
|
|
for _ in range(iters):
|
|
for s in range(0, len(patterns), 2048):
|
|
blk = patterns[s:s + 2048]
|
|
labels[s:s + 2048] = (blk[:, None, :] != code[None]).sum(2).argmin(1)
|
|
moved = False
|
|
for j in range(k):
|
|
mem = patterns[labels == j]
|
|
if len(mem):
|
|
med = np.array([np.bincount(mem[:, p], minlength=4).argmax()
|
|
for p in range(mem.shape[1])], patterns.dtype)
|
|
if not np.array_equal(med, code[j]):
|
|
code[j] = med; moved = True
|
|
if not moved:
|
|
break
|
|
for s in range(0, len(patterns), 2048):
|
|
blk = patterns[s:s + 2048]
|
|
labels[s:s + 2048] = (blk[:, None, :] != code[None]).sum(2).argmin(1)
|
|
return code, labels
|
|
|
|
|
|
def encode(img_rgb, dither_mode, subpalettes):
|
|
"""subpalettes: list of 4 lists [bg0,c1,c2,c3] of NES colour indices (the
|
|
universal bg = subpalettes[*][0], identical across all four)."""
|
|
plab = npal.palette_lab()
|
|
prgb = npal.get_palette().astype(np.uint8)
|
|
img_lab = c64pal.srgb_to_lab(img_rgb)
|
|
bg0 = subpalettes[0][0]
|
|
sp = np.array(subpalettes) # (4,4) NES indices
|
|
|
|
# assign each 16x16 region to the sub-palette giving least nearest-colour error
|
|
region_pal = np.zeros((RROWS, RCOLS), np.int64)
|
|
for ry in range(RROWS):
|
|
for rx in range(RCOLS):
|
|
blk = img_lab[ry * 16:ry * 16 + 16, rx * 16:rx * 16 + 16].reshape(-1, 3)
|
|
errs = []
|
|
for s in range(4):
|
|
d = np.sum((blk[:, None, :] - plab[sp[s]][None, :, :]) ** 2, 2)
|
|
errs.append(d.min(1).sum())
|
|
region_pal[ry, rx] = int(np.argmin(errs))
|
|
|
|
# per-pixel allowed colours = the pixel's region sub-palette; dither
|
|
allowed = np.zeros((H, W, 4), np.int64)
|
|
for ry in range(RROWS):
|
|
for rx in range(RCOLS):
|
|
allowed[ry * 16:ry * 16 + 16, rx * 16:rx * 16 + 16] = sp[region_pal[ry, rx]]
|
|
idx = dither.quantize(img_lab, allowed, plab, dither_mode).astype(np.int64)
|
|
|
|
# pen per pixel = position of its colour within its region's sub-palette
|
|
pen = np.zeros((H, W), np.uint8)
|
|
for ry in range(RROWS):
|
|
for rx in range(RCOLS):
|
|
ys, xs = slice(ry * 16, ry * 16 + 16), slice(rx * 16, rx * 16 + 16)
|
|
pal = sp[region_pal[ry, rx]]
|
|
block = idx[ys, xs]
|
|
pmap = np.zeros(block.shape, np.uint8)
|
|
for k, col in enumerate(pal):
|
|
pmap[block == col] = k
|
|
pen[ys, xs] = pmap
|
|
|
|
# 8x8 tiles -> patterns; vector-quantise to <=256 CHR tiles
|
|
tiles = pen.reshape(TROWS, 8, TCOLS, 8).transpose(0, 2, 1, 3).reshape(TROWS * TCOLS, 64)
|
|
code, labels = _tile_codebook(tiles, NTILES)
|
|
nametable = labels.astype(np.uint8).reshape(TROWS, TCOLS)
|
|
|
|
# ---- emit NES data ----
|
|
chr_rom = bytearray(8192)
|
|
for t in range(NTILES):
|
|
pat = code[t].reshape(8, 8)
|
|
for r in range(8):
|
|
p0 = p1 = 0
|
|
for x in range(8):
|
|
v = int(pat[r, x])
|
|
p0 |= (v & 1) << (7 - x)
|
|
p1 |= ((v >> 1) & 1) << (7 - x)
|
|
chr_rom[t * 16 + r] = p0
|
|
chr_rom[t * 16 + 8 + r] = p1
|
|
|
|
attr = bytearray(64)
|
|
for ar in range(8):
|
|
for ac in range(8):
|
|
b = 0
|
|
for q, (dy, dx) in enumerate(((0, 0), (0, 1), (1, 0), (1, 1))):
|
|
ry, rx = ar * 2 + dy, ac * 2 + dx
|
|
s = int(region_pal[ry, rx]) if ry < RROWS and rx < RCOLS else 0
|
|
b |= (s & 3) << (q * 2)
|
|
attr[ar * 8 + ac] = b
|
|
|
|
pal32 = bytearray(32)
|
|
for s in range(4):
|
|
pal32[s * 4] = bg0
|
|
pal32[s * 4 + 1] = int(sp[s][1])
|
|
pal32[s * 4 + 2] = int(sp[s][2])
|
|
pal32[s * 4 + 3] = int(sp[s][3])
|
|
pal32[16:32] = pal32[0:16] # sprite palette = mirror
|
|
nametable_full = bytes(nametable.reshape(-1)) + bytes(attr) # 960 + 64
|
|
|
|
# rebuild the displayed image (clustered tiles + region palettes) for preview
|
|
disp = code[labels].reshape(TROWS, TCOLS, 8, 8).transpose(0, 2, 1, 3).reshape(H, W)
|
|
final_idx = np.zeros((H, W), np.uint16)
|
|
for ry in range(RROWS):
|
|
for rx in range(RCOLS):
|
|
ys, xs = slice(ry * 16, ry * 16 + 16), slice(rx * 16, rx * 16 + 16)
|
|
pal = sp[region_pal[ry, rx]]
|
|
final_idx[ys, xs] = pal[disp[ys, xs]]
|
|
|
|
err = base.perceptual_error(final_idx, img_lab, plab)
|
|
return bytes(pal32), nametable_full, bytes(chr_rom), final_idx, err, plab, prgb
|
|
|
|
|
|
def pick_subpalettes(img_rgb, n_groups=4, mono=False, base_color=None):
|
|
"""Choose the universal bg + ``n_groups`` sub-palettes (each bg + 3 colours)."""
|
|
plab = npal.palette_lab()
|
|
img_lab = c64pal.srgb_to_lab(img_rgb)
|
|
|
|
if mono:
|
|
greys = sorted(npal.GREYS, key=lambda i: plab[i, 0])
|
|
if base_color in range(64):
|
|
ramp = sorted({greys[0], int(base_color), greys[-1]}, key=lambda i: plab[i, 0])
|
|
else:
|
|
# 4 greys spanning black->white (include the lightest so highlights
|
|
# actually reach white -- otherwise the image comes out muddy/dark)
|
|
lums = np.array([plab[i, 0] for i in greys])
|
|
ramp = [greys[int(np.argmin(np.abs(lums - t)))]
|
|
for t in np.linspace(lums.min(), lums.max(), 4)]
|
|
bg0 = ramp[0]
|
|
others = [c for c in ramp if c != bg0][:3]
|
|
while len(others) < 3:
|
|
others.append(others[-1])
|
|
return [[bg0] + others] * 4
|
|
|
|
bg0 = base.best_global_color(img_lab, plab)
|
|
# cluster regions by mean colour into n_groups, then pick 3 colours per group
|
|
feats = []
|
|
for ry in range(RROWS):
|
|
for rx in range(RCOLS):
|
|
feats.append(img_lab[ry * 16:ry * 16 + 16, rx * 16:rx * 16 + 16]
|
|
.reshape(-1, 3).mean(0))
|
|
feats = np.array(feats)
|
|
rng = np.random.default_rng(0)
|
|
cen = feats[rng.choice(len(feats), n_groups, replace=False)]
|
|
for _ in range(12):
|
|
lab = np.argmin(np.sum((feats[:, None, :] - cen[None]) ** 2, 2), 1)
|
|
for g in range(n_groups):
|
|
if (lab == g).any():
|
|
cen[g] = feats[lab == g].mean(0)
|
|
subs = []
|
|
for g in range(n_groups):
|
|
members = np.where(lab == g)[0]
|
|
if len(members) == 0:
|
|
subs.append([bg0, bg0, bg0, bg0]); continue
|
|
pool = np.concatenate([
|
|
img_lab[(m // RCOLS) * 16:(m // RCOLS) * 16 + 16,
|
|
(m % RCOLS) * 16:(m % RCOLS) * 16 + 16].reshape(-1, 3)
|
|
for m in members])
|
|
if len(pool) > 4000:
|
|
pool = pool[rng.choice(len(pool), 4000, replace=False)]
|
|
subs.append([bg0] + _best_colors(pool, plab, bg0, 3))
|
|
return subs
|