#!/usr/bin/python3
"""
Find electron spin-resonance crossings and write them to an SDDS file.

Resonance condition:
    nu_s = k * N + l * nu_x + m * nu_y

For electrons in a planar ring:
    nu_s = a_e * gamma

so the crossing gammas are:
    gamma = (k * N + l * nu_x + m * nu_y) / a_e

Input tune options:
  1) Provide --nu-x and --nu-y directly, or
  2) Provide --twiss-file and read SDDS parameters nux and nuy from it.

Output:
  One-page SDDS file with columns:
      gamma, nu_s, k, l, m, order, kind
  and parameters describing the search inputs.

Notes:
  * Uses the SDDS Python module distributed as soliday.sdds on PyPI.
  * Depending on installation, the import may be either:
        import soliday.sdds as sdds
    or
        import sdds
"""

from __future__ import annotations

import argparse
import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

# Try both common import styles for the soliday.sdds package.
try:
    import soliday.sdds as sdds
except ImportError:
    import sdds  # type: ignore


A_E = 0.00115965218128  # electron anomalous magnetic moment


@dataclass(frozen=True)
class Resonance:
    gamma: float
    nu_s: float
    k: int
    l: int
    m: int
    kind: str
    order: int


def resonance_kind(l: int, m: int) -> str:
    return "imperfection" if (l == 0 and m == 0) else "intrinsic"


def read_tunes_from_sdds(filename: str, page: int) -> Tuple[float, float]:
    """Read nux and nuy parameters from an SDDS file."""
    obj = sdds.SDDS(0)
    obj.load(filename)

    try:
        nu_x = obj.getParameterValue("nux", page=page)
        nu_y = obj.getParameterValue("nuy", page=page)
    except Exception as exc:
        raise RuntimeError(
            f"Failed to read parameters 'nux' and 'nuy' from {filename!r} "
            f"on page {page}"
        ) from exc

    # The SDDS API may return scalars or 1-element containers depending on usage.
    nu_x = unwrap_scalar(nu_x)
    nu_y = unwrap_scalar(nu_y)

    try:
        return float(nu_x), float(nu_y)
    except Exception as exc:
        raise RuntimeError(
            f"Could not convert nux={nu_x!r} and nuy={nu_y!r} to float"
        ) from exc


def unwrap_scalar(value):
    """Convert possible 1-element SDDS return containers to a plain scalar."""
    if isinstance(value, (list, tuple)):
        if len(value) == 0:
            raise ValueError("Empty container returned where scalar was expected")
        if len(value) == 1:
            return value[0]
    return value


def generate_resonances(
    nu_x: float,
    nu_y: float,
    superperiodicity: int,
    gamma_min: float,
    gamma_max: float,
    max_order: int,
    imperfection_only: bool,
    intrinsic_only: bool,
    dedup_tol: float,
) -> List[Resonance]:
    if gamma_min <= 0 or gamma_max <= 0:
        raise ValueError("gamma_min and gamma_max must be positive")
    if gamma_max < gamma_min:
        raise ValueError("gamma_max must be >= gamma_min")
    if superperiodicity <= 0:
        raise ValueError("superperiodicity must be positive")
    if max_order < 0:
        raise ValueError("max_order must be >= 0")
    if imperfection_only and intrinsic_only:
        raise ValueError("Cannot use both --imperfection-only and --intrinsic-only")

    nu_s_min = A_E * gamma_min
    nu_s_max = A_E * gamma_max

    lm_pairs: List[Tuple[int, int]] = []
    for l in range(-max_order, max_order + 1):
        for m in range(-max_order, max_order + 1):
            if abs(l) + abs(m) <= max_order:
                lm_pairs.append((l, m))

    results: List[Resonance] = []
    gamma_bins: Dict[int, List[Resonance]] = {}

    for l, m in lm_pairs:
        kind = resonance_kind(l, m)

        if imperfection_only and kind != "imperfection":
            continue
        if intrinsic_only and kind != "intrinsic":
            continue

        offset = l * nu_x + m * nu_y
        k_min = math.ceil((nu_s_min - offset) / superperiodicity)
        k_max = math.floor((nu_s_max - offset) / superperiodicity)

        for k in range(k_min, k_max + 1):
            nu_s = k * superperiodicity + offset
            gamma = nu_s / A_E

            if not (gamma_min <= gamma <= gamma_max):
                continue

            res = Resonance(
                gamma=gamma,
                nu_s=nu_s,
                k=k,
                l=l,
                m=m,
                kind=kind,
                order=abs(l) + abs(m),
            )

            # Suppress near-duplicate gammas from different integer tuples.
            gamma_bin = round(gamma / dedup_tol)
            existing = gamma_bins.get(gamma_bin, [])
            if any(abs(r.gamma - gamma) <= dedup_tol for r in existing):
                continue

            gamma_bins.setdefault(gamma_bin, []).append(res)
            results.append(res)

    return sorted(results, key=lambda r: (r.gamma, r.order, r.k, r.l, r.m))


