Skip to content

Model

Bases: LightningModule

Class for model based on LightningModule.

Source code in srai/embedders/_base.py
class Model(LightningModule):  # type: ignore
    """Class for model based on LightningModule."""

    def get_config(self) -> Dict[str, Any]:
        """Get model config."""
        model_config = {
            k: v
            for k, v in vars(self).items()
            if k[0] != "_"
            and k
            not in (
                "training",
                "prepare_data_per_node",
                "allow_zero_length_dataloader_with_multiple_devices",
            )
        }

        return model_config

    def save(self, path: Union[Path, str]) -> None:
        """
        Save the model to a directory.

        Args:
            path (Path): Path to the directory.
        """
        import torch

        torch.save(self.state_dict(), path)

    @classmethod
    def load(cls, path: Union[Path, str], **kwargs: Any) -> "Model":
        """
        Load model from a file.

        Args:
            path (Union[Path, str]): Path to the file.
            **kwargs (dict): Additional kwargs to pass to the model constructor.
        """
        import torch

        if isinstance(path, str):
            path = Path(path)

        model = cls(**kwargs)
        model.load_state_dict(torch.load(path))
        return model

get_config

get_config() -> Dict[str, Any]

Get model config.

Source code in srai/embedders/_base.py
def get_config(self) -> Dict[str, Any]:
    """Get model config."""
    model_config = {
        k: v
        for k, v in vars(self).items()
        if k[0] != "_"
        and k
        not in (
            "training",
            "prepare_data_per_node",
            "allow_zero_length_dataloader_with_multiple_devices",
        )
    }

    return model_config

load classmethod

load(path: Union[Path, str], **kwargs: Any) -> Model

Load model from a file.

PARAMETER DESCRIPTION
path

Path to the file.

TYPE: Union[Path, str]

**kwargs

Additional kwargs to pass to the model constructor.

TYPE: dict DEFAULT: {}

Source code in srai/embedders/_base.py
@classmethod
def load(cls, path: Union[Path, str], **kwargs: Any) -> "Model":
    """
    Load model from a file.

    Args:
        path (Union[Path, str]): Path to the file.
        **kwargs (dict): Additional kwargs to pass to the model constructor.
    """
    import torch

    if isinstance(path, str):
        path = Path(path)

    model = cls(**kwargs)
    model.load_state_dict(torch.load(path))
    return model

save

save(path: Union[Path, str]) -> None

Save the model to a directory.

PARAMETER DESCRIPTION
path

Path to the directory.

TYPE: Path

Source code in srai/embedders/_base.py
def save(self, path: Union[Path, str]) -> None:
    """
    Save the model to a directory.

    Args:
        path (Path): Path to the directory.
    """
    import torch

    torch.save(self.state_dict(), path)