"""
Gaze over the real stimulus — heatmap / scanpath / AOIs drawn on the actual image
the participant saw (not a blank screen). This is the credibility view iMotions is
known for: a heatmap that sits on top of the ad, the scene, the web page.

Given a session and a stimulus image, it pulls the gaze recorded while that
stimulus was on screen (across all its presentations), and renders:

  * heatmap_over_image  — gaussian gaze density on top of the stimulus,
  * scanpath_over_image — ordered fixations + saccade lines on the stimulus,
  * with optional AOI rectangles/polygons overlaid.

Gaze is in normalized [0,1] screen coords, so it maps onto the image regardless of
resolution. The image is embedded as a data URI, so the output HTML is self-contained.
"""

from __future__ import annotations

import base64
import mimetypes

import numpy as np

from . import load as A
from . import eyetracking as ET


# --------------------------------------------------------------------------
def gaze_for_stimulus(session: dict, stimulus: str, *, stream="EyeTracker",
                      pad=0.0) -> tuple:
    """Concatenate gaze (x, y) recorded while `stimulus` was on, across all onsets."""
    s = session["streams"].get(stream)
    if not s:
        return np.array([]), np.array([])
    t = np.asarray(s["t"], float); X = np.asarray(s["x"], float)
    gx, gy = X[:, 0], X[:, 1]
    # build on/off windows for this stimulus from the marker stream
    windows, open_t = [], None
    for mt, label in A.markers(session):
        label = str(label)
        if label == f"stim_on:{stimulus}":
            open_t = float(mt)
        elif label == f"stim_off:{stimulus}" and open_t is not None:
            windows.append((open_t - pad, float(mt) + pad)); open_t = None
    if open_t is not None:
        windows.append((open_t - pad, t[-1] if len(t) else open_t))
    if not windows:
        return np.array([]), np.array([])
    mask = np.zeros(len(t), bool)
    for lo, hi in windows:
        mask |= (t >= lo) & (t <= hi)
    return gx[mask], gy[mask]


def _img_data_uri(path: str) -> str:
    mime = mimetypes.guess_type(path)[0] or "image/png"
    with open(path, "rb") as f:
        b64 = base64.b64encode(f.read()).decode()
    return f"data:{mime};base64,{b64}"


def _base_fig(image_path: str, title: str):
    import plotly.graph_objects as go
    fig = go.Figure()
    fig.add_layout_image(dict(source=_img_data_uri(image_path), xref="x", yref="y",
                              x=0, y=0, sizex=1, sizey=1, sizing="stretch",
                              layer="below", opacity=1.0))
    fig.update_xaxes(range=[0, 1], visible=False, constrain="domain")
    fig.update_yaxes(range=[1, 0], visible=False, scaleanchor="x")  # y down = screen
    fig.update_layout(title=title, height=460, margin=dict(l=10, r=10, t=46, b=10),
                      plot_bgcolor="white")
    return fig


def _aoi_shapes(fig, aois):
    for a in (aois or []):
        if getattr(a, "rect", None):
            x0, y0, x1, y1 = a.rect
            fig.add_shape(type="rect", x0=x0, y0=y0, x1=x1, y1=y1, xref="x", yref="y",
                          line=dict(color="#F58518", width=2), fillcolor="rgba(0,0,0,0)")
            fig.add_annotation(x=x0, y=y0, text=a.name, showarrow=False, xref="x", yref="y",
                               font=dict(color="#F58518", size=11), yshift=8, xanchor="left")
        elif getattr(a, "polygon", None):
            pts = a.polygon + [a.polygon[0]]
            path = "M " + " L ".join(f"{x},{y}" for x, y in pts) + " Z"
            fig.add_shape(type="path", path=path, line=dict(color="#F58518", width=2))


# --------------------------------------------------------------------------
def heatmap_over_image(session, stimulus, image_path, *, bins=60, sigma=1.6,
                       aois=None, opacity=0.55):
    import plotly.graph_objects as go
    gx, gy = gaze_for_stimulus(session, stimulus)
    fig = _base_fig(image_path, f"Gaze heatmap — {stimulus}")
    if len(gx):
        H, xe, ye = np.histogram2d(gx, gy, bins=bins, range=[[0, 1], [0, 1]])
        H = ET._gauss_blur(H, sigma).T
        fig.add_trace(go.Heatmap(z=H, x=(xe[:-1]+xe[1:])/2, y=(ye[:-1]+ye[1:])/2,
                                 colorscale="Turbo", zsmooth="best", opacity=opacity,
                                 showscale=False, hoverinfo="skip"))
    _aoi_shapes(fig, aois)
    return fig


def scanpath_over_image(session, stimulus, image_path, *, aois=None, screen=None):
    import plotly.graph_objects as go
    fig = _base_fig(image_path, f"Scanpath — {stimulus}")
    # fixations restricted to this stimulus' gaze, via a temporary session slice
    gx, gy = gaze_for_stimulus(session, stimulus)
    if len(gx) > 2:
        sub = {"streams": {"EyeTracker": {"type": "Gaze", "srate": session["streams"]
                ["EyeTracker"]["srate"], "channels": ["gaze_x", "gaze_y"],
                "t": np.arange(len(gx)) / max(1.0, session["streams"]["EyeTracker"]["srate"]),
                "x": np.c_[gx, gy]}}}
        fix, _, _ = ET.classify_ivt(sub, screen=screen)
        if fix:
            xs = [f.x for f in fix]; ys = [f.y for f in fix]
            sizes = [10 + 70 * f.duration for f in fix]
            fig.add_trace(go.Scatter(x=xs, y=ys, mode="lines",
                                     line=dict(color="rgba(58,154,168,.9)", width=2),
                                     hoverinfo="skip", showlegend=False))
            fig.add_trace(go.Scatter(x=xs, y=ys, mode="markers+text",
                                     text=[str(i+1) for i in range(len(fix))],
                                     textfont=dict(color="#fff", size=9),
                                     marker=dict(size=sizes, color="rgba(47,100,112,.8)",
                                                 line=dict(color="#fff", width=1)),
                                     showlegend=False))
    _aoi_shapes(fig, aois)
    return fig


def stimulus_images(study_plan, media_root) -> dict:
    """Map stimulus name -> local image file (from uploaded /media/stimuli paths)."""
    from pathlib import Path
    out = {}
    for b in study_plan.blocks:
        for leaf in b.leaves():
            if leaf.kind in ("image",) and leaf.image:
                name = leaf.image.split("/")[-1]
                p = Path(media_root) / name
                if p.exists():
                    out[leaf.name] = str(p)
    return out
