# scanomaly/plot.py
from __future__ import annotations
from dataclasses import dataclass
from typing import Tuple
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from .anomaly_models import get_flat_plot_model_masked, get_anom_plot_model_masked
from .singlelens_model import (
A_pspl_func,
A_fspl_func,
A_pspl_parallax_func,
A_fspl_parallax_func,
A_cv_asymexp_logtau_func,
)
def _single_lens_model_flux(fit, time) -> np.ndarray:
"""
Evaluate the fitted single-lens model at arbitrary times.
"""
names = tuple(getattr(fit, "param_names", ()))
params = jnp.asarray(fit.params)
time_j = jnp.asarray(time)
if names == ("t0", "tE", "u0"):
A = A_pspl_func(params, time_j)
elif names == ("t0", "tE", "u0", "rho"):
A = A_fspl_func(params, time_j)
elif names == ("t0", "tE", "u0", "piEN", "piEE"):
P = getattr(fit, "parallax_projector", None)
if P is None:
raise ValueError("Cannot plot parallax model without fit.parallax_projector.")
A = A_pspl_parallax_func(params, time_j, P)
elif names == ("t0", "tE", "u0", "rho", "piEN", "piEE"):
P = getattr(fit, "parallax_projector", None)
if P is None:
raise ValueError("Cannot plot parallax model without fit.parallax_projector.")
A = A_fspl_parallax_func(params, time_j, P)
elif names == ("t0", "tau_rise", "tau_decay"):
t0, tau_rise, tau_decay = params
q = jnp.array([t0, jnp.log(tau_rise), jnp.log(tau_decay)])
A = A_cv_asymexp_logtau_func(q, time_j)
else:
raise ValueError(f"Unsupported single-lens parameter set: {names}")
flux = fit.fs * A + fit.fb
return np.asarray(jax.device_get(flux), dtype=float)
def _adaptive_single_lens_curve(
fit,
xlim: Tuple[float, float],
*,
base_points: int = 128,
max_flux_step: float | None = None,
min_time_step: float | None = None,
max_points: int = 12000,
max_iter: int = 24,
) -> tuple[np.ndarray, np.ndarray]:
"""
Sample the model only inside xlim, densifying intervals with large flux jumps.
"""
xmin, xmax = map(float, xlim)
if xmax < xmin:
xmin, xmax = xmax, xmin
if xmax == xmin:
xmax = xmin + 1.0
n0 = max(2, int(base_points))
x = np.linspace(xmin, xmax, n0)
y = _single_lens_model_flux(fit, x)
if max_flux_step is None:
ferr = np.asarray(getattr(fit, "ferr", []), dtype=float)
ferr = ferr[np.isfinite(ferr) & (ferr > 0)]
if ferr.size:
max_flux_step = 0.25 * float(np.median(ferr))
else:
flux = np.asarray(getattr(fit, "flux", []), dtype=float)
flux = flux[np.isfinite(flux)]
scale = np.ptp(flux) if flux.size else 1.0
max_flux_step = 0.0025 * max(float(scale), 1.0)
max_flux_step = max(float(max_flux_step), np.finfo(float).eps)
if min_time_step is None:
min_time_step = max((xmax - xmin) / max_points, 1e-6)
min_time_step = max(float(min_time_step), np.finfo(float).eps)
for _ in range(max_iter):
if len(x) >= max_points:
break
dx = np.diff(x)
refineable = np.flatnonzero(dx > min_time_step)
if refineable.size == 0:
break
# Evaluate midpoints for all refineable intervals so we can check both
# jump size and curvature. Curvature = |y_mid - linear_interp| catches
# the "equal endpoints but bump in between" (trapezoid) case that a
# pure jump criterion misses.
mids = 0.5 * (x[refineable] + x[refineable + 1])
y_mids = _single_lens_model_flux(fit, mids)
dy_abs = np.abs(y[refineable + 1] - y[refineable])
y_interp = 0.5 * (y[refineable] + y[refineable + 1])
curvature = np.abs(y_mids - y_interp)
needs_split = (dy_abs > max_flux_step) | (curvature > max_flux_step)
if not needs_split.any():
break
ins = np.flatnonzero(needs_split)
available = max_points - len(x)
if available <= 0:
break
if ins.size > available:
priority = np.maximum(dy_abs[ins], curvature[ins])
order = np.argsort(priority)[-available:]
ins = np.sort(ins[order])
x = np.concatenate([x, mids[ins]])
y = np.concatenate([y, y_mids[ins]])
order = np.argsort(x)
x = x[order]
y = y[order]
return x, y
[docs]
@dataclass
class AnomalyPlotter:
"""
Plot utilities for scanomaly results.
Conventions
-----------
- "pspl": uses PSPL best-fit parameters to define center/width
- "anomaly": uses the best anomaly candidate (result.best) to define center/width
Notes
-----
- All plot methods return (fig, ax/axes). If show=False, nothing is displayed
and you can adjust limits/styles outside.
"""
# ----------------------------
# helpers
# ----------------------------
def _compute_xlim(
self,
result,
*,
center: str = "pspl", # "pspl" or "anomaly"
width_mode: str = "pspl", # "pspl" or "anomaly" or "custom"
a: float = 3.0,
xlim: Tuple[float, float] | None = None,
half_width: float | None = None,
min_hw: float | None = None,
) -> Tuple[float, float]:
"""
Compute xlim for plots.
Parameters
----------
center : {"pspl", "anomaly"} (compat: "best")
Center of the view window.
width_mode : {"pspl", "anomaly", "custom"} (compat: "best")
How to compute window half-width.
- "pspl": half_width = a * tE * u0
- "anomaly": half_width = a * teff (from result.best)
- "custom": half_width must be provided.
a : float
Multiplier for the default half-width rule.
xlim : tuple | None
If provided, returned as-is.
half_width : float | None
Used only when width_mode="custom".
min_hw : float | None
Minimum half-width enforced after the normal calculation.
Prevents degenerate (u0≈0) fits from producing a near-zero window.
Returns
-------
(xmin, xmax)
"""
if xlim is not None:
return xlim
# backward compat
if center == "best":
center = "anomaly"
if width_mode == "best":
width_mode = "anomaly"
# center
if center == "anomaly" and (getattr(result, "best", None) is not None):
t_center = float(result.best.t0)
else:
t_center = float(np.asarray(result.fit.params)[0])
# width
if width_mode == "custom":
if half_width is None:
raise ValueError("When width_mode='custom', half_width must be specified.")
hw = float(half_width)
elif width_mode == "anomaly":
if getattr(result, "best", None) is None:
t0, tE, u0 = map(float, np.asarray(result.fit.params)[:3])
hw = float(a * abs(tE * u0))
else:
hw = float(a * result.best.teff)
else: # "pspl"
t0, tE, u0 = map(float, np.asarray(result.fit.params)[:3])
hw = float(a * abs(tE * u0))
# Apply minimum half-width (prevents degenerate u0≈0 from collapsing window)
if min_hw is not None:
hw = max(hw, float(min_hw))
xmin = t_center - hw
xmax = t_center + hw
# If the best anomaly candidate is just outside the window, extend to include it.
# Only extend when the anomaly is within 3×hw (i.e., not a completely different season).
best = getattr(result, "best", None)
if best is not None:
ano_t0 = float(best.t0)
if not (xmin <= ano_t0 <= xmax):
dist = min(abs(ano_t0 - xmin), abs(ano_t0 - xmax))
if dist <= 2.0 * hw:
buf = max(float(best.teff) * 5.0, 5.0)
if ano_t0 < xmin:
xmin = ano_t0 - buf
else:
xmax = ano_t0 + buf
return (xmin, xmax)
# ----------------------------
# basic plots
# ----------------------------
def plot_lc(
self,
result,
*,
show: bool = True,
ax=None,
center: str = "pspl", # "pspl" or "anomaly"
width_mode: str = "pspl", # "pspl" or "anomaly" or "custom"
a: float = 3.0,
xlim: Tuple[float, float] | None = None,
half_width: float | None = None,
model_base_points: int = 128,
model_max_flux_step: float | None = None,
model_min_time_step: float | None = None,
model_max_points: int = 12000,
):
"""
Plot light curve and PSPL best-fit model.
Returns (fig, ax).
"""
t, f, e = result.time, result.flux, result.ferr
xl = self._compute_xlim(
result, center=center, width_mode=width_mode, a=a, xlim=xlim, half_width=half_width
)
t_plot, m_plot = _adaptive_single_lens_curve(
result.fit,
xl,
base_points=model_base_points,
max_flux_step=model_max_flux_step,
min_time_step=model_min_time_step,
max_points=model_max_points,
)
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.figure
ax.errorbar(t, f, yerr=e, fmt=".", label="data", zorder=0)
ax.plot(t_plot, m_plot, label="PSPL model", zorder=1)
ax.set_xlabel("time")
ax.set_ylabel("flux")
ax.legend()
ax.set_xlim(xl)
if show:
plt.show()
return fig, ax
def plot_residual(
self,
result,
*,
show: bool = True,
ax=None,
center: str = "pspl", # "pspl" or "anomaly"
width_mode: str = "pspl", # "pspl" or "anomaly" or "custom"
a: float = 50.0,
xlim: Tuple[float, float] | None = None,
half_width: float | None = None,
):
"""
Plot PSPL residual (flux - model_flux).
Returns (fig, ax).
"""
t, r = result.time, result.residual
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.figure
ax.plot(t, r, ".", zorder=0)
ax.axhline(0.0, zorder=1, c="C1")
ax.set_xlabel("time")
ax.set_ylabel("residual")
xl = self._compute_xlim(
result, center=center, width_mode=width_mode, a=a, xlim=xlim, half_width=half_width
)
ax.set_xlim(xl)
if show:
plt.show()
return fig, ax
# ----------------------------
# anomaly window plot
# ----------------------------
def plot_anomaly_window(
self,
result,
*,
show: bool = True,
ax=None,
xlim: Tuple[float, float] | None = None,
a: float = 5.0, # xlim = teff * a
show_flat: bool = True,
show_anom: bool = True,
teff_coeff: float = 3.0,
use_errorbar: bool = True,
):
"""
Plot residual around the best anomaly candidate, and overlay template models
ONLY inside the chi2 evaluation window.
- Data: residual vs time (optionally errorbar)
- Model lines are drawn only within |t - t0| <= teff_coeff * teff
Returns (fig, ax).
"""
if getattr(result, "best", None) is None:
return None, None
t0 = float(result.best.t0)
teff = float(result.best.teff)
w = teff_coeff * teff
# CPU arrays
t_np = np.asarray(result.time)
r_np = np.asarray(result.residual)
e_np = np.asarray(result.ferr)
t_plot_np = np.arange(t0 - w, t0 + w + 0.001, 0.001)
# x window for chi2 evaluation
mask = (t_np >= (t0 - w)) & (t_np <= (t0 + w))
mask_j = jnp.asarray(mask)
# JAX arrays for prediction
t = jnp.asarray(t_np)
r = jnp.asarray(r_np)
e = jnp.asarray(e_np)
w_j = 1.0 / (e ** 2)
t_plot = jnp.asarray(t_plot_np)
y_flat = None
y_anom = None
if show_flat:
# predict_flat_model should accept (data_flux, data_ferr) or (r,e) depending on your definition
y_flat = np.asarray(jax.device_get(get_flat_plot_model_masked(t_plot, r, w_j, mask_j)))
if show_anom:
y_anom_j, _ = get_anom_plot_model_masked(t_plot, t0, teff, t, r, w_j, mask_j)
y_anom = np.asarray(jax.device_get(y_anom_j))
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.figure
if use_errorbar:
ax.errorbar(t_np, r_np, yerr=e_np, fmt=".", zorder=2)
else:
ax.plot(t_np, r_np, ".", zorder=2)
# Draw model lines ONLY within the chi2 window
if y_flat is not None:
ax.plot(t_plot_np, y_flat, label="flat", c="C1", zorder=3)
if y_anom is not None:
ax.plot(t_plot_np, y_anom, label="anomaly", c="r", zorder=3)
# range
if xlim is None:
hw = a * teff
xlim = (t0 - hw, t0 + hw)
ax.set_xlim(xlim)
# Set ylim from the visible x-window only so out-of-window points
# do not inflate the vertical range.
vis = (t_np >= xlim[0]) & (t_np <= xlim[1])
if vis.any():
y_samples = [r_np[vis] - e_np[vis], r_np[vis] + e_np[vis]]
line_vis = (t_plot_np >= xlim[0]) & (t_plot_np <= xlim[1])
if line_vis.any():
if y_flat is not None:
y_samples.append(y_flat[line_vis])
if y_anom is not None:
y_samples.append(y_anom[line_vis])
y_all = np.concatenate(y_samples)
y_lo = np.percentile(y_all, 1)
y_hi = np.percentile(y_all, 99)
if y_hi > y_lo:
mg = 0.15 * (y_hi - y_lo)
ax.set_ylim(y_lo - mg, y_hi + mg)
ax.set_xlabel("time")
ax.set_ylabel("residual")
ax.legend()
if show:
plt.show()
return fig, ax
# ----------------------------
# tripanel plot
# ----------------------------
def plot_result(
self,
result,
*,
center: str = "pspl", # "pspl" or "anomaly"
width_mode: str = "pspl", # "pspl" or "anomaly" or "custom"
a: float = 3.0,
xlim: Tuple[float, float] | None = None,
half_width: float | None = None,
min_hw: float = 30.0,
show: bool = True,
figsize=(10, 8),
height_ratios=(3, 1, 1),
show_anomaly_window: bool = False,
teff_coeff: float = 3.0,
model_base_points: int = 128,
model_max_flux_step: float | None = None,
model_min_time_step: float | None = None,
model_max_points: int = 12000,
):
"""
3-panel plot:
1) data + PSPL model
2) residual
3) clusters (t0 vs dchi2)
Range control:
- center/width_mode/a OR xlim/half_width
- min_hw: minimum half-width in time units (default 30 days);
prevents degenerate u0≈0 fits from collapsing the window.
"""
t = np.asarray(result.time)
f = np.asarray(result.flux)
e = np.asarray(result.ferr)
res = np.asarray(result.residual)
clusters = np.asarray(result.clusters_all)
xl = self._compute_xlim(
result, center=center, width_mode=width_mode, a=a, xlim=xlim,
half_width=half_width, min_hw=min_hw,
)
t_plot, m_plot = _adaptive_single_lens_curve(
result.fit,
xl,
base_points=model_base_points,
max_flux_step=model_max_flux_step,
min_time_step=model_min_time_step,
max_points=model_max_points,
)
fig, axes = plt.subplots(
3, 1, figsize=figsize, sharex=True, height_ratios=height_ratios
)
# 1) data + model
ax = axes[0]
ax.errorbar(t, f, yerr=e, fmt="o", markersize=2, alpha=0.7, label="data", zorder=0)
ax.plot(t_plot, m_plot, lw=2, label="best model", zorder=1)
ax.set_xlim(xl)
ax.set_ylabel("flux")
ax.minorticks_on()
ax.legend()
# 2) residual
ax = axes[1]
ax.errorbar(t, res, yerr=e, fmt="o", markersize=2, alpha=1.0, zorder=0)
ax.axhline(0.0, lw=1, zorder=1, c="C1")
ax.set_xlim(xl)
ax.set_ylabel("residual")
ax.minorticks_on()
if show_anomaly_window and (getattr(result, "best", None) is not None):
t0c = float(result.best.t0)
w = float(teff_coeff * result.best.teff)
ax.axvline(t0c, lw=1)
ax.axvspan(t0c - w, t0c + w, alpha=0.1)
# 3) clusters
ax = axes[2]
if clusters.size:
ax.scatter(clusters[:, 0], clusters[:, 2], s=60, marker="x", c="r")
ax.set_xlim(xl)
ax.set_xlabel("time")
ax.set_ylabel("dchi2")
ax.minorticks_on()
# Set data-based ylim for flux and residual panels.
# Prevents the PSPL model spike (which is huge when u0≈0) from
# dominating the y-axis and hiding all the data.
vis = (t >= xl[0]) & (t <= xl[1])
if vis.any():
f_vis = f[vis]
e_vis = e[vis]
f_lo = np.percentile(f_vis - e_vis, 1)
f_hi = np.percentile(f_vis + e_vis, 99)
if f_hi > f_lo:
mg = 0.15 * (f_hi - f_lo)
axes[0].set_ylim(f_lo - mg, f_hi + mg)
r_vis = res[vis]
r_lo = np.percentile(r_vis - e_vis, 1)
r_hi = np.percentile(r_vis + e_vis, 99)
if r_hi > r_lo:
mg = 0.15 * (r_hi - r_lo)
axes[1].set_ylim(r_lo - mg, r_hi + mg)
if show:
plt.show()
return fig, axes
@dataclass
class SingleLensPlotter:
"""
Plot utilities for single lens fitting results only.
This plotter mirrors the interface philosophy of AnomalyPlotter,
but operates directly on SingleLensFitResult.
Conventions
-----------
- center = "pspl" only (kept for API compatibility)
- width_mode:
- "pspl": a * tE * u0
- "custom": use half_width
"""
# ----------------------------
# helpers
# ----------------------------
def _compute_xlim(
self,
fit,
*,
width_mode: str = "pspl",
a: float = 50.0,
xlim: tuple[float, float] | None = None,
half_width: float | None = None,
) -> tuple[float, float]:
"""
Compute xlim for single lens plots.
"""
if xlim is not None:
return xlim
t0, tE, u0 = map(float, np.asarray(fit.params)[:3])
if width_mode == "custom":
if half_width is None:
raise ValueError("width_mode='custom' requires half_width.")
hw = float(half_width)
else:
hw = float(a * abs(tE * u0))
return (t0 - hw, t0 + hw)
# ----------------------------
# basic plots
# ----------------------------
def plot_lc(
self,
fit,
*,
show: bool = True,
ax=None,
width_mode: str = "pspl",
a: float = 3.0,
xlim: tuple[float, float] | None = None,
half_width: float | None = None,
model_base_points: int = 128,
model_max_flux_step: float | None = None,
model_min_time_step: float | None = None,
model_max_points: int = 12000,
):
"""
Plot light curve with single lens best-fit model.
API-consistent with AnomalyPlotter.plot_lc_with_model.
"""
t = np.asarray(fit.time)
f = np.asarray(fit.flux)
e = np.asarray(fit.ferr)
xl = self._compute_xlim(
fit,
width_mode=width_mode,
a=a,
xlim=xlim,
half_width=half_width,
)
t_plot, m_plot = _adaptive_single_lens_curve(
fit,
xl,
base_points=model_base_points,
max_flux_step=model_max_flux_step,
min_time_step=model_min_time_step,
max_points=model_max_points,
)
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.figure
ax.errorbar(t, f, yerr=e, fmt=".", label="data", zorder=0)
ax.plot(t_plot, m_plot, lw=2, label="model", zorder=1)
ax.set_xlabel("time")
ax.set_ylabel("flux")
ax.legend()
ax.minorticks_on()
ax.set_xlim(xl)
if show:
plt.show()
return fig, ax
def plot_residual(
self,
fit,
*,
show: bool = True,
ax=None,
width_mode: str = "pspl",
a: float = 3.0,
xlim: tuple[float, float] | None = None,
half_width: float | None = None,
use_errorbar: bool = True,
):
"""
Plot single lens residual (flux - model_flux).
API-consistent with AnomalyPlotter.plot_residual.
"""
t = np.asarray(fit.time)
r = np.asarray(fit.residual)
e = np.asarray(fit.ferr)
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.figure
if use_errorbar:
ax.errorbar(t, r, yerr=e, fmt=".", zorder=0)
else:
ax.plot(t, r, ".", label="residual", zorder=0)
ax.axhline(0.0, zorder=1, c="C1")
ax.set_xlabel("time")
ax.set_ylabel("residual")
ax.minorticks_on()
xl = self._compute_xlim(
fit,
width_mode=width_mode,
a=a,
xlim=xlim,
half_width=half_width,
)
ax.set_xlim(xl)
if show:
plt.show()
return fig, ax