import io
import os
import warnings
from contextlib import contextmanager
from pathlib import Path
try:
from astropy.coordinates import ICRS
except ImportError:
ICRS = None
try:
from astropy.coordinates.representation import CartesianRepresentation
except ImportError:
CartesianRepresentation = None
try:
from astropy.coordinates.representation import CartesianDifferential
except ImportError:
CartesianDifferential = None
import yaml
import asdf
from .. import generic_io, versioning
from ..asdf import AsdfFile, get_asdf_library_info
from ..block import Block
from ..constants import YAML_TAG_PREFIX
from ..exceptions import AsdfConversionWarning
from ..extension import default_extensions
from ..resolver import Resolver, ResolverChain
from ..tags.core import AsdfObject
from ..versioning import (
AsdfVersion,
asdf_standard_development_version,
get_version_map,
split_tag_version,
supported_versions,
)
from .httpserver import RangeHTTPServer
try:
from pytest_remotedata.disable_internet import INTERNET_OFF
except ImportError:
INTERNET_OFF = False
__all__ = [
"get_test_data_path",
"assert_tree_match",
"assert_roundtrip_tree",
"yaml_to_asdf",
"get_file_sizes",
"display_warnings",
]
[docs]def get_test_data_path(name, module=None):
if module is None:
from . import data as test_data
module = test_data
module_root = Path(module.__file__).parent
if name is None or name == "":
return str(module_root)
else:
return str(module_root / name)
[docs]def assert_tree_match(old_tree, new_tree, ctx=None, funcname="assert_equal", ignore_keys=None):
"""
Assert that two ASDF trees match.
Parameters
----------
old_tree : ASDF tree
new_tree : ASDF tree
ctx : ASDF file context
Used to look up the set of types in effect.
funcname : `str` or `callable`
The name of a method on members of old_tree and new_tree that
will be used to compare custom objects. The default of
``assert_equal`` handles Numpy arrays.
ignore_keys : list of str
List of keys to ignore
"""
seen = set()
if ignore_keys is None:
ignore_keys = ["asdf_library", "history"]
ignore_keys = set(ignore_keys)
if ctx is None:
version_string = str(versioning.default_version)
ctx = default_extensions.extension_list
else:
version_string = ctx.version_string
def recurse(old, new):
if id(old) in seen or id(new) in seen:
return
seen.add(id(old))
seen.add(id(new))
old_type = ctx.type_index.from_custom_type(type(old), version_string)
new_type = ctx.type_index.from_custom_type(type(new), version_string)
if (
old_type is not None
and new_type is not None
and old_type is new_type
and (callable(funcname) or hasattr(old_type, funcname))
):
if callable(funcname):
funcname(old, new)
else:
getattr(old_type, funcname)(old, new)
elif isinstance(old, dict) and isinstance(new, dict):
assert {x for x in old.keys() if x not in ignore_keys} == {x for x in new.keys() if x not in ignore_keys}
for key in old.keys():
if key not in ignore_keys:
recurse(old[key], new[key])
elif isinstance(old, (list, tuple)) and isinstance(new, (list, tuple)):
assert len(old) == len(new)
for a, b in zip(old, new):
recurse(a, b)
# The astropy classes CartesianRepresentation, CartesianDifferential,
# and ICRS do not define equality in a way that is meaningful for unit
# tests. We explicitly compare the fields that we care about in order
# to enable our unit testing. It is possible that in the future it will
# be necessary or useful to account for fields that are not currently
# compared.
elif CartesianRepresentation is not None and isinstance(old, CartesianRepresentation):
assert old.x == new.x and old.y == new.y and old.z == new.z
elif CartesianDifferential is not None and isinstance(old, CartesianDifferential):
assert old.d_x == new.d_x and old.d_y == new.d_y and old.d_z == new.d_z
elif ICRS is not None and isinstance(old, ICRS):
assert old.ra == new.ra and old.dec == new.dec
else:
assert old == new
recurse(old_tree, new_tree)
[docs]def assert_roundtrip_tree(*args, **kwargs):
"""
Assert that a given tree saves to ASDF and, when loaded back,
the tree matches the original tree.
tree : ASDF tree
tmp_path : `str` or `pathlib.Path`
Path to temporary directory to save file
tree_match_func : `str` or `callable`
Passed to `assert_tree_match` and used to compare two objects in the
tree.
raw_yaml_check_func : callable, optional
Will be called with the raw YAML content as a string to
perform any additional checks.
asdf_check_func : callable, optional
Will be called with the reloaded ASDF file to perform any
additional checks.
"""
with warnings.catch_warnings():
warnings.filterwarnings("error", category=AsdfConversionWarning)
_assert_roundtrip_tree(*args, **kwargs)
def _assert_roundtrip_tree(
tree,
tmp_path,
*,
asdf_check_func=None,
raw_yaml_check_func=None,
write_options={},
init_options={},
extensions=None,
tree_match_func="assert_equal",
):
fname = os.path.join(str(tmp_path), "test.asdf")
# First, test writing/reading a BytesIO buffer
buff = io.BytesIO()
AsdfFile(tree, extensions=extensions, **init_options).write_to(buff, **write_options)
assert not buff.closed
buff.seek(0)
with asdf.open(buff, mode="rw", extensions=extensions) as ff:
assert not buff.closed
assert isinstance(ff.tree, AsdfObject)
assert "asdf_library" in ff.tree
assert ff.tree["asdf_library"] == get_asdf_library_info()
assert_tree_match(tree, ff.tree, ff, funcname=tree_match_func)
if asdf_check_func:
asdf_check_func(ff)
buff.seek(0)
ff = AsdfFile(extensions=extensions, **init_options)
content = AsdfFile._open_impl(ff, buff, mode="r", _get_yaml_content=True)
buff.close()
# We *never* want to get any raw python objects out
assert b"!!python" not in content
assert b"!core/asdf" in content
assert content.startswith(b"%YAML 1.1")
if raw_yaml_check_func:
raw_yaml_check_func(content)
# Then, test writing/reading to a real file
ff = AsdfFile(tree, extensions=extensions, **init_options)
ff.write_to(fname, **write_options)
with asdf.open(fname, mode="rw", extensions=extensions) as ff:
assert_tree_match(tree, ff.tree, ff, funcname=tree_match_func)
if asdf_check_func:
asdf_check_func(ff)
# Make sure everything works without a block index
write_options["include_block_index"] = False
buff = io.BytesIO()
AsdfFile(tree, extensions=extensions, **init_options).write_to(buff, **write_options)
assert not buff.closed
buff.seek(0)
with asdf.open(buff, mode="rw", extensions=extensions) as ff:
assert not buff.closed
assert isinstance(ff.tree, AsdfObject)
assert_tree_match(tree, ff.tree, ff, funcname=tree_match_func)
if asdf_check_func:
asdf_check_func(ff)
# Now try everything on an HTTP range server
if not INTERNET_OFF:
server = RangeHTTPServer()
try:
ff = AsdfFile(tree, extensions=extensions, **init_options)
ff.write_to(os.path.join(server.tmpdir, "test.asdf"), **write_options)
with asdf.open(server.url + "test.asdf", mode="r", extensions=extensions) as ff:
assert_tree_match(tree, ff.tree, ff, funcname=tree_match_func)
if asdf_check_func:
asdf_check_func(ff)
finally:
server.finalize()
# Now don't be lazy and check that nothing breaks
with io.BytesIO() as buff:
AsdfFile(tree, extensions=extensions, **init_options).write_to(buff, **write_options)
buff.seek(0)
ff = asdf.open(buff, extensions=extensions, copy_arrays=True, lazy_load=False)
# Ensure that all the blocks are loaded
for block in ff.blocks._internal_blocks:
assert isinstance(block, Block)
assert block._data is not None
# The underlying file is closed at this time and everything should still work
assert_tree_match(tree, ff.tree, ff, funcname=tree_match_func)
if asdf_check_func:
asdf_check_func(ff)
# Now repeat with copy_arrays=False and a real file to test mmap()
AsdfFile(tree, extensions=extensions, **init_options).write_to(fname, **write_options)
with asdf.open(fname, mode="rw", extensions=extensions, copy_arrays=False, lazy_load=False) as ff:
for block in ff.blocks._internal_blocks:
assert isinstance(block, Block)
assert block._data is not None
assert_tree_match(tree, ff.tree, ff, funcname=tree_match_func)
if asdf_check_func:
asdf_check_func(ff)
[docs]def yaml_to_asdf(yaml_content, yaml_headers=True, standard_version=None):
"""
Given a string of YAML content, adds the extra pre-
and post-amble to make it an ASDF file.
Parameters
----------
yaml_content : string
yaml_headers : bool, optional
When True (default) add the standard ASDF YAML headers.
Returns
-------
buff : io.BytesIO()
A file-like object containing the ASDF-like content.
"""
if isinstance(yaml_content, str):
yaml_content = yaml_content.encode("utf-8")
buff = io.BytesIO()
if standard_version is None:
standard_version = versioning.default_version
standard_version = AsdfVersion(standard_version)
vm = get_version_map(standard_version)
file_format_version = vm["FILE_FORMAT"]
yaml_version = vm["YAML_VERSION"]
tree_version = vm["tags"]["tag:stsci.edu:asdf/core/asdf"]
if yaml_headers:
buff.write(
"""#ASDF {}
#ASDF_STANDARD {}
%YAML {}
%TAG ! tag:stsci.edu:asdf/
--- !core/asdf-{}
""".format(
file_format_version, standard_version, yaml_version, tree_version
).encode(
"ascii"
)
)
buff.write(yaml_content)
if yaml_headers:
buff.write(b"\n...\n")
buff.seek(0)
return buff
[docs]def get_file_sizes(dirname):
"""
Get the file sizes in a directory.
Parameters
----------
dirname : string
Path to a directory
Returns
-------
sizes : dict
Dictionary of (file, size) pairs.
"""
files = {}
for filename in os.listdir(dirname):
path = os.path.join(dirname, filename)
if os.path.isfile(path):
files[filename] = os.stat(path).st_size
return files
[docs]def display_warnings(_warnings):
"""
Return a string that displays a list of unexpected warnings
Parameters
----------
_warnings : iterable
List of warnings to be displayed
Returns
-------
msg : str
String containing the warning messages to be displayed
"""
if len(_warnings) == 0:
return "No warnings occurred (was one expected?)"
msg = "Unexpected warning(s) occurred:\n"
for warning in _warnings:
msg += f"{warning.filename}:{warning.lineno}: {warning.category.__name__}: {warning.message}\n"
return msg
@contextmanager
def assert_no_warnings(warning_class=None):
"""
Assert that no warnings were emitted within the context.
Requires that pytest be installed.
Parameters
----------
warning_class : type, optional
Assert only that no warnings of the specified class were
emitted.
"""
import pytest
if warning_class is None:
with warnings.catch_warnings():
warnings.simplefilter("error")
yield
else:
with pytest.warns(Warning) as recorded_warnings:
yield
assert not any(isinstance(w.message, warning_class) for w in recorded_warnings), display_warnings(
recorded_warnings
)
def assert_extension_correctness(extension):
"""
Assert that an ASDF extension's types are all correctly formed and
that the extension provides all of the required schemas.
Parameters
----------
extension : asdf.AsdfExtension
The extension to validate
"""
__tracebackhide__ = True
resolver = ResolverChain(
Resolver(extension.tag_mapping, "tag"),
Resolver(extension.url_mapping, "url"),
)
for extension_type in extension.types:
_assert_extension_type_correctness(extension, extension_type, resolver)
def _assert_extension_type_correctness(extension, extension_type, resolver):
__tracebackhide__ = True
if extension_type.yaml_tag is not None and extension_type.yaml_tag.startswith(YAML_TAG_PREFIX):
return
if extension_type == asdf.stream.Stream:
# Stream is a special case. It was implemented as a subclass of NDArrayType,
# but shares a tag with that class, so it isn't really a distinct type.
return
assert extension_type.name is not None, f"{extension_type.__name__} must set the 'name' class attribute"
# Currently ExtensionType sets a default version of 1.0.0,
# but we want to encourage an explicit version on the subclass.
assert "version" in extension_type.__dict__, "{} must set the 'version' class attribute".format(
extension_type.__name__
)
# check the default version
types_to_check = [extension_type]
# Adding or updating a schema/type version might involve updating multiple
# packages. This can result in types without schema and schema without types
# for the development version of the asdf-standard. To account for this,
# don't include versioned siblings of types with versions that are not
# in one of the asdf-standard versions in supported_versions (excluding the
# current development version).
asdf_standard_versions = supported_versions.copy()
if asdf_standard_development_version in asdf_standard_versions:
asdf_standard_versions.remove(asdf_standard_development_version)
for sibling in extension_type.versioned_siblings:
tag_base, version = split_tag_version(sibling.yaml_tag)
for asdf_standard_version in asdf_standard_versions:
vm = get_version_map(asdf_standard_version)
if tag_base in vm["tags"] and AsdfVersion(vm["tags"][tag_base]) == version:
types_to_check.append(sibling)
break
for check_type in types_to_check:
schema_location = resolver(check_type.yaml_tag)
assert schema_location is not None, (
f"{extension_type.__name__} supports tag, {check_type.yaml_tag}, "
+ "but tag does not resolve. Check the tag_mapping and uri_mapping "
+ f"properties on the related extension ({extension_type.__name__})."
)
if schema_location not in asdf.get_config().resource_manager:
try:
with generic_io.get_file(schema_location) as f:
yaml.safe_load(f.read())
except Exception:
assert False, (
f"{extension_type.__name__} supports tag, {check_type.yaml_tag}, "
+ f"which resolves to schema at {schema_location}, but "
+ "schema cannot be read."
)