8bitlenser/lenser/nes/convert/_common.py
2026-07-03 19:35:35 -07:00

206 lines
8.6 KiB
Python

"""NES background encoder: 4 sub-palettes + per-16x16 attribute + 256-tile CHR.
The NES PPU draws a 32x30 grid of 8x8 tiles (pattern table, <=256 unique), each
tile 2bpp. Colour comes from 4 background sub-palettes (each = a shared universal
background + 3 colours), one chosen per 16x16 region via the attribute table. So
the pipeline is: choose the universal bg, cluster the image into 4 sub-palettes,
assign each region its best one, dither, then vector-quantise the 8x8 tile
patterns down to 256 CHR tiles.
"""
from __future__ import annotations
import numpy as np
from ... import dither, palette as c64pal
from ...convert import base
from .. import palette as npal
W, H = 256, 240
RCOLS, RROWS = 16, 15 # 16x16-pixel regions (2x2 tiles)
TCOLS, TROWS = 32, 30 # 8x8 tiles
NTILES = 256
def _best_colors(pix_lab, plab, bg0, n, k_cand=20):
"""Best ``n`` NES colours (besides fixed bg0) for a pool of pixels."""
mean_d = np.sum((pix_lab.mean(0)[None, :] - plab) ** 2, 1)
cand = [c for c in np.argsort(mean_d)[:k_cand] if c != bg0]
dist = np.sum((pix_lab[:, None, :] - plab[None, cand, :]) ** 2, 2) # (px, k)
d_bg = np.sum((pix_lab - plab[bg0]) ** 2, 1) # (px,)
from itertools import combinations
best, best_err = None, np.inf
for combo in combinations(range(len(cand)), n):
m = np.minimum(d_bg, dist[:, combo].min(1)).sum()
if m < best_err:
best_err, best = m, [cand[i] for i in combo]
return best
def _tile_codebook(patterns, k, iters=8):
"""k-medoids over 8x8 pen patterns (values 0-3); distance = differing pixels."""
uniq, counts = np.unique(patterns, axis=0, return_counts=True)
if len(uniq) <= k:
code = np.zeros((k, patterns.shape[1]), patterns.dtype)
code[:len(uniq)] = uniq
lut = {tuple(p): i for i, p in enumerate(uniq)}
return code, np.array([lut[tuple(p)] for p in patterns])
code = uniq[np.argsort(-counts)[:k]].copy()
labels = np.zeros(len(patterns), np.int64)
for _ in range(iters):
for s in range(0, len(patterns), 2048):
blk = patterns[s:s + 2048]
labels[s:s + 2048] = (blk[:, None, :] != code[None]).sum(2).argmin(1)
moved = False
for j in range(k):
mem = patterns[labels == j]
if len(mem):
med = np.array([np.bincount(mem[:, p], minlength=4).argmax()
for p in range(mem.shape[1])], patterns.dtype)
if not np.array_equal(med, code[j]):
code[j] = med; moved = True
if not moved:
break
for s in range(0, len(patterns), 2048):
blk = patterns[s:s + 2048]
labels[s:s + 2048] = (blk[:, None, :] != code[None]).sum(2).argmin(1)
return code, labels
def encode(img_rgb, dither_mode, subpalettes):
"""subpalettes: list of 4 lists [bg0,c1,c2,c3] of NES colour indices (the
universal bg = subpalettes[*][0], identical across all four)."""
plab = npal.palette_lab()
prgb = npal.get_palette().astype(np.uint8)
img_lab = c64pal.srgb_to_lab(img_rgb)
bg0 = subpalettes[0][0]
sp = np.array(subpalettes) # (4,4) NES indices
# assign each 16x16 region to the sub-palette giving least nearest-colour error
region_pal = np.zeros((RROWS, RCOLS), np.int64)
for ry in range(RROWS):
for rx in range(RCOLS):
blk = img_lab[ry * 16:ry * 16 + 16, rx * 16:rx * 16 + 16].reshape(-1, 3)
errs = []
for s in range(4):
d = np.sum((blk[:, None, :] - plab[sp[s]][None, :, :]) ** 2, 2)
errs.append(d.min(1).sum())
region_pal[ry, rx] = int(np.argmin(errs))
# per-pixel allowed colours = the pixel's region sub-palette; dither
allowed = np.zeros((H, W, 4), np.int64)
for ry in range(RROWS):
for rx in range(RCOLS):
allowed[ry * 16:ry * 16 + 16, rx * 16:rx * 16 + 16] = sp[region_pal[ry, rx]]
idx = dither.quantize(img_lab, allowed, plab, dither_mode).astype(np.int64)
# pen per pixel = position of its colour within its region's sub-palette
pen = np.zeros((H, W), np.uint8)
for ry in range(RROWS):
for rx in range(RCOLS):
ys, xs = slice(ry * 16, ry * 16 + 16), slice(rx * 16, rx * 16 + 16)
pal = sp[region_pal[ry, rx]]
block = idx[ys, xs]
pmap = np.zeros(block.shape, np.uint8)
for k, col in enumerate(pal):
pmap[block == col] = k
pen[ys, xs] = pmap
# 8x8 tiles -> patterns; vector-quantise to <=256 CHR tiles
tiles = pen.reshape(TROWS, 8, TCOLS, 8).transpose(0, 2, 1, 3).reshape(TROWS * TCOLS, 64)
code, labels = _tile_codebook(tiles, NTILES)
nametable = labels.astype(np.uint8).reshape(TROWS, TCOLS)
# ---- emit NES data ----
chr_rom = bytearray(8192)
for t in range(NTILES):
pat = code[t].reshape(8, 8)
for r in range(8):
p0 = p1 = 0
for x in range(8):
v = int(pat[r, x])
p0 |= (v & 1) << (7 - x)
p1 |= ((v >> 1) & 1) << (7 - x)
chr_rom[t * 16 + r] = p0
chr_rom[t * 16 + 8 + r] = p1
attr = bytearray(64)
for ar in range(8):
for ac in range(8):
b = 0
for q, (dy, dx) in enumerate(((0, 0), (0, 1), (1, 0), (1, 1))):
ry, rx = ar * 2 + dy, ac * 2 + dx
s = int(region_pal[ry, rx]) if ry < RROWS and rx < RCOLS else 0
b |= (s & 3) << (q * 2)
attr[ar * 8 + ac] = b
pal32 = bytearray(32)
for s in range(4):
pal32[s * 4] = bg0
pal32[s * 4 + 1] = int(sp[s][1])
pal32[s * 4 + 2] = int(sp[s][2])
pal32[s * 4 + 3] = int(sp[s][3])
pal32[16:32] = pal32[0:16] # sprite palette = mirror
nametable_full = bytes(nametable.reshape(-1)) + bytes(attr) # 960 + 64
# rebuild the displayed image (clustered tiles + region palettes) for preview
disp = code[labels].reshape(TROWS, TCOLS, 8, 8).transpose(0, 2, 1, 3).reshape(H, W)
final_idx = np.zeros((H, W), np.uint16)
for ry in range(RROWS):
for rx in range(RCOLS):
ys, xs = slice(ry * 16, ry * 16 + 16), slice(rx * 16, rx * 16 + 16)
pal = sp[region_pal[ry, rx]]
final_idx[ys, xs] = pal[disp[ys, xs]]
err = base.perceptual_error(final_idx, img_lab, plab)
return bytes(pal32), nametable_full, bytes(chr_rom), final_idx, err, plab, prgb
def pick_subpalettes(img_rgb, n_groups=4, mono=False, base_color=None):
"""Choose the universal bg + ``n_groups`` sub-palettes (each bg + 3 colours)."""
plab = npal.palette_lab()
img_lab = c64pal.srgb_to_lab(img_rgb)
if mono:
greys = sorted(npal.GREYS, key=lambda i: plab[i, 0])
if base_color in range(64):
ramp = sorted({greys[0], int(base_color), greys[-1]}, key=lambda i: plab[i, 0])
else:
# 4 greys spanning black->white (include the lightest so highlights
# actually reach white -- otherwise the image comes out muddy/dark)
lums = np.array([plab[i, 0] for i in greys])
ramp = [greys[int(np.argmin(np.abs(lums - t)))]
for t in np.linspace(lums.min(), lums.max(), 4)]
bg0 = ramp[0]
others = [c for c in ramp if c != bg0][:3]
while len(others) < 3:
others.append(others[-1])
return [[bg0] + others] * 4
bg0 = base.best_global_color(img_lab, plab)
# cluster regions by mean colour into n_groups, then pick 3 colours per group
feats = []
for ry in range(RROWS):
for rx in range(RCOLS):
feats.append(img_lab[ry * 16:ry * 16 + 16, rx * 16:rx * 16 + 16]
.reshape(-1, 3).mean(0))
feats = np.array(feats)
rng = np.random.default_rng(0)
cen = feats[rng.choice(len(feats), n_groups, replace=False)]
for _ in range(12):
lab = np.argmin(np.sum((feats[:, None, :] - cen[None]) ** 2, 2), 1)
for g in range(n_groups):
if (lab == g).any():
cen[g] = feats[lab == g].mean(0)
subs = []
for g in range(n_groups):
members = np.where(lab == g)[0]
if len(members) == 0:
subs.append([bg0, bg0, bg0, bg0]); continue
pool = np.concatenate([
img_lab[(m // RCOLS) * 16:(m // RCOLS) * 16 + 16,
(m % RCOLS) * 16:(m % RCOLS) * 16 + 16].reshape(-1, 3)
for m in members])
if len(pool) > 4000:
pool = pool[rng.choice(len(pool), 4000, replace=False)]
subs.append([bg0] + _best_colors(pool, plab, bg0, 3))
return subs