Skip to content

Index

embedders.s2vec

S2Vec.

S2VecDataset

S2VecDataset(data: pd.DataFrame, img_patch_joint_gdf: gpd.GeoDataFrame)

Bases: Dataset['torch.Tensor'], Generic[T]

Dataset for the S2 masked autoencoder.

It works by returning a 3d tensor of square S2 regions.

PARAMETER DESCRIPTION
data

Data to use for training. Raw counts of features in regions.

TYPE: DataFrame

img_patch_joint_gdf

GeoDataFrame with the images and patches

TYPE: GeoDataFrame

Source code in srai/embedders/s2vec/dataset.py
def __init__(
    self,
    data: pd.DataFrame,
    img_patch_joint_gdf: gpd.GeoDataFrame,
):
    """
    Initialize the S2VecDataset.

    Args:
        data (pd.DataFrame): Data to use for training. Raw counts of features in regions.
        img_patch_joint_gdf (gpd.GeoDataFrame): GeoDataFrame with the images and patches
        S2 indices.
    """
    import_optional_dependencies(dependency_group="torch", modules=["torch"])
    import torch

    # number of columns in the dataset
    self._N: int = data.shape[1]
    # store the data as a torch tensor
    self._data_torch = torch.Tensor(data.to_numpy(dtype=np.float32))

    self.patch_s2_ids = data.index.tolist()

    self._input_ids = [
        [data.index.get_loc(index) for index in group.index.get_level_values(1)]
        for _, group in tqdm(img_patch_joint_gdf.groupby(level=0), disable=FORCE_TERMINAL)
    ]

__getitem__

__getitem__(index: Any) -> torch.Tensor

Return a single item from the dataset.

PARAMETER DESCRIPTION
index

The index of dataset item to return

TYPE: Any

RETURNS DESCRIPTION
Tensor

torch.Tensor: The dataset item

Source code in srai/embedders/s2vec/dataset.py
def __getitem__(self, index: Any) -> "torch.Tensor":
    """
    Return a single item from the dataset.

    Args:
        index (Any): The index of dataset item to return

    Returns:
        torch.Tensor: The dataset item
    """
    patch_idxs = self._input_ids[index]
    return self._data_torch[patch_idxs]

__len__

__len__() -> int

Returns the number of inputs in the dataset.

RETURNS DESCRIPTION
int

Number of inputs in the dataset.

TYPE: int

Source code in srai/embedders/s2vec/dataset.py
def __len__(self) -> int:
    """
    Returns the number of inputs in the dataset.

    Returns:
        int: Number of inputs in the dataset.
    """
    return len(self._input_ids)

S2VecEmbedder

S2VecEmbedder(
    target_features: Union[list[str], OsmTagsFilter, GroupedOsmTagsFilter],
    count_subcategories: bool = True,
    batch_size: Optional[int] = 64,
    img_res: int = 8,
    patch_res: int = 12,
    num_heads: int = 8,
    encoder_layers: int = 6,
    decoder_layers: int = 2,
    embedding_dim: int = 256,
    decoder_dim: int = 128,
    mask_ratio: float = 0.75,
    dropout_prob: float = 0.2,
)

Bases: CountEmbedder

S2Vec Embedder.

PARAMETER DESCRIPTION
target_features

The features that are to be used in the embedding. Should be in "flat" format, i.e. "_", or use OsmTagsFilter object.

TYPE: Union[List[str], OsmTagsFilter, GroupedOsmTagsFilter]

count_subcategories

Whether to count all subcategories individually or count features only on the highest level based on features column name. Defaults to True.

TYPE: bool DEFAULT: True

batch_size

Batch size. Defaults to 64.

TYPE: int DEFAULT: 64

img_res

Image resolution. Defaults to 8.

TYPE: int DEFAULT: 8

patch_res

Patch resolution. Defaults to 12.

TYPE: int DEFAULT: 12

num_heads

Number of heads in the transformer. Defaults to 8.

