"""Shared helpers for the Atari encoders.""" from __future__ import annotations import numpy as np DATA_ADDR = 0x4000 # bitmap base COLOR_ADDR = 0x6000 # colour data base (fixed, after the bitmap) SPLIT_LINE = 102 # lines that fit in the first 4K ($4000-$4FEF) BYTES_PER_LINE = 40 LINES = 192 def split_screen(line_bytes: list[bytes]) -> bytes: """Lay out 192 screen lines with the 16-byte gap that pushes line 102 onto the $5000 boundary (so no ANTIC line crosses a 4K boundary), then pad up to COLOR_ADDR so colour data can follow at a fixed address.""" first = b"".join(line_bytes[:SPLIT_LINE]) # 4080 bytes -> $4000 second = b"".join(line_bytes[SPLIT_LINE:]) # 3600 bytes -> $5000 body = first + bytes(0x1000 - len(first)) + second # gap fills to $5000 pad = (COLOR_ADDR - DATA_ADDR) - len(body) return body + bytes(pad) def luminance_lab(img_rgb, plab): """Return (image, palette) recast into luminance-only CIELAB (L, 0, 0), so matching is by brightness alone -- used by the single-hue modes.""" from ...palette import srgb_to_lab L = srgb_to_lab(img_rgb)[..., 0] img_mono = np.zeros(img_rgb.shape[:2] + (3,)) img_mono[..., 0] = L plab_mono = np.zeros_like(plab) plab_mono[:, 0] = plab[:, 0] return img_mono, plab_mono def choose_palette(img_lab: np.ndarray, plab: np.ndarray, k: int, iters: int = 12) -> list[int]: """Pick the ``k`` palette register values (0..255) that best represent the image, by palette-constrained k-means in CIELAB.""" flat = img_lab.reshape(-1, 3).astype(np.float32) D = np.sum((flat[:, None, :] - plab[None, :, :].astype(np.float32)) ** 2, axis=-1) # (N,256) # k-means++-ish greedy init. chosen = [int(np.argmin(np.sum((plab - flat.mean(0)) ** 2, axis=-1)))] for _ in range(k - 1): md = D[:, chosen].min(axis=1) improv = np.maximum(0.0, md[:, None] - D).sum(axis=0) improv[chosen] = -1.0 chosen.append(int(np.argmax(improv))) # Lloyd refinement, each centroid snapped to its best palette colour. for _ in range(iters): assign = np.argmin(D[:, chosen], axis=1) new = [] for j in range(k): mask = assign == j if not mask.any(): new.append(chosen[j]) else: new.append(int(np.argmin(D[mask].sum(axis=0)))) # keep distinct where possible if new == chosen: break chosen = new return chosen def _seg_all(sub, c1all, c2): """Distance from each ``sub`` pixel to the segment between every palette colour (c1all, shape (256,3)) and a fixed endpoint c2. Returns (256, Nsub).""" seg = c2 - c1all # (256,3) L = np.sum(seg * seg, axis=1) + 1e-9 # (256,) rel = sub[None, :, :] - c1all[:, None, :] # (256,Nsub,3) t = np.clip(np.sum(rel * seg[:, None, :], axis=2) / L[:, None], 0.0, 1.0) proj = c1all[:, None, :] + t[:, :, None] * seg[:, None, :] return np.sum((sub[None, :, :] - proj) ** 2, axis=2) def relevant_candidates(img_lab, plab): """Palette colours that are the nearest match to some image pixel -- a small set (the image's own gamut) to restrict the dither-aware search to.""" flat = img_lab.reshape(-1, 3).astype(np.float32) if len(flat) > 4000: flat = flat[::len(flat) // 4000] d = np.sum((flat[:, None, :] - plab[None, :, :].astype(np.float32)) ** 2, axis=-1) return np.unique(np.argmin(d, axis=1)).astype(np.int64) def choose_palette_dither(img_lab, plab, k, init=None, n_sample=900, iters=5, candidates=None): """Dither-aware palette: pick the ``k`` colours whose pairwise *segment* blends (what error diffusion can reproduce) best cover the image -- so the colours span the gamut instead of sitting at k-means centroids. Vectorised local search (all candidates per slot at once) from a k-means start.""" from itertools import combinations flat = img_lab.reshape(-1, 3) sub = flat[::max(1, len(flat) // n_sample)] if len(flat) > n_sample else flat colors = list(init) if init is not None else choose_palette(img_lab, plab, k) cand = np.asarray(candidates if candidates is not None else range(256), np.int64) cand_lab = plab[cand].astype(np.float64) # (C,3) for _ in range(iters): changed = False for i in range(k): others = [colors[j] for j in range(k) if j != i] fixed = None for x, y in combinations(others, 2): s = _seg_all(sub, plab[x][None], plab[y])[0] fixed = s if fixed is None else np.minimum(fixed, s) m = None for o in others: d = _seg_all(sub, cand_lab, plab[o]) # (C, Nsub) m = d if m is None else np.minimum(m, d) if fixed is not None: m = np.minimum(m, fixed[None, :]) err = m.sum(axis=1) # (C,) for ci, c in enumerate(cand): if c in others: err[ci] = np.inf # avoid duplicate colours best = int(cand[np.argmin(err)]) if best != colors[i]: colors[i] = best changed = True if not changed: break return colors def quantize_global(img_lab, plab, colors, dither_mode): """Dither the whole image to a fixed global set of palette indices.""" from ... import dither H, W, _ = img_lab.shape allowed = np.tile(np.array(colors, dtype=np.int64), (H, W, 1)) return dither.quantize(img_lab, allowed, plab, dither_mode).astype(np.int64) def pack_2bpp(val_image: np.ndarray) -> list[bytes]: """160-wide 2-bits-per-pixel -> list of 192 x 40-byte lines.""" H, W = val_image.shape lines = [] for y in range(H): row = val_image[y] out = bytearray() for x in range(0, W, 4): out.append((row[x] << 6) | (row[x + 1] << 4) | (row[x + 2] << 2) | row[x + 3]) lines.append(bytes(out)) return lines def pack_4bpp(val_image: np.ndarray) -> list[bytes]: """80-wide 4-bits-per-pixel -> list of 192 x 40-byte lines.""" H, W = val_image.shape lines = [] for y in range(H): row = val_image[y] out = bytearray() for x in range(0, W, 2): out.append((row[x] << 4) | row[x + 1]) lines.append(bytes(out)) return lines def pack_1bpp(val_image: np.ndarray) -> list[bytes]: """320-wide 1-bit-per-pixel -> list of 192 x 40-byte lines.""" H, W = val_image.shape lines = [] for y in range(H): row = val_image[y] out = bytearray() for x in range(0, W, 8): b = 0 for i in range(8): b = (b << 1) | int(row[x + i]) out.append(b) lines.append(bytes(out)) return lines