Source code for cabaret.queries

import math
import re
import sqlite3
import threading
import warnings
from collections.abc import Callable, Sequence
from contextlib import closing
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import TypeVar

import numpy as np
from astropy.coordinates import Angle, SkyCoord
from astropy.table import Table
from astropy.time import Time
from astropy.units import Quantity

from cabaret.sources import Sources

__all__ = [
    "Filters",
    "GaiaSQLiteSource",
    "GaiaTAPSource",
    "GaiaQuery",
]

_T = TypeVar("_T")


[docs] @dataclass(frozen=True) class GaiaSQLiteSource: """Configuration for querying a local SQLite catalog. Parameters ---------- database : str Path to the SQLite database file. table : str, optional Catalog table name. Default is ``gaia_sources``. """ database: str table: str = "gaia_sources"
[docs] class Filters(Enum): """Allowed Gaia and 2MASS flux filter_band strings. Examples -------- >>> from cabaret.queries import Filters >>> Filters.G <Filters.G: 'phot_g_mean_mag'> >>> Filters.from_string('RP') <Filters.RP: 'phot_rp_mean_mag'> >>> Filters.is_tmass('J') True >>> Filters.options() ('G', 'BP', 'RP', 'J', 'H', 'KS') """ G = "phot_g_mean_mag" """Gaia G band magnitude [Gaia Vega system]""" BP = "phot_bp_mean_mag" """Gaia BP band magnitude [Gaia Vega system]""" RP = "phot_rp_mean_mag" """Gaia RP band magnitude [Gaia Vega system]""" J = "j_m" """2MASS J-band magnitude [2MASS Vega system]""" H = "h_m" """2MASS H-band magnitude [2MASS Vega system]""" KS = "ks_m" """2MASS KS-band magnitude [2MASS Vega system]"""
[docs] @classmethod def options(cls) -> tuple[str, ...]: """Return all valid filter_band options.""" return tuple(cls.__members__.keys())
[docs] @classmethod def from_string(cls, value: str) -> "Filters": """Return the Filters enum member for a given string.""" try: return cls[value.upper()] except KeyError: raise ValueError( f"Invalid filter_band string: {value}. " f"Valid options are: {cls.options()}" )
[docs] @classmethod def is_tmass(cls, value: "Filters | str") -> bool: """Check if the filter_band string is a 2MASS filter_band.""" if isinstance(value, cls): name = value.name elif isinstance(value, str): name = value.upper() else: raise ValueError( f"Value must be a Filters enum or string, got {type(value)}" ) return name in ("J", "H", "KS")
[docs] @classmethod def ensure_enum(cls, value: "Filters | str") -> "Filters": """Convert a string or Filters to Filters enum.""" if isinstance(value, cls): return value elif isinstance(value, str): return cls.from_string(value) else: raise ValueError( f"Value must be a Filters enum or string, got {type(value)}" )
[docs] @classmethod def all(cls) -> list["Filters"]: """Return all filter bands in definition order.""" return list(cls)
[docs] @classmethod def is_valid(cls, value: str) -> bool: """Check if the filter_band string is valid.""" return value.upper() in cls.__members__
[docs] class GaiaTAPSource(Enum): """TAP service endpoints for Gaia DR3 data. Examples -------- >>> from cabaret.queries import GaiaTAPSource >>> GaiaTAPSource.VIZIER <GaiaTAPSource.VIZIER: 'https://tapvizier.cds.unistra.fr/TAPVizieR/tap'> >>> GaiaTAPSource.ensure_enum("GAIA") <GaiaTAPSource.GAIA: 'https://gea.esac.esa.int/tap-server/tap'> """ GAIA = "https://gea.esac.esa.int/tap-server/tap" """ESA Gaia Archive TAP service.""" VIZIER = "https://tapvizier.cds.unistra.fr/TAPVizieR/tap" """CDS VizieR TAP service (hosts a copy of Gaia DR3)."""
[docs] @classmethod def ensure_enum(cls, value: "GaiaTAPSource | str") -> "GaiaTAPSource": """Convert a string or GaiaTAPSource to GaiaTAPSource enum.""" if isinstance(value, cls): return value elif isinstance(value, str): try: return cls[value.upper()] except KeyError: raise ValueError( f"Invalid TAP source {value!r}. " f"Valid options are: {[m.name for m in cls]}" ) else: raise ValueError( f"Value must be a GaiaTAPSource enum or string, got {type(value)}" )
# Per-source ADQL building blocks. All sources expose the same normalised # column names via AS aliases so the rest of the code needs no changes. _TAP_CONFIG: dict[GaiaTAPSource, dict] = { GaiaTAPSource.GAIA: { "from": "gaiadr3.gaia_source AS gaia", "ra": "gaia.ra", "dec": "gaia.dec", "pmra": "gaia.pmra", "pmdec": "gaia.pmdec", "g_mag": "phot_g_mean_mag", "bp_mag": "phot_bp_mean_mag", "rp_mag": "phot_rp_mean_mag", "tmass_joins": [ "INNER JOIN gaiadr3.tmass_psc_xsc_best_neighbour AS tmass_match" " ON tmass_match.source_id = gaia.source_id", "INNER JOIN external.tmass_psc AS tmass" " ON tmass.designation = tmass_match.original_ext_source_id", ], "j_col": "tmass.j_m", "h_col": "tmass.h_m", "ks_col": "tmass.ks_m", }, GaiaTAPSource.VIZIER: { "from": '"I/355/gaiadr3" AS g', "ra": "g.RA_ICRS", "dec": "g.DE_ICRS", "pmra": "g.pmRA", "pmdec": "g.pmDE", "g_mag": "g.Gmag", "bp_mag": "g.BPmag", "rp_mag": "g.RPmag", "tmass_joins": [ 'INNER JOIN "II/246/out" AS t ON g."2MASS" = t."2MASS"', ], "j_col": "t.Jmag", "h_col": "t.Hmag", "ks_col": "t.Kmag", }, } # Map Filters enum names to the per-source 2MASS column key in _TAP_CONFIG. _TMASS_COL_KEY = {"J": "j_col", "H": "h_col", "KS": "ks_col"}
[docs] class GaiaQuery: """Class to query Gaia DR3 data and retrieve sources. The class provides methods to query a configurable TAP service and return either the raw Astropy Table, a flux-normalised multi-band Table, or a Sources instance with RA-DEC coordinates and fluxes. The query backend is selected via ``GaiaQuery.DEFAULT_TAP_SOURCE`` (class level) or the per-call ``tap_source`` argument. Accepts a ``GaiaTAPSource`` for an online TAP service, a ``GaiaSQLiteSource`` for an offline SQLite database, a TAP endpoint URL, a ``sqlite:///...`` URI, or a direct path to a ``.db``/``.sqlite``/``.sqlite3`` file. The default is ``GaiaTAPSource.VIZIER`` (CDS VizieR), which hosts a copy of Gaia DR3. Examples -------- >>> from cabaret.queries import GaiaQuery >>> from astropy.coordinates import SkyCoord >>> center = SkyCoord(ra=10.68458, dec=41.269, unit="deg") The Astropy Table from Gaia can be obtained with: >>> table = GaiaQuery.query(center, radius=0.05, limit=10, timeout=30) Whereas a Sources instance carrying coordinates and fluxes can be queried for with: >>> sources = GaiaQuery.get_sources(center, radius=0.05, limit=10, timeout=30) """ DEFAULT_TAP_SOURCE: GaiaTAPSource | GaiaSQLiteSource | str = GaiaTAPSource.VIZIER """Default TAP service used when ``tap_source=None`` is passed to query methods."""
[docs] @staticmethod def query( center: tuple[float, float] | SkyCoord | None = None, radius: float | Angle | None = None, bounds: tuple[float, float, float, float] | None = None, filter_bands: Filters | str | Sequence[Filters | str] = Filters.G, limit: int = 100000, timeout: float | None = None, tap_source: GaiaTAPSource | GaiaSQLiteSource | str | None = None, allow_nulls: bool = False, filter_band: Filters | str | None = None, ) -> Table: """Query a Gaia DR3 TAP service within a given radius around the center. Parameters ---------- center : tuple or astropy.coordinates.SkyCoord The sky coordinates of the center of the FOV. If a tuple is given, it should contain the RA and DEC in degrees. radius : float or astropy.units.Quantity The radius of the FOV in degrees. If a Quantity is given, it must be convertible to degrees. bounds : tuple or None Field-of-view bounds specified as ``(ra_min, ra_max, dec_min, dec_max)`` in degrees. RA values are interpreted modulo 360 (degrees) and may be given so that ``ra_min <= ra_max`` for a contiguous RA interval, or ``ra_min > ra_max`` to indicate an interval that wraps across RA=0 (for example ``(350.0, 10.0, -5.0, 5.0)``). Declination bounds must satisfy ``-90 <= dec_min <= dec_max <= 90``. filter_bands : Filters, str, or sequence thereof, optional One or more filter bands to include as magnitude columns. Accepts a single ``Filters`` member or its name as a string, or a list of either. Pass ``"all"`` to request every available band (``Filters.all()``). Default is ``Filters.G``. When multiple bands are requested, the ``ORDER BY`` is determined by the first band (brightest-first, ASC for all magnitude columns). If any 2MASS band is included the 2MASS cross-match join is added. limit : int, optional The maximum number of sources to retrieve from the Gaia archive. By default, it is set to 100000. timeout : float, optional The maximum time to wait for the Gaia query to complete, in seconds. If None, there is no timeout. By default, it is set to None. tap_source : GaiaTAPSource, GaiaSQLiteSource, str, or None, optional Query backend to use. Accepts a ``GaiaTAPSource`` for an online TAP service, a ``GaiaSQLiteSource`` for an offline SQLite database, a TAP endpoint URL as a string, a ``sqlite:///...`` URI, or a direct path to a SQLite database file ending in ``.db``, ``.sqlite``, or ``.sqlite3``. If None, the default TAP backend is used. allow_nulls : bool, optional If False (default), only rows where all requested band columns are non-NULL are returned (``IS NOT NULL`` filter per band). Set to True to allow rows with missing magnitude values through. filter_band : Filters or str, optional Deprecated. Use ``filter_bands`` instead. Returns ------- astropy.table.Table The raw Astropy Table returned by the TAP service, with columns normalised to ``ra``, ``dec``, ``pmra``, ``pmdec``, and one magnitude column per requested band named after ``filter_band.value`` (e.g. ``phot_g_mean_mag``, ``h_m``). Examples -------- >>> from cabaret.queries import GaiaQuery >>> from astropy.coordinates import SkyCoord >>> center = SkyCoord(ra=10.68458, dec=41.26917, unit='deg') >>> table = GaiaQuery.query(center, radius=0.1, limit=10, timeout=30) """ if filter_band is not None: warnings.warn( "filter_band is deprecated; use filter_bands instead.", DeprecationWarning, stacklevel=2, ) filter_bands = filter_band requested_all = isinstance(filter_bands, str) and filter_bands.upper() == "ALL" bands = GaiaQuery._normalize_bands(filter_bands) center_coords, radius_deg, bounds_tuple = GaiaQuery._normalize_region( center=center, radius=radius, bounds=bounds, ) backend = GaiaQuery._resolve_query_source(tap_source) if isinstance(backend, GaiaTAPSource): table = GaiaQuery._query_tap( tap_source=backend, bands=bands, center=center_coords, radius_deg=radius_deg, bounds=bounds_tuple, limit=limit, timeout=timeout, allow_nulls=allow_nulls, ) else: table = GaiaQuery._query_sqlite( sqlite_source=backend, requested_bands=bands, requested_all=requested_all, center=center_coords, radius_deg=radius_deg, bounds=bounds_tuple, limit=limit, timeout=timeout, allow_nulls=allow_nulls, ) return table
[docs] @staticmethod def get_flux_table( center: tuple[float, float] | SkyCoord | None = None, radius: float | Angle | None = None, bounds: tuple[float, float, float, float] | None = None, filter_bands: Filters | str | Sequence[Filters | str] = Filters.G, dateobs: datetime | None = None, limit: int = 100000, timeout: float | None = None, tap_source: GaiaTAPSource | GaiaSQLiteSource | str | None = None, allow_nulls: bool = False, keep_mag: bool = False, ) -> Table: """Query and return a Table with all columns expressed as physical fluxes. Identical to :meth:`query` but additionally: * applies proper-motion correction when ``dateobs`` is given, and * converts supported magnitude columns to photons/s/m² using :meth:`_mag_to_photons` and renames them (e.g. ``"j_m"`` → ``"j_flux"`` and ``"phot_g_mean_mag"`` → ``"g_flux"``). Parameters ---------- center : tuple or astropy.coordinates.SkyCoord The sky coordinates of the center of the FOV. radius : float or astropy.units.Quantity The radius of the FOV in degrees. bounds : tuple or None Field-of-view bounds specified as ``(ra_min, ra_max, dec_min, dec_max)`` in degrees. RA values are interpreted modulo 360 (degrees) and may be given so that ``ra_min <= ra_max`` for a contiguous RA interval, or ``ra_min > ra_max`` to indicate an interval that wraps across RA=0 (for example ``(350.0, 10.0, -5.0, 5.0)``). Declination bounds must satisfy ``-90 <= dec_min <= dec_max <= 90``. filter_bands : Filters, str, or sequence thereof, optional One or more filter bands. Default is ``Filters.G``. dateobs : datetime.datetime or None, optional Observation date for proper-motion correction. Default is None. limit : int, optional Maximum number of sources to retrieve. Default is 100000. timeout : float or None, optional Query timeout in seconds. Default is None (no timeout). tap_source : GaiaTAPSource, GaiaSQLiteSource, str, or None, optional Query backend to use. Accepts a ``GaiaTAPSource``, ``GaiaSQLiteSource``, TAP URL, ``sqlite:///...`` URI, or ``.db``/``.sqlite``/``.sqlite3`` path. Default is ``GaiaQuery.DEFAULT_TAP_SOURCE``. allow_nulls : bool, optional Forwarded to :meth:`query`. If True, rows with NULL magnitude values are included in the result. Default is False. keep_mag : bool, optional If True, retain the original magnitude column (e.g. ``"phot_g_mean_mag"``) alongside the converted flux column (e.g. ``"g_flux"``). Default is False. Returns ------- astropy.table.Table Table with columns ``ra``, ``dec``, ``pmra``, ``pmdec``, and one flux column per requested band, all in photons/s/m². Gaia band columns are renamed from ``"phot_<b>_mean_mag"`` to ``"<b>_flux"`` (e.g. ``"g_flux"``); 2MASS band columns are renamed from ``"<band>_m"`` to ``"<band>_flux"`` (e.g. ``"h_flux"``). When ``keep_mag=True``, the original magnitude columns are also present. Examples -------- >>> from cabaret.queries import GaiaQuery, Filters >>> from astropy.coordinates import SkyCoord >>> center = SkyCoord(ra=10.68458, dec=41.26917, unit='deg') >>> table = GaiaQuery.get_flux_table( ... center, radius=0.1, filter_bands=[Filters.G, Filters.H, Filters.KS], ... limit=10, timeout=30, ... ) # doctest: +SKIP """ bands = GaiaQuery._normalize_bands(filter_bands) table = GaiaQuery.query( center=center, radius=radius, bounds=bounds, filter_bands=bands, limit=limit, timeout=timeout, tap_source=tap_source, allow_nulls=allow_nulls, ) if dateobs is not None: table = GaiaQuery._apply_proper_motion(table, dateobs) seen: set[Filters] = set() for band in bands: if band in seen: continue seen.add(band) col = band.value if Filters.is_tmass(band.name): new_name = str(col).removesuffix("_m") + "_flux" else: new_name = band.name.lower() + "_flux" # e.g. "g_flux", "bp_flux" table[new_name] = GaiaQuery._mag_to_photons( np.ma.filled(table[col].value, np.nan), # type: ignore band, ) if not keep_mag: table.remove_column(col) return table
[docs] @staticmethod def get_sources( center: tuple[float, float] | SkyCoord | None = None, radius: float | Angle | None = None, bounds: tuple[float, float, float, float] | None = None, filter_band: Filters | str = Filters.G, dateobs: datetime | None = None, limit: int = 100000, timeout: float | None = None, tap_source: GaiaTAPSource | GaiaSQLiteSource | str | None = None, ) -> Sources: """ Query a Gaia DR3 TAP service to retrieve the RA-DEC coordinates of stars within a given radius of a center position, along with their fluxes. Parameters ---------- center : tuple or astropy.coordinates.SkyCoord The sky coordinates of the center of the FOV. If a tuple is given, it should contain the RA and DEC in degrees. radius : float or astropy.units.Quantity or None The radius of the FOV in degrees. If a Quantity is given, it must be convertible to degrees. bounds : tuple or None Field-of-view bounds specified as ``(ra_min, ra_max, dec_min, dec_max)`` in degrees. RA values are interpreted modulo 360 (degrees) and may be given so that ``ra_min <= ra_max`` for a contiguous RA interval, or ``ra_min > ra_max`` to indicate an interval that wraps across RA=0 (for example ``(350.0, 10.0, -5.0, 5.0)``). Declination bounds must satisfy ``-90 <= dec_min <= dec_max <= 90``. filter_band : Filters or str, optional The filter to use for the flux column. Default is Filters.G. dateobs : datetime.datetime, optional The date of the observation. If given, the proper motions of the sources will be taken into account. By default, it is set to None. limit : int, optional The maximum number of sources to retrieve from the Gaia archive. By default, it is set to 10000. timeout : float, optional The maximum time to wait for the Gaia query to complete, in seconds. If None, there is no timeout. By default, it is set to None. tap_source : GaiaTAPSource, GaiaSQLiteSource, str, or None, optional Query backend to use. Accepts a ``GaiaTAPSource`` for an online TAP service, a ``GaiaSQLiteSource`` for an offline SQLite database, a TAP endpoint URL as a string, a ``sqlite:///...`` URI, or a direct path to a SQLite database file ending in ``.db``, ``.sqlite``, or ``.sqlite3``. If None, the default TAP backend is used. Returns ------- Sources A Sources instance containing the coordinates and fluxes of the retrieved sources. Notes ----- Fluxes are always returned in photons/s/m² via :meth:`_mag_to_photons`. Raises ------ ImportError If the astroquery package is not installed. Examples -------- >>> from cabaret.queries import GaiaQuery >>> from astropy.coordinates import SkyCoord >>> center = SkyCoord(ra=10.68458, dec=41.26917, unit='deg') >>> sources = GaiaQuery.get_sources( ... center, radius=0.1, timeout=30, limit=10 ... ) # doctest: +SKIP """ filter_band = Filters.ensure_enum(filter_band) table = GaiaQuery.query( center=center, radius=radius, bounds=bounds, limit=limit, timeout=timeout, filter_bands=filter_band, tap_source=tap_source, ) if dateobs is not None: table = GaiaQuery._apply_proper_motion(table, dateobs) fluxes = GaiaQuery._mag_to_photons( np.ma.filled(table[filter_band.value].value, np.nan), # type: ignore filter_band, ) table.remove_rows(np.isnan(fluxes)) fluxes = fluxes[~np.isnan(fluxes)] return Sources.from_arrays( ra=table["ra"].value.data, # type: ignore dec=table["dec"].value.data, # type: ignore fluxes=fluxes, )
@staticmethod def _resolve_query_source( tap_source: GaiaTAPSource | GaiaSQLiteSource | str | None, ) -> GaiaTAPSource | GaiaSQLiteSource: """Resolve a query source into either TAP or SQLite backend config.""" if tap_source is None: return GaiaQuery.DEFAULT_TAP_SOURCE if isinstance(tap_source, GaiaSQLiteSource | GaiaTAPSource): return tap_source if not isinstance(tap_source, str): raise ValueError( "tap_source must be a GaiaTAPSource, GaiaSQLiteSource, or str, " f"got {type(tap_source)}" ) source_text = tap_source.strip() if source_text.lower().startswith("sqlite:///"): db_path = source_text[len("sqlite:///") :] if not db_path: raise ValueError("SQLite URI must include a database path.") return GaiaSQLiteSource(database=db_path) if source_text.lower().endswith((".db", ".sqlite", ".sqlite3")): return GaiaSQLiteSource(database=source_text) return GaiaTAPSource.ensure_enum(source_text) @staticmethod def _normalize_region( center: tuple[float, float] | SkyCoord | None, radius: float | Angle | None, bounds: tuple[float, float, float, float] | None, ) -> tuple[ tuple[float, float] | None, float | None, tuple[float, float, float, float] | None, ]: """Normalize circle/bounds geometry inputs. Exactly one of ``radius`` or ``bounds`` must be provided. The ``bounds`` tuple should be ``(ra_min, ra_max, dec_min, dec_max)`` in degrees. RA values are normalised modulo 360; if ``ra_min <= ra_max`` the RA interval is contiguous, otherwise ``ra_min > ra_max`` indicates an interval that wraps across RA=0. Declination bounds must satisfy ``-90 <= dec_min <= dec_max <= 90``. The function returns a tuple of ``(center_coords, radius_deg, bounds_tuple)`` where only the applicable geometry is populated. """ if (radius is None) == (bounds is None): raise ValueError( "Exactly one geometry must be specified: either radius or bounds." ) center_coords: tuple[float, float] | None = None radius_deg: float | None = None bounds_tuple: tuple[float, float, float, float] | None = None if radius is not None: if center is None: raise ValueError("center must be provided for radius-based queries.") radius_value = ( radius.value if isinstance(radius, Quantity) else float(radius) ) if radius_value <= 0: raise ValueError("radius must be > 0.") radius_deg = radius_value if isinstance(center, SkyCoord): center_coords = ( float(center.ra.deg % 360.0), # type: ignore float(center.dec.deg), # type: ignore ) else: center_coords = (float(center[0] % 360.0), float(center[1])) else: assert bounds is not None if len(bounds) != 4: raise ValueError( "bounds must be a tuple of (ra_min, ra_max, dec_min, dec_max)." ) ra_min, ra_max, dec_min, dec_max = bounds dec_min = float(dec_min) dec_max = float(dec_max) if dec_min > dec_max: raise ValueError("dec_min must be <= dec_max.") if dec_min < -90.0 or dec_max > 90.0: raise ValueError("declination bounds must stay within [-90, 90].") bounds_tuple = ( float(ra_min % 360.0), float(ra_max % 360.0), dec_min, dec_max, ) return center_coords, radius_deg, bounds_tuple @staticmethod def _query_tap( tap_source: GaiaTAPSource, bands: list[Filters], center: tuple[float, float] | None, radius_deg: float | None, bounds: tuple[float, float, float, float] | None, limit: int, timeout: float | None, allow_nulls: bool, ) -> Table: """Execute a Gaia query against a TAP backend.""" cfg = _TAP_CONFIG[tap_source] select_cols = [ f"{cfg['ra']} AS ra", f"{cfg['dec']} AS dec", f"{cfg['pmra']} AS pmra", f"{cfg['pmdec']} AS pmdec", ] where: list[str] = [] joins: list[str] = [] need_tmass_join = False for band in bands: if Filters.is_tmass(band.name): col_expr = cfg[_TMASS_COL_KEY[band.name]] need_tmass_join = True else: col_expr = cfg[band.name.lower() + "_mag"] select_cols.append(f"{col_expr} AS {band.value}") if not allow_nulls: where.append(f"{col_expr} IS NOT NULL") if need_tmass_join: joins.extend(cfg["tmass_joins"]) first = bands[0] order_by = f"{first.value} ASC" if radius_deg is not None: assert center is not None where.append( f"1=CONTAINS(" f"POINT('ICRS', {cfg['ra']}, {cfg['dec']}), " f"CIRCLE('ICRS', {center[0]}, {center[1]}, {radius_deg}))" ) else: assert bounds is not None where.append( GaiaQuery._build_tap_bounds_where( cfg["ra"], cfg["dec"], bounds, ) ) select_clause = ", ".join(select_cols) joins_clause = "\n".join(joins) where_clause = " AND ".join(where) adql = f""" SELECT TOP {limit} {select_clause} FROM {cfg["from"]} {joins_clause} WHERE {where_clause} ORDER BY {order_by} """ return GaiaQuery._launch_job_with_timeout( adql, tap_source=tap_source, timeout=timeout ) @staticmethod def _query_sqlite( sqlite_source: GaiaSQLiteSource, requested_bands: list[Filters], requested_all: bool, center: tuple[float, float] | None, radius_deg: float | None, bounds: tuple[float, float, float, float] | None, limit: int, timeout: float | None, allow_nulls: bool, ) -> Table: """Execute a Gaia-like query against a local SQLite catalog.""" connection = sqlite3.connect(sqlite_source.database, check_same_thread=False) connection.row_factory = sqlite3.Row def _run_query() -> Table: with closing(connection): cursor = connection.cursor() selected_tables, schema_table = GaiaQuery._select_sqlite_tables( cursor=cursor, sqlite_source=sqlite_source, center=center, radius_deg=radius_deg, bounds=bounds, ) schema_table_name = GaiaQuery._quote_sql_identifier(schema_table) available_columns = { str(row[1]) for row in cursor.execute(f"PRAGMA table_info({schema_table_name})") } required_position_columns = {"ra", "dec"} missing_required = sorted(required_position_columns - available_columns) if missing_required: raise ValueError( "SQLite source is missing required columns: " f"{', '.join(missing_required)}" ) if requested_all: selected_bands = [ band for band in requested_bands if band.value in available_columns ] if not selected_bands: raise ValueError( "SQLite source does not contain any supported " "magnitude columns." ) else: missing_requested = [ band.value for band in requested_bands if band.value not in available_columns ] if missing_requested: raise ValueError( "SQLite source is missing requested band columns: " f"{', '.join(sorted(missing_requested))}" ) selected_bands = requested_bands select_cols = [ '"ra" AS ra', '"dec" AS dec', ( '"pmra" AS pmra' if "pmra" in available_columns else "NULL AS pmra" ), ( '"pmdec" AS pmdec' if "pmdec" in available_columns else "NULL AS pmdec" ), ] for band in selected_bands: quoted_col = GaiaQuery._quote_sql_identifier(band.value) select_cols.append(f"{quoted_col} AS {band.value}") where_sql: list[str] = [] params: list[float] = [] if not allow_nulls: for band in selected_bands: quoted_col = GaiaQuery._quote_sql_identifier(band.value) where_sql.append(f"{quoted_col} IS NOT NULL") if bounds is not None: bounds_sql, bounds_params = GaiaQuery._build_sqlite_bounds_where( bounds ) where_sql.append(bounds_sql) params.extend(bounds_params) else: assert center is not None assert radius_deg is not None circle_sql, circle_params = ( GaiaQuery._build_sqlite_circle_prefilter_where( center=center, radius_deg=radius_deg, ) ) where_sql.append(circle_sql) params.extend(circle_params) first_band = selected_bands[0] output_names = ["ra", "dec", "pmra", "pmdec"] + [ band.value for band in selected_bands ] if not selected_tables: return Table(rows=[], names=output_names) output_rows: list[tuple] = [] for table in selected_tables: table_name = GaiaQuery._quote_sql_identifier(table) query_parts = [ f"SELECT {', '.join(select_cols)}", f"FROM {table_name}", ] if where_sql: query_parts.append("WHERE " + " AND ".join(where_sql)) sql = "\n".join(query_parts) rows = list(cursor.execute(sql, tuple(params))) output_rows.extend( tuple( float("nan") if row[name] is None or row[name] == "" else row[name] for name in output_names ) for row in rows ) table = Table(rows=output_rows, names=output_names) if radius_deg is not None: assert center is not None inside_circle = GaiaQuery._on_sky_circle_mask( ra=np.asarray(table["ra"], dtype=float), dec=np.asarray(table["dec"], dtype=float), center_ra=center[0], center_dec=center[1], radius_deg=radius_deg, ) table = table[inside_circle] if len(table) > 0: sort_values = np.asarray( np.ma.filled(table[first_band.value], np.inf), # type: ignore[index] dtype=float, ) table = table[np.argsort(sort_values, kind="stable")] return Table(table[:limit]) return GaiaQuery._run_callable_with_timeout( func=_run_query, timeout=timeout, timeout_message=( "SQLite Gaia query timed out. " "You may want to increase the timeout or reduce the query size." ), on_timeout=connection.interrupt, ) @staticmethod def _build_tap_bounds_where( ra_col: str, dec_col: str, bounds: tuple[float, float, float, float], ) -> str: """Build an ADQL bounds predicate with RA wrap handling.""" ra_min, ra_max, dec_min, dec_max = bounds dec_clause = f"({dec_col} >= {dec_min} AND {dec_col} <= {dec_max})" if ra_min <= ra_max: ra_clause = f"({ra_col} >= {ra_min} AND {ra_col} <= {ra_max})" else: ra_clause = f"({ra_col} >= {ra_min} OR {ra_col} <= {ra_max})" return f"{dec_clause} AND {ra_clause}" @staticmethod def _select_sqlite_tables( cursor: sqlite3.Cursor, sqlite_source: GaiaSQLiteSource, center: tuple[float, float] | None, radius_deg: float | None, bounds: tuple[float, float, float, float] | None, ) -> tuple[list[str], str]: """Resolve SQLite tables for a query, including dec-ring shards.""" all_tables = GaiaQuery._list_sqlite_tables(cursor) if sqlite_source.table in all_tables: return [sqlite_source.table], sqlite_source.table ring_tables: list[tuple[str, float, float]] = [] for table in all_tables: dec_range = GaiaQuery._parse_dec_ring_table_name(table) if dec_range is not None: ring_tables.append((table, dec_range[0], dec_range[1])) if not ring_tables: raise ValueError( f"SQLite table '{sqlite_source.table}' not found and no sharded " "declination-ring tables were detected." ) if bounds is not None: dec_min = bounds[2] dec_max = bounds[3] else: assert center is not None assert radius_deg is not None dec_min = max(-90.0, center[1] - radius_deg) dec_max = min(90.0, center[1] + radius_deg) ring_dec_min = {table: tmin for table, tmin, _ in ring_tables} selected = sorted( [ table for table, ring_min, ring_max in ring_tables if ring_max >= dec_min and ring_min <= dec_max ], key=lambda name: ring_dec_min[name], ) schema_table = min(ring_tables, key=lambda item: item[1])[0] return selected, schema_table @staticmethod def _list_sqlite_tables(cursor: sqlite3.Cursor) -> list[str]: """List user tables in a SQLite database.""" return [ str(row[0]) for row in cursor.execute( "SELECT name FROM sqlite_master " "WHERE type='table' AND name NOT LIKE 'sqlite_%'" ) ] @staticmethod def _parse_dec_ring_table_name(table_name: str) -> tuple[float, float] | None: """Parse sharded dec-ring table names such as '-25_-24'.""" match = re.match(r"^(-?\d+(?:\.\d+)?)_(-?\d+(?:\.\d+)?)$", table_name) if match is None: return None dec_min = float(match.group(1)) dec_max = float(match.group(2)) if dec_min > dec_max: return dec_max, dec_min return dec_min, dec_max @staticmethod def _build_sqlite_bounds_where( bounds: tuple[float, float, float, float], ) -> tuple[str, tuple[float, ...]]: """Build SQL bounds predicate with parameters and RA wrap handling.""" ra_min, ra_max, dec_min, dec_max = bounds if ra_min <= ra_max: return ( '("dec" >= ? AND "dec" <= ?) AND ("ra" >= ? AND "ra" <= ?)', (dec_min, dec_max, ra_min, ra_max), ) return ( '("dec" >= ? AND "dec" <= ?) AND (("ra" >= ?) OR ("ra" <= ?))', (dec_min, dec_max, ra_min, ra_max), ) @staticmethod def _build_sqlite_circle_prefilter_where( center: tuple[float, float], radius_deg: float, ) -> tuple[str, tuple[float, ...]]: """Build a rectangular SQL prefilter around a circle query.""" center_ra, center_dec = center dec_min = max(-90.0, center_dec - radius_deg) dec_max = min(90.0, center_dec + radius_deg) cos_dec = max(abs(math.cos(math.radians(center_dec))), 1e-6) ra_half_width = radius_deg / cos_dec # When the circle is wide enough to span all RA values (e.g. near a pole), # skip the RA filter — only the dec bounds apply. if ra_half_width >= 180.0 or dec_min <= -90.0 or dec_max >= 90.0: return '"dec" >= ? AND "dec" <= ?', (dec_min, dec_max) ra_min = (center_ra - ra_half_width) % 360.0 ra_max = (center_ra + ra_half_width) % 360.0 return GaiaQuery._build_sqlite_bounds_where((ra_min, ra_max, dec_min, dec_max)) @staticmethod def _on_sky_circle_mask( ra: np.ndarray, dec: np.ndarray, center_ra: float, center_dec: float, radius_deg: float, ) -> np.ndarray: """Return mask for points within an angular radius on the sphere.""" ra_rad = np.radians(ra) dec_rad = np.radians(dec) center_ra_rad = math.radians(center_ra) center_dec_rad = math.radians(center_dec) cos_sep = np.sin(dec_rad) * math.sin(center_dec_rad) + np.cos( dec_rad ) * math.cos(center_dec_rad) * np.cos(ra_rad - center_ra_rad) cos_sep = np.clip(cos_sep, -1.0, 1.0) sep_deg = np.degrees(np.arccos(cos_sep)) return sep_deg <= radius_deg @staticmethod def _quote_sql_identifier(identifier: str) -> str: """Quote a SQLite identifier safely.""" if not identifier: raise ValueError("SQL identifier cannot be empty.") return '"' + identifier.replace('"', '""') + '"' @staticmethod def _run_callable_with_timeout( func: Callable[[], _T], timeout: float | None, timeout_message: str, on_timeout: Callable[[], None] | None = None, ) -> _T: """Run a callable with optional timeout using a daemon thread.""" if timeout is None: return func() result: list[_T] = [] exc: list = [] def _run_safe(): try: result.append(func()) except Exception as e: exc.append(e) t = threading.Thread(target=_run_safe, daemon=True) t.start() t.join(timeout=timeout) if t.is_alive(): if on_timeout is not None: on_timeout() raise TimeoutError(timeout_message) if exc: raise exc[0] return result[0] @staticmethod def _launch_job_with_timeout( query: str, tap_source: GaiaTAPSource, timeout: float | None = None, **kwargs, ) -> Table: """ Launch a TAP job and return its results, optionally enforcing a timeout. Parameters ---------- query : str The ADQL query string. tap_source : GaiaTAPSource The TAP service to use. timeout : float or None, optional Maximum number of seconds to wait for the job to complete. If None, the job is run on the main thread (no thread overhead). **kwargs Additional keyword arguments forwarded to ``TapPlus.launch_job``. Returns ------- object The result returned by job.get_results(). Raises ------ TimeoutError If `timeout` is not None and the call does not complete within `timeout`. """ from astroquery.utils.tap.core import TapPlus def _run() -> Table: tap = TapPlus(url=tap_source.value) job = tap.launch_job(query, **kwargs) return Table(job.get_results()) # type: ignore[arg-type] return GaiaQuery._run_callable_with_timeout( func=_run, timeout=timeout, timeout_message=( "Gaia query timed out. " "You may want to increase the timeout or reduce the query size. " f"Query was: {query}" ), ) # Vega zero-points for all supported bands. # See _derive_band_properties for how these were obtained. _BAND_PROPS = { "G": {"dlam_lam": 0.7117, "flux_m0_Jy": 4031.5}, "BP": {"dlam_lam": 0.5131, "flux_m0_Jy": 3683.21}, "RP": {"dlam_lam": 0.3741, "flux_m0_Jy": 5040.41}, "J": {"dlam_lam": 0.1312, "flux_m0_Jy": 1594.0}, "H": {"dlam_lam": 0.151, "flux_m0_Jy": 1024.0}, "KS": {"dlam_lam": 0.1214, "flux_m0_Jy": 666.8}, } @staticmethod def _mag_to_photons(mags: np.ndarray, filter_band: Filters) -> np.ndarray: """Convert Vega magnitudes to photon flux in photons/s/m². Applies to all supported bands (Gaia G/BP/RP and 2MASS J/H/KS). Formula: Δλ/λ × F₀ × 1.51×10⁷ × 10^(−0.4 × mag), where F₀ is the Vega zero-point flux density in Jy. Parameters ---------- mags : np.ndarray Vega magnitudes. filter_band : Filters Any supported filter band. Returns ------- np.ndarray Flux in photons/s/m². """ try: props = GaiaQuery._BAND_PROPS[filter_band.name] except KeyError: raise ValueError( f"_mag_to_photons: unsupported filter {filter_band}. " f"Supported: {list(GaiaQuery._BAND_PROPS)}." ) Jy = 1.51e7 # [photons s^-1 m^-2 (Δλ/λ)^-1] return props["dlam_lam"] * props["flux_m0_Jy"] * Jy * 10 ** (-0.4 * mags) # Gaia DR3 reference epoch J2016.0 expressed as a decimal year. _GAIA_DR3_EPOCH = float(Time(2016.0, format="jyear").decimalyear) @staticmethod def _apply_proper_motion(table: Table, dateobs: datetime | Time) -> Table: """ Apply proper motion correction to RA and DEC columns for the given observation date. """ if isinstance(dateobs, datetime): dateobs = Time(dateobs) if not isinstance(dateobs, Time): raise ValueError( f"dateobs must be an astropy.time.Time or datetime, got {type(dateobs)}" ) years = float(dateobs.decimalyear) - GaiaQuery._GAIA_DR3_EPOCH # Zero-fill missing proper motions so sources without pm data keep their # catalogue position rather than being silently NaN-corrupted. # np.where returns a plain ndarray (no units/mask), so use column # assignment (=) rather than in-place (+=) to avoid Astropy dropping # column metadata when mixing Column and ndarray operands. pmra = np.where(np.isnan(table["pmra"]), 0.0, table["pmra"]) pmdec = np.where(np.isnan(table["pmdec"]), 0.0, table["pmdec"]) table["ra"] = table["ra"] + years * pmra / 1000 / 3600 # type: ignore table["dec"] = table["dec"] + years * pmdec / 1000 / 3600 # type: ignore return table @staticmethod def _normalize_bands( filter_bands: Filters | str | Sequence[Filters | str], ) -> list["Filters"]: """Normalize filter_bands argument to a deduplicated list[Filters]. Passing the string ``"all"`` (case-insensitive) expands to every available filter, equivalent to ``Filters.all()``. """ if isinstance(filter_bands, str) and filter_bands.upper() == "ALL": return Filters.all() if isinstance(filter_bands, Filters | str): filter_bands = [filter_bands] if not isinstance(filter_bands, Sequence): raise ValueError( f"filter_bands must be a Filters, str, or sequence thereof, " f"got {type(filter_bands)}" ) if not filter_bands: raise ValueError("At least one filter_band must be specified.") return list(dict.fromkeys(Filters.ensure_enum(b) for b in filter_bands)) @staticmethod def _derive_band_properties(): """ Derives physical band properties for Gaia DR3 and 2MASS. This method serves as a reference for how the Vega zero-point fluxes and Δλ/λ values were obtained for the supported bands. Sources ------- 2MASS: Derived from Isophotal Fluxes Cohen et al. 2003 AJ 126 1090 (Table 2) | Bibcode: 2003AJ....126.1090C https://doi.org/10.1086/376474 Gaia: Derived from Magnitude Zero Points Riello et al. 2021 A&A 649 A3 (Table 3) | Bibcode: 2021A&A...649A...3R https://doi.org/10.1051/0004-6361/202039587 Example ------- >>> from cabaret.queries import GaiaQuery >>> properties = GaiaQuery._derive_band_properties() >>> properties == GaiaQuery._BAND_PROPS True """ AB_ZERO_POINT_JY = 3631.0 # GAIA DR3 (Riello et al. 2021 A&A 649 A3, Table 3) # Format: {Band: [ZP_VEG, ZP_AB, FWHM_nm, LAM_0_nm]} gaia_table = { "G": [25.6874, 25.8010, 454.82, 639.07], "BP": [25.3385, 25.3540, 265.90, 518.26], "RP": [24.7479, 25.1040, 292.75, 782.51], } # 2MASS (Martin Cohen et al. 2003 AJ 126 1090, Table 2) # Format: {Band: [Flux_Jy, Bandwidth_um, Lam_Iso_um]} # Note: KS refers to the 2MASS "Short" K-band, which has a narrower # bandwidth and shorter pivot wavelength than standard Johnson K. tmass_table = { "J": [1594.0, 0.162, 1.235], "H": [1024.0, 0.251, 1.662], "KS": [666.8, 0.262, 2.159], } results = {} for band, (zp_v, zp_a, fwhm, l0) in gaia_table.items(): # dlam_lam = FWHM / Pivot Wavelength dlam_lam = fwhm / l0 # Flux density of Vega = 3631 * 10^(0.4 * (ZP_AB - ZP_VEG)) flux_jy = AB_ZERO_POINT_JY * (10 ** (0.4 * (zp_a - zp_v))) results[band] = { "dlam_lam": round(dlam_lam, 4), "flux_m0_Jy": round(flux_jy, 2), } for band, (flux_jy, width, l_iso) in tmass_table.items(): # dlam_lam = Width / Isophotal Wavelength dlam_lam = width / l_iso results[band] = { "dlam_lam": round(dlam_lam, 4), "flux_m0_Jy": round(flux_jy, 2), } return results