# -*- coding: utf-8 -*-
#!/usr/bin/env python3
# Spotify album art (4x20) -> WLED (100 LEDs), UDP DRGB
# Modes via GPIO (BCM 26 default):
# 0) Default: 4x20 art on top 4 rows + bottom L->R progress (WHITE)
# 1) Default + DYNAMIC progress color from album art
# 2) Rotated 90 art + progress (WHITE)
# 3) Rotated 90 art + DYNAMIC progress color from album art
# 4) Off

import os, io, sys, time, argparse, socket, threading
import requests
from PIL import Image, ImageOps, ImageEnhance
from dotenv import load_dotenv
import spotipy
from spotipy.oauth2 import SpotifyOAuth

# Optional GPIO (auto-fallback to polling)
GPIO = None
try:
    import RPi.GPIO as GPIO  # type: ignore
except Exception:
    GPIO = None

load_dotenv()

# ===== Env / config =====
WLED_IP       = os.getenv("WLED_IP", "").strip() or sys.exit("Set WLED_IP in .env")
UDP_PORT      = int(os.getenv("WLED_UDP_PORT", "21324"))
UDP_TIMEOUT_SECONDS = int(os.getenv("UDP_TIMEOUT_SECONDS", "3"))

MATRIX_W      = int(os.getenv("MATRIX_W", "20"))
MATRIX_H      = int(os.getenv("MATRIX_H", "5"))
ART_H_SMALL   = 4
ART_H_FULL    = 5
PROGRESS_ROW  = int(os.getenv("PROGRESS_ROW", "4"))
START_CORNER  = os.getenv("START_CORNER", "TL").upper()

def _parse_row_dirs(s, h):
    vals = [v.strip().upper() for v in (s or "").split(",") if v.strip()]
    if len(vals) != h: vals = ["LR"] * h
    return [1 if v in ("LR","L2R","+","RIGHT") else -1 for v in vals]
ROW_DIR       = _parse_row_dirs(os.getenv("ROW_DIRS"), MATRIX_H)

# Button / debounce / polling
BUTTON_PIN    = int(os.getenv("BUTTON_PIN", "26"))
BUTTON_PULL   = os.getenv("BUTTON_PULL", "UP").upper()          # UP or DOWN
BUTTON_ACTIVE = os.getenv("BUTTON_ACTIVE_STATE", "LOW").upper() # LOW if wired to GND
DEBOUNCE_MS   = int(os.getenv("BUTTON_DEBOUNCE_MS", "250"))
POLL_BUTTON_HZ= float(os.getenv("POLL_BUTTON_HZ", "60"))        # used only if IRQ fails

# Spotify / look
SCOPES        = [s.strip() for s in os.getenv(
    "SCOPES","user-read-currently-playing,user-read-playback-state"
).split(",")]
NP_POLL_SEC   = float(os.getenv("NP_POLL_SEC", "3"))
PROGRESS_FPS  = float(os.getenv("PROGRESS_FPS", "20"))
PROGRESS_GAMMA= float(os.getenv("PROGRESS_GAMMA", "2.2"))
ART_FADE_MS   = int(os.getenv("ART_FADE_MS", "800"))
BAR_LINK_FADEIN = True  # link bar & art fades when both shown

CLIENT_ID     = os.getenv("SPOTIFY_CLIENT_ID", "")
CLIENT_SECRET = os.getenv("SPOTIFY_CLIENT_SECRET", "")
REDIRECT_URI  = os.getenv("SPOTIFY_REDIRECT_URI", "")

# Color pop controls
ART_GAMMA        = float(os.getenv("ART_GAMMA", "2.0"))
SATURATION       = float(os.getenv("SATURATION", "1.35"))
CONTRAST         = float(os.getenv("CONTRAST", "1.20"))
BRIGHTNESS_SCALE = float(os.getenv("BRIGHTNESS_SCALE", "1.15"))
AUTO_CONTRAST    = os.getenv("AUTO_CONTRAST", "0").strip().lower() in ("1","true","yes","on")

# Black cutoff controls (album art only)
BLACK_CUTOFF       = int(os.getenv("BLACK_CUTOFF", "12"))
BLACK_DIM_INSTEAD  = os.getenv("BLACK_DIM_INSTEAD", "0").strip().lower() in ("1","true","yes","on")
BLACK_DIM_LEVEL    = int(os.getenv("BLACK_DIM_LEVEL", "3"))

