"""Sega Master System background encoder: 256x192, 2 palettes of 16 (of 64), per-8x8-tile palette select, <=512 tiles in VRAM. Each tile is 4bpp (16 colours) and picks one of two 16-colour palettes, so up to 32 colours on screen -- far less constrained than the NES. We pick palette 0 for the whole image, palette 1 for the colours it serves worst, assign each tile its better palette, dither, then vector-quantise the 8x8 patterns to 512 tiles. """ from __future__ import annotations import numpy as np from ... import dither, palette as c64pal from ...convert.base import perceptual_error from .. import palette as smspal W, H = 256, 192 TCOLS, TROWS = 32, 24 # VRAM is 16K: pattern table $0000-$37FF (448 tiles), name table $3800, sprite # attribute table $3F00 -- so at most 448 unique background tiles. NTILES = 448 def _choose(img_lab, plab, n, weight=None): flat = img_lab.reshape(-1, 3) d = np.sum((flat[:, None, :] - plab[None, :, :]) ** 2, axis=-1) # (px,64) if weight is not None: d = d * weight[:, None] chosen, best = [], np.full(len(flat), np.inf) for _ in range(n): cand = np.minimum(best[:, None], d).sum(0) for c in chosen: cand[c] = np.inf c = int(cand.argmin()) chosen.append(c) best = np.minimum(best, d[:, c]) return sorted(chosen) def _tile_codebook(patterns, k, iters=8): uniq, counts = np.unique(patterns, axis=0, return_counts=True) if len(uniq) <= k: code = np.zeros((k, patterns.shape[1]), patterns.dtype) code[:len(uniq)] = uniq lut = {tuple(p): i for i, p in enumerate(uniq)} return code, np.array([lut[tuple(p)] for p in patterns]) code = uniq[np.argsort(-counts)[:k]].copy() labels = np.zeros(len(patterns), np.int64) for _ in range(iters): for s in range(0, len(patterns), 2048): blk = patterns[s:s + 2048] labels[s:s + 2048] = (blk[:, None, :] != code[None]).sum(2).argmin(1) moved = False for j in range(k): mem = patterns[labels == j] if len(mem): med = np.array([np.bincount(mem[:, p], minlength=16).argmax() for p in range(mem.shape[1])], patterns.dtype) if not np.array_equal(med, code[j]): code[j] = med; moved = True if not moved: break for s in range(0, len(patterns), 2048): blk = patterns[s:s + 2048] labels[s:s + 2048] = (blk[:, None, :] != code[None]).sum(2).argmin(1) return code, labels def _palettes(img_lab, mono, base_color): plab = smspal.palette_lab() if mono: greys = sorted(smspal.GREYS, key=lambda i: plab[i, 0]) pal0 = (greys * 4)[:16] # 4 greys, padded to 16 return [pal0, pal0] pal0 = _choose(img_lab, plab, 16) # palette 1 covers the colours palette 0 reproduces worst flat = img_lab.reshape(-1, 3) resid = np.min(np.sum((flat[:, None, :] - plab[pal0][None]) ** 2, 2), 1) pal1 = _choose(img_lab, plab, 16, weight=resid) return [pal0, pal1] def encode(img_rgb, dither_mode, mono=False, base_color=None): plab = smspal.palette_lab() prgb = smspal.get_palette().astype(np.uint8) img_lab = c64pal.srgb_to_lab(img_rgb) pals = _palettes(img_lab, mono, base_color) # 2 x 16 indices pal_idx = np.array(pals) # (2,16) plab_pal = plab[pal_idx] # (2,16,3) # assign each tile the palette (0/1) with lower nearest-colour error tile_pal = np.zeros((TROWS, TCOLS), np.int64) for ty in range(TROWS): for tx in range(TCOLS): blk = img_lab[ty * 8:ty * 8 + 8, tx * 8:tx * 8 + 8].reshape(-1, 3) e0 = np.min(np.sum((blk[:, None, :] - plab_pal[0][None]) ** 2, 2), 1).sum() e1 = np.min(np.sum((blk[:, None, :] - plab_pal[1][None]) ** 2, 2), 1).sum() tile_pal[ty, tx] = 0 if e0 <= e1 else 1 # per-pixel allowed = its tile's 16 palette colours (global index 0-31); dither plab32 = plab[pal_idx.reshape(-1)] # (32,3) allowed = np.zeros((H, W, 16), np.int64) for ty in range(TROWS): for tx in range(TCOLS): base = tile_pal[ty, tx] * 16 allowed[ty * 8:ty * 8 + 8, tx * 8:tx * 8 + 8] = np.arange(base, base + 16) idx = dither.quantize(img_lab, allowed, plab32, dither_mode).astype(np.int64) pen = (idx - np.repeat(np.repeat(tile_pal, 8, 0), 8, 1) * 16).astype(np.uint8) # 8x8 tiles -> patterns (pen 0-15); vector-quantise to <=512 tiles = pen.reshape(TROWS, 8, TCOLS, 8).transpose(0, 2, 1, 3).reshape(TROWS * TCOLS, 64) code, labels = _tile_codebook(tiles, NTILES) name_pat = labels.reshape(TROWS, TCOLS) # ---- emit VDP data ---- patterns = bytearray(NTILES * 32) for t in range(NTILES): pat = code[t].reshape(8, 8) for r in range(8): for k in range(4): byte = 0 for x in range(8): byte |= ((int(pat[r, x]) >> k) & 1) << (7 - x) patterns[t * 32 + r * 4 + k] = byte nametable = bytearray(TROWS * TCOLS * 2) for ty in range(TROWS): for tx in range(TCOLS): entry = (int(name_pat[ty, tx]) & 0x1FF) | (int(tile_pal[ty, tx]) << 11) o = (ty * TCOLS + tx) * 2 nametable[o] = entry & 0xFF nametable[o + 1] = (entry >> 8) & 0xFF palette = bytes(int(c) for c in pal_idx.reshape(-1)) # 32 colour indices (0-63) # rebuild displayed image (clustered tiles + per-tile palette) for preview disp = code[labels].reshape(TROWS, TCOLS, 8, 8).transpose(0, 2, 1, 3).reshape(H, W) final = np.zeros((H, W), np.uint16) for ty in range(TROWS): for tx in range(TCOLS): ys, xs = slice(ty * 8, ty * 8 + 8), slice(tx * 8, tx * 8 + 8) final[ys, xs] = pal_idx[tile_pal[ty, tx]][disp[ys, xs]] if mono: lum = img_lab.copy(); lum[..., 1:] = 0.0 pl = plab.copy(); pl[:, 1:] = 0.0 err = perceptual_error(final, lum, pl) else: err = perceptual_error(final, img_lab, plab) return bytes(patterns), bytes(nametable), palette, prgb[final], err