150 lines
6.1 KiB
Python
150 lines
6.1 KiB
Python
"""Sega Master System background encoder: 256x192, 2 palettes of 16 (of 64),
|
|
per-8x8-tile palette select, <=512 tiles in VRAM.
|
|
|
|
Each tile is 4bpp (16 colours) and picks one of two 16-colour palettes, so up to
|
|
32 colours on screen -- far less constrained than the NES. We pick palette 0 for
|
|
the whole image, palette 1 for the colours it serves worst, assign each tile its
|
|
better palette, dither, then vector-quantise the 8x8 patterns to 512 tiles.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import numpy as np
|
|
|
|
from ... import dither, palette as c64pal
|
|
from ...convert.base import perceptual_error
|
|
from .. import palette as smspal
|
|
|
|
W, H = 256, 192
|
|
TCOLS, TROWS = 32, 24
|
|
# VRAM is 16K: pattern table $0000-$37FF (448 tiles), name table $3800, sprite
|
|
# attribute table $3F00 -- so at most 448 unique background tiles.
|
|
NTILES = 448
|
|
|
|
|
|
def _choose(img_lab, plab, n, weight=None):
|
|
flat = img_lab.reshape(-1, 3)
|
|
d = np.sum((flat[:, None, :] - plab[None, :, :]) ** 2, axis=-1) # (px,64)
|
|
if weight is not None:
|
|
d = d * weight[:, None]
|
|
chosen, best = [], np.full(len(flat), np.inf)
|
|
for _ in range(n):
|
|
cand = np.minimum(best[:, None], d).sum(0)
|
|
for c in chosen:
|
|
cand[c] = np.inf
|
|
c = int(cand.argmin())
|
|
chosen.append(c)
|
|
best = np.minimum(best, d[:, c])
|
|
return sorted(chosen)
|
|
|
|
|
|
def _tile_codebook(patterns, k, iters=8):
|
|
uniq, counts = np.unique(patterns, axis=0, return_counts=True)
|
|
if len(uniq) <= k:
|
|
code = np.zeros((k, patterns.shape[1]), patterns.dtype)
|
|
code[:len(uniq)] = uniq
|
|
lut = {tuple(p): i for i, p in enumerate(uniq)}
|
|
return code, np.array([lut[tuple(p)] for p in patterns])
|
|
code = uniq[np.argsort(-counts)[:k]].copy()
|
|
labels = np.zeros(len(patterns), np.int64)
|
|
for _ in range(iters):
|
|
for s in range(0, len(patterns), 2048):
|
|
blk = patterns[s:s + 2048]
|
|
labels[s:s + 2048] = (blk[:, None, :] != code[None]).sum(2).argmin(1)
|
|
moved = False
|
|
for j in range(k):
|
|
mem = patterns[labels == j]
|
|
if len(mem):
|
|
med = np.array([np.bincount(mem[:, p], minlength=16).argmax()
|
|
for p in range(mem.shape[1])], patterns.dtype)
|
|
if not np.array_equal(med, code[j]):
|
|
code[j] = med; moved = True
|
|
if not moved:
|
|
break
|
|
for s in range(0, len(patterns), 2048):
|
|
blk = patterns[s:s + 2048]
|
|
labels[s:s + 2048] = (blk[:, None, :] != code[None]).sum(2).argmin(1)
|
|
return code, labels
|
|
|
|
|
|
def _palettes(img_lab, mono, base_color):
|
|
plab = smspal.palette_lab()
|
|
if mono:
|
|
greys = sorted(smspal.GREYS, key=lambda i: plab[i, 0])
|
|
pal0 = (greys * 4)[:16] # 4 greys, padded to 16
|
|
return [pal0, pal0]
|
|
pal0 = _choose(img_lab, plab, 16)
|
|
# palette 1 covers the colours palette 0 reproduces worst
|
|
flat = img_lab.reshape(-1, 3)
|
|
resid = np.min(np.sum((flat[:, None, :] - plab[pal0][None]) ** 2, 2), 1)
|
|
pal1 = _choose(img_lab, plab, 16, weight=resid)
|
|
return [pal0, pal1]
|
|
|
|
|
|
def encode(img_rgb, dither_mode, mono=False, base_color=None):
|
|
plab = smspal.palette_lab()
|
|
prgb = smspal.get_palette().astype(np.uint8)
|
|
img_lab = c64pal.srgb_to_lab(img_rgb)
|
|
pals = _palettes(img_lab, mono, base_color) # 2 x 16 indices
|
|
pal_idx = np.array(pals) # (2,16)
|
|
plab_pal = plab[pal_idx] # (2,16,3)
|
|
|
|
# assign each tile the palette (0/1) with lower nearest-colour error
|
|
tile_pal = np.zeros((TROWS, TCOLS), np.int64)
|
|
for ty in range(TROWS):
|
|
for tx in range(TCOLS):
|
|
blk = img_lab[ty * 8:ty * 8 + 8, tx * 8:tx * 8 + 8].reshape(-1, 3)
|
|
e0 = np.min(np.sum((blk[:, None, :] - plab_pal[0][None]) ** 2, 2), 1).sum()
|
|
e1 = np.min(np.sum((blk[:, None, :] - plab_pal[1][None]) ** 2, 2), 1).sum()
|
|
tile_pal[ty, tx] = 0 if e0 <= e1 else 1
|
|
|
|
# per-pixel allowed = its tile's 16 palette colours (global index 0-31); dither
|
|
plab32 = plab[pal_idx.reshape(-1)] # (32,3)
|
|
allowed = np.zeros((H, W, 16), np.int64)
|
|
for ty in range(TROWS):
|
|
for tx in range(TCOLS):
|
|
base = tile_pal[ty, tx] * 16
|
|
allowed[ty * 8:ty * 8 + 8, tx * 8:tx * 8 + 8] = np.arange(base, base + 16)
|
|
idx = dither.quantize(img_lab, allowed, plab32, dither_mode).astype(np.int64)
|
|
pen = (idx - np.repeat(np.repeat(tile_pal, 8, 0), 8, 1) * 16).astype(np.uint8)
|
|
|
|
# 8x8 tiles -> patterns (pen 0-15); vector-quantise to <=512
|
|
tiles = pen.reshape(TROWS, 8, TCOLS, 8).transpose(0, 2, 1, 3).reshape(TROWS * TCOLS, 64)
|
|
code, labels = _tile_codebook(tiles, NTILES)
|
|
name_pat = labels.reshape(TROWS, TCOLS)
|
|
|
|
# ---- emit VDP data ----
|
|
patterns = bytearray(NTILES * 32)
|
|
for t in range(NTILES):
|
|
pat = code[t].reshape(8, 8)
|
|
for r in range(8):
|
|
for k in range(4):
|
|
byte = 0
|
|
for x in range(8):
|
|
byte |= ((int(pat[r, x]) >> k) & 1) << (7 - x)
|
|
patterns[t * 32 + r * 4 + k] = byte
|
|
|
|
nametable = bytearray(TROWS * TCOLS * 2)
|
|
for ty in range(TROWS):
|
|
for tx in range(TCOLS):
|
|
entry = (int(name_pat[ty, tx]) & 0x1FF) | (int(tile_pal[ty, tx]) << 11)
|
|
o = (ty * TCOLS + tx) * 2
|
|
nametable[o] = entry & 0xFF
|
|
nametable[o + 1] = (entry >> 8) & 0xFF
|
|
|
|
palette = bytes(int(c) for c in pal_idx.reshape(-1)) # 32 colour indices (0-63)
|
|
|
|
# rebuild displayed image (clustered tiles + per-tile palette) for preview
|
|
disp = code[labels].reshape(TROWS, TCOLS, 8, 8).transpose(0, 2, 1, 3).reshape(H, W)
|
|
final = np.zeros((H, W), np.uint16)
|
|
for ty in range(TROWS):
|
|
for tx in range(TCOLS):
|
|
ys, xs = slice(ty * 8, ty * 8 + 8), slice(tx * 8, tx * 8 + 8)
|
|
final[ys, xs] = pal_idx[tile_pal[ty, tx]][disp[ys, xs]]
|
|
|
|
if mono:
|
|
lum = img_lab.copy(); lum[..., 1:] = 0.0
|
|
pl = plab.copy(); pl[:, 1:] = 0.0
|
|
err = perceptual_error(final, lum, pl)
|
|
else:
|
|
err = perceptual_error(final, img_lab, plab)
|
|
return bytes(patterns), bytes(nametable), palette, prgb[final], err
|