8bitlenser/lenser/sms/convert/_common.py
2026-07-03 19:35:35 -07:00

150 lines
6.1 KiB
Python

"""Sega Master System background encoder: 256x192, 2 palettes of 16 (of 64),
per-8x8-tile palette select, <=512 tiles in VRAM.
Each tile is 4bpp (16 colours) and picks one of two 16-colour palettes, so up to
32 colours on screen -- far less constrained than the NES. We pick palette 0 for
the whole image, palette 1 for the colours it serves worst, assign each tile its
better palette, dither, then vector-quantise the 8x8 patterns to 512 tiles.
"""
from __future__ import annotations
import numpy as np
from ... import dither, palette as c64pal
from ...convert.base import perceptual_error
from .. import palette as smspal
W, H = 256, 192
TCOLS, TROWS = 32, 24
# VRAM is 16K: pattern table $0000-$37FF (448 tiles), name table $3800, sprite
# attribute table $3F00 -- so at most 448 unique background tiles.
NTILES = 448
def _choose(img_lab, plab, n, weight=None):
flat = img_lab.reshape(-1, 3)
d = np.sum((flat[:, None, :] - plab[None, :, :]) ** 2, axis=-1) # (px,64)
if weight is not None:
d = d * weight[:, None]
chosen, best = [], np.full(len(flat), np.inf)
for _ in range(n):
cand = np.minimum(best[:, None], d).sum(0)
for c in chosen:
cand[c] = np.inf
c = int(cand.argmin())
chosen.append(c)
best = np.minimum(best, d[:, c])
return sorted(chosen)
def _tile_codebook(patterns, k, iters=8):
uniq, counts = np.unique(patterns, axis=0, return_counts=True)
if len(uniq) <= k:
code = np.zeros((k, patterns.shape[1]), patterns.dtype)
code[:len(uniq)] = uniq
lut = {tuple(p): i for i, p in enumerate(uniq)}
return code, np.array([lut[tuple(p)] for p in patterns])
code = uniq[np.argsort(-counts)[:k]].copy()
labels = np.zeros(len(patterns), np.int64)
for _ in range(iters):
for s in range(0, len(patterns), 2048):
blk = patterns[s:s + 2048]
labels[s:s + 2048] = (blk[:, None, :] != code[None]).sum(2).argmin(1)
moved = False
for j in range(k):
mem = patterns[labels == j]
if len(mem):
med = np.array([np.bincount(mem[:, p], minlength=16).argmax()
for p in range(mem.shape[1])], patterns.dtype)
if not np.array_equal(med, code[j]):
code[j] = med; moved = True
if not moved:
break
for s in range(0, len(patterns), 2048):
blk = patterns[s:s + 2048]
labels[s:s + 2048] = (blk[:, None, :] != code[None]).sum(2).argmin(1)
return code, labels
def _palettes(img_lab, mono, base_color):
plab = smspal.palette_lab()
if mono:
greys = sorted(smspal.GREYS, key=lambda i: plab[i, 0])
pal0 = (greys * 4)[:16] # 4 greys, padded to 16
return [pal0, pal0]
pal0 = _choose(img_lab, plab, 16)
# palette 1 covers the colours palette 0 reproduces worst
flat = img_lab.reshape(-1, 3)
resid = np.min(np.sum((flat[:, None, :] - plab[pal0][None]) ** 2, 2), 1)
pal1 = _choose(img_lab, plab, 16, weight=resid)
return [pal0, pal1]
def encode(img_rgb, dither_mode, mono=False, base_color=None):
plab = smspal.palette_lab()
prgb = smspal.get_palette().astype(np.uint8)
img_lab = c64pal.srgb_to_lab(img_rgb)
pals = _palettes(img_lab, mono, base_color) # 2 x 16 indices
pal_idx = np.array(pals) # (2,16)
plab_pal = plab[pal_idx] # (2,16,3)
# assign each tile the palette (0/1) with lower nearest-colour error
tile_pal = np.zeros((TROWS, TCOLS), np.int64)
for ty in range(TROWS):
for tx in range(TCOLS):
blk = img_lab[ty * 8:ty * 8 + 8, tx * 8:tx * 8 + 8].reshape(-1, 3)
e0 = np.min(np.sum((blk[:, None, :] - plab_pal[0][None]) ** 2, 2), 1).sum()
e1 = np.min(np.sum((blk[:, None, :] - plab_pal[1][None]) ** 2, 2), 1).sum()
tile_pal[ty, tx] = 0 if e0 <= e1 else 1
# per-pixel allowed = its tile's 16 palette colours (global index 0-31); dither
plab32 = plab[pal_idx.reshape(-1)] # (32,3)
allowed = np.zeros((H, W, 16), np.int64)
for ty in range(TROWS):
for tx in range(TCOLS):
base = tile_pal[ty, tx] * 16
allowed[ty * 8:ty * 8 + 8, tx * 8:tx * 8 + 8] = np.arange(base, base + 16)
idx = dither.quantize(img_lab, allowed, plab32, dither_mode).astype(np.int64)
pen = (idx - np.repeat(np.repeat(tile_pal, 8, 0), 8, 1) * 16).astype(np.uint8)
# 8x8 tiles -> patterns (pen 0-15); vector-quantise to <=512
tiles = pen.reshape(TROWS, 8, TCOLS, 8).transpose(0, 2, 1, 3).reshape(TROWS * TCOLS, 64)
code, labels = _tile_codebook(tiles, NTILES)
name_pat = labels.reshape(TROWS, TCOLS)
# ---- emit VDP data ----
patterns = bytearray(NTILES * 32)
for t in range(NTILES):
pat = code[t].reshape(8, 8)
for r in range(8):
for k in range(4):
byte = 0
for x in range(8):
byte |= ((int(pat[r, x]) >> k) & 1) << (7 - x)
patterns[t * 32 + r * 4 + k] = byte
nametable = bytearray(TROWS * TCOLS * 2)
for ty in range(TROWS):
for tx in range(TCOLS):
entry = (int(name_pat[ty, tx]) & 0x1FF) | (int(tile_pal[ty, tx]) << 11)
o = (ty * TCOLS + tx) * 2
nametable[o] = entry & 0xFF
nametable[o + 1] = (entry >> 8) & 0xFF
palette = bytes(int(c) for c in pal_idx.reshape(-1)) # 32 colour indices (0-63)
# rebuild displayed image (clustered tiles + per-tile palette) for preview
disp = code[labels].reshape(TROWS, TCOLS, 8, 8).transpose(0, 2, 1, 3).reshape(H, W)
final = np.zeros((H, W), np.uint16)
for ty in range(TROWS):
for tx in range(TCOLS):
ys, xs = slice(ty * 8, ty * 8 + 8), slice(tx * 8, tx * 8 + 8)
final[ys, xs] = pal_idx[tile_pal[ty, tx]][disp[ys, xs]]
if mono:
lum = img_lab.copy(); lum[..., 1:] = 0.0
pl = plab.copy(); pl[:, 1:] = 0.0
err = perceptual_error(final, lum, pl)
else:
err = perceptual_error(final, img_lab, plab)
return bytes(patterns), bytes(nametable), palette, prgb[final], err