TYPE: int DEFAULT: 8

encoder_layers

Number of encoder layers in the transformer. Defaults to 6.

TYPE: int DEFAULT: 6

decoder_layers

Number of decoder layers in the transformer. Defaults to 2.

TYPE: int DEFAULT: 2

embedding_dim

Embedding dimension. Defaults to 256.

TYPE: int DEFAULT: 256

decoder_dim

Decoder dimension. Defaults to 128.

TYPE: int DEFAULT: 128

mask_ratio

Mask ratio for the transformer. Defaults to 0.75.

TYPE: float DEFAULT: 0.75

dropout_prob

The dropout probability. Defaults to 0.2.

TYPE: float DEFAULT: 0.2

Source code in srai/embedders/s2vec/embedder.py
def __init__(
    self,
    target_features: Union[list[str], OsmTagsFilter, GroupedOsmTagsFilter],
    count_subcategories: bool = True,
    batch_size: Optional[int] = 64,
    img_res: int = 8,
    patch_res: int = 12,
    num_heads: int = 8,
    encoder_layers: int = 6,
    decoder_layers: int = 2,
    embedding_dim: int = 256,
    decoder_dim: int = 128,
    mask_ratio: float = 0.75,
    dropout_prob: float = 0.2,
) -> None:
    """
    Initialize S2Vec Embedder.

    Args:
        target_features (Union[List[str], OsmTagsFilter, GroupedOsmTagsFilter]): The features
            that are to be used in the embedding. Should be in "flat" format,
            i.e. "<super-tag>_<sub-tag>", or use OsmTagsFilter object.
        count_subcategories (bool, optional): Whether to count all subcategories individually
            or count features only on the highest level based on features column name.
            Defaults to True.
        batch_size (int, optional): Batch size. Defaults to 64.
        img_res (int, optional): Image resolution. Defaults to 8.
        patch_res (int, optional): Patch resolution. Defaults to 12.
        num_heads (int, optional): Number of heads in the transformer. Defaults to 8.
        encoder_layers (int, optional): Number of encoder layers in the transformer.
            Defaults to 6.
        decoder_layers (int, optional): Number of decoder layers in the transformer.
            Defaults to 2.
        embedding_dim (int, optional): Embedding dimension. Defaults to 256.
        decoder_dim (int, optional): Decoder dimension. Defaults to 128.
        mask_ratio (float, optional): Mask ratio for the transformer. Defaults to 0.75.
        dropout_prob (float, optional): The dropout probability. Defaults to 0.2.
    """
    import_optional_dependencies(
        dependency_group="torch", modules=["torch", "pytorch_lightning", "timm"]
    )

    super().__init__(
        expected_output_features=target_features,
        count_subcategories=count_subcategories,
    )

    assert 0.0 <= mask_ratio <= 1.0, "Mask ratio must be between 0 and 1."
    assert 0.0 <= dropout_prob <= 1.0, "Dropout probability must be between 0 and 1."

    self._model: Optional[S2VecModel] = None
    self._is_fitted = False
    self._img_res = img_res
    self._patch_res = patch_res
    self.img_size = 2 ** (patch_res - img_res)
    self._num_heads = num_heads
    self._encoder_layers = encoder_layers
    self._decoder_layers = decoder_layers
    self._embedding_dim = embedding_dim
    self._decoder_dim = decoder_dim
    self._mask_ratio = mask_ratio
    self._dropout_prob = dropout_prob

    self._batch_size = batch_size

    self._dataset: DataLoader = None

fit

fit(
    regions_gdf: gpd.GeoDataFrame,
    features_gdf: gpd.GeoDataFrame,
    learning_rate: float = 0.001,
    trainer_kwargs: Optional[dict[str, Any]] = None,
) -> None

Fit the model to the data.

PARAMETER DESCRIPTION
regions_gdf

Region indexes and geometries.

TYPE: GeoDataFrame

features_gdf

Feature indexes, geometries and feature values.