# ===== UDP realtime (DRGB) =====
_udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
_udp.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
TARGET = (WLED_IP, UDP_PORT)

def send_full_frame(rgb_by_index):
    n = MATRIX_W * MATRIX_H
    if len(rgb_by_index) != n:
        rgb_by_index = (rgb_by_index + [(0,0,0)]*n)[:n]
    pkt = bytearray([2, UDP_TIMEOUT_SECONDS])  # 2 = DRGB
    for r,g,b in rgb_by_index:
        pkt += bytes((int(r)&255, int(g)&255, int(b)&255))
    try:
        _udp.sendto(pkt, TARGET)
    except Exception:
        pass

def clear_all():
    send_full_frame([(0,0,0)] * (MATRIX_W*MATRIX_H))

# ===== Mapping =====
def _orient_xy(x, y):
    if START_CORNER not in ("TL","TR","BL","BR"):
        raise ValueError("START_CORNER must be TL/TR/BL/BR")
    if START_CORNER in ("TR","BR"): x = MATRIX_W-1-x
    if START_CORNER in ("BL","BR"): y = MATRIX_H-1-y
    return x, y

def xy_to_index(x, y):
    x, y = _orient_xy(x, y)
    if ROW_DIR[y] == -1: x = MATRIX_W - 1 - x
    return y * MATRIX_W + x

# ===== Image helpers =====
def fetch_album_image(url, max_bytes=3*1024*1024):
    try:
        r = requests.get(url, timeout=2.5); r.raise_for_status()
        return r.content[:max_bytes]
    except requests.RequestException:
        return None

def fit_album(img_bytes, w, h, rotate90=False):
    # square crop -> (optional) rotate -> resize -> contrast/sat boost
    with Image.open(io.BytesIO(img_bytes)) as im:
        sq = ImageOps.fit(im, (min(im.size), min(im.size)), method=Image.LANCZOS)
        if rotate90:
            sq = sq.rotate(90, expand=False)
        img = sq.resize((w, h), resample=Image.LANCZOS).convert("RGB")
        if AUTO_CONTRAST:
            img = ImageOps.autocontrast(img, cutoff=1)
        if abs(SATURATION - 1.0) > 1e-3:
            img = ImageEnhance.Color(img).enhance(SATURATION)
        if abs(CONTRAST - 1.0) > 1e-3:
            img = ImageEnhance.Contrast(img).enhance(CONTRAST)
        return img

def img_to_rgb_list(img_pil):
    img = img_pil.convert("RGB")
    w, h = img.size
    out = []
    for y in range(h):
        for x in range(w):
            out.append(img.getpixel((x, y)))
    return out, w, h

# ===== Dynamic progress color =====
def _saturate(rgb, factor=1.3):
    r,g,b = rgb
    # simple HSV-like saturation bump without dependencies
    mx, mn = max(rgb), min(rgb)
    if mx == mn:  # gray
        return rgb
    # push away from mean
    mean = (r+g+b)/3.0
    rr = int(max(0, min(255, mean + (r-mean)*factor)))
    gg = int(max(0, min(255, mean + (g-mean)*factor)))
    bb = int(max(0, min(255, mean + (b-mean)*factor)))
    return (rr,gg,bb)

