#!/usr/bin/env python3 """ Cluster face encodings from per-directory faces.json sidecars into people. Reads all faces.json files under MEDIA_ROOT, groups face encodings by similarity (L2 distance < threshold), and writes MEDIA_ROOT/people.json. Existing names and UUIDs in people.json are preserved across re-runs: any cluster whose members overlap with a named cluster in the old file keeps that UUID and name. Usage: python3 cluster_faces.py [--threshold 0.55] [--media-root /var/albumen] Threshold guide: 0.50 — very strict; may split the same person across lighting/angle 0.55 — good default for family photos 0.60 — lenient; may merge different people who look alike """ import argparse import glob import json import os import sys import uuid as _uuid from datetime import datetime, timezone try: import numpy as np except ImportError: print("numpy is required: /opt/albumen/venv/bin/pip install numpy", file=sys.stderr) sys.exit(1) MEDIA_ROOT = os.environ.get('MEDIA_ROOT', '/var/albumen') def collect_faces(media_root): """Return list of {rel, box, encoding} for all processed face instances.""" faces = [] for path in sorted(glob.glob(os.path.join(media_root, '**', 'faces.json'), recursive=True)): dir_abs = os.path.dirname(path) dir_rel = dir_abs[len(media_root):].lstrip('/') try: data = json.load(open(path)) except Exception: continue for filename, face_list in data.items(): if not isinstance(face_list, list): continue for face in face_list: enc = face.get('encoding') box = face.get('box') if not enc or not box or len(enc) != 128: continue rel = f"{dir_rel}/{filename}" if dir_rel else filename faces.append({'rel': rel, 'box': box, 'encoding': enc}) return faces def cluster(encodings, threshold): """ Greedy centroid clustering with up to 3 refinement passes. Returns an integer label array of length n. """ n = len(encodings) if n == 0: return np.array([], dtype=np.int32) # First pass: greedy — each face goes to the nearest centroid or starts a new cluster labels = np.zeros(n, dtype=np.int32) cent_sum = [encodings[0].copy()] cent_count = [1] for i in range(1, n): means = np.vstack(cent_sum) / np.array(cent_count, dtype=np.float32)[:, np.newaxis] dists = np.sqrt(np.sum((encodings[i] - means) ** 2, axis=1)) best = int(np.argmin(dists)) if dists[best] < threshold: labels[i] = best cent_sum[best] += encodings[i] cent_count[best] += 1 else: labels[i] = len(cent_sum) cent_sum.append(encodings[i].copy()) cent_count.append(1) # Refinement passes: reassign each face to the nearest centroid for _ in range(3): k = int(labels.max()) + 1 centroids = np.zeros((k, 128), dtype=np.float32) counts = np.zeros(k, dtype=np.int32) for i in range(n): centroids[labels[i]] += encodings[i] counts[labels[i]] += 1 nz = counts > 0 centroids[nz] /= counts[nz, np.newaxis] new_labels = np.zeros(n, dtype=np.int32) chunk = 512 for start in range(0, n, chunk): end = min(start + chunk, n) diff = encodings[start:end, np.newaxis, :] - centroids[np.newaxis, :, :] dist = np.sqrt(np.sum(diff ** 2, axis=2)) # (c, k) best = np.argmin(dist, axis=1) bd = dist[np.arange(end - start), best] for j in range(end - start): new_labels[start + j] = int(best[j]) if bd[j] < threshold else int(labels[start + j]) if np.array_equal(new_labels, labels): break labels = new_labels # Compact: remap to 0..k-1 unique = np.unique(labels) remap = {int(old): new for new, old in enumerate(unique)} return np.array([remap[int(l)] for l in labels], dtype=np.int32) def load_people(path): if not os.path.exists(path): return {} try: return json.load(open(path)).get('people', {}) except Exception: return {} def build_people(faces, labels, existing): """ Group faces into clusters and match against existing UUIDs/names. A new cluster inherits the UUID and name of the existing cluster with the most member overlap (≥1 shared photo). """ clusters = {} for face, label in zip(faces, labels.tolist()): clusters.setdefault(int(label), []).append(face) existing_rels = {uid: frozenset(m['rel'] for m in p.get('members', [])) for uid, p in existing.items()} used = set() people = {} # Process largest clusters first so names attach to the most representative cluster for label, members in sorted(clusters.items(), key=lambda x: -len(x[1])): new_rels = frozenset(m['rel'] for m in members) best_uid, best_n = None, 0 for uid, old_rels in existing_rels.items(): if uid in used: continue n = len(new_rels & old_rels) if n > best_n: best_n, best_uid = n, uid if best_uid and best_n >= 1: uid = best_uid name = existing[uid].get('name') slug = existing[uid].get('slug') used.add(uid) else: uid = str(_uuid.uuid4()) name = None slug = None people[uid] = { 'name': name, 'slug': slug, 'members': [{'rel': m['rel'], 'box': m['box']} for m in members], } return people def atomic_write(path, obj): tmp = path + '.tmp.cluster' with open(tmp, 'w') as f: json.dump(obj, f) os.rename(tmp, path) def main(): ap = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) ap.add_argument('--threshold', type=float, default=0.55, help='L2 distance threshold (default 0.55; lower = stricter)') ap.add_argument('--media-root', default=MEDIA_ROOT) args = ap.parse_args() people_path = os.path.join(args.media_root, 'people.json') print(f"Collecting faces from {args.media_root} ...") faces = collect_faces(args.media_root) print(f" {len(faces)} face instances") if not faces: print("No faces to cluster. Run the face daemon first.") return print(f"Clustering (threshold={args.threshold}) ...") encodings = np.array([f['encoding'] for f in faces], dtype=np.float32) labels = cluster(encodings, args.threshold) k = int(labels.max()) + 1 print(f" {k} clusters from {len(faces)} faces") existing = load_people(people_path) print(f" {len(existing)} existing records (names preserved)") people = build_people(faces, labels, existing) result = { 'updated_at': datetime.now(timezone.utc).strftime('%Y-%m-%dT%H:%M:%SZ'), 'threshold': args.threshold, 'people': people, } atomic_write(people_path, result) named = sum(1 for p in people.values() if p.get('name')) print(f"Wrote {people_path}: {named} named, {k - named} unnamed") if __name__ == '__main__': main()