TYPE: GeoDataFrame

learning_rate

Learning rate. Defaults to 0.001.

TYPE: float DEFAULT: 0.001

trainer_kwargs

Trainer kwargs. This is where the number of epochs can be set. Defaults to None.

TYPE: Optional[Dict[str, Any]] DEFAULT: None

Source code in srai/embedders/s2vec/embedder.py
def fit(
    self,
    regions_gdf: gpd.GeoDataFrame,
    features_gdf: gpd.GeoDataFrame,
    learning_rate: float = 0.001,
    trainer_kwargs: Optional[dict[str, Any]] = None,
) -> None:
    """
    Fit the model to the data.

    Args:
        regions_gdf (gpd.GeoDataFrame): Region indexes and geometries.
        features_gdf (gpd.GeoDataFrame): Feature indexes, geometries and feature values.
        learning_rate (float, optional): Learning rate. Defaults to 0.001.
        trainer_kwargs (Optional[Dict[str, Any]], optional): Trainer kwargs.
            This is where the number of epochs can be set. Defaults to None.
    """
    import pytorch_lightning as pl

    trainer_kwargs = self._prepare_trainer_kwargs(trainer_kwargs)
    counts_df, dataloader, dataset = self._prepare_dataset(  # type: ignore
        regions_gdf,
        features_gdf,
        self._batch_size,
        shuffle=True,
        is_fitting=True,
    )

    self._prepare_model(counts_df, learning_rate)

    trainer = pl.Trainer(**trainer_kwargs)
    trainer.fit(self._model, dataloader)
    self._is_fitted = True
    self._dataset = dataset

fit_transform

fit_transform(
    regions_gdf: gpd.GeoDataFrame,
    features_gdf: gpd.GeoDataFrame,
    learning_rate: float = 0.001,
    trainer_kwargs: Optional[dict[str, Any]] = None,
) -> pd.DataFrame

Fit the model to the data and create region embeddings.

PARAMETER DESCRIPTION
regions_gdf

Region indexes and geometries.

TYPE: GeoDataFrame

features_gdf

Feature indexes, geometries and feature values.

TYPE: GeoDataFrame

learning_rate

Learning rate. Defaults to 0.001.

TYPE: float DEFAULT: 0.001

trainer_kwargs

Trainer kwargs. This is where the number of epochs can be set. Defaults to None.

TYPE: Optional[Dict[str, Any]] DEFAULT: None

Source code in srai/embedders/s2vec/embedder.py
def fit_transform(
    self,
    regions_gdf: gpd.GeoDataFrame,
    features_gdf: gpd.GeoDataFrame,
    learning_rate: float = 0.001,
    trainer_kwargs: Optional[dict[str, Any]] = None,
) -> pd.DataFrame:
    """
    Fit the model to the data and create region embeddings.

    Args:
        regions_gdf (gpd.GeoDataFrame): Region indexes and geometries.
        features_gdf (gpd.GeoDataFrame): Feature indexes, geometries and feature values.
        learning_rate (float, optional): Learning rate. Defaults to 0.001.
        trainer_kwargs (Optional[Dict[str, Any]], optional): Trainer kwargs. This is where the
            number of epochs can be set. Defaults to None.
    """
    self.fit(
        regions_gdf=regions_gdf,
        features_gdf=features_gdf,
        learning_rate=learning_rate,
        trainer_kwargs=trainer_kwargs,
    )
    assert self._dataset is not None  # for mypy
    return self._transform(dataset=self._dataset)

load

classmethod
load(path: Union[Path, str]) -> S2VecEmbedder

Load the model from a directory.

PARAMETER DESCRIPTION
path

Path to the directory.

TYPE: Union[Path, str]

model_module

Model class.

TYPE: type[ModelT]

RETURNS DESCRIPTION
S2VecEmbedder

S2VecEmbedder object.

TYPE: S2VecEmbedder

