Skip to content

HuggingFaceDataset

srai.datasets.HuggingFaceDataset

HuggingFaceDataset(
    path: str,
    version: Optional[str] = None,
    type: Optional[str] = None,
    numerical_columns: Optional[list[str]] = None,
    categorical_columns: Optional[list[str]] = None,
    target: Optional[str] = None,
    resolution: Optional[int] = None,
)

Bases: ABC

Abstract class for HuggingFace datasets.

Source code in srai/datasets/_base.py
def __init__(
    self,
    path: str,
    version: Optional[str] = None,
    type: Optional[str] = None,
    numerical_columns: Optional[list[str]] = None,
    categorical_columns: Optional[list[str]] = None,
    target: Optional[str] = None,
    resolution: Optional[int] = None,
) -> None:
    self.path = path
    self.version = version
    self.numerical_columns = numerical_columns
    self.categorical_columns = categorical_columns
    self.target = target
    self.type = type
    self.train_gdf = None
    self.test_gdf = None
    self.val_gdf = None
    self.resolution = resolution

get_h3_with_labels

abstractmethod
get_h3_with_labels() -> (
    tuple[
        gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]
    ]
)

Returns indexes with target labels from the dataset depending on dataset and task type.

RETURNS DESCRIPTION
tuple[GeoDataFrame, Optional[GeoDataFrame], Optional[GeoDataFrame]]

tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]: Train, Val, Test indexes with target labels in GeoDataFrames

Source code in srai/datasets/_base.py
@abc.abstractmethod
def get_h3_with_labels(
    self,
    # resolution: Optional[int] = None,
    # target_column: Optional[str] = None,
) -> tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]:
    """
    Returns indexes with target labels from the dataset depending on dataset and task type.

    Returns:
        tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]: \
            Train, Val, Test indexes with target labels in GeoDataFrames
    """
    raise NotImplementedError

load

load(
    version: Optional[Union[int, str]] = None, hf_token: Optional[str] = None
) -> dict[str, gpd.GeoDataFrame]

Method to load dataset.

PARAMETER DESCRIPTION
hf_token

If needed, a User Access Token needed to authenticate to the Hugging Face Hub. Environment variable HF_TOKEN can be also used. Defaults to None.

TYPE: str DEFAULT: None

version

version of a dataset

TYPE: str or int DEFAULT: None

RETURNS DESCRIPTION
dict[str, GeoDataFrame]

dict[str, gpd.GeoDataFrame]: Dictionary with all splits loaded from the dataset. Will contain keys "train" and "test" if available.

Source code in srai/datasets/_base.py
def load(
    self,
    version: Optional[Union[int, str]] = None,
    hf_token: Optional[str] = None,
) -> dict[str, gpd.GeoDataFrame]:
    """
    Method to load dataset.

    Args:
        hf_token (str, optional): If needed, a User Access Token needed to authenticate to
            the Hugging Face Hub. Environment variable `HF_TOKEN` can be also used.
            Defaults to None.
        version (str or int, optional): version of a dataset

    Returns:
        dict[str, gpd.GeoDataFrame]: Dictionary with all splits loaded from the dataset. Will
             contain keys "train" and "test" if available.
    """
    from datasets import load_dataset

    result = {}

    self.train_gdf, self.val_gdf, self.test_gdf = None, None, None
    dataset_name = self.path
    self.version = str(version)

    if (
        self.resolution is None
        and self.version in ("8", "9", "10")
        or (self.version in ("8", "9", "10") and str(self.resolution) != self.version)
    ):
        with suppress(ValueError):
            # Try to parse version as int (e.g. "8" or "9")
            self.resolution = int(self.version)

    data = load_dataset(dataset_name, str(version), token=hf_token, trust_remote_code=True)
    train = data["train"].to_pandas()
    processed_train = self._preprocessing(train)
    self.train_gdf = processed_train
    result["train"] = processed_train
    if "test" in data:
        test = data["test"].to_pandas()
        processed_test = self._preprocessing(test)
        self.test_gdf = processed_test
        result["test"] = processed_test

    return result

train_test_split

abstractmethod
train_test_split(
    target_column: Optional[str] = None,
    resolution: Optional[int] = None,
    test_size: float = 0.2,
    n_bins: int = 7,
    random_state: Optional[int] = None,
    validation_split: bool = False,
    force_split: bool = False,
    task: Optional[str] = None,
) -> tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]

Method to generate train/test or train/val split from GeoDataFrame.

PARAMETER DESCRIPTION
target_column

Target column name for Points, trajectories id column fortrajectory datasets. Defaults to preset dataset target column.

TYPE: Optional[str] DEFAULT: None

resolution

H3 resolution, subclasses mayb use this argument to regionalize data. Defaults to default value from the dataset.

TYPE: int DEFAULT: None

test_size

Percentage of test set. Defaults to 0.2.

TYPE: float DEFAULT: 0.2

n_bins

Bucket number used to stratify target data.

TYPE: int DEFAULT: 7

random_state

Controls the shuffling applied to the data before applying the split. Pass an int for reproducible output across multiple function. Defaults to None.

TYPE: int DEFAULT: None

validation_split

If True, creates a validation split from existing train split and assigns it to self.val_gdf.

TYPE: bool DEFAULT: False

force_split

If True, forces a new split to be created, even if an existing train/test or validation split is already present. - With validation_split=False, regenerates and overwrites the test split. - With validation_split=True, regenerates and overwrites the validation split.

TYPE: bool DEFAULT: False

task

Task identifier. Subclasses may use this argument to determine stratification logic (e.g., by duration or spatial pattern). Defaults to None.

TYPE: Optional[str] DEFAULT: None

RETURNS DESCRIPTION
tuple

Train-test or Train-val split made on previous train subset.

TYPE: (GeoDataFrame, GeoDataFrame)

Source code in srai/datasets/_base.py
@abc.abstractmethod
def train_test_split(
    self,
    target_column: Optional[str] = None,
    resolution: Optional[int] = None,
    test_size: float = 0.2,
    n_bins: int = 7,
    random_state: Optional[int] = None,
    validation_split: bool = False,
    force_split: bool = False,
    task: Optional[str] = None,
) -> tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]:
    """
    Method to generate train/test or train/val split from GeoDataFrame.

    Args:
        target_column (Optional[str], optional): Target column name for Points, trajectories id\
            column fortrajectory datasets. Defaults to preset dataset target column.
        resolution (int, optional): H3 resolution, subclasses mayb use this argument to\
            regionalize data. Defaults to default value from the dataset.
        test_size (float, optional): Percentage of test set. Defaults to 0.2.
        n_bins (int, optional): Bucket number used to stratify target data.
        random_state (int, optional):  Controls the shuffling applied to the data before \
            applying the split.
            Pass an int for reproducible output across multiple function. Defaults to None.
        validation_split (bool): If True, creates a validation split from existing train split\
            and assigns it to self.val_gdf.
        force_split: If True, forces a new split to be created, even if an existing train/test\
            or validation split is already present.
            - With `validation_split=False`, regenerates and overwrites the test split.
            - With `validation_split=True`, regenerates and overwrites the validation split.
        task (Optional[str], optional): Task identifier. Subclasses may use this argument to
            determine stratification logic (e.g., by duration or spatial pattern).\
                Defaults to None.

    Returns:
        tuple(gpd.GeoDataFrame, gpd.GeoDataFrame): Train-test or Train-val split made on\
            previous train subset.
    """
    raise NotImplementedError