Skip to content

base

Base class for embedders and models.

Model

Bases: LightningModule

Class for model based on LightningModule.

get_config()

Get model config.

Source code in srai/embedders/_base.py
def get_config(self) -> dict[str, Any]:
    """Get model config."""
    model_config = {
        k: v
        for k, v in vars(self).items()
        if k[0] != "_"
        and k
        not in (
            "training",
            "prepare_data_per_node",
            "allow_zero_length_dataloader_with_multiple_devices",
        )
    }

    return model_config

save(path)

Save the model to a directory.

PARAMETER DESCRIPTION
path

Path to the directory.

TYPE: Path

Source code in srai/embedders/_base.py
def save(self, path: Union[Path, str]) -> None:
    """
    Save the model to a directory.

    Args:
        path (Path): Path to the directory.
    """
    import torch

    torch.save(self.state_dict(), path)

load(path, **kwargs)

classmethod

Load model from a file.

PARAMETER DESCRIPTION
path

Path to the file.

TYPE: Union[Path, str]

**kwargs

Additional kwargs to pass to the model constructor.

TYPE: dict DEFAULT: {}

Source code in srai/embedders/_base.py
@classmethod
def load(cls, path: Union[Path, str], **kwargs: Any) -> "Model":
    """
    Load model from a file.

    Args:
        path (Union[Path, str]): Path to the file.
        **kwargs (dict): Additional kwargs to pass to the model constructor.
    """
    import torch

    if isinstance(path, str):
        path = Path(path)

    model = cls(**kwargs)
    model.load_state_dict(torch.load(path))
    return model

Embedder

Bases: ABC

Abstract class for embedders.

transform(regions_gdf, features_gdf, joint_gdf)

abstractmethod

Embed regions using features.

PARAMETER DESCRIPTION
regions_gdf

Region indexes and geometries.

TYPE: GeoDataFrame

features_gdf

Feature indexes, geometries and feature values.

TYPE: GeoDataFrame

joint_gdf

Joiner result with region-feature multi-index.

TYPE: GeoDataFrame

RETURNS DESCRIPTION
DataFrame

pd.DataFrame: Embedding and geometry index for each region in regions_gdf.

RAISES DESCRIPTION
ValueError

If any of the gdfs index names is None.

ValueError

If joint_gdf.index is not of type pd.MultiIndex or doesn't have 2 levels.

ValueError

If index levels in gdfs don't overlap correctly.

Source code in srai/embedders/_base.py
@abc.abstractmethod
def transform(
    self,
    regions_gdf: gpd.GeoDataFrame,
    features_gdf: gpd.GeoDataFrame,
    joint_gdf: gpd.GeoDataFrame,
) -> pd.DataFrame:  # pragma: no cover
    """
    Embed regions using features.

    Args:
        regions_gdf (gpd.GeoDataFrame): Region indexes and geometries.
        features_gdf (gpd.GeoDataFrame): Feature indexes, geometries and feature values.
        joint_gdf (gpd.GeoDataFrame): Joiner result with region-feature multi-index.

    Returns:
        pd.DataFrame: Embedding and geometry index for each region in regions_gdf.

    Raises:
        ValueError: If any of the gdfs index names is None.
        ValueError: If joint_gdf.index is not of type pd.MultiIndex or doesn't have 2 levels.
        ValueError: If index levels in gdfs don't overlap correctly.
    """
    raise NotImplementedError