Source code in srai/embedders/s2vec/embedder.py
@classmethod
def load(cls, path: Union[Path, str]) -> "S2VecEmbedder":
    """
    Load the model from a directory.

    Args:
        path (Union[Path, str]): Path to the directory.
        model_module (type[ModelT]): Model class.

    Returns:
        S2VecEmbedder: S2VecEmbedder object.
    """
    return cls._load(path, S2VecModel)

save

save(path: Union[str, Any]) -> None

Save the S2VecEmbedder model to a directory.

PARAMETER DESCRIPTION
path

Path to the directory.

TYPE: Union[str, Any]

Source code in srai/embedders/s2vec/embedder.py
def save(self, path: Union[str, Any]) -> None:
    """
    Save the S2VecEmbedder model to a directory.

    Args:
        path (Union[str, Any]): Path to the directory.
    """
    embedder_config = {
        "target_features": cast("pd.Series", self.expected_output_features).to_json(
            orient="records"
        ),
        "count_subcategories": self.count_subcategories,
        "batch_size": self._batch_size,
        "img_res": self._img_res,
        "patch_res": self._patch_res,
        "num_heads": self._num_heads,
        "encoder_layers": self._encoder_layers,
        "decoder_layers": self._decoder_layers,
        "embedding_dim": self._embedding_dim,
        "decoder_dim": self._decoder_dim,
        "mask_ratio": self._mask_ratio,
        "dropout_prob": self._dropout_prob,
    }

    normalisation_config = {
        "feature_means": self._feature_means.tolist(),
        "feature_stds": self._feature_stds.tolist(),
        "empty_features_mask": self._empty_features_mask.tolist(),
    }

    self._save(path, embedder_config, normalisation_config)

transform

transform(
    regions_gdf: gpd.GeoDataFrame, features_gdf: gpd.GeoDataFrame
) -> pd.DataFrame

Create region embeddings.

PARAMETER DESCRIPTION
regions_gdf

Region indexes and geometries.

TYPE: GeoDataFrame

features_gdf

Feature indexes, geometries and feature values.

TYPE: GeoDataFrame

RETURNS DESCRIPTION
DataFrame

pd.DataFrame: Region embeddings.

Source code in srai/embedders/s2vec/embedder.py
def transform(  # type: ignore[override]
    self,
    regions_gdf: gpd.GeoDataFrame,
    features_gdf: gpd.GeoDataFrame,
) -> pd.DataFrame:
    """
    Create region embeddings.

    Args:
        regions_gdf (gpd.GeoDataFrame): Region indexes and geometries.
        features_gdf (gpd.GeoDataFrame): Feature indexes, geometries and feature values.

    Returns:
        pd.DataFrame: Region embeddings.
    """
    self._check_is_fitted()

    _, dataloader, self._dataset = self._prepare_dataset(
        regions_gdf,
        features_gdf,
        self._batch_size,
        shuffle=False,
        is_fitting=False,
    )

    return self._transform(dataset=self._dataset, dataloader=dataloader)

S2VecModel

S2VecModel(
    img_size: int,
    patch_size: int,
    in_ch: int,
    num_heads: int = 8,
    encoder_layers: int = 6,
    decoder_layers: int = 2,
    embed_dim: int = 256,
    decoder_dim: int = 128,
    mask_ratio: float = 0.75,
    dropout_prob: float = 0.2,
    lr: float = 0.0005,
    weight_decay: float = 0.001,
)

Bases: Model

S2Vec Model.

This class implements the S2Vec model. It is based on the masked autoencoder architecture. The model is described in [1]. It takes a rasterized image as input (counts of features per region) and outputs dense embeddings.

PARAMETER DESCRIPTION
img_size

The size of the input image.

TYPE: int

patch_size

The size of the patches.

TYPE: int

in_ch

The number of input channels.

TYPE: int

num_heads

The number of attention heads.

TYPE: int DEFAULT: 8

encoder_layers

The number of encoder layers. Defaults to 6.

TYPE: int DEFAULT: 6

decoder_layers

