"""
Hardware-validation harness.

The honest gate between biosync-as-prototype and biosync-as-tool is proving each
device actually works. This module runs a structured, per-modality validation on a
recorded window and returns a pass/fail report:

  * stream present + correct channel count,
  * actual sample rate within tolerance of nominal,
  * live signal quality is "good" (not flatline/clipping/invalid),
  * a modality-specific KNOWN-SIGNAL check:
      - Gaze : calibration accuracy < threshold (deg), validity high,
      - ECG  : recovered heart rate is physiologically plausible,
      - EEG  : real band-limited power present (not noise/flat),
      - EDA  : level in a sensible microsiemens range,
      - EMG  : activations detectable above baseline.

Pure analysis over a session window (numpy + the existing analyzers), so it runs
the same whether the data came from a real device or a synthetic stand-in. The app
wraps this in a guided routine; `VALIDATION.md` is the human protocol.
"""

from __future__ import annotations

from dataclasses import dataclass, field, asdict

import numpy as np

from .analysis import load as A
from .analysis import physio, eyetracking as ET, quality as Q
from .analysis.calibration import evaluate as eval_calibration


@dataclass
class Check:
    name: str
    passed: bool
    detail: str = ""
    value: float | None = None
    threshold: float | None = None

    def dict(self):
        return asdict(self)


@dataclass
class Report:
    device: str
    modality: str
    stream: str
    checks: list = field(default_factory=list)

    @property
    def passed(self) -> bool:
        return bool(self.checks) and all(c.passed for c in self.checks)

    def dict(self):
        return {"device": self.device, "modality": self.modality, "stream": self.stream,
                "passed": self.passed,
                "n_pass": sum(c.passed for c in self.checks), "n": len(self.checks),
                "checks": [c.dict() for c in self.checks]}


# --------------------------------------------------------------------------
def _stream(session, name):
    s = session["streams"].get(name)
    if not s:
        return None
    return {"t": np.asarray(s["t"], float), "x": np.asarray(s["x"], float),
            "srate": float(s["srate"]), "type": s["type"], "channels": s["channels"]}


def _rate(t):
    if len(t) < 2:
        return 0.0
    return (len(t) - 1) / ((t[-1] - t[0]) or 1e-9)


def validate_stream(session, *, device="device", stream="EyeTracker",
                    modality="gaze", min_channels=1, rate_tol=0.25,
                    calibration=None, hr_range=(40, 180), eda_range=(0.05, 60),
                    screen=None) -> Report:
    """Run the full validation for one recorded stream. Returns a Report."""
    rep = Report(device=device, modality=modality, stream=stream)
    s = _stream(session, stream)

    # 1) present
    if s is None or len(s["t"]) == 0:
        rep.checks.append(Check("stream present", False,
                                f"no samples on stream '{stream}'"))
        return rep
    rep.checks.append(Check("stream present", True,
                            f"{len(s['t'])} samples"))

    # 2) channels
    nch = s["x"].shape[1] if s["x"].ndim > 1 else 1
    rep.checks.append(Check("channel count", nch >= min_channels,
                            f"{nch} channel(s)", value=nch, threshold=min_channels))

    # 3) sample rate
    rate = _rate(s["t"]); nominal = s["srate"]
    if nominal > 0:
        ok = abs(rate - nominal) / nominal <= rate_tol
        rep.checks.append(Check("sample rate", ok,
                                f"{rate:.0f} Hz vs nominal {nominal:.0f}",
                                value=round(rate, 1), threshold=nominal))
    else:
        rep.checks.append(Check("sample rate", rate > 0, f"{rate:.1f} Hz (irregular)"))

    # 4) live quality
    q = Q.stream_quality(s["t"], s["x"], stype=s["type"], srate=nominal,
                         window=min(5.0, s["t"][-1] - s["t"][0] or 1.0))
    rep.checks.append(Check("signal quality", q["status"] != "bad",
                            f"{q['status']}: {', '.join(q['reasons']) or 'clean'}"))

    # 5) modality-specific known-signal check
    rep.checks.append(_known_signal(session, s, modality, calibration, hr_range,
                                    eda_range, screen))
    return rep


def _known_signal(session, s, modality, calibration, hr_range, eda_range, screen):
    if modality == "gaze":
        if calibration:
            res = eval_calibration({tuple(k): np.asarray(v) for k, v in calibration.items()},
                                   screen=screen, threshold_deg=1.0)
            return Check("calibration accuracy", res.passed,
                         f"{res.accuracy_deg:.2f}deg (prec {res.precision_deg:.2f})",
                         value=round(res.accuracy_deg, 3), threshold=1.0)
        X = s["x"]; gx = X[:, 0]; gy = X[:, 1] if X.ndim > 1 and X.shape[1] > 1 else X[:, 0]
        on = np.isfinite(gx) & np.isfinite(gy) & (gx >= -0.05) & (gx <= 1.05) \
            & (gy >= -0.05) & (gy <= 1.05)
        v = float(np.mean(on))
        return Check("gaze validity", v >= 0.8, f"{v*100:.0f}% on-screen",
                     value=round(v, 3), threshold=0.8)

    if modality == "ecg":
        h = physio.heart_metrics(session)
        hr = h.get("hr_bpm")
        ok = hr is not None and hr_range[0] <= hr <= hr_range[1]
        return Check("plausible heart rate", ok,
                     f"{hr:.1f} bpm" if hr else "no HR recovered",
                     value=round(hr, 1) if hr else None)

    if modality == "eda":
        x = s["x"][:, 0] if s["x"].ndim > 1 else s["x"]
        lvl = float(np.nanmean(x))
        ok = eda_range[0] <= lvl <= eda_range[1]
        return Check("EDA level in range", ok, f"{lvl:.2f} uS",
                     value=round(lvl, 2))

    if modality == "eeg":
        x = s["x"][:, 0] if s["x"].ndim > 1 else s["x"]
        fs = s["srate"] or _rate(s["t"])
        # real EEG has structured band power; flat/DC noise does not
        xf = x - np.mean(x)
        if len(xf) < int(fs):
            return Check("EEG band power", False, "too short")
        freqs = np.fft.rfftfreq(len(xf), 1 / fs)
        psd = np.abs(np.fft.rfft(xf)) ** 2
        band = (freqs >= 1) & (freqs <= 40)
        frac = float(np.sum(psd[band]) / (np.sum(psd) + 1e-12))
        return Check("EEG band power (1-40 Hz)", frac > 0.3,
                     f"{frac*100:.0f}% of power in band", value=round(frac, 3),
                     threshold=0.3)

    if modality == "emg":
        x = s["x"][:, 0] if s["x"].ndim > 1 else s["x"]
        m = physio.emg_metrics(session)
        amp = m.get("amplitude_max", 0)
        return Check("EMG activation detectable", amp > 0,
                     f"max amplitude {amp:.3f}", value=round(amp, 3))

    return Check("known-signal check", True, "no modality-specific check defined")