def write_resonances_to_sdds(
    filename: str,
    resonances: List[Resonance],
    nu_x: float,
    nu_y: float,
    superperiodicity: int,
    gamma_min: float,
    gamma_max: float,
    max_order: int,
    source_file: Optional[str],
    source_page: int,
    ascii_output: bool,
) -> None:
    obj = sdds.SDDS(0)
    obj.setDescription(
        "Spin resonance crossings",
        "Electron spin resonances from nu_s = k*N + l*nu_x + m*nu_y",
    )

    # Match output mode to user preference.
    obj.mode = sdds.SDDS_ASCII if ascii_output else sdds.SDDS_BINARY

    # Parameters describing the calculation.
    obj.defineSimpleParameter("nux", sdds.SDDS_DOUBLE)
    obj.defineSimpleParameter("nuy", sdds.SDDS_DOUBLE)
    obj.defineSimpleParameter("superperiodicity", sdds.SDDS_LONG)
    obj.defineSimpleParameter("gammaMin", sdds.SDDS_DOUBLE)
    obj.defineSimpleParameter("gammaMax", sdds.SDDS_DOUBLE)
    obj.defineSimpleParameter("maxOrder", sdds.SDDS_LONG)
    obj.defineSimpleParameter("a_e", sdds.SDDS_DOUBLE)
    obj.defineSimpleParameter("sourcePage", sdds.SDDS_LONG)
    obj.defineSimpleParameter("sourceFile", sdds.SDDS_STRING)

    # Output columns.
    obj.defineSimpleColumn("gamma", sdds.SDDS_DOUBLE)
    obj.defineSimpleColumn("nuSpin", sdds.SDDS_DOUBLE)
    obj.defineSimpleColumn("k", sdds.SDDS_LONG)
    obj.defineSimpleColumn("mx", sdds.SDDS_LONG)
    obj.defineSimpleColumn("my", sdds.SDDS_LONG)
    obj.defineSimpleColumn("order", sdds.SDDS_LONG)
    obj.defineSimpleColumn("kind", sdds.SDDS_STRING)

    # Parameter values for page 1.
    obj.setParameterValue("nux", float(nu_x), 1)
    obj.setParameterValue("nuy", float(nu_y), 1)
    obj.setParameterValue("superperiodicity", int(superperiodicity), 1)
    obj.setParameterValue("gammaMin", float(gamma_min), 1)
    obj.setParameterValue("gammaMax", float(gamma_max), 1)
    obj.setParameterValue("maxOrder", int(max_order), 1)
    obj.setParameterValue("a_e", float(A_E), 1)
    obj.setParameterValue("sourcePage", int(source_page), 1)
    obj.setParameterValue("sourceFile", source_file or "", 1)

    # Column data for page 1.
    obj.setColumnValueList("gamma", [r.gamma for r in resonances], 1)
    obj.setColumnValueList("nuSpin", [r.nu_s for r in resonances], 1)
    obj.setColumnValueList("k", [r.k for r in resonances], 1)
    obj.setColumnValueList("mx", [r.l for r in resonances], 1)
    obj.setColumnValueList("my", [r.m for r in resonances], 1)
    obj.setColumnValueList("order", [r.order for r in resonances], 1)
    obj.setColumnValueList("kind", [r.kind for r in resonances], 1)

    obj.save(filename)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Find electron spin resonance crossings and write them to SDDS. Solves for gamma such that nus-mx*nux-my*nuy=integer*periodicity."
    )

    tune_group = parser.add_argument_group("tune input")
    tune_group.add_argument(
        "--nu-x",
        type=float,
        default=None,
        help="Horizontal tune. Optional if --twiss-file is provided.",
    )
    tune_group.add_argument(
        "--nu-y",
        type=float,
        default=None,
        help="Vertical tune. Optional if --twiss-file is provided.",
    )
    tune_group.add_argument(
        "--twiss-file",
        type=str,
        default=None,
        help="SDDS twiss_output file from elegant containing parameters nux and nuy.",
    )
    tune_group.add_argument(
        "--twiss-page",
        type=int,
        default=1,
        help="Page number to read from the SDDS twiss file (default: 1).",
    )

    parser.add_argument(
        "--superperiodicity",
        type=int,
        required=True,
        help="Ring superperiodicity N.",
    )
    parser.add_argument(
        "--gamma-min",
        type=float,
        required=True,
        help="Minimum gamma to search.",
    )
    parser.add_argument(
        "--gamma-max",
        type=float,
        required=True,
        help="Maximum gamma to search.",
    )
    parser.add_argument(
        "--max-order",
        type=int,
        default=2,
        help="Maximum resonance order |l|+|m| to include (default: 2).",
    )
    parser.add_argument(
        "--imperfection-only",
        action="store_true",
        help="Include only imperfection resonances (l = m = 0).",
    )
    parser.add_argument(
        "--intrinsic-only",
        action="store_true",
        help="Include only intrinsic resonances (not l = m = 0).",
    )
    parser.add_argument(
        "--dedup-tol",
        type=float,
        default=1e-9,
        help="Tolerance in gamma for suppressing near-duplicate crossings.",
    )
    parser.add_argument(
        "--output",
        type=str,
        required=True,
        help="Output SDDS filename.",
    )
    parser.add_argument(
        "--ascii-output",
        action="store_true",
        help="Write ASCII SDDS instead of binary.",
    )

    return parser.parse_args()