The number of decoder layers. Defaults to 2.

TYPE: int DEFAULT: 2

embed_dim

The dimension of the encoder. Defaults to 256.

TYPE: int DEFAULT: 256

decoder_dim

The dimension of the decoder. Defaults to 128.

TYPE: int DEFAULT: 128

mask_ratio

The ratio of masked patches. Defaults to 0.75.

TYPE: float DEFAULT: 0.75

dropout_prob

The dropout probability. Defaults to 0.2.

TYPE: float DEFAULT: 0.2

lr

The learning rate. Defaults to 5e-4.

TYPE: float DEFAULT: 0.0005

weight_decay

The weight decay. Defaults to 1e-3.

TYPE: float DEFAULT: 0.001

Source code in srai/embedders/s2vec/model.py
def __init__(
    self,
    img_size: int,
    patch_size: int,
    in_ch: int,
    num_heads: int = 8,
    encoder_layers: int = 6,
    decoder_layers: int = 2,
    embed_dim: int = 256,
    decoder_dim: int = 128,
    mask_ratio: float = 0.75,
    dropout_prob: float = 0.2,
    lr: float = 5e-4,
    weight_decay: float = 1e-3,
):
    """
    Initialize the S2Vec model.

    Args:
        img_size (int): The size of the input image.
        patch_size (int): The size of the patches.
        in_ch (int): The number of input channels.
        num_heads (int): The number of attention heads.
        encoder_layers (int): The number of encoder layers. Defaults to 6.
        decoder_layers (int): The number of decoder layers. Defaults to 2.
        embed_dim (int): The dimension of the encoder. Defaults to 256.
        decoder_dim (int): The dimension of the decoder. Defaults to 128.
        mask_ratio (float): The ratio of masked patches. Defaults to 0.75.
        dropout_prob (float): The dropout probability. Defaults to 0.2.
        lr (float): The learning rate. Defaults to 5e-4.
        weight_decay (float): The weight decay. Defaults to 1e-3.
    """
    if img_size <= 0:
        raise ValueError("img_size must be a positive integer.")
    if patch_size <= 0 or img_size % patch_size != 0:
        raise ValueError("patch_size must be a positive integer and divide img_size evenly.")
    if in_ch <= 0:
        raise ValueError("in_ch must be a positive integer.")
    if num_heads <= 0:
        raise ValueError("num_heads must be a positive integer.")
    if encoder_layers <= 0:
        raise ValueError("encoder_layers must be a positive integer.")
    if decoder_layers <= 0:
        raise ValueError("decoder_layers must be a positive integer.")
    if embed_dim <= 0:
        raise ValueError("embed_dim must be a positive integer.")
    if decoder_dim <= 0:
        raise ValueError("decoder_dim must be a positive integer.")
    if not (0.0 < mask_ratio < 1.0):
        raise ValueError("mask_ratio must be between 0 and 1 (exclusive).")
    if not (0.0 <= dropout_prob <= 1.0):
        raise ValueError("dropout_prob must be between 0 and 1 (inclusive).")

    import_optional_dependencies(
        dependency_group="torch", modules=["timm", "torch", "pytorch_lightning"]
    )
    from torch import nn

    super().__init__()

    self.img_size = img_size
    self.patch_size = patch_size
    self.in_ch = in_ch
    self.embed_dim = embed_dim
    num_patches = (img_size // patch_size) ** 2
    patch_dim = patch_size * patch_size * in_ch
    self.grid_size = img_size // patch_size
    self.patch_embed = nn.Linear(patch_dim, embed_dim)
    self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
    self.num_heads = num_heads
    self.encoder_layers = encoder_layers
    self.decoder_layers = decoder_layers
    self.decoder_dim = decoder_dim
    self.dropout_prob = dropout_prob
    self.encoder = MAEEncoder(embed_dim, encoder_layers, num_heads, dropout_prob)
    self.decoder_embed = nn.Linear(embed_dim, decoder_dim)
    self.decoder = MAEDecoder(decoder_dim, patch_dim, decoder_layers, num_heads, dropout_prob)
    self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim))
    self.mask_ratio = mask_ratio
    pos_embed = get_2d_sincos_pos_embed(embed_dim, self.grid_size, cls_token=True)
    self.pos_embed = nn.Parameter(
        torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False
    )
    self.pos_embed.data.copy_(pos_embed.float())
    decoder_pos_embed = get_2d_sincos_pos_embed(decoder_dim, self.grid_size, cls_token=True)
    self.decoder_pos_embed = nn.Parameter(
        torch.zeros(1, num_patches + 1, decoder_dim), requires_grad=False
    )
    self.decoder_pos_embed.data.copy_(decoder_pos_embed.float())
    self.patch_dim = patch_dim
    self.lr = lr
    self.weight_decay = weight_decay

    torch.nn.init.normal_(self.cls_token, std=0.02)
    torch.nn.init.normal_(self.mask_token, std=0.02)

    torch.nn.init.xavier_uniform_(self.patch_embed.weight)
    torch.nn.init.xavier_uniform_(self.decoder_embed.weight)

configure_optimizers

configure_optimizers() -> dict[str, Any]

Configure the optimizers. This is called by PyTorch Lightning.

RETURNS DESCRIPTION
dict[str, Any]

List[torch.optim.Optimizer]: The optimizers.

Source code in srai/embedders/s2vec/model.py
def configure_optimizers(self) -> dict[str, Any]:
    """
    Configure the optimizers. This is called by PyTorch Lightning.

    Returns:
        List[torch.optim.Optimizer]: The optimizers.
    """
    opt: torch.optim.Optimizer = torch.optim.AdamW(
        self.parameters(),
        lr=self.lr,
        weight_decay=self.weight_decay,
        betas=(0.9, 0.95),
    )

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=100)
    return {
        "optimizer": opt,
        "lr_scheduler": {
            "scheduler": scheduler,
            "interval": "epoch",
            "frequency": 1,
        },
    }

decode

decode(x: torch.Tensor, ids_restore: torch.Tensor) -> torch.Tensor

Forward pass of the decoder.

PARAMETER DESCRIPTION
x

The input tensor. The dimensions are (batch_size, num_patches, embed_dim).

TYPE: Tensor

ids_restore

The indices to restore the original order.

TYPE: Tensor

RETURNS DESCRIPTION
Tensor

torch.Tensor: The output tensor from the decoder.

Source code in srai/embedders/s2vec/model.py
def decode(self, x: "torch.Tensor", ids_restore: "torch.Tensor") -> "torch.Tensor":
    """
    Forward pass of the decoder.

    Args:
        x (torch.Tensor): The input tensor. The dimensions are
            (batch_size, num_patches, embed_dim).
        ids_restore (torch.Tensor): The indices to restore the original order.

    Returns:
        torch.Tensor: The output tensor from the decoder.
    """
    x = self.decoder_embed(x)  # Project to decoder dimension
    mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)

    x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
    x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))
    x = torch.cat([x[:, :1, :], x_], dim=1)

    x = x + self.decoder_pos_embed

    x = self.decoder(x)

    x = x[:, 1:, :]  # Exclude class token
    return x

encode

encode(
    x: torch.Tensor, mask_ratio: float
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Forward pass of the encoder.

PARAMETER DESCRIPTION
x

The input tensor. The dimensions are (batch_size, num_patches, embed_dim).

TYPE: Tensor

mask_ratio

The ratio of masked patches.

TYPE: float

RETURNS DESCRIPTION
Tensor

tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The encoded tensor, the mask, and the

Tensor

indices to restore the original order.

Source code in srai/embedders/s2vec/model.py
def encode(
    self, x: "torch.Tensor", mask_ratio: float
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
    """
    Forward pass of the encoder.

    Args:
        x (torch.Tensor): The input tensor. The dimensions are
            (batch_size, num_patches, embed_dim).
        mask_ratio (float): The ratio of masked patches.

    Returns:
        tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The encoded tensor, the mask, and the
        indices to restore the original order.
    """
    x = self.patch_embed(x)
    x = x + self.pos_embed[:, 1:, :]  # Add positional embedding, excluding class token

    x, mask, ids_restore = self.random_masking(x, mask_ratio)

    cls_token = self.cls_token + self.pos_embed[:, :1, :]  # Class token
    cls_tokens = cls_token.expand(x.shape[0], -1, -1)  # Expand class token to batch size

    x = torch.cat([cls_tokens, x], dim=1)  # Concatenate class token

    return self.encoder(x), mask, ids_restore

forward

forward(
    inputs: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Forward pass of the S2Vec model.

PARAMETER DESCRIPTION
inputs

The input tensor. The dimensions are (batch_size, num_patches, num_features).

TYPE: Tensor

RETURNS DESCRIPTION
Tensor

tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The reconstructed tensor,

Tensor

the target tensor, and the mask.

Source code in srai/embedders/s2vec/model.py
def forward(
    self, inputs: "torch.Tensor"
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
    """
    Forward pass of the S2Vec model.

    Args:
        inputs (torch.Tensor): The input tensor. The dimensions are
            (batch_size, num_patches, num_features).

    Returns:
        tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The reconstructed tensor,
        the target tensor, and the mask.
    """
    latent, mask, ids_restore = self.encode(inputs, self.mask_ratio)
    pred = self.decode(latent, ids_restore)
    target = inputs

    return pred, target, mask

get_config

get_config() -> dict[str, Union[int, float]]

Get the model configuration.

RETURNS DESCRIPTION
dict[str, Union[int, float]]

Dict[str, Union[int, float]]: The model configuration.

Source code in srai/embedders/s2vec/model.py
def get_config(self) -> dict[str, Union[int, float]]:
    """
    Get the model configuration.

    Returns:
        Dict[str, Union[int, float]]: The model configuration.
    """
    return {
        "img_size": self.img_size,
        "patch_size": self.patch_size,
        "in_ch": self.in_ch,
        "num_heads": self.num_heads,
        "embed_dim": self.embed_dim,
        "decoder_dim": self.decoder_dim,
        "encoder_layers": self.encoder_layers,
        "decoder_layers": self.decoder_layers,
        "mask_ratio": self.mask_ratio,
        "dropout_prob": self.dropout_prob,
        "lr": self.lr,
    }

load

classmethod
load(path: Union[Path, str], **kwargs: Any) -> Model

Load model from a file.

PARAMETER DESCRIPTION
path

Path to the file.

TYPE: Union[Path, str]

**kwargs

Additional kwargs to pass to the model constructor.

TYPE: dict DEFAULT: {}

Source code in srai/embedders/_base.py
@classmethod
def load(cls, path: Union[Path, str], **kwargs: Any) -> "Model":
    """
    Load model from a file.

    Args:
        path (Union[Path, str]): Path to the file.
        **kwargs (dict): Additional kwargs to pass to the model constructor.
    """
    import torch

    if isinstance(path, str):
        path = Path(path)

    model = cls(**kwargs)
    model.load_state_dict(torch.load(path))
    return model

random_masking

random_masking(
    x: torch.Tensor, mask_ratio: float
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Randomly mask patches in the input tensor.

This function randomly selects a subset of patches to mask and returns the masked tensor, the mask, and the indices to restore the original order. The mask is a binary tensor indicating which patches are masked (1) and which are not (0).

PARAMETER DESCRIPTION
x

The input tensor. The dimensions are (batch_size, num_patches, embed_dim).

TYPE: Tensor

mask_ratio

The ratio of masked patches.

TYPE: float

RETURNS DESCRIPTION
Tensor

tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The masked tensor, the mask, and the

Tensor

indices to restore the original order.

Source code in srai/embedders/s2vec/model.py
def random_masking(
    self, x: "torch.Tensor", mask_ratio: float
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
    """
    Randomly mask patches in the input tensor.

    This function randomly selects a subset of patches to mask and returns the masked
    tensor, the mask, and the indices to restore the original order.
    The mask is a binary tensor indicating which patches are masked (1) and which are not (0).

    Args:
        x (torch.Tensor): The input tensor. The dimensions are
            (batch_size, num_patches, embed_dim).
        mask_ratio (float): The ratio of masked patches.

    Returns:
        tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The masked tensor, the mask, and the
        indices to restore the original order.
    """
    B, N, D = x.shape

    if mask_ratio == 0.0:
        mask = torch.zeros([B, N], device=x.device)
        ids_restore = torch.arange(N, device=x.device).unsqueeze(0).repeat(B, 1)
        return x, mask, ids_restore
    len_keep = int(N * (1 - mask_ratio))

    noise = torch.rand(B, N, device=x.device)
    ids_shuffle = torch.argsort(noise, dim=1)
    ids_restore = torch.argsort(ids_shuffle, dim=1)
    ids_keep = ids_shuffle[:, :len_keep]

    x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

    mask = torch.ones([B, N], device=x.device)
    mask[:, :len_keep] = 0
    mask = torch.gather(mask, dim=1, index=ids_restore)

    return x_masked, mask, ids_restore

save

save(path: Union[Path, str]) -> None

Save the model to a directory.

PARAMETER DESCRIPTION
path

Path to the directory.

TYPE: Path

Source code in srai/embedders/_base.py
def save(self, path: Union[Path, str]) -> None:
    """
    Save the model to a directory.

    Args:
        path (Path): Path to the directory.
    """
    import torch

    torch.save(self.state_dict(), path)

training_step

training_step(batch: list[torch.Tensor], batch_idx: int) -> torch.Tensor

Perform a training step. This is called by PyTorch Lightning.

One training step consists of a forward pass, a loss calculation, and a backward pass.

PARAMETER DESCRIPTION
batch

The batch of data.

TYPE: List[Tensor]

batch_idx

The index of the batch.

TYPE: int

RETURNS DESCRIPTION
Tensor

torch.Tensor: The loss value.

Source code in srai/embedders/s2vec/model.py
def training_step(self, batch: list["torch.Tensor"], batch_idx: int) -> "torch.Tensor":
    """
    Perform a training step. This is called by PyTorch Lightning.

    One training step consists of a forward pass, a loss calculation, and a backward pass.

    Args:
        batch (List[torch.Tensor]): The batch of data.
        batch_idx (int): The index of the batch.

    Returns:
        torch.Tensor: The loss value.
    """
    rec, target, mask = self(batch)

    loss = (rec - target).pow(2).mean(dim=-1)  # MSE per patch
    loss = (loss * mask).sum() / mask.sum()  # Only on masked patches

    self.log("train_loss", loss, on_step=True, on_epoch=True)
    return loss

validation_step

validation_step(batch: list[torch.Tensor], batch_idx: int) -> torch.Tensor

Perform a validation step. This is called by PyTorch Lightning.

PARAMETER DESCRIPTION
batch

The batch of data.

TYPE: List[Tensor]

batch_idx

The index of the batch.

TYPE: int

RETURNS DESCRIPTION
Tensor

torch.Tensor: The loss value.

Source code in srai/embedders/s2vec/model.py
def validation_step(self, batch: list["torch.Tensor"], batch_idx: int) -> "torch.Tensor":
    """
    Perform a validation step. This is called by PyTorch Lightning.

    Args:
        batch (List[torch.Tensor]): The batch of data.
        batch_idx (int): The index of the batch.

    Returns:
        torch.Tensor: The loss value.
    """
    rec, target, mask = self(batch)

    loss = (rec - target).pow(2).mean(dim=-1)
    loss = (loss * mask).sum() / mask.sum()
    self.log("validation_loss", loss, on_step=True, on_epoch=True)
    return loss