from __future__ import annotations
from dataclasses import dataclass, field, replace
from typing import Optional
import logging
import numpy as np
import jax
import jax.numpy as jnp
from .config import FinderConfig
from .criteria import CandidateCriteria
from .singlelens_fit import (
SingleLensFitResult,
PSPLFitter,
CPPPSPLFitter,
FSPLFitter,
PSPLParallaxFitter,
FSPLParallaxFitter,
)
from .plot import AnomalyPlotter
from .seasons import SeasonSplitter
from .extract import ResultExtractor
from .runner import SeasonGridRunner
from .models import AnomalyResult, BestCandidate, CandidateQuality
from .template_free import TemplateFreeScanner, TemplateFreeSearchConfig, TemplateFreeSearchResult
logger = logging.getLogger(__name__)
[docs]
@dataclass
class Finder:
"""
Main entry point of **jacscanomaly**.
`Finder` orchestrates the full anomaly-search pipeline:
1. Fit a single-lens microlensing model to the full light curve
(PSPL / FSPL / ± annual parallax).
2. Split the residual light curve into observing seasons.
3. Perform grid scans on residuals within each season.
4. Extract and merge statistically significant clusters.
5. Select the best anomaly candidate, if any.
The choice of single-lens model is controlled by :class:`FinderConfig`
(via ``fitter_kind``), or by explicitly injecting a fitter instance.
Parameters
----------
config : FinderConfig, optional
Configuration object controlling fitting, season splitting,
grid scanning, and candidate selection.
fitter : optional
A single-lens fitter instance. If ``None``, a default fitter
is constructed from ``config.fitter_kind``.
Any object implementing::
fit(time, flux, ferr, x0) -> SingleLensFitResult
is acceptable.
plotter : AnomalyPlotter, optional
Plotting helper used by the ``plot_*`` convenience methods.
Notes
-----
* The dimensionality of the initial parameter vector ``x0`` depends
on the selected fitter:
======================= ===============================
Model x0 parameters
======================= ===============================
PSPL (t0, tE, u0)
FSPL (t0, tE, u0, logrho)
PSPL + parallax (t0, tE, u0, piEN, piEE)
FSPL + parallax (t0, tE, u0, logrho, piEN, piEE)
======================= ===============================
* For parallax models, ``ra_deg`` and ``dec_deg`` must be provided
in :class:`FinderConfig`. If ``tref`` is not specified, the median
observation time is used.
"""
config: FinderConfig = field(default_factory=FinderConfig)
fitter: Optional[object] = None
plotter: Optional[AnomalyPlotter] = None
def __post_init__(self) -> None:
if self.plotter is None:
self.plotter = AnomalyPlotter()
splitter = SeasonSplitter(gap=self.config.gap)
extractor = ResultExtractor(
sigma_overlap=self.config.overlap_sigma,
min_points=self.config.min_cluster_points,
)
self.runner = SeasonGridRunner(
splitter=splitter,
extractor=extractor,
config=self.config,
)
self._last_result: Optional[AnomalyResult] = None
self._last_template_free_result: Optional[TemplateFreeSearchResult] = None
def _ensure_fitter(self, t_ref) -> None:
"""
Instantiate the default single-lens fitter from the current configuration.
Notes
-----
- If `config.fitter_kind` selects a parallax model, `ra_deg` and `dec_deg`
must be provided. If `tref` is not set, it defaults to `median(time)`.
"""
if self.fitter is not None:
return
k = self.config.fitter_kind
# -----------------------------
# 1) Validate model selection
# -----------------------------
valid = {"pspl", "fspl", "pspl_parallax", "fspl_parallax"}
if k not in valid:
raise ValueError(
f"Unknown fitter_kind '{k}'. "
f"Valid options are: {sorted(valid)}"
)
# -----------------------------
# 2) Validate model requirements
# -----------------------------
needs_parallax = k.endswith("_parallax")
if needs_parallax:
if self.config.ra_deg is None or self.config.dec_deg is None:
raise ValueError(
f"{k} requires ra_deg and dec_deg in FinderConfig "
"(sky coordinates are required for annual parallax)."
)
# -----------------------------
# 3) Build fitter
# -----------------------------
if k == "pspl":
if self.config.single_fit_backend == "cpp":
self.fitter = CPPPSPLFitter(
u0_min=float(self.config.pspl_fit_u0_min),
min_t0_support_points=int(self.config.pspl_fit_min_t0_support_points),
t0_support_tE_coeff=float(self.config.pspl_fit_t0_support_tE_coeff),
)
else:
self.fitter = PSPLFitter()
return
if k == "fspl":
self.fitter = FSPLFitter()
return
# Parallax variants
tref = self.config.tref
if tref is None:
tref = t_ref
if k == "pspl_parallax":
self.fitter = PSPLParallaxFitter(
RA=self.config.ra_deg,
Dec=self.config.dec_deg,
tref=tref,
)
return
# k == "fspl_parallax"
self.fitter = FSPLParallaxFitter(
RA=self.config.ra_deg,
Dec=self.config.dec_deg,
tref=tref,
)
# ------------------------------------------------------------------
# Public APIs
# ------------------------------------------------------------------
def fit_single_lens(
self,
time,
flux,
ferr,
x0=None,
) -> SingleLensFitResult:
"""
Run only the single-lens fit selected by the current configuration.
Parameters
----------
time, flux, ferr : array-like
One-dimensional light-curve arrays.
x0 : array-like, optional
Initial guess for the nonlinear model parameters.
If omitted, initial values are estimated from a scan of the light curve.
Returns
-------
SingleLensFitResult
Result of the single-lens fit.
"""
time_j, flux_j, ferr_j, x0_j, time_np, _, _ = self._to_arrays(time, flux, ferr, x0)
self._ensure_fitter(float(np.median(time_np)))
if x0_j is None:
return self._fit_from_auto_initial_guesses(time_j, flux_j, ferr_j, time_np)
return self.fitter.fit(time_j, flux_j, ferr_j, x0_j)
def run(
self,
time,
flux,
ferr,
x0=None,
*,
verbose: bool = True,
log: Optional[logging.Logger] = None,
) -> AnomalyResult:
"""
Run the full anomaly-search pipeline.
Parameters
----------
time, flux, ferr : array-like
One-dimensional light-curve arrays.
x0 : array-like, optional
Initial guess for the single-lens model parameters. If omitted, the
finder estimates multiple initial values and uses the best fit.
verbose : bool, optional
If True, print progress messages.
log : logging.Logger, optional
Logger used for detailed progress reporting.
Returns
-------
AnomalyResult
Object containing the single-lens fit, residuals,
per-season cluster summaries, and the best anomaly candidate.
"""
time_j, flux_j, ferr_j, x0_j, time_np, flux_np, ferr_np = self._to_arrays(
time, flux, ferr, x0
)
self._ensure_fitter(float(np.median(time_np)))
if x0_j is None:
if verbose:
(logger if log is None else log).info("Estimating single-lens initial values.")
fit = self._fit_from_auto_initial_guesses(time_j, flux_j, ferr_j, time_np)
else:
fit = self.fitter.fit(time_j, flux_j, ferr_j, x0_j)
residual_j = fit.residual
model_flux_j = fit.model_flux
residual_np, model_flux_np, chi2_dof = jax.device_get(
(residual_j, model_flux_j, fit.chi2_dof)
)
residual_np = np.asarray(residual_np, dtype=float)
model_flux_np = np.asarray(model_flux_np, dtype=float)
chi2_dof = float(chi2_dof)
seasons, clusters_all, grid_metrics_all = self.runner.run(
time_j=time_j,
residual_j=residual_j,
ferr_j=ferr_j,
time_np=time_np,
verbose=verbose,
log=log,
)
best_obj = self._pick_best_candidate(clusters_all, grid_metrics_all)
result = AnomalyResult(
time=time_np,
flux=flux_np,
ferr=ferr_np,
fit=fit,
residual=residual_np,
model_flux=model_flux_np,
chi2_dof=chi2_dof,
seasons=seasons,
clusters_all=clusters_all,
grid_metrics_all=grid_metrics_all,
best=best_obj,
)
self._last_result = result
return result
def run_template_free(
self,
time,
flux,
ferr,
x0=None,
*,
fit: Optional[SingleLensFitResult] = None,
config: Optional[TemplateFreeSearchConfig] = None,
) -> TemplateFreeSearchResult:
"""
Run a template-free anomaly search on single-lens residuals.
This leaves the existing bell-template anomaly pipeline untouched.
"""
time_j, flux_j, ferr_j, x0_j, time_np, _, ferr_np = self._to_arrays(
time, flux, ferr, x0
)
if fit is None:
self._ensure_fitter(float(np.median(time_np)))
if x0_j is None:
fit = self._fit_from_auto_initial_guesses(time_j, flux_j, ferr_j, time_np)
else:
fit = self.fitter.fit(time_j, flux_j, ferr_j, x0_j)
residual_np = np.asarray(jax.device_get(fit.residual), dtype=float)
scanner_config = TemplateFreeSearchConfig(gap=self.config.gap) if config is None else config
result = TemplateFreeScanner(scanner_config).run(time_np, residual_np, ferr_np)
self._last_template_free_result = result
return result
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _to_arrays(self, time, flux, ferr, x0):
"""
Validate inputs and convert them to both NumPy and JAX arrays.
"""
time_np = np.asarray(time, dtype=float)
flux_np = np.asarray(flux, dtype=float)
ferr_np = np.asarray(ferr, dtype=float)
if time_np.ndim != 1 or flux_np.ndim != 1 or ferr_np.ndim != 1:
raise ValueError("time/flux/ferr must be 1D arrays.")
if not (len(time_np) == len(flux_np) == len(ferr_np)):
raise ValueError("time/flux/ferr must have the same length.")
if np.any(~np.isfinite(time_np)) or np.any(~np.isfinite(flux_np)) or np.any(~np.isfinite(ferr_np)):
raise ValueError("time/flux/ferr must be finite.")
if np.any(ferr_np <= 0):
raise ValueError("ferr must be positive.")
time_j = jnp.asarray(time_np)
flux_j = jnp.asarray(flux_np)
ferr_j = jnp.asarray(ferr_np)
x0_j = None if x0 is None else jnp.asarray(x0, dtype=time_j.dtype)
return time_j, flux_j, ferr_j, x0_j, time_np, flux_np, ferr_np
def _fit_from_auto_initial_guesses(
self,
time_j: jnp.ndarray,
flux_j: jnp.ndarray,
ferr_j: jnp.ndarray,
time_np: np.ndarray,
) -> SingleLensFitResult:
guesses = self._estimate_single_lens_initial_guesses(
time_j=time_j,
flux_j=flux_j,
ferr_j=ferr_j,
time_np=time_np,
)
best_fit = None
best_chi2 = np.inf
errors = []
for x0 in guesses:
try:
fit = self.fitter.fit(time_j, flux_j, ferr_j, jnp.asarray(x0, dtype=time_j.dtype))
chi2 = float(jax.device_get(fit.chi2))
except Exception as exc:
errors.append(exc)
continue
if np.isfinite(chi2) and chi2 < best_chi2:
best_chi2 = chi2
best_fit = fit
if best_fit is None:
msg = "All automatic single-lens initial guesses failed."
if errors:
msg += f" First error: {errors[0]}"
raise RuntimeError(msg)
return best_fit
def _estimate_single_lens_initial_guesses(
self,
*,
time_j: jnp.ndarray,
flux_j: jnp.ndarray,
ferr_j: jnp.ndarray,
time_np: np.ndarray,
) -> np.ndarray:
cfg = self.config
if cfg.auto_init_teff_min <= 0 or cfg.auto_init_teff_max <= 0:
raise ValueError("auto_init_teff_min and auto_init_teff_max must be positive.")
if cfg.auto_init_teff_max < cfg.auto_init_teff_min:
raise ValueError("auto_init_teff_max must be >= auto_init_teff_min.")
if cfg.auto_init_u0_min <= 0 or cfg.auto_init_u0_max <= 0:
raise ValueError("auto_init_u0_min and auto_init_u0_max must be positive.")
if cfg.auto_init_u0_max < cfg.auto_init_u0_min:
raise ValueError("auto_init_u0_max must be >= auto_init_u0_min.")
n_teff = max(1, int(cfg.auto_init_teff_grid_n))
ratio = 1.0
if n_teff > 1 and cfg.auto_init_teff_max > cfg.auto_init_teff_min:
ratio = float((cfg.auto_init_teff_max / cfg.auto_init_teff_min) ** (1.0 / (n_teff - 1)))
init_config = replace(
cfg,
teff_init=float(cfg.auto_init_teff_min),
common_ratio=ratio,
teff_grid_n=n_teff,
dt0_coeff=float(cfg.auto_init_dt0_coeff),
)
init_runner = SeasonGridRunner(
splitter=SeasonSplitter(gap=cfg.gap),
extractor=ResultExtractor(sigma_overlap=cfg.overlap_sigma, min_points=1),
config=init_config,
)
_, clusters, grid_metrics = init_runner.run(
time_j=time_j,
residual_j=flux_j,
ferr_j=ferr_j,
time_np=time_np,
verbose=False,
)
clusters = np.asarray(clusters, dtype=float)
if clusters.size:
clusters = clusters[np.isfinite(clusters).all(axis=1)]
clusters = clusters[clusters[:, 2] > 0]
grid_metrics = np.asarray(grid_metrics, dtype=float)
if grid_metrics.size:
qualities = np.asarray(
[
self._grid_quality_for_cluster(float(row[0]), float(row[1]), grid_metrics)
for row in clusters
],
dtype=float,
)
pass_eff = qualities[:, 0] >= float(cfg.auto_init_min_n_eff)
if np.any(pass_eff):
clusters = clusters[pass_eff]
qualities = qualities[pass_eff]
order = np.argsort(clusters[:, 2])[::-1]
else:
order = np.lexsort((-clusters[:, 2], -qualities[:, 0]))
else:
order = np.argsort(clusters[:, 2])[::-1]
clusters = clusters[order[: max(1, int(cfg.auto_init_max_clusters))]]
if clusters.size == 0:
flux_np = np.asarray(jax.device_get(flux_j), dtype=float)
i_peak = int(np.nanargmax(flux_np))
span = float(np.nanmax(time_np) - np.nanmin(time_np))
teff = min(max(0.1 * span, float(cfg.auto_init_teff_min)), float(cfg.auto_init_teff_max))
clusters = np.asarray([[float(time_np[i_peak]), teff, 0.0]], dtype=float)
if cfg.auto_init_tE_min <= 0 or cfg.auto_init_tE_max <= 0:
raise ValueError("auto_init_tE_min and auto_init_tE_max must be positive.")
if cfg.auto_init_tE_max < cfg.auto_init_tE_min:
raise ValueError("auto_init_tE_max must be >= auto_init_tE_min.")
n_tE = max(1, int(cfg.auto_init_tE_grid_n))
if n_tE == 1:
tE_grid = np.asarray([float(cfg.auto_init_tE_max)], dtype=float)
else:
tE_grid = np.exp(
np.linspace(
np.log(float(cfg.auto_init_tE_min)),
np.log(float(cfg.auto_init_tE_max)),
n_tE,
)
)
guesses = []
for t0, teff, _ in clusters:
teff = float(teff)
for tE in tE_grid:
u0 = teff / float(tE)
if not (float(cfg.auto_init_u0_min) <= u0 <= float(cfg.auto_init_u0_max)):
continue
guesses.append(self._build_initial_vector(float(t0), float(tE), float(u0)))
if not guesses:
t0 = float(clusters[0, 0])
teff = float(clusters[0, 1])
tE = min(max(teff / 0.1, float(cfg.auto_init_tE_min)), float(cfg.auto_init_tE_max))
u0 = min(max(teff / tE, float(cfg.auto_init_u0_min)), float(cfg.auto_init_u0_max))
guesses.append(self._build_initial_vector(t0, tE, u0))
return np.asarray(guesses, dtype=float)
@staticmethod
def _grid_quality_for_cluster(t0: float, teff: float, metrics: np.ndarray) -> tuple[float, float]:
if metrics.size == 0:
return 0.0, 0.0
i = int(np.argmin(np.abs(metrics[:, 0] - t0) + np.abs(metrics[:, 1] - teff)))
return float(metrics[i, 5]), float(metrics[i, 6])
def _build_initial_vector(self, t0: float, tE: float, u0: float) -> np.ndarray:
k = self.config.fitter_kind
if k == "pspl":
return np.asarray([t0, tE, u0], dtype=float)
if k == "fspl":
return np.asarray([t0, tE, u0, float(self.config.auto_init_logrho)], dtype=float)
if k == "pspl_parallax":
return np.asarray([t0, tE, u0, 0.0, 0.0], dtype=float)
if k == "fspl_parallax":
return np.asarray([t0, tE, u0, float(self.config.auto_init_logrho), 0.0, 0.0], dtype=float)
raise ValueError(f"Unknown fitter_kind '{k}'.")
def _pick_best_candidate(
self,
clusters_all: np.ndarray,
grid_metrics_all: np.ndarray,
) -> Optional[BestCandidate]:
"""
Select the strongest anomaly candidate from all extracted clusters.
"""
if clusters_all is None or clusters_all.size == 0:
return None
clusters_use = np.asarray(clusters_all, dtype=float)
clusters_use = clusters_use[np.isfinite(clusters_use).all(axis=1)]
if clusters_use.size == 0:
return None
max_ind = int(np.argmax(clusters_use[:, 2]))
best = clusters_use[max_ind]
others = np.delete(clusters_use, max_ind, axis=0)
if others.shape[0] >= 2:
other_dchi2 = others[:, 2]
# Estimate the background spread from the bulk of the distribution.
# A few large-dchi2 secondary peaks can otherwise inflate std and
# suppress the best-candidate score.
trim_percentile = float(self.config.best_score_trim_percentile)
if not (0.0 < trim_percentile <= 100.0):
raise ValueError(
"best_score_trim_percentile must satisfy 0 < value <= 100."
)
bulk_dchi2 = other_dchi2
if trim_percentile < 100.0:
cutoff = float(np.percentile(other_dchi2, trim_percentile))
trimmed_dchi2 = other_dchi2[other_dchi2 <= cutoff]
if trimmed_dchi2.shape[0] >= 2:
bulk_dchi2 = trimmed_dchi2
if bulk_dchi2.shape[0] < 2:
bulk_dchi2 = other_dchi2
med = float(np.median(bulk_dchi2))
std = float(np.std(bulk_dchi2))
score = (best[2] - med) / std if std > 0 else float("nan")
else:
med = std = score = float("nan")
quality = self._quality_for_point(float(best[0]), float(best[1]), grid_metrics_all)
return BestCandidate(
t0=float(best[0]),
teff=float(best[1]),
dchi2=float(best[2]),
med_others=med,
std_others=std,
score=float(score),
quality=quality,
)
@staticmethod
def _quality_for_point(t0: float, teff: float, grid_metrics_all: np.ndarray) -> CandidateQuality:
if grid_metrics_all is None or grid_metrics_all.size == 0:
return CandidateQuality(
n_window=0,
n_contrib=0,
n_eff=0.0,
peak_frac=0.0,
rho1=0.0,
longest_run=0,
)
dist = np.abs(grid_metrics_all[:, 0] - t0) + np.abs(grid_metrics_all[:, 1] - teff)
i = int(np.argmin(dist))
row = grid_metrics_all[i]
return CandidateQuality(
n_window=int(round(float(row[3]))),
n_contrib=int(round(float(row[4]))),
n_eff=float(row[5]),
peak_frac=float(row[6]),
rho1=float(row[7]),
longest_run=int(round(float(row[8]))),
)
# ----------------------------
# Plot sugar APIs
# ----------------------------
def _require_result(self) -> AnomalyResult:
if self._last_result is None:
raise RuntimeError("Finder.run() has not been called yet.")
return self._last_result
def plot_lc(self, **kwargs):
"""
Plot light curve with single lens model using the last result.
"""
result = self._require_result()
return self.plotter.plot_lc(result, **kwargs)
def plot_residual(self, **kwargs):
"""
Plot residuals using the last result.
"""
result = self._require_result()
return self.plotter.plot_residual(result, **kwargs)
def plot_anomaly_window(self, **kwargs):
"""
Plot residuals around the best anomaly window.
"""
result = self._require_result()
return self.plotter.plot_anomaly_window(result, **kwargs)
def plot_result(self, **kwargs):
"""
Full 3-panel diagnostic plot.
"""
result = self._require_result()
return self.plotter.plot_result(result, **kwargs)
def plot_template_free(self, **kwargs):
"""
Plot the last template-free anomaly search result.
"""
if self._last_template_free_result is None:
raise RuntimeError("Finder.run_template_free() has not been called yet.")
return self._last_template_free_result.plot(**kwargs)