"""
Analysis depth — validated physiology (Phase 4).

One reader per metric family, each taking the loaded `session` dict (from
`analysis.load.load_session`) plus a stream name, and returning a plain dict of
features. NeuroKit2 does the heavy lifting (the validated, open replacement for
iMotions' locked metrics). Nothing here mutates the session.

Readers degrade gracefully: too-short or missing signals return {} rather than
raising, so a dashboard can call them unconditionally.
"""

from __future__ import annotations

import neurokit2 as nk
import numpy as np


def _signal(session: dict, stream: str, channel: int = 0):
    s = session["streams"].get(stream)
    if not s or not len(s["x"]):
        return None, None
    x = np.asarray(s["x"], dtype="float64")
    sig = x[:, channel] if x.ndim > 1 else x.ravel()
    return sig, int(s["srate"])


def heart_metrics(session: dict, stream: str = "ECG") -> dict:
    """Heart rate + time/frequency HRV from an ECG stream."""
    sig, fs = _signal(session, stream)
    if sig is None or fs <= 0 or len(sig) < fs * 10:
        return {}
    clean = nk.ecg_clean(sig, sampling_rate=fs)
    _, info = nk.ecg_peaks(clean, sampling_rate=fs)
    peaks = info.get("ECG_R_Peaks", [])
    if len(peaks) < 4:
        return {}
    rr = np.diff(peaks) / fs
    out = {
        "n_beats": int(len(peaks)),
        "hr_bpm": float(60.0 / np.mean(rr)),
        "sdnn_ms": float(np.std(rr, ddof=1) * 1000.0),
        "rmssd_ms": float(np.sqrt(np.mean(np.diff(rr) ** 2)) * 1000.0),
    }
    try:  # full HRV (adds frequency-domain) when the record is long enough
        hrv = nk.hrv(peaks, sampling_rate=fs, show=False)
        for k in ("HRV_SDNN", "HRV_RMSSD", "HRV_pNN50", "HRV_LF", "HRV_HF", "HRV_LFHF"):
            if k in hrv:
                out[k.replace("HRV_", "").lower()] = float(hrv[k].iloc[0])
    except Exception:
        pass
    return out


def eda_metrics(session: dict, stream: str = "GSR") -> dict:
    """Tonic level + phasic SCR features from an EDA/GSR stream."""
    sig, fs = _signal(session, stream)
    if sig is None or fs <= 0 or len(sig) < fs * 5:
        return {}
    try:
        signals, info = nk.eda_process(sig, sampling_rate=fs)
    except Exception:
        return {}
    onsets = info.get("SCR_Onsets", [])
    scr_n = int(np.sum(~np.isnan(onsets))) if len(onsets) else 0
    dur_min = len(sig) / fs / 60.0
    return {
        "tonic_mean_uS": float(np.nanmean(signals["EDA_Tonic"])),
        "phasic_mean_uS": float(np.nanmean(signals["EDA_Phasic"])),
        "scr_count": scr_n,
        "scr_per_min": float(scr_n / dur_min) if dur_min else 0.0,
        "scr_amplitude_mean": float(np.nanmean(info["SCR_Amplitude"]))
        if "SCR_Amplitude" in info and len(info["SCR_Amplitude"]) else 0.0,
    }


def emg_metrics(session: dict, stream: str = "EMG") -> dict:
    """Activation / amplitude features from an EMG stream."""
    sig, fs = _signal(session, stream)
    if sig is None or fs <= 0 or len(sig) < fs * 2:
        return {}
    try:
        signals, info = nk.emg_process(sig, sampling_rate=fs)
        amp = signals["EMG_Amplitude"]
        return {
            "amplitude_mean": float(np.nanmean(amp)),
            "amplitude_max": float(np.nanmax(amp)),
            "n_activations": int(len(info.get("EMG_Onsets", []))),
        }
    except Exception:
        rect = np.abs(sig - np.mean(sig))
        return {"amplitude_mean": float(np.mean(rect)),
                "amplitude_max": float(np.max(rect)), "n_activations": 0}


def rsa(session: dict, ecg_stream: str = "ECG", rsp_stream: str = "RSP") -> dict:
    """Respiratory sinus arrhythmia (needs synchronized ECG + respiration)."""
    ecg, fs = _signal(session, ecg_stream)
    rsp, fs2 = _signal(session, rsp_stream)
    if ecg is None or rsp is None or fs <= 0 or fs != fs2 or len(ecg) < fs * 20:
        return {}
    try:
        ecg_sig, _ = nk.ecg_process(ecg, sampling_rate=fs)
        rsp_sig, _ = nk.rsp_process(rsp, sampling_rate=fs)
        out = nk.hrv_rsa(ecg_sig, rsp_sig, sampling_rate=fs, continuous=False)
        return {k: float(v) for k, v in out.items() if np.isscalar(v)}
    except Exception:
        return {}


def summary(session: dict) -> dict:
    """Run every applicable reader; skip what isn't present. Dashboard entry point."""
    streams = session["streams"]
    out = {}
    if "ECG" in streams:
        out["heart"] = heart_metrics(session)
    if "GSR" in streams:
        out["eda"] = eda_metrics(session)
    if "EMG" in streams:
        out["emg"] = emg_metrics(session)
    if "ECG" in streams and "RSP" in streams:
        out["rsa"] = rsa(session)
    return out
