Source code for pystac.extensions.ext

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Generic, Literal, TypeVar, cast

from pystac import (
    Asset,
    Catalog,
    Collection,
    Item,
    ItemAssetDefinition,
    Link,
    STACError,
)
from pystac.extensions.classification import ClassificationExtension
from pystac.extensions.datacube import DatacubeExtension
from pystac.extensions.eo import EOExtension
from pystac.extensions.file import FileExtension
from pystac.extensions.grid import GridExtension
from pystac.extensions.item_assets import ItemAssetsExtension
from pystac.extensions.mgrs import MgrsExtension
from pystac.extensions.mlm import (
    AssetDetailedMLMExtension,
    AssetGeneralMLMExtension,
    MLMExtension,
)
from pystac.extensions.pointcloud import PointcloudExtension
from pystac.extensions.projection import ProjectionExtension
from pystac.extensions.raster import RasterExtension
from pystac.extensions.render import Render, RenderExtension
from pystac.extensions.sar import SarExtension
from pystac.extensions.sat import SatExtension
from pystac.extensions.scientific import ScientificExtension
from pystac.extensions.storage import StorageExtension
from pystac.extensions.table import TableExtension
from pystac.extensions.timestamps import TimestampsExtension
from pystac.extensions.version import BaseVersionExtension, VersionExtension
from pystac.extensions.view import ViewExtension
from pystac.extensions.xarray_assets import XarrayAssetsExtension

#: Generalized version of :class:`~pystac.Asset`,
#: :class:`~pystac.ItemAssetDefinition`, or :class:`~pystac.Link`
T = TypeVar("T", Asset, ItemAssetDefinition, Link)
#: Generalized version of :class:`~pystac.Asset` or
#: :class:`~pystac.ItemAssetDefinition`
U = TypeVar("U", Asset, ItemAssetDefinition)

EXTENSION_NAMES = Literal[
    "classification",
    "cube",
    "eo",
    "file",
    "grid",
    "item_assets",
    "mgrs",
    "mlm",
    "pc",
    "proj",
    "raster",
    "render",
    "sar",
    "sat",
    "sci",
    "storage",
    "table",
    "timestamps",
    "version",
    "view",
    "xarray",
]

EXTENSION_NAME_MAPPING: dict[EXTENSION_NAMES, Any] = {
    ClassificationExtension.name: ClassificationExtension,
    DatacubeExtension.name: DatacubeExtension,
    EOExtension.name: EOExtension,
    FileExtension.name: FileExtension,
    GridExtension.name: GridExtension,
    ItemAssetsExtension.name: ItemAssetsExtension,
    MgrsExtension.name: MgrsExtension,
    MLMExtension.name: MLMExtension,
    PointcloudExtension.name: PointcloudExtension,
    ProjectionExtension.name: ProjectionExtension,
    RasterExtension.name: RasterExtension,
    RenderExtension.name: RenderExtension,
    SarExtension.name: SarExtension,
    SatExtension.name: SatExtension,
    ScientificExtension.name: ScientificExtension,
    StorageExtension.name: StorageExtension,
    TableExtension.name: TableExtension,
    TimestampsExtension.name: TimestampsExtension,
    VersionExtension.name: VersionExtension,
    ViewExtension.name: ViewExtension,
    XarrayAssetsExtension.name: XarrayAssetsExtension,
}


def _get_class_by_name(name: str) -> Any:
    try:
        return EXTENSION_NAME_MAPPING[cast(EXTENSION_NAMES, name)]
    except KeyError as e:
        raise KeyError(
            f"Extension '{name}' is not a valid extension. "
            f"Options are {list(EXTENSION_NAME_MAPPING)}"
        ) from e


