Source code for astropy.timeseries.core

# Licensed under a 3-clause BSD style license - see LICENSE.rst

from contextlib import contextmanager
from functools import wraps
from types import FunctionType

from astropy.table import QTable

__all__ = ["BaseTimeSeries", "autocheck_required_columns"]

COLUMN_RELATED_METHODS = [
    "add_column",
    "add_columns",
    "keep_columns",
    "remove_column",
    "remove_columns",
    "rename_column",
]


[docs]def autocheck_required_columns(cls): """ This is a decorator that ensures that the table contains specific methods indicated by the _required_columns attribute. The aim is to decorate all methods that might affect the columns in the table and check for consistency after the methods have been run. """ def decorator_method(method): @wraps(method) def wrapper(self, *args, **kwargs): result = method(self, *args, **kwargs) self._check_required_columns() return result return wrapper for name in COLUMN_RELATED_METHODS: if not hasattr(cls, name) or not isinstance(getattr(cls, name), FunctionType): raise ValueError(f"{name} is not a valid method") setattr(cls, name, decorator_method(getattr(cls, name))) return cls
[docs]class BaseTimeSeries(QTable): _required_columns = None _required_columns_enabled = True # If _required_column_relax is True, we don't require the columns to be # present but we do require them to be the correct ones IF present. Note # that this is a temporary state - as soon as the required columns # are all present, we toggle this to False _required_columns_relax = False def _check_required_columns(self): def as_scalar_or_list_str(obj): if not hasattr(obj, "__len__"): return f"'{obj}'" elif len(obj) == 1: return f"'{obj[0]}'" else: return str(obj) if not self._required_columns_enabled: return if self._required_columns is not None: if self._required_columns_relax: required_columns = self._required_columns[: len(self.colnames)] else: required_columns = self._required_columns plural = "s" if len(required_columns) > 1 else "" if not self._required_columns_relax and len(self.colnames) == 0: raise ValueError( f"{self.__class__.__name__} object is invalid - expected" f" '{required_columns[0]}' as the first column{plural} but time" " series has no columns" ) elif self.colnames[: len(required_columns)] != required_columns: raise ValueError( f"{self.__class__.__name__} object is invalid - expected" f" {as_scalar_or_list_str(required_columns)} as the first" f" column{plural} but found" f" {as_scalar_or_list_str(self.colnames[: len(required_columns)])}" ) if ( self._required_columns_relax and self._required_columns == self.colnames[: len(self._required_columns)] ): self._required_columns_relax = False @contextmanager def _delay_required_column_checks(self): self._required_columns_enabled = False yield self._required_columns_enabled = True self._check_required_columns()