def _dynamic_color_from_img(img_pil):
    """
    Choose a readable 'dominant' color from the small 20x4/5 image:
    - Ignore near-black and near-white pixels
    - Average remaining pixels
    - Add a touch of saturation
    """
    if img_pil is None: return (255,255,255)
    img = img_pil.convert("RGB")
    w,h = img.size
    keep = []
    for y in range(h):
        for x in range(w):
            r,g,b = img.getpixel((x,y))
            luma = 0.2126*r + 0.7152*g + 0.0722*b
            if luma < 18:     # drop very dark
                continue
            if r>235 and g>235 and b>235:  # drop almost-white
                continue
            keep.append((r,g,b))
    if not keep:
        return (255,255,255)
    # average
    sr = sum(p[0] for p in keep)
    sg = sum(p[1] for p in keep)
    sb = sum(p[2] for p in keep)
    avg = (sr//len(keep), sg//len(keep), sb//len(keep))
    return _saturate(avg, 1.25)

# Easing / gamma
def ease_smoothstep(t):
    t = max(0.0, min(1.0, t))
    return t*t*(3 - 2*t)

def gamma_scale(color, s, g):
    s = max(0.0, min(1.0, s))
    s_lin = s ** (1.0/g)
    r, gg, b = color
    return (int(r*s_lin), int(gg*s_lin), int(b*s_lin))

def luma_709(r, g, b):
    return 0.2126*r + 0.7152*g + 0.0722*b

# ===== Composers =====
def blit(img_rgb, iw, ih, dst, y_offset=0, alpha=1.0):
    """Blit an iw x ih image into matrix at y_offset with global alpha (gamma + brightness + black cutoff)."""
    if not img_rgb or alpha <= 0:
        return
    a = max(0.0, min(1.0, alpha)) ** (1.0 / ART_GAMMA)
    scale = a * BRIGHTNESS_SCALE
    cutoff = max(0, min(255, BLACK_CUTOFF))
    dim_px = (max(0, min(255, BLACK_DIM_LEVEL)),) * 3
    for y in range(ih):
        for x in range(iw):
            r, g, b = img_rgb[y*iw + x]
            rr = int(min(255, r * scale))
            gg = int(min(255, g * scale))
            bb = int(min(255, b * scale))
            if luma_709(rr, gg, bb) <= cutoff:
                out = dim_px if BLACK_DIM_INSTEAD else (0, 0, 0)
            else:
                out = (rr, gg, bb)
            dst[xy_to_index(x, y + y_offset)] = out

def compose_with_bar(art_rgb, pr, bar_color, art_alpha=1.0):
    buf = [(0,0,0)]*(MATRIX_W*MATRIX_H)
    if art_rgb:
        blit(art_rgb[0], art_rgb[1], art_rgb[2], buf, y_offset=0, alpha=art_alpha)
    if pr is not None:
        x_f = max(0.0, min(1.0, pr)) * MATRIX_W
        for x in range(MATRIX_W):
            b_raw = x_f - x
            if b_raw <= 0: b = 0.0
            elif b_raw >= 1: b = 1.0
            else: b = ease_smoothstep(b_raw)
            buf[xy_to_index(x, PROGRESS_ROW)] = gamma_scale(bar_color, b, PROGRESS_GAMMA)
    return buf

def compose_stretched_full(full_rgb, art_alpha=1.0):
    buf = [(0,0,0)]*(MATRIX_W*MATRIX_H)
    if full_rgb:
        blit(full_rgb[0], full_rgb[1], full_rgb[2], buf, y_offset=0, alpha=art_alpha)
    return buf

# ===== Spotify =====
def get_sp_client():
    if not (CLIENT_ID and CLIENT_SECRET and REDIRECT_URI):
        raise SystemExit("Set SPOTIFY_CLIENT_ID / SPOTIFY_CLIENT_SECRET / SPOTIFY_REDIRECT_URI in .env")
    auth = SpotifyOAuth(client_id=CLIENT_ID, client_secret=CLIENT_SECRET,
                        redirect_uri=REDIRECT_URI, scope=" ".join(SCOPES),
                        cache_path=".spotipy_cache", open_browser=False)
    tok = auth.get_cached_token()
    if not tok:
        url = auth.get_authorize_url()
        print("\nOpen this URL, log in, then paste the FINAL URL (?code=...):\n")
        print(url, "\n")
        redirected = input("Paste FULL redirected URL: ").strip()
        code = auth.parse_response_code(redirected)
        if not code: raise RuntimeError("No ?code=... in pasted URL.")
        tok = auth.get_access_token(code, check_cache=True)
    return spotipy.Spotify(auth=tok["access_token"], auth_manager=auth)

def now_playing(sp):
    cur = sp.current_user_playing_track()
    if not cur or not cur.get("item"): return None
    item = cur["item"]
    images = (item.get("album") or {}).get("images") or []
    return {
        "id": item.get("id") or item.get("uri"),
        "art": images[0]["url"] if images else None,
        "prog": cur.get("progress_ms") or 0,
        "dur":  item.get("duration_ms") or 1,
        "play": bool(cur.get("is_playing")),
    }

# ===== Modes / button =====
MODE_DEFAULT_WHITE   = 0  # normal art, white bar
MODE_DEFAULT_DYNAMIC = 1  # normal art, dynamic bar
MODE_ROT_WHITE       = 2  # rotated art, white bar
MODE_ROT_DYNAMIC     = 3  # rotated art, dynamic bar
MODE_OFF             = 4
NUM_MODES            = 5

mode_lock = threading.Lock()
mode = MODE_DEFAULT_WHITE

def _cycle_mode():
    global mode
    with mode_lock:
        mode = (mode + 1) % NUM_MODES
        print(f"[MODE] -> {mode}")

def setup_button():
    if GPIO is None:
        print("RPi.GPIO not available; button disabled.")
        return "disabled"
    GPIO.setmode(GPIO.BCM)
    pud  = GPIO.PUD_UP if BUTTON_PULL == "UP" else GPIO.PUD_DOWN
    edge = GPIO.FALLING if BUTTON_PULL == "UP" else GPIO.RISING
    GPIO.setup(BUTTON_PIN, GPIO.IN, pull_up_down=pud)
    try:
        GPIO.remove_event_detect(BUTTON_PIN)
    except Exception:
        pass
    def irq_cb(_ch):
        now = time.monotonic()
        if not hasattr(irq_cb, "last"): irq_cb.last = 0.0
        if (now - irq_cb.last) * 1000.0 < DEBOUNCE_MS: return
        irq_cb.last = now
        s = GPIO.input(BUTTON_PIN)
        active_level = 0 if BUTTON_ACTIVE == "LOW" else 1
        if s == active_level:
            _cycle_mode()
    try:
        GPIO.add_event_detect(BUTTON_PIN, edge, callback=irq_cb, bouncetime=DEBOUNCE_MS)
        print(f"Button IRQ on GPIO{BUTTON_PIN} ({BUTTON_PULL}, active {BUTTON_ACTIVE}); debounce {DEBOUNCE_MS}ms.")
        return "irq"
    except Exception as e:
        print(f"[Button] IRQ failed ({e}); falling back to polling at {POLL_BUTTON_HZ} Hz.")
        def poll():
            interval = 1.0/max(1.0, POLL_BUTTON_HZ)
            last = GPIO.input(BUTTON_PIN)
            last_t = time.monotonic()
            while True:
                time.sleep(interval)
                s = GPIO.input(BUTTON_PIN)
                now = time.monotonic()
                if s != last and (now - last_t)*1000.0 >= DEBOUNCE_MS:
                    last_t = now
                    last  = s
                    if ((BUTTON_PULL == "UP" and s == 0) or
                        (BUTTON_PULL == "DOWN" and s == 1)):
                        _cycle_mode()
        threading.Thread(target=poll, daemon=True).start()
        return "polling"

# ===== Live loop =====
def run_live():
    setup_button()
    sp = get_sp_client()

    last_track_id = None
    np = None
    last_poll = 0.0

    # Pre-rendered art variants
    art_small = None        # (rgb, w=20, h=4) non-rotated
    art_small_rot = None    # (rgb, w=20, h=4) rotated 90
    art_full_vert = None    # (rgb, w=20, h=5) rotated 90 (stretched)
    art_full_horz = None    # (rgb, w=20, h=5) non-rotated (stretched)

    # Store PIL for dynamic color
    img_small_pil = None
    img_small_rot_pil = None
    bar_color_dynamic = (255,255,255)

    # Fade machine for art modes
    art_alpha = 1.0
    state = "steady"  # "steady" | "fade_out" | "fade_in"
    t0_ms = 0.0
    half_ms = max(1, ART_FADE_MS // 2)

    # Progress interpolation
    base_prog_ms = 0.0
    base_wall_ms = time.time() * 1000.0

    try:
        while True:
            now_s = time.time()
            # Poll Spotify
            if (now_s - last_poll) >= NP_POLL_SEC or np is None:
                last_poll = now_s
                try: fresh = now_playing(sp)
                except Exception: fresh = None
                if fresh and fresh.get("id"):
                    changed = (fresh["id"] != last_track_id)
                    if changed:
                        data = fetch_album_image(fresh["art"]) if fresh.get("art") else None
                        if data:
                            img_small      = fit_album(data, MATRIX_W, ART_H_SMALL, rotate90=False)
                            img_small_rot  = fit_album(data, MATRIX_W, ART_H_SMALL, rotate90=True)
                            img_full_vert  = fit_album(data, MATRIX_W, ART_H_FULL, rotate90=True)
                            img_full_horz  = fit_album(data, MATRIX_W, ART_H_FULL, rotate90=False)
                            art_small      = img_to_rgb_list(img_small)
                            art_small_rot  = img_to_rgb_list(img_small_rot)
                            art_full_vert  = img_to_rgb_list(img_full_vert)
                            art_full_horz  = img_to_rgb_list(img_full_horz)
                            img_small_pil      = img_small
                            img_small_rot_pil  = img_small_rot
                            # compute dynamic color from the non-rotated small image
                            bar_color_dynamic  = _dynamic_color_from_img(img_small_pil)
                        else:
                            art_small = art_small_rot = art_full_vert = art_full_horz = None
                            img_small_pil = img_small_rot_pil = None
                            bar_color_dynamic = (255,255,255)
                        state = "fade_out"; t0_ms = now_s * 1000.0
                        last_track_id = fresh["id"]
                    base_prog_ms = float(fresh["prog"])
                    base_wall_ms = now_s * 1000.0
                    np = fresh
                else:
                    np = None

            # Fade state for art modes (visual nicety when art changes)
            now_ms = time.time() * 1000.0
            with mode_lock:
                m = mode
            if m in (MODE_DEFAULT_WHITE, MODE_DEFAULT_DYNAMIC, MODE_ROT_WHITE, MODE_ROT_DYNAMIC):
                if state == "fade_out":
                    t = (now_ms - t0_ms)
                    art_alpha = max(0.0, 1.0 - t/half_ms)
                    if art_alpha <= 0.0:
                        state = "fade_in"; t0_ms = now_ms
                elif state == "fade_in":
                    t = (now_ms - t0_ms)
                    art_alpha = min(1.0, t/half_ms)
                    if art_alpha >= 1.0:
                        state = "steady"
                else:
                    art_alpha = 1.0
            else:
                state = "steady"; art_alpha = 1.0

            # Compute progress ratio
            if np and np["dur"] > 0:
                if np["play"]:
                    elapsed = now_ms - base_wall_ms
                    prog_ms = min(float(np["dur"]), base_prog_ms + elapsed)
                else:
                    prog_ms = base_prog_ms
                pr = max(0.0, min(1.0, prog_ms / float(np["dur"])))
            else:
                pr = None

            # Select bar color per mode
            WHITE = (255,255,255)
            if m == MODE_DEFAULT_WHITE:
                bar_color = WHITE
                frame = compose_with_bar(art_small, pr, bar_color, art_alpha=art_alpha)
            elif m == MODE_DEFAULT_DYNAMIC:
                bar_color = bar_color_dynamic
                frame = compose_with_bar(art_small, pr, bar_color, art_alpha=art_alpha)
            elif m == MODE_ROT_WHITE:
                bar_color = WHITE
                frame = compose_with_bar(art_small_rot, pr, bar_color, art_alpha=art_alpha)
            elif m == MODE_ROT_DYNAMIC:
                bar_color = bar_color_dynamic
                frame = compose_with_bar(art_small_rot, pr, bar_color, art_alpha=art_alpha)
            else:  # MODE_OFF
                frame = [(0,0,0)]*(MATRIX_W*MATRIX_H)

            send_full_frame(frame)
            time.sleep(1.0/PROGRESS_FPS)
    except KeyboardInterrupt:
        pass
    finally:
        clear_all()
        if GPIO:
            try: GPIO.cleanup()
            except Exception: pass

# ===== CLI =====
if __name__ == "__main__":
    ap = argparse.ArgumentParser(description="Spotify -> WLED (Modes: default-white/default-dynamic/rot-white/rot-dynamic/off)")
    ap.add_argument("--testbars", action="store_true")
    args = ap.parse_args()
    if args.testbars:
        colors = [(255,0,0),(0,255,0),(0,0,255),(255,255,0),(255,0,255)]
        buf = [(0,0,0)]*(MATRIX_W*MATRIX_H)
        for y in range(MATRIX_H):
            c = colors[y % len(colors)]
            for x in range(MATRIX_W):
                buf[xy_to_index(x,y)] = c
        send_full_frame(buf); sys.exit(0)
    run_live()
