"""
Analysis depth — EEG / fNIRS via MNE-Python (Phase 4).

Turns a recorded multichannel neuro stream into MNE objects, filters, epochs
around stimulus markers (the event-locking primitive from `analysis.load`), and
averages to an ERP. Works on any stream the recorder captured — a BrainFlow board,
an actiCHamp LSL outlet, or a synthetic EEG source — because they all land as the
same labelled, time-stamped HDF5 group.

The session's per-sample timestamps share one clock with the Markers stream, so
marker times map straight onto EEG sample indices.
"""

from __future__ import annotations

import mne
import numpy as np

mne.set_log_level("ERROR")


def to_raw(session: dict, stream: str = "EEG") -> "mne.io.RawArray":
    """Build an MNE Raw from a recorded EEG stream (assumes nominal sfreq)."""
    s = session["streams"][stream]
    x = np.asarray(s["x"], dtype="float64")
    if x.ndim == 1:
        x = x[:, None]
    data = x.T                       # MNE wants [n_channels, n_times]
    sfreq = float(s["srate"]) or _infer_sfreq(s["t"])
    info = mne.create_info(ch_names=list(s["channels"]), sfreq=sfreq, ch_types="eeg")
    raw = mne.io.RawArray(data, info)
    raw._biosync_t0 = float(s["t"][0])      # stash clock origin for event mapping
    return raw


def _infer_sfreq(t) -> float:
    t = np.asarray(t, dtype="float64")
    return float(1.0 / np.median(np.diff(t))) if len(t) > 1 else 1.0


def _events_from_markers(session: dict, raw, prefix: str = "stim_on"):
    """Map marker times -> MNE events array, one event_id per stimulus name."""
    from .load import markers
    t0, sfreq = raw._biosync_t0, raw.info["sfreq"]
    rows, event_id = [], {}
    for t, label in markers(session):
        label = str(label)
        if not label.startswith(prefix):
            continue
        name = label.split(":", 1)[1] if ":" in label else label
        eid = event_id.setdefault(name, len(event_id) + 1)
        samp = int(round((t - t0) * sfreq))
        if 0 <= samp < raw.n_times:
            rows.append([samp, 0, eid])
    events = np.array(sorted(rows), dtype=int) if rows else np.empty((0, 3), int)
    return events, event_id


def epochs_around(session: dict, stream: str = "EEG", *, tmin=-0.2, tmax=0.8,
                  l_freq=1.0, h_freq=40.0, baseline=(None, 0)) -> "mne.Epochs":
    """Filter, then epoch the EEG around every stimulus onset marker."""
    raw = to_raw(session, stream)
    if l_freq or h_freq:
        raw.filter(l_freq, h_freq, verbose="ERROR")
    events, event_id = _events_from_markers(session, raw)
    if not len(events):
        raise RuntimeError("no stimulus markers fell inside the EEG span")
    return mne.Epochs(raw, events, event_id=event_id, tmin=tmin, tmax=tmax,
                      baseline=baseline, preload=True, verbose="ERROR")


def erp(session: dict, stream: str = "EEG", **kw) -> dict:
    """Grand-average ERP across all stimulus onsets. Returns a plottable dict."""
    epochs = epochs_around(session, stream, **kw)
    ev = epochs.average()
    gfp = ev.data.std(axis=0)                        # global field power (topography)
    rms = np.sqrt((ev.data ** 2).mean(axis=0))       # cross-channel RMS (response size)
    # Peak latency from the post-stimulus RMS (baseline-corrected, so pre-stim ~0).
    post = ev.times >= 0
    peak_i = int(np.flatnonzero(post)[np.argmax(rms[post])])
    return {
        "n_epochs": int(len(epochs)),
        "times": ev.times.tolist(),
        "channels": ev.ch_names,
        "data_uv": ev.data.tolist(),                 # [n_ch, n_times]
        "gfp": gfp.tolist(),
        "rms": rms.tolist(),
        "peak_latency_s": float(ev.times[peak_i]),
        "conditions": list(epochs.event_id.keys()),
    }