[docs] @dataclass class CatalogExt: """Supporting the :attr:`~pystac.Catalog.ext` accessor for interacting with extension classes """ stac_object: Catalog
[docs] def has(self, name: EXTENSION_NAMES) -> bool: """Whether the given extension is enabled on this STAC object Args: name : Extension identifier (eg: 'eo') Returns: bool: ``True`` if extension is enabled, otherwise ``False`` """ return cast(bool, _get_class_by_name(name).has_extension(self.stac_object))
[docs] def add(self, name: EXTENSION_NAMES) -> None: """Add the given extension to this STAC object Args: name : Extension identifier (eg: 'eo') """ _get_class_by_name(name).add_to(self.stac_object)
[docs] def remove(self, name: EXTENSION_NAMES) -> None: """Remove the given extension from this STAC object Args: name : Extension identifier (eg: 'eo') """ _get_class_by_name(name).remove_from(self.stac_object)
@property def version(self) -> VersionExtension[Catalog]: return VersionExtension.ext(self.stac_object)
[docs] @dataclass class CollectionExt(CatalogExt): """Supporting the :attr:`~pystac.Collection.ext` accessor for interacting with extension classes """ stac_object: Collection @property def cube(self) -> DatacubeExtension[Collection]: return DatacubeExtension.ext(self.stac_object) @property def item_assets(self) -> dict[str, ItemAssetDefinition]: return ItemAssetsExtension.ext(self.stac_object).item_assets @property def mlm(self) -> MLMExtension[Collection]: return MLMExtension.ext(self.stac_object) @property def render(self) -> dict[str, Render]: return RenderExtension.ext(self.stac_object).renders @property def sci(self) -> ScientificExtension[Collection]: return ScientificExtension.ext(self.stac_object) @property def table(self) -> TableExtension[Collection]: return TableExtension.ext(self.stac_object) @property def xarray(self) -> XarrayAssetsExtension[Collection]: return XarrayAssetsExtension.ext(self.stac_object)
[docs] @dataclass class ItemExt: """Supporting the :attr:`~pystac.Item.ext` accessor for interacting with extension classes """ stac_object: Item
[docs] def has(self, name: EXTENSION_NAMES) -> bool: """Whether the given extension is enabled on this STAC object Args: name : Extension identifier (eg: 'eo') Returns: bool: ``True`` if extension is enabled, otherwise ``False`` """ return cast(bool, _get_class_by_name(name).has_extension(self.stac_object))
[docs] def add(self, name: EXTENSION_NAMES) -> None: """Add the given extension to this STAC object Args: name : Extension identifier (eg: 'eo') """ _get_class_by_name(name).add_to(self.stac_object)
[docs] def remove(self, name: EXTENSION_NAMES) -> None: """Remove the given extension from this STAC object Args: name : Extension identifier (eg: 'eo') """ _get_class_by_name(name).remove_from(self.stac_object)
@property def classification(self) -> ClassificationExtension[Item]: return ClassificationExtension.ext(self.stac_object) @property def cube(self) -> DatacubeExtension[Item]: return DatacubeExtension.ext(self.stac_object) @property def eo(self) -> EOExtension[Item]: return EOExtension.ext(self.stac_object) @property def grid(self) -> GridExtension: return GridExtension.ext(self.stac_object) @property def mgrs(self) -> MgrsExtension: return MgrsExtension.ext(self.stac_object) @property def mlm(self) -> MLMExtension[Item]: return MLMExtension.ext(self.stac_object) @property def pc(self) -> PointcloudExtension[Item]: return PointcloudExtension.ext(self.stac_object) @property def proj(self) -> ProjectionExtension[Item]: return ProjectionExtension.ext(self.stac_object) @property def render(self) -> RenderExtension[Item]: return RenderExtension.ext(self.stac_object) @property def sar(self) -> SarExtension[Item]: return SarExtension.ext(self.stac_object) @property def sat(self) -> SatExtension[Item]: return SatExtension.ext(self.stac_object) @property def sci(self) -> ScientificExtension[Item]: return ScientificExtension.ext(self.stac_object) @property def storage(self) -> StorageExtension[Item]: return StorageExtension.ext(self.stac_object) @property def table(self) -> TableExtension[Item]: return TableExtension.ext(self.stac_object) @property def timestamps(self) -> TimestampsExtension[Item]: return TimestampsExtension.ext(self.stac_object) @property def version(self) -> VersionExtension[Item]: return VersionExtension.ext(self.stac_object) @property def view(self) -> ViewExtension[Item]: return ViewExtension.ext(self.stac_object) @property def xarray(self) -> XarrayAssetsExtension[Item]: return XarrayAssetsExtension.ext(self.stac_object)
class _AssetsExt(Generic[T]): stac_object: T def has(self, name: EXTENSION_NAMES) -> bool: """Whether the given extension is enabled on the owner Args: name : Extension identifier (eg: 'eo') Returns: bool: ``True`` if extension is enabled, otherwise ``False`` """ if self.stac_object.owner is None: raise STACError( f"Attempted to use `.ext.has('{name}') for an object with no owner. " "Use `.set_owner` and then try to check the extension again." ) else: return cast( bool, _get_class_by_name(name).has_extension(self.stac_object.owner) ) def add(self, name: EXTENSION_NAMES) -> None: """Add the given extension to the owner Args: name : Extension identifier (eg: 'eo') """ if self.stac_object.owner is None: raise STACError( f"Attempted to add extension='{name}' for an object with no owner. " "Use `.set_owner` and then try to add the extension again." ) else: _get_class_by_name(name).add_to(self.stac_object.owner) def remove(self, name: EXTENSION_NAMES) -> None: """Remove the given extension from the owner Args: name : Extension identifier (eg: 'eo') """ if self.stac_object.owner is None: raise STACError( f"Attempted to remove extension='{name}' for an object with no owner. " "Use `.set_owner` and then try to remove the extension again." ) else: _get_class_by_name(name).remove_from(self.stac_object.owner) class _AssetExt(_AssetsExt[U]): stac_object: U @property def classification(self) -> ClassificationExtension[U]: return ClassificationExtension.ext(self.stac_object) @property def cube(self) -> DatacubeExtension[U]: return DatacubeExtension.ext(self.stac_object) @property def eo(self) -> EOExtension[U]: return EOExtension.ext(self.stac_object) @property def pc(self) -> PointcloudExtension[U]: return PointcloudExtension.ext(self.stac_object) @property def proj(self) -> ProjectionExtension[U]: return ProjectionExtension.ext(self.stac_object) @property def raster(self) -> RasterExtension[U]: return RasterExtension.ext(self.stac_object) @property def sar(self) -> SarExtension[U]: return SarExtension.ext(self.stac_object) @property def sat(self) -> SatExtension[U]: return SatExtension.ext(self.stac_object) @property def storage(self) -> StorageExtension[U]: return StorageExtension.ext(self.stac_object) @property def table(self) -> TableExtension[U]: return TableExtension.ext(self.stac_object) @property def version(self) -> BaseVersionExtension[U]: return BaseVersionExtension.ext(self.stac_object) @property def view(self) -> ViewExtension[U]: return ViewExtension.ext(self.stac_object)
[docs] @dataclass class AssetExt(_AssetExt[Asset]): """Supporting the :attr:`~pystac.Asset.ext` accessor for interacting with extension classes """ stac_object: Asset @property def file(self) -> FileExtension[Asset]: return FileExtension.ext(self.stac_object) @property def mlm(self) -> AssetGeneralMLMExtension[Asset] | AssetDetailedMLMExtension: if "mlm:name" in self.stac_object.extra_fields: return AssetDetailedMLMExtension.ext(self.stac_object) else: return AssetGeneralMLMExtension.ext(self.stac_object) @property def timestamps(self) -> TimestampsExtension[Asset]: return TimestampsExtension.ext(self.stac_object) @property def xarray(self) -> XarrayAssetsExtension[Asset]: return XarrayAssetsExtension.ext(self.stac_object)
[docs] @dataclass class ItemAssetExt(_AssetExt[ItemAssetDefinition]): """Supporting the :attr:`~pystac.ItemAssetDefinition.ext` accessor for interacting with extension classes """ stac_object: ItemAssetDefinition @property def mlm(self) -> MLMExtension[ItemAssetDefinition]: return MLMExtension.ext(self.stac_object)
[docs] @dataclass class LinkExt(_AssetsExt[Link]): """Supporting the :attr:`~pystac.Link.ext` accessor for interacting with extension classes """ stac_object: Link @property def file(self) -> FileExtension[Link]: return FileExtension.ext(self.stac_object)