176 lines
6.9 KiB
Python
176 lines
6.9 KiB
Python
"""Shared helpers for the Atari encoders."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import numpy as np
|
|
|
|
DATA_ADDR = 0x4000 # bitmap base
|
|
COLOR_ADDR = 0x6000 # colour data base (fixed, after the bitmap)
|
|
SPLIT_LINE = 102 # lines that fit in the first 4K ($4000-$4FEF)
|
|
BYTES_PER_LINE = 40
|
|
LINES = 192
|
|
|
|
|
|
def split_screen(line_bytes: list[bytes]) -> bytes:
|
|
"""Lay out 192 screen lines with the 16-byte gap that pushes line 102 onto
|
|
the $5000 boundary (so no ANTIC line crosses a 4K boundary), then pad up to
|
|
COLOR_ADDR so colour data can follow at a fixed address."""
|
|
first = b"".join(line_bytes[:SPLIT_LINE]) # 4080 bytes -> $4000
|
|
second = b"".join(line_bytes[SPLIT_LINE:]) # 3600 bytes -> $5000
|
|
body = first + bytes(0x1000 - len(first)) + second # gap fills to $5000
|
|
pad = (COLOR_ADDR - DATA_ADDR) - len(body)
|
|
return body + bytes(pad)
|
|
|
|
|
|
def luminance_lab(img_rgb, plab):
|
|
"""Return (image, palette) recast into luminance-only CIELAB (L, 0, 0), so
|
|
matching is by brightness alone -- used by the single-hue modes."""
|
|
from ...palette import srgb_to_lab
|
|
L = srgb_to_lab(img_rgb)[..., 0]
|
|
img_mono = np.zeros(img_rgb.shape[:2] + (3,))
|
|
img_mono[..., 0] = L
|
|
plab_mono = np.zeros_like(plab)
|
|
plab_mono[:, 0] = plab[:, 0]
|
|
return img_mono, plab_mono
|
|
|
|
|
|
def choose_palette(img_lab: np.ndarray, plab: np.ndarray, k: int,
|
|
iters: int = 12) -> list[int]:
|
|
"""Pick the ``k`` palette register values (0..255) that best represent the
|
|
image, by palette-constrained k-means in CIELAB."""
|
|
flat = img_lab.reshape(-1, 3).astype(np.float32)
|
|
D = np.sum((flat[:, None, :] - plab[None, :, :].astype(np.float32)) ** 2, axis=-1) # (N,256)
|
|
|
|
# k-means++-ish greedy init.
|
|
chosen = [int(np.argmin(np.sum((plab - flat.mean(0)) ** 2, axis=-1)))]
|
|
for _ in range(k - 1):
|
|
md = D[:, chosen].min(axis=1)
|
|
improv = np.maximum(0.0, md[:, None] - D).sum(axis=0)
|
|
improv[chosen] = -1.0
|
|
chosen.append(int(np.argmax(improv)))
|
|
|
|
# Lloyd refinement, each centroid snapped to its best palette colour.
|
|
for _ in range(iters):
|
|
assign = np.argmin(D[:, chosen], axis=1)
|
|
new = []
|
|
for j in range(k):
|
|
mask = assign == j
|
|
if not mask.any():
|
|
new.append(chosen[j])
|
|
else:
|
|
new.append(int(np.argmin(D[mask].sum(axis=0))))
|
|
# keep distinct where possible
|
|
if new == chosen:
|
|
break
|
|
chosen = new
|
|
return chosen
|
|
|
|
|
|
def _seg_all(sub, c1all, c2):
|
|
"""Distance from each ``sub`` pixel to the segment between every palette colour
|
|
(c1all, shape (256,3)) and a fixed endpoint c2. Returns (256, Nsub)."""
|
|
seg = c2 - c1all # (256,3)
|
|
L = np.sum(seg * seg, axis=1) + 1e-9 # (256,)
|
|
rel = sub[None, :, :] - c1all[:, None, :] # (256,Nsub,3)
|
|
t = np.clip(np.sum(rel * seg[:, None, :], axis=2) / L[:, None], 0.0, 1.0)
|
|
proj = c1all[:, None, :] + t[:, :, None] * seg[:, None, :]
|
|
return np.sum((sub[None, :, :] - proj) ** 2, axis=2)
|
|
|
|
|
|
def relevant_candidates(img_lab, plab):
|
|
"""Palette colours that are the nearest match to some image pixel -- a small
|
|
set (the image's own gamut) to restrict the dither-aware search to."""
|
|
flat = img_lab.reshape(-1, 3).astype(np.float32)
|
|
if len(flat) > 4000:
|
|
flat = flat[::len(flat) // 4000]
|
|
d = np.sum((flat[:, None, :] - plab[None, :, :].astype(np.float32)) ** 2, axis=-1)
|
|
return np.unique(np.argmin(d, axis=1)).astype(np.int64)
|
|
|
|
|
|
def choose_palette_dither(img_lab, plab, k, init=None, n_sample=900, iters=5,
|
|
candidates=None):
|
|
"""Dither-aware palette: pick the ``k`` colours whose pairwise *segment* blends
|
|
(what error diffusion can reproduce) best cover the image -- so the colours
|
|
span the gamut instead of sitting at k-means centroids. Vectorised local
|
|
search (all candidates per slot at once) from a k-means start."""
|
|
from itertools import combinations
|
|
flat = img_lab.reshape(-1, 3)
|
|
sub = flat[::max(1, len(flat) // n_sample)] if len(flat) > n_sample else flat
|
|
colors = list(init) if init is not None else choose_palette(img_lab, plab, k)
|
|
cand = np.asarray(candidates if candidates is not None else range(256), np.int64)
|
|
cand_lab = plab[cand].astype(np.float64) # (C,3)
|
|
for _ in range(iters):
|
|
changed = False
|
|
for i in range(k):
|
|
others = [colors[j] for j in range(k) if j != i]
|
|
fixed = None
|
|
for x, y in combinations(others, 2):
|
|
s = _seg_all(sub, plab[x][None], plab[y])[0]
|
|
fixed = s if fixed is None else np.minimum(fixed, s)
|
|
m = None
|
|
for o in others:
|
|
d = _seg_all(sub, cand_lab, plab[o]) # (C, Nsub)
|
|
m = d if m is None else np.minimum(m, d)
|
|
if fixed is not None:
|
|
m = np.minimum(m, fixed[None, :])
|
|
err = m.sum(axis=1) # (C,)
|
|
for ci, c in enumerate(cand):
|
|
if c in others:
|
|
err[ci] = np.inf # avoid duplicate colours
|
|
best = int(cand[np.argmin(err)])
|
|
if best != colors[i]:
|
|
colors[i] = best
|
|
changed = True
|
|
if not changed:
|
|
break
|
|
return colors
|
|
|
|
|
|
def quantize_global(img_lab, plab, colors, dither_mode):
|
|
"""Dither the whole image to a fixed global set of palette indices."""
|
|
from ... import dither
|
|
H, W, _ = img_lab.shape
|
|
allowed = np.tile(np.array(colors, dtype=np.int64), (H, W, 1))
|
|
return dither.quantize(img_lab, allowed, plab, dither_mode).astype(np.int64)
|
|
|
|
|
|
def pack_2bpp(val_image: np.ndarray) -> list[bytes]:
|
|
"""160-wide 2-bits-per-pixel -> list of 192 x 40-byte lines."""
|
|
H, W = val_image.shape
|
|
lines = []
|
|
for y in range(H):
|
|
row = val_image[y]
|
|
out = bytearray()
|
|
for x in range(0, W, 4):
|
|
out.append((row[x] << 6) | (row[x + 1] << 4) | (row[x + 2] << 2) | row[x + 3])
|
|
lines.append(bytes(out))
|
|
return lines
|
|
|
|
|
|
def pack_4bpp(val_image: np.ndarray) -> list[bytes]:
|
|
"""80-wide 4-bits-per-pixel -> list of 192 x 40-byte lines."""
|
|
H, W = val_image.shape
|
|
lines = []
|
|
for y in range(H):
|
|
row = val_image[y]
|
|
out = bytearray()
|
|
for x in range(0, W, 2):
|
|
out.append((row[x] << 4) | row[x + 1])
|
|
lines.append(bytes(out))
|
|
return lines
|
|
|
|
|
|
def pack_1bpp(val_image: np.ndarray) -> list[bytes]:
|
|
"""320-wide 1-bit-per-pixel -> list of 192 x 40-byte lines."""
|
|
H, W = val_image.shape
|
|
lines = []
|
|
for y in range(H):
|
|
row = val_image[y]
|
|
out = bytearray()
|
|
for x in range(0, W, 8):
|
|
b = 0
|
|
for i in range(8):
|
|
b = (b << 1) | int(row[x + i])
|
|
out.append(b)
|
|
lines.append(bytes(out))
|
|
return lines
|