def resolve_tunes(args: argparse.Namespace) -> Tuple[float, float, Optional[str], int]:
    nu_x = args.nu_x
    nu_y = args.nu_y
    source_file = None
    source_page = args.twiss_page

    if args.twiss_file:
        file_nu_x, file_nu_y = read_tunes_from_sdds(args.twiss_file, args.twiss_page)
        source_file = args.twiss_file
        if nu_x is None:
            nu_x = file_nu_x
        if nu_y is None:
            nu_y = file_nu_y

    if nu_x is None or nu_y is None:
        raise ValueError(
            "You must provide either both --nu-x and --nu-y, or --twiss-file "
            "containing parameters nux and nuy."
        )

    return float(nu_x), float(nu_y), source_file, source_page


def main() -> None:
    args = parse_args()
    nu_x, nu_y, source_file, source_page = resolve_tunes(args)

    resonances = generate_resonances(
        nu_x=nu_x,
        nu_y=nu_y,
        superperiodicity=args.superperiodicity,
        gamma_min=args.gamma_min,
        gamma_max=args.gamma_max,
        max_order=args.max_order,
        imperfection_only=args.imperfection_only,
        intrinsic_only=args.intrinsic_only,
        dedup_tol=args.dedup_tol,
    )

    write_resonances_to_sdds(
        filename=args.output,
        resonances=resonances,
        nu_x=nu_x,
        nu_y=nu_y,
        superperiodicity=args.superperiodicity,
        gamma_min=args.gamma_min,
        gamma_max=args.gamma_max,
        max_order=args.max_order,
        source_file=source_file,
        source_page=source_page,
        ascii_output=args.ascii_output,
    )


if __name__ == "__main__":
    main()
