"""NES background encoder: 4 sub-palettes + per-16x16 attribute + 256-tile CHR. The NES PPU draws a 32x30 grid of 8x8 tiles (pattern table, <=256 unique), each tile 2bpp. Colour comes from 4 background sub-palettes (each = a shared universal background + 3 colours), one chosen per 16x16 region via the attribute table. So the pipeline is: choose the universal bg, cluster the image into 4 sub-palettes, assign each region its best one, dither, then vector-quantise the 8x8 tile patterns down to 256 CHR tiles. """ from __future__ import annotations import numpy as np from ... import dither, palette as c64pal from ...convert import base from .. import palette as npal W, H = 256, 240 RCOLS, RROWS = 16, 15 # 16x16-pixel regions (2x2 tiles) TCOLS, TROWS = 32, 30 # 8x8 tiles NTILES = 256 def _best_colors(pix_lab, plab, bg0, n, k_cand=20): """Best ``n`` NES colours (besides fixed bg0) for a pool of pixels.""" mean_d = np.sum((pix_lab.mean(0)[None, :] - plab) ** 2, 1) cand = [c for c in np.argsort(mean_d)[:k_cand] if c != bg0] dist = np.sum((pix_lab[:, None, :] - plab[None, cand, :]) ** 2, 2) # (px, k) d_bg = np.sum((pix_lab - plab[bg0]) ** 2, 1) # (px,) from itertools import combinations best, best_err = None, np.inf for combo in combinations(range(len(cand)), n): m = np.minimum(d_bg, dist[:, combo].min(1)).sum() if m < best_err: best_err, best = m, [cand[i] for i in combo] return best def _tile_codebook(patterns, k, iters=8): """k-medoids over 8x8 pen patterns (values 0-3); distance = differing pixels.""" uniq, counts = np.unique(patterns, axis=0, return_counts=True) if len(uniq) <= k: code = np.zeros((k, patterns.shape[1]), patterns.dtype) code[:len(uniq)] = uniq lut = {tuple(p): i for i, p in enumerate(uniq)} return code, np.array([lut[tuple(p)] for p in patterns]) code = uniq[np.argsort(-counts)[:k]].copy() labels = np.zeros(len(patterns), np.int64) for _ in range(iters): for s in range(0, len(patterns), 2048): blk = patterns[s:s + 2048] labels[s:s + 2048] = (blk[:, None, :] != code[None]).sum(2).argmin(1) moved = False for j in range(k): mem = patterns[labels == j] if len(mem): med = np.array([np.bincount(mem[:, p], minlength=4).argmax() for p in range(mem.shape[1])], patterns.dtype) if not np.array_equal(med, code[j]): code[j] = med; moved = True if not moved: break for s in range(0, len(patterns), 2048): blk = patterns[s:s + 2048] labels[s:s + 2048] = (blk[:, None, :] != code[None]).sum(2).argmin(1) return code, labels def encode(img_rgb, dither_mode, subpalettes): """subpalettes: list of 4 lists [bg0,c1,c2,c3] of NES colour indices (the universal bg = subpalettes[*][0], identical across all four).""" plab = npal.palette_lab() prgb = npal.get_palette().astype(np.uint8) img_lab = c64pal.srgb_to_lab(img_rgb) bg0 = subpalettes[0][0] sp = np.array(subpalettes) # (4,4) NES indices # assign each 16x16 region to the sub-palette giving least nearest-colour error region_pal = np.zeros((RROWS, RCOLS), np.int64) for ry in range(RROWS): for rx in range(RCOLS): blk = img_lab[ry * 16:ry * 16 + 16, rx * 16:rx * 16 + 16].reshape(-1, 3) errs = [] for s in range(4): d = np.sum((blk[:, None, :] - plab[sp[s]][None, :, :]) ** 2, 2) errs.append(d.min(1).sum()) region_pal[ry, rx] = int(np.argmin(errs)) # per-pixel allowed colours = the pixel's region sub-palette; dither allowed = np.zeros((H, W, 4), np.int64) for ry in range(RROWS): for rx in range(RCOLS): allowed[ry * 16:ry * 16 + 16, rx * 16:rx * 16 + 16] = sp[region_pal[ry, rx]] idx = dither.quantize(img_lab, allowed, plab, dither_mode).astype(np.int64) # pen per pixel = position of its colour within its region's sub-palette pen = np.zeros((H, W), np.uint8) for ry in range(RROWS): for rx in range(RCOLS): ys, xs = slice(ry * 16, ry * 16 + 16), slice(rx * 16, rx * 16 + 16) pal = sp[region_pal[ry, rx]] block = idx[ys, xs] pmap = np.zeros(block.shape, np.uint8) for k, col in enumerate(pal): pmap[block == col] = k pen[ys, xs] = pmap # 8x8 tiles -> patterns; vector-quantise to <=256 CHR tiles tiles = pen.reshape(TROWS, 8, TCOLS, 8).transpose(0, 2, 1, 3).reshape(TROWS * TCOLS, 64) code, labels = _tile_codebook(tiles, NTILES) nametable = labels.astype(np.uint8).reshape(TROWS, TCOLS) # ---- emit NES data ---- chr_rom = bytearray(8192) for t in range(NTILES): pat = code[t].reshape(8, 8) for r in range(8): p0 = p1 = 0 for x in range(8): v = int(pat[r, x]) p0 |= (v & 1) << (7 - x) p1 |= ((v >> 1) & 1) << (7 - x) chr_rom[t * 16 + r] = p0 chr_rom[t * 16 + 8 + r] = p1 attr = bytearray(64) for ar in range(8): for ac in range(8): b = 0 for q, (dy, dx) in enumerate(((0, 0), (0, 1), (1, 0), (1, 1))): ry, rx = ar * 2 + dy, ac * 2 + dx s = int(region_pal[ry, rx]) if ry < RROWS and rx < RCOLS else 0 b |= (s & 3) << (q * 2) attr[ar * 8 + ac] = b pal32 = bytearray(32) for s in range(4): pal32[s * 4] = bg0 pal32[s * 4 + 1] = int(sp[s][1]) pal32[s * 4 + 2] = int(sp[s][2]) pal32[s * 4 + 3] = int(sp[s][3]) pal32[16:32] = pal32[0:16] # sprite palette = mirror nametable_full = bytes(nametable.reshape(-1)) + bytes(attr) # 960 + 64 # rebuild the displayed image (clustered tiles + region palettes) for preview disp = code[labels].reshape(TROWS, TCOLS, 8, 8).transpose(0, 2, 1, 3).reshape(H, W) final_idx = np.zeros((H, W), np.uint16) for ry in range(RROWS): for rx in range(RCOLS): ys, xs = slice(ry * 16, ry * 16 + 16), slice(rx * 16, rx * 16 + 16) pal = sp[region_pal[ry, rx]] final_idx[ys, xs] = pal[disp[ys, xs]] err = base.perceptual_error(final_idx, img_lab, plab) return bytes(pal32), nametable_full, bytes(chr_rom), final_idx, err, plab, prgb def pick_subpalettes(img_rgb, n_groups=4, mono=False, base_color=None): """Choose the universal bg + ``n_groups`` sub-palettes (each bg + 3 colours).""" plab = npal.palette_lab() img_lab = c64pal.srgb_to_lab(img_rgb) if mono: greys = sorted(npal.GREYS, key=lambda i: plab[i, 0]) if base_color in range(64): ramp = sorted({greys[0], int(base_color), greys[-1]}, key=lambda i: plab[i, 0]) else: # 4 greys spanning black->white (include the lightest so highlights # actually reach white -- otherwise the image comes out muddy/dark) lums = np.array([plab[i, 0] for i in greys]) ramp = [greys[int(np.argmin(np.abs(lums - t)))] for t in np.linspace(lums.min(), lums.max(), 4)] bg0 = ramp[0] others = [c for c in ramp if c != bg0][:3] while len(others) < 3: others.append(others[-1]) return [[bg0] + others] * 4 bg0 = base.best_global_color(img_lab, plab) # cluster regions by mean colour into n_groups, then pick 3 colours per group feats = [] for ry in range(RROWS): for rx in range(RCOLS): feats.append(img_lab[ry * 16:ry * 16 + 16, rx * 16:rx * 16 + 16] .reshape(-1, 3).mean(0)) feats = np.array(feats) rng = np.random.default_rng(0) cen = feats[rng.choice(len(feats), n_groups, replace=False)] for _ in range(12): lab = np.argmin(np.sum((feats[:, None, :] - cen[None]) ** 2, 2), 1) for g in range(n_groups): if (lab == g).any(): cen[g] = feats[lab == g].mean(0) subs = [] for g in range(n_groups): members = np.where(lab == g)[0] if len(members) == 0: subs.append([bg0, bg0, bg0, bg0]); continue pool = np.concatenate([ img_lab[(m // RCOLS) * 16:(m // RCOLS) * 16 + 16, (m % RCOLS) * 16:(m % RCOLS) * 16 + 16].reshape(-1, 3) for m in members]) if len(pool) > 4000: pool = pool[rng.choice(len(pool), 4000, replace=False)] subs.append([bg0] + _best_colors(pool, plab, bg0, 3)) return subs