Model
embedders.s2vec.model ¶
S2VecModel.
This pytorch lightning module implements the S2Vec masked autoencoder model.
References
[1] https://arxiv.org/abs/2504.16942 [2] https://arxiv.org/abs/2111.06377
MAEDecoder ¶
Bases: Module
Masked Autoencoder Decoder.
PARAMETER | DESCRIPTION |
---|---|
decoder_dim
|
The dimension of the decoder.
TYPE:
|
patch_dim
|
The dimension of the patches.
TYPE:
|
depth
|
The number of decoder layers.
TYPE:
|
num_heads
|
The number of attention heads.
TYPE:
|
dropout_prob
|
The dropout probability.
TYPE:
|
Source code in srai/embedders/s2vec/model.py
forward ¶
Forward pass of the MAEDecoder.
PARAMETER | DESCRIPTION |
---|---|
x
|
The input tensor. The dimensions are (batch_size, num_patches, decoder_dim).
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
Tensor
|
torch.Tensor: The output tensor from the decoder. |
Source code in srai/embedders/s2vec/model.py
MAEEncoder ¶
Bases: Module
Masked Autoencoder Encoder.
PARAMETER | DESCRIPTION |
---|---|
embed_dim
|
The dimension of the embedding.
TYPE:
|
depth
|
The number of encoder layers.
TYPE:
|
num_heads
|
The number of attention heads.
TYPE:
|
dropout_prob
|
The dropout probability.
TYPE:
|
Source code in srai/embedders/s2vec/model.py
forward ¶
Forward pass of the MAEEncoder.
PARAMETER | DESCRIPTION |
---|---|
x
|
The input tensor. The dimensions are (batch_size, num_patches, embed_dim).
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
Tensor
|
torch.Tensor: The output tensor from the encoder. |
Source code in srai/embedders/s2vec/model.py
S2VecModel ¶
S2VecModel(
img_size: int,
patch_size: int,
in_ch: int,
num_heads: int = 8,
encoder_layers: int = 6,
decoder_layers: int = 2,
embed_dim: int = 256,
decoder_dim: int = 128,
mask_ratio: float = 0.75,
dropout_prob: float = 0.2,
lr: float = 0.0005,
weight_decay: float = 0.001,
)
Bases: Model
S2Vec Model.
This class implements the S2Vec model. It is based on the masked autoencoder architecture. The model is described in [1]. It takes a rasterized image as input (counts of features per region) and outputs dense embeddings.
PARAMETER | DESCRIPTION |
---|---|
img_size
|
The size of the input image.
TYPE:
|
patch_size
|
The size of the patches.
TYPE:
|
in_ch
|
The number of input channels.
TYPE:
|
num_heads
|
The number of attention heads.
TYPE:
|
encoder_layers
|
The number of encoder layers. Defaults to 6.
TYPE:
|
decoder_layers
|
The number of decoder layers. Defaults to 2.
TYPE:
|
embed_dim
|
The dimension of the encoder. Defaults to 256.
TYPE:
|
decoder_dim
|
The dimension of the decoder. Defaults to 128.
TYPE:
|
mask_ratio
|
The ratio of masked patches. Defaults to 0.75.
TYPE:
|
dropout_prob
|
The dropout probability. Defaults to 0.2.
TYPE:
|
lr
|
The learning rate. Defaults to 5e-4.
TYPE:
|
weight_decay
|
The weight decay. Defaults to 1e-3.
TYPE:
|
Source code in srai/embedders/s2vec/model.py
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 |
|
configure_optimizers ¶
Configure the optimizers. This is called by PyTorch Lightning.
RETURNS | DESCRIPTION |
---|---|
dict[str, Any]
|
List[torch.optim.Optimizer]: The optimizers. |
Source code in srai/embedders/s2vec/model.py
decode ¶
Forward pass of the decoder.
PARAMETER | DESCRIPTION |
---|---|
x
|
The input tensor. The dimensions are (batch_size, num_patches, embed_dim).
TYPE:
|
ids_restore
|
The indices to restore the original order.
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
Tensor
|
torch.Tensor: The output tensor from the decoder. |
Source code in srai/embedders/s2vec/model.py
encode ¶
Forward pass of the encoder.
PARAMETER | DESCRIPTION |
---|---|
x
|
The input tensor. The dimensions are (batch_size, num_patches, embed_dim).
TYPE:
|
mask_ratio
|
The ratio of masked patches.
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
Tensor
|
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The encoded tensor, the mask, and the |
Tensor
|
indices to restore the original order. |
Source code in srai/embedders/s2vec/model.py
forward ¶
Forward pass of the S2Vec model.
PARAMETER | DESCRIPTION |
---|---|
inputs
|
The input tensor. The dimensions are (batch_size, num_patches, num_features).
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
Tensor
|
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The reconstructed tensor, |
Tensor
|
the target tensor, and the mask. |
Source code in srai/embedders/s2vec/model.py
get_config ¶
Get the model configuration.
RETURNS | DESCRIPTION |
---|---|
dict[str, Union[int, float]]
|
Dict[str, Union[int, float]]: The model configuration. |
Source code in srai/embedders/s2vec/model.py
load ¶
classmethod
Load model from a file.
PARAMETER | DESCRIPTION |
---|---|
path
|
Path to the file.
TYPE:
|
**kwargs
|
Additional kwargs to pass to the model constructor.
TYPE:
|
Source code in srai/embedders/_base.py
random_masking ¶
random_masking(
x: torch.Tensor, mask_ratio: float
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]
Randomly mask patches in the input tensor.
This function randomly selects a subset of patches to mask and returns the masked tensor, the mask, and the indices to restore the original order. The mask is a binary tensor indicating which patches are masked (1) and which are not (0).
PARAMETER | DESCRIPTION |
---|---|
x
|
The input tensor. The dimensions are (batch_size, num_patches, embed_dim).
TYPE:
|
mask_ratio
|
The ratio of masked patches.
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
Tensor
|
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The masked tensor, the mask, and the |
Tensor
|
indices to restore the original order. |
Source code in srai/embedders/s2vec/model.py
save ¶
Save the model to a directory.
PARAMETER | DESCRIPTION |
---|---|
path
|
Path to the directory.
TYPE:
|
training_step ¶
Perform a training step. This is called by PyTorch Lightning.
One training step consists of a forward pass, a loss calculation, and a backward pass.
PARAMETER | DESCRIPTION |
---|---|
batch
|
The batch of data.
TYPE:
|
batch_idx
|
The index of the batch.
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
Tensor
|
torch.Tensor: The loss value. |
Source code in srai/embedders/s2vec/model.py
validation_step ¶
Perform a validation step. This is called by PyTorch Lightning.
PARAMETER | DESCRIPTION |
---|---|
batch
|
The batch of data.
TYPE:
|
batch_idx
|
The index of the batch.
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
Tensor
|
torch.Tensor: The loss value. |