Skip to content

Index

datasets

This module contains dataset used to load dataset containing spatial information.

Datasets can be loaded using .load() method. Some of them may need name of version.

AirbnbMulticityDataset

AirbnbMulticityDataset()

Bases: PointDataset

AirbnbMulticity dataset.

Dataset description will be added.

Source code in srai/datasets/airbnb_multicity.py
def __init__(self) -> None:
    """Create the dataset."""
    categorical_columns = ["name", "host_name", "neighborhood", "room_type", "city"]
    numerical_columns = [
        "number_of_reviews",
        "minimum_nights",
        "availability_365",
        "calculated_host_listings_count",
        "number_of_reviews_ltm",
    ]
    target = "price"
    type = "point"
    super().__init__(
        "kraina/airbnb_multicity",
        type=type,
        numerical_columns=numerical_columns,
        categorical_columns=categorical_columns,
        target=target,
    )

get_h3_with_labels

get_h3_with_labels() -> (
    tuple[
        gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]
    ]
)

Returns h3 indexes with target labels from the dataset.

Points are aggregated to hexes and target column values are averaged or if target column is None, then the number of points is calculted within a hex and scaled to [0,1].

RETURNS DESCRIPTION
tuple[GeoDataFrame, Optional[GeoDataFrame], Optional[GeoDataFrame]]

tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]: Train, Val, Test hexes with target labels in GeoDataFrames

Source code in srai/datasets/_base.py
def get_h3_with_labels(
    self,
    # resolution: Optional[int] = None,
    # target_column: Optional[str] = None,
) -> tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]:
    """
    Returns h3 indexes with target labels from the dataset.

    Points are aggregated to hexes and target column values are averaged or if target column \
    is None, then the number of points is calculted within a hex and scaled to [0,1].

    Returns:
        tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]:\
            Train, Val, Test hexes with target labels in GeoDataFrames
    """
    # if target_column is None:
    #     target_column = "count"

    # resolution = resolution if resolution is not None else self.resolution

    assert self.train_gdf is not None
    # If resolution is still None, raise an error
    if self.resolution is None:
        raise ValueError(
            "No preset resolution for the dataset in self.resolution. Please"
            "provide a resolution."
        )
    # elif self.resolution is not None and resolution != self.resolution:
    #     raise ValueError(
    #         "Resolution provided is different from the preset resolution for the"
    #         "dataset. This may result in a data leak between splits."
    #     )

    _train_gdf = self._aggregate_hexes(self.train_gdf, self.resolution, self.target)

    if self.test_gdf is not None:
        _test_gdf = self._aggregate_hexes(self.test_gdf, self.resolution, self.target)
    else:
        _test_gdf = None

    if self.val_gdf is not None:
        _val_gdf = self._aggregate_hexes(self.val_gdf, self.resolution, self.target)
    else:
        _val_gdf = None

    # Scale the "count" column to [0, 1] if it is the target column
    if self.target == "count":
        scaler = MinMaxScaler()
        # Fit the scaler on the train dataset and transform
        _train_gdf["count"] = scaler.fit_transform(_train_gdf[["count"]])
        if _test_gdf is not None:
            _test_gdf["count"] = scaler.transform(_test_gdf[["count"]])
            _test_gdf["count"] = np.clip(_test_gdf["count"], 0, 1)
        if _val_gdf is not None:
            _val_gdf["count"] = scaler.transform(_val_gdf[["count"]])
            _val_gdf["count"] = np.clip(_val_gdf["count"], 0, 1)

    return _train_gdf, _val_gdf, _test_gdf

load

load(
    version: Optional[Union[int, str]] = 8, hf_token: Optional[str] = None
) -> dict[str, gpd.GeoDataFrame]

Method to load dataset.

PARAMETER DESCRIPTION
hf_token

If needed, a User Access Token needed to authenticate to the Hugging Face Hub. Environment variable HF_TOKEN can be also used. Defaults to None.

TYPE: str DEFAULT: None

version

version of a dataset. Available: '8', '9', '10', where number is a h3 resolution used in train-test split. Benchmark version comprises six cities: Paris, Rome, London, Amsterdam, Melbourne, New York City. Raw, full data from ~80 cities available as 'all'.

TYPE: str or int DEFAULT: 8

RETURNS DESCRIPTION
dict[str, GeoDataFrame]

dict[str, gpd.GeoDataFrame]: Dictionary with all splits loaded from the dataset. Will contain keys "train" and "test" if available.

Source code in srai/datasets/airbnb_multicity.py
def load(
    self, version: Optional[Union[int, str]] = 8, hf_token: Optional[str] = None
) -> dict[str, gpd.GeoDataFrame]:
    """
    Method to load dataset.

    Args:
        hf_token (str, optional): If needed, a User Access Token needed to authenticate to
            the Hugging Face Hub. Environment variable `HF_TOKEN` can be also used.
            Defaults to None.
        version (str or int, optional): version of a dataset.
            Available: '8', '9', '10', where number is a h3 resolution used in train-test \
                split. Benchmark version comprises six cities: Paris, Rome, London, Amsterdam, \
                    Melbourne, New York City. Raw, full data from ~80 cities available as 'all'.

    Returns:
        dict[str, gpd.GeoDataFrame]: Dictionary with all splits loaded from the dataset. Will
            contain keys "train" and "test" if available.
    """
    return super().load(version=version, hf_token=hf_token)

train_test_split

train_test_split(
    target_column: Optional[str] = None,
    resolution: Optional[int] = None,
    test_size: float = 0.2,
    n_bins: int = 7,
    random_state: Optional[int] = None,
    validation_split: bool = False,
    force_split: bool = False,
    task: Optional[str] = None,
) -> tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]

Method to generate splits from GeoDataFrame, based on the target_column values.

PARAMETER DESCRIPTION
target_column

Target column name. If None, split is generated based on number of points within a hex of a given resolution. Defaults to preset dataset target column.

TYPE: Optional[str] DEFAULT: None

resolution

h3 resolution to regionalize data. Defaults to default value from the dataset.

TYPE: int DEFAULT: None

test_size

Percentage of test set. Defaults to 0.2.

TYPE: float DEFAULT: 0.2

n_bins

Bucket number used to stratify target data. Defaults to 7.

TYPE: int DEFAULT: 7

random_state

Controls the shuffling applied to the data before applying the split. Pass an int for reproducible output across multiple function. Defaults to None.

TYPE: int DEFAULT: None

validation_split

If True, creates a validation split from existing train split and assigns it to self.val_gdf.

TYPE: bool DEFAULT: False

force_split

If True, forces a new split to be created, even if an existing train/test or validation split is already present. - With validation_split=False, regenerates and overwrites the test split. - With validation_split=True, regenerates and overwrites the validation split.

TYPE: bool DEFAULT: False

task

Currently not supported. Ignored in this subclass.

TYPE: Optional[str] DEFAULT: None

RETURNS DESCRIPTION
tuple

Train-test or train-val split made on previous train subset.

TYPE: (GeoDataFrame, GeoDataFrame)

Source code in srai/datasets/_base.py
def train_test_split(
    self,
    target_column: Optional[str] = None,
    resolution: Optional[int] = None,
    test_size: float = 0.2,
    n_bins: int = 7,
    random_state: Optional[int] = None,
    validation_split: bool = False,
    force_split: bool = False,
    task: Optional[str] = None,
) -> tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]:
    """
    Method to generate splits from GeoDataFrame, based on the target_column values.

    Args:
        target_column (Optional[str], optional): Target column name. If None, split is\
            generated based on number of points within a hex of a given resolution.\
            Defaults to preset dataset target column.
        resolution (int, optional): h3 resolution to regionalize data. Defaults to default\
            value from the dataset.
        test_size (float, optional): Percentage of test set. Defaults to 0.2.
        n_bins (int, optional): Bucket number used to stratify target data.\
            Defaults to 7.
        random_state (int, optional):  Controls the shuffling applied to the data before\
            applying the split. \
            Pass an int for reproducible output across multiple function. Defaults to None.
        validation_split (bool): If True, creates a validation split from existing train split\
            and assigns it to self.val_gdf.
        force_split: If True, forces a new split to be created, even if an existing train/test\
            or validation split is already present.
            - With `validation_split=False`, regenerates and overwrites the test split.
            - With `validation_split=True`, regenerates and overwrites the validation split.
        task (Optional[str], optional): Currently not supported. Ignored in this subclass.

    Returns:
        tuple(gpd.GeoDataFrame, gpd.GeoDataFrame): Train-test or train-val split made on\
            previous train subset.
    """
    assert self.train_gdf is not None

    if (self.val_gdf is not None and validation_split and not force_split) or (
        self.test_gdf is not None and not validation_split and not force_split
    ):
        raise ValueError(
            "A split already exists. Use `force_split=True` to overwrite the existing "
            f"{'validation' if validation_split else 'test'} split."
        )

    resolution = resolution or self.resolution

    if resolution is None:
        raise ValueError(
            "No preset resolution for the dataset in self.resolution. Please "
            "provide a resolution."
        )
    elif self.resolution is not None and resolution != self.resolution:
        raise ValueError(
            "Resolution provided is different from the preset resolution for the "
            "dataset. This may result in a data leak between splits."
        )

    if self.resolution is None:
        self.resolution = resolution
    target_column = target_column if target_column is not None else self.target
    if target_column is None:
        target_column = "count"

    gdf = self.train_gdf
    gdf_ = gdf.copy()

    train, test = train_test_spatial_split(
        gdf_,
        parent_h3_resolution=resolution,
        target_column=target_column,
        test_size=test_size,
        n_bins=n_bins,
        random_state=random_state,
    )

    self.train_gdf = train
    if not validation_split:
        self.test_gdf = test
        test_len = len(self.test_gdf) if self.test_gdf is not None else 0
        print(
            f"Created new train_gdf and test_gdf. Train len: {len(self.train_gdf)},"
            f"test len: {test_len}"
        )
    else:
        self.val_gdf = test
        val_len = len(self.val_gdf) if self.val_gdf is not None else 0
        test_len = len(self.test_gdf) if self.test_gdf is not None else 0
        print(
            f"Created new train_gdf and val_gdf. Test split remains unchanged."
            f"Train len: {len(self.train_gdf)}, val len: {val_len},"
            f"test len: {test_len}"
        )
    return train, test

ChicagoCrimeDataset

ChicagoCrimeDataset()

Bases: PointDataset

Chicago Crime dataset.

This dataset reflects reported incidents of crime (with the exception of murders where data exists for each victim) that occurred in the City of Chicago. Data is extracted from the Chicago Police Department's CLEAR (Citizen Law Enforcement Analysis and Reporting) system.

Source code in srai/datasets/chicago_crime.py
def __init__(self) -> None:
    """Create the dataset."""
    numerical_columns = ["Ward", "Community Area"]
    categorical_columns = [
        "Primary Type",
        "Description",
        "Location Description",
        "Arrest",
        "Domestic",
        "Year",
        "FBI Code",
    ]
    type = "point"
    # target = "Primary Type"
    target = "count"
    super().__init__(
        "kraina/chicago_crime",
        type=type,
        numerical_columns=numerical_columns,
        categorical_columns=categorical_columns,
        target=target,
    )

get_h3_with_labels

get_h3_with_labels() -> (
    tuple[
        gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]
    ]
)

Returns h3 indexes with target labels from the dataset.

Points are aggregated to hexes and target column values are averaged or if target column is None, then the number of points is calculted within a hex and scaled to [0,1].

RETURNS DESCRIPTION
tuple[GeoDataFrame, Optional[GeoDataFrame], Optional[GeoDataFrame]]

tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]: Train, Val, Test hexes with target labels in GeoDataFrames

Source code in srai/datasets/_base.py
def get_h3_with_labels(
    self,
    # resolution: Optional[int] = None,
    # target_column: Optional[str] = None,
) -> tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]:
    """
    Returns h3 indexes with target labels from the dataset.

    Points are aggregated to hexes and target column values are averaged or if target column \
    is None, then the number of points is calculted within a hex and scaled to [0,1].

    Returns:
        tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]:\
            Train, Val, Test hexes with target labels in GeoDataFrames
    """
    # if target_column is None:
    #     target_column = "count"

    # resolution = resolution if resolution is not None else self.resolution

    assert self.train_gdf is not None
    # If resolution is still None, raise an error
    if self.resolution is None:
        raise ValueError(
            "No preset resolution for the dataset in self.resolution. Please"
            "provide a resolution."
        )
    # elif self.resolution is not None and resolution != self.resolution:
    #     raise ValueError(
    #         "Resolution provided is different from the preset resolution for the"
    #         "dataset. This may result in a data leak between splits."
    #     )

    _train_gdf = self._aggregate_hexes(self.train_gdf, self.resolution, self.target)

    if self.test_gdf is not None:
        _test_gdf = self._aggregate_hexes(self.test_gdf, self.resolution, self.target)
    else:
        _test_gdf = None

    if self.val_gdf is not None:
        _val_gdf = self._aggregate_hexes(self.val_gdf, self.resolution, self.target)
    else:
        _val_gdf = None

    # Scale the "count" column to [0, 1] if it is the target column
    if self.target == "count":
        scaler = MinMaxScaler()
        # Fit the scaler on the train dataset and transform
        _train_gdf["count"] = scaler.fit_transform(_train_gdf[["count"]])
        if _test_gdf is not None:
            _test_gdf["count"] = scaler.transform(_test_gdf[["count"]])
            _test_gdf["count"] = np.clip(_test_gdf["count"], 0, 1)
        if _val_gdf is not None:
            _val_gdf["count"] = scaler.transform(_val_gdf[["count"]])
            _val_gdf["count"] = np.clip(_val_gdf["count"], 0, 1)

    return _train_gdf, _val_gdf, _test_gdf

load

load(
    version: Optional[Union[int, str]] = 9, hf_token: Optional[str] = None
) -> dict[str, gpd.GeoDataFrame]

Method to load dataset.

PARAMETER DESCRIPTION
hf_token

If needed, a User Access Token needed to authenticate to the Hugging Face Hub. Environment variable HF_TOKEN can be also used. Defaults to None.

TYPE: str DEFAULT: None

version

version of a dataset. Available: Official spatial train-test split from year 2022 in chosen h3 resolution: '8', '9, '10'. Defaults to '9'. Raw data from other years available as: '2020', '2021', '2022'.

TYPE: str or int DEFAULT: 9

RETURNS DESCRIPTION
dict[str, GeoDataFrame]

dict[str, gpd.GeoDataFrame]: Dictionary with all splits loaded from the dataset. Will contain keys "train" and "test" if available.

Source code in srai/datasets/chicago_crime.py
def load(
    self, version: Optional[Union[int, str]] = 9, hf_token: Optional[str] = None
) -> dict[str, gpd.GeoDataFrame]:
    """
    Method to load dataset.

    Args:
        hf_token (str, optional): If needed, a User Access Token needed to authenticate to
            the Hugging Face Hub. Environment variable `HF_TOKEN` can be also used.
            Defaults to None.
        version (str or int, optional): version of a dataset.
            Available: Official spatial train-test split from year 2022 in chosen h3 resolution:
            '8', '9, '10'. Defaults to '9'. Raw data from other years available
            as: '2020', '2021', '2022'.

    Returns:
        dict[str, gpd.GeoDataFrame]: Dictionary with all splits loaded from the dataset. Will
            contain keys "train" and "test" if available.
    """
    return super().load(hf_token=hf_token, version=version)

train_test_split

train_test_split(
    target_column: Optional[str] = None,
    resolution: Optional[int] = None,
    test_size: float = 0.2,
    n_bins: int = 7,
    random_state: Optional[int] = None,
    validation_split: bool = False,
    force_split: bool = False,
    task: Optional[str] = None,
) -> tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]

Method to generate splits from GeoDataFrame, based on the target_column values.

PARAMETER DESCRIPTION
target_column

Target column name. If None, split is generated based on number of points within a hex of a given resolution. Defaults to preset dataset target column.

TYPE: Optional[str] DEFAULT: None

resolution

h3 resolution to regionalize data. Defaults to default value from the dataset.

TYPE: int DEFAULT: None

test_size

Percentage of test set. Defaults to 0.2.

TYPE: float DEFAULT: 0.2

n_bins

Bucket number used to stratify target data. Defaults to 7.

TYPE: int DEFAULT: 7

random_state

Controls the shuffling applied to the data before applying the split. Pass an int for reproducible output across multiple function. Defaults to None.

TYPE: int DEFAULT: None

validation_split

If True, creates a validation split from existing train split and assigns it to self.val_gdf.

TYPE: bool DEFAULT: False

force_split

If True, forces a new split to be created, even if an existing train/test or validation split is already present. - With validation_split=False, regenerates and overwrites the test split. - With validation_split=True, regenerates and overwrites the validation split.

TYPE: bool DEFAULT: False

task

Currently not supported. Ignored in this subclass.

TYPE: Optional[str] DEFAULT: None

RETURNS DESCRIPTION
tuple

Train-test or train-val split made on previous train subset.

TYPE: (GeoDataFrame, GeoDataFrame)

Source code in srai/datasets/_base.py
def train_test_split(
    self,
    target_column: Optional[str] = None,
    resolution: Optional[int] = None,
    test_size: float = 0.2,
    n_bins: int = 7,
    random_state: Optional[int] = None,
    validation_split: bool = False,
    force_split: bool = False,
    task: Optional[str] = None,
) -> tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]:
    """
    Method to generate splits from GeoDataFrame, based on the target_column values.

    Args:
        target_column (Optional[str], optional): Target column name. If None, split is\
            generated based on number of points within a hex of a given resolution.\
            Defaults to preset dataset target column.
        resolution (int, optional): h3 resolution to regionalize data. Defaults to default\
            value from the dataset.
        test_size (float, optional): Percentage of test set. Defaults to 0.2.
        n_bins (int, optional): Bucket number used to stratify target data.\
            Defaults to 7.
        random_state (int, optional):  Controls the shuffling applied to the data before\
            applying the split. \
            Pass an int for reproducible output across multiple function. Defaults to None.
        validation_split (bool): If True, creates a validation split from existing train split\
            and assigns it to self.val_gdf.
        force_split: If True, forces a new split to be created, even if an existing train/test\
            or validation split is already present.
            - With `validation_split=False`, regenerates and overwrites the test split.
            - With `validation_split=True`, regenerates and overwrites the validation split.
        task (Optional[str], optional): Currently not supported. Ignored in this subclass.

    Returns:
        tuple(gpd.GeoDataFrame, gpd.GeoDataFrame): Train-test or train-val split made on\
            previous train subset.
    """
    assert self.train_gdf is not None

    if (self.val_gdf is not None and validation_split and not force_split) or (
        self.test_gdf is not None and not validation_split and not force_split
    ):
        raise ValueError(
            "A split already exists. Use `force_split=True` to overwrite the existing "
            f"{'validation' if validation_split else 'test'} split."
        )

    resolution = resolution or self.resolution

    if resolution is None:
        raise ValueError(
            "No preset resolution for the dataset in self.resolution. Please "
            "provide a resolution."
        )
    elif self.resolution is not None and resolution != self.resolution:
        raise ValueError(
            "Resolution provided is different from the preset resolution for the "
            "dataset. This may result in a data leak between splits."
        )

    if self.resolution is None:
        self.resolution = resolution
    target_column = target_column if target_column is not None else self.target
    if target_column is None:
        target_column = "count"

    gdf = self.train_gdf
    gdf_ = gdf.copy()

    train, test = train_test_spatial_split(
        gdf_,
        parent_h3_resolution=resolution,
        target_column=target_column,
        test_size=test_size,
        n_bins=n_bins,
        random_state=random_state,
    )

    self.train_gdf = train
    if not validation_split:
        self.test_gdf = test
        test_len = len(self.test_gdf) if self.test_gdf is not None else 0
        print(
            f"Created new train_gdf and test_gdf. Train len: {len(self.train_gdf)},"
            f"test len: {test_len}"
        )
    else:
        self.val_gdf = test
        val_len = len(self.val_gdf) if self.val_gdf is not None else 0
        test_len = len(self.test_gdf) if self.test_gdf is not None else 0
        print(
            f"Created new train_gdf and val_gdf. Test split remains unchanged."
            f"Train len: {len(self.train_gdf)}, val len: {val_len},"
            f"test len: {test_len}"
        )
    return train, test

GeolifeDataset

GeolifeDataset()

Bases: TrajectoryDataset

Geolife dataset.

GPS trajectories that were collected in (Microsoft Research Asia) Geolife Project by 182 users

Source code in srai/datasets/geolife.py
def __init__(self) -> None:
    """Create the dataset."""
    numerical_columns = ["altitude"]
    categorical_columns = ["mode"]
    type = "trajectory"
    target = "trajectory_id"

    self.cache_root = self._get_global_dataset_cache_path()
    self.raw_data_path = self.cache_root / "raw"
    self.processed_path = self.cache_root / "preprocessed"
    self.prepared_path = self.cache_root / "prepared"

    # target = None
    super().__init__(
        "kraina/geolife",
        type=type,
        numerical_columns=numerical_columns,
        categorical_columns=categorical_columns,
        target=target,
    )

get_h3_with_labels

get_h3_with_labels() -> (
    tuple[
        gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]
    ]
)

Returns ids, h3 indexes sequences, with target labels from the dataset.

Points are aggregated to hex trajectories and target column values are calculated for each trajectory (time duration for TTE task, future movement sequence for HMP task).

RETURNS DESCRIPTION
tuple[GeoDataFrame, Optional[GeoDataFrame], Optional[GeoDataFrame]]

tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]: Train, Val, Test hexes sequences with target labels in GeoDataFrames

Source code in srai/datasets/_base.py
def get_h3_with_labels(
    self,
    # resolution: Optional[int] = None,
    # target_column: Optional[str] = None,
) -> tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]:
    """
    Returns ids, h3 indexes sequences, with target labels from the dataset.

    Points are aggregated to hex trajectories and target column values are calculated \
        for each trajectory (time duration for TTE task, future movement sequence for HMP task).

    Returns:
        tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]: Train,\
            Val, Test hexes sequences with target labels in GeoDataFrames
    """
    # resolution = resolution if resolution is not None else self.resolution

    assert self.train_gdf is not None
    # If resolution is still None, raise an error
    if self.resolution is None:
        raise ValueError(
            "No preset resolution for the dataset in self.resolution. Please"
            "provide a resolution."
        )
    # elif self.resolution is not None and resolution != self.resolution:
    #     raise ValueError(
    #         "Resolution provided is different from the preset resolution for the"
    #         "dataset. This may result in a data leak between splits."
    #     )

    if self.version == "TTE":
        _train_gdf = self.train_gdf[[self.target, "h3_sequence", "duration"]]

        if self.test_gdf is not None:
            _test_gdf = self.test_gdf[[self.target, "h3_sequence", "duration"]]
        else:
            _test_gdf = None

        if self.val_gdf is not None:
            _val_gdf = self.val_gdf[[self.target, "h3_sequence", "duration"]]
        else:
            _val_gdf = None

    elif self.version == "HMP":
        _train_gdf = self.train_gdf[[self.target, "h3_sequence_x", "h3_sequence_y"]]

        if self.test_gdf is not None:
            _test_gdf = self.test_gdf[[self.target, "h3_sequence_x", "h3_sequence_y"]]
        else:
            _test_gdf = None

        if self.val_gdf is not None:
            _val_gdf = self.val_gdf[[self.target, "h3_sequence_x", "h3_sequence_y"]]
        else:
            _val_gdf = None

    elif self.version == "all":
        raise TypeError(
            "Could not provide target labels, as version 'all'\
        of dataset does not provide one."
        )

    return _train_gdf, _val_gdf, _test_gdf

load

load(
    version: Optional[Union[int, str]] = "HMP",
    hf_token: Optional[str] = None,
    resolution: Optional[int] = None,
) -> dict[str, gpd.GeoDataFrame]

Method to load dataset.

PARAMETER DESCRIPTION
hf_token

If needed, a User Access Token needed to authenticate to the Hugging Face Hub. Environment variable HF_TOKEN can be also used. Defaults to None.

TYPE: Optional[str] DEFAULT: None

version

version of a dataset. Available: Official train-test split for Travel Time Estimation task (TTE) and Human Mobility Prediction task (HMP). Raw data from available as: 'all'.

TYPE: Optional[str, int] DEFAULT: 'HMP'

resolution

H3 resolution for hex trajectories. Neccessary if using 'all' split.

TYPE: Optional[int] DEFAULT: None

RETURNS DESCRIPTION
dict[str, GeoDataFrame]

dict[str, gpd.GeoDataFrame]: Dictionary with all splits loaded from the dataset. Will contain keys "train" and "test" if available.

Source code in srai/datasets/geolife.py
def load(
    self,
    version: Optional[Union[int, str]] = "HMP",
    hf_token: Optional[str] = None,
    resolution: Optional[int] = None,
) -> dict[str, gpd.GeoDataFrame]:
    """
    Method to load dataset.

    Args:
        hf_token (Optional[str]): If needed, a User Access Token needed to authenticate to
            the Hugging Face Hub. Environment variable `HF_TOKEN` can be also used.
            Defaults to None.
        version (Optional[str, int]): version of a dataset.
            Available: Official train-test split for Travel Time Estimation task (TTE) and
            Human Mobility Prediction task (HMP). Raw data from available as: 'all'.
        resolution (Optional[int]): H3 resolution for hex trajectories.
            Neccessary if using 'all' split.

    Returns:
        dict[str, gpd.GeoDataFrame]: Dictionary with all splits loaded from the dataset. Will
            contain keys "train" and "test" if available.
    """
    self.version = str(version)
    if self.version in ("TTE", "HMP"):
        self.resolution = 9
    elif self.version == "all":
        self.resolution = resolution if resolution is not None else None
    else:
        raise NotImplementedError("Version not implemented")

    if not self.prepared_path.exists():
        self._download_geolife()
        self._geolife_preprocess()

    # Remove raw downloaded data from cache and keep only preprocessed files
    shutil.rmtree(self.raw_data_path, ignore_errors=True)

    return super().load(hf_token=hf_token, version=version)

train_test_split

train_test_split(
    target_column: Optional[str] = None,
    resolution: Optional[int] = None,
    test_size: float = 0.2,
    n_bins: int = 4,
    random_state: Optional[int] = None,
    validation_split: bool = False,
    force_split: bool = False,
    task: Optional[str] = "TTE",
) -> tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]

Generate train/test split or train/val split from trajectory GeoDataFrame.

Train-test/train-val split is generated by splitting train_gdf.

PARAMETER DESCRIPTION
target_column

Column identifying each trajectory (contains trajectory ids).

TYPE: str DEFAULT: None

test_size

Fraction of data to be used as test set.

TYPE: float DEFAULT: 0.2

n_bins

Number of stratification bins.

TYPE: int DEFAULT: 4

random_state

Controls the shuffling applied to the data before applying the split. Pass an int for reproducible output across multiple function. Defaults to None.

TYPE: int DEFAULT: None

validation_split

If True, creates a validation split from existing train split and assigns it to self.val_gdf.

TYPE: bool DEFAULT: False

force_split

If True, forces a new split to be created, even if an existing train/test or validation split is already present. - With validation_split=False, regenerates and overwrites the test split. - With validation_split=True, regenerates and overwrites the validation split.

TYPE: bool DEFAULT: False

resolution

H3 resolution to regionalize data. Currently ignored in this subclass, different resolutions splits not supported yet. Defaults to default value from the dataset.

TYPE: int DEFAULT: None

task

Task type. Stratifies by duration (TTE) or hex length (HMP).

TYPE: Literal[TTE, HMP] DEFAULT: 'TTE'

RETURNS DESCRIPTION
tuple[GeoDataFrame, GeoDataFrame]

Tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]: Train/test or train/val GeoDataFrames.

Source code in srai/datasets/_base.py
def train_test_split(
    self,
    target_column: Optional[str] = None,
    resolution: Optional[int] = None,
    test_size: float = 0.2,
    n_bins: int = 4,
    random_state: Optional[int] = None,
    validation_split: bool = False,
    force_split: bool = False,
    task: Optional[str] = "TTE",
) -> tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]:
    """
    Generate train/test split or train/val split from trajectory GeoDataFrame.

    Train-test/train-val split is generated by splitting train_gdf.

    Args:
        target_column (str): Column identifying each trajectory (contains trajectory ids).
        test_size (float): Fraction of data to be used as test set.
        n_bins (int): Number of stratification bins.
        random_state (int, optional):  Controls the shuffling applied to the data before\
            applying the split. Pass an int for reproducible output across multiple function.\
                Defaults to None.
        validation_split (bool): If True, creates a validation split from existing train split\
            and assigns it to self.val_gdf.
        force_split: If True, forces a new split to be created, even if an existing train/test\
            or validation split is already present.
            - With `validation_split=False`, regenerates and overwrites the test split.
            - With `validation_split=True`, regenerates and overwrites the validation split.
        resolution (int, optional): H3 resolution to regionalize data. Currently ignored in\
            this subclass, different resolutions splits not supported yet.\
                Defaults to default value from the dataset.
        task (Literal["TTE", "HMP"]): Task type. Stratifies by duration
            (TTE) or hex length (HMP).


    Returns:
        Tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]: Train/test or train/val GeoDataFrames.
    """
    if (self.val_gdf is not None and validation_split and not force_split) or (
        self.test_gdf is not None and not validation_split and not force_split
    ):
        raise ValueError(
            "A split already exists. Use `force_split=True` to overwrite the existing "
            f"{'validation' if validation_split else 'test'} split."
        )
    assert self.train_gdf is not None
    trajectory_id_column = target_column or self.target
    gdf_copy = self.train_gdf.copy()

    if task not in {"TTE", "HMP"}:
        raise ValueError(f"Unsupported task: {task}")

    if task == "TTE":
        self.version = "TTE"
        # Calculate duration in seconds from timestamps list

        if "duration" in gdf_copy.columns:
            gdf_copy["stratify_col"] = gdf_copy["duration"]
        elif "duration" not in gdf_copy.columns and "timestamp" in gdf_copy.columns:
            gdf_copy["stratify_col"] = gdf_copy["timestamp"].apply(
                #     lambda ts: (0.0 if len(ts) < 2 else (ts[-1] - ts[0]).total_seconds())
                # )
                lambda ts: (
                    0.0 if len(ts) < 2 else pd.Timedelta(ts[-1] - ts[0]).total_seconds()
                )
            )
        else:
            raise ValueError(
                "Duration column and timestamp column does not exist.\
                              Can't stratify it."
            )

    elif task == "HMP":
        self.version = "HMP"

        def split_sequence(seq):
            split_idx = int(len(seq) * 0.85)
            if split_idx == len(seq):
                split_idx = len(seq) - 1
            return seq[:split_idx], seq[split_idx:]

        if "h3_sequence_x" not in gdf_copy.columns:
            split_result = gdf_copy["h3_sequence"].apply(split_sequence)
            gdf_copy["h3_sequence_x"] = split_result.apply(operator.itemgetter(0))
            gdf_copy["h3_sequence_y"] = split_result.apply(operator.itemgetter(1))

        # Calculate trajectory length in unique hexagons
        gdf_copy["x_len"] = gdf_copy["h3_sequence_x"].apply(lambda seq: len(set(seq)))
        gdf_copy["y_len"] = gdf_copy["h3_sequence_y"].apply(lambda seq: len(set(seq)))
        gdf_copy["stratify_col"] = gdf_copy.apply(
            lambda row: row["x_len"] + row["y_len"], axis=1
        )
    else:
        raise ValueError(f"Unsupported task type: {task}")

    gdf_copy["stratification_bin"] = pd.cut(gdf_copy["stratify_col"], bins=n_bins, labels=False)

    trajectory_indices = gdf_copy[trajectory_id_column].unique()
    duration_bins = (
        gdf_copy[[trajectory_id_column, "stratification_bin"]]
        .drop_duplicates()
        .set_index(trajectory_id_column)["stratification_bin"]
    )

    train_indices, test_indices = train_test_split(
        trajectory_indices,
        test_size=test_size,
        stratify=duration_bins.loc[trajectory_indices],
        random_state=random_state,
    )

    train_gdf = gdf_copy[gdf_copy[trajectory_id_column].isin(train_indices)]
    test_gdf = gdf_copy[gdf_copy[trajectory_id_column].isin(test_indices)]

    test_gdf = test_gdf.drop(
        columns=[
            col
            for col in (
                "x_len",
                "y_len",
                "stratification_bin",
                "stratify_col",
            )
            if col in test_gdf.columns
        ],
    )
    train_gdf = train_gdf.drop(
        columns=[
            col
            for col in (
                "x_len",
                "y_len",
                "stratification_bin",
                "stratify_col",
            )
            if col in test_gdf.columns
        ],
    )

    self.train_gdf = train_gdf
    if not validation_split:
        self.test_gdf = test_gdf
        test_len = len(self.test_gdf) if self.test_gdf is not None else 0
        print(
            f"Created new train_gdf and test_gdf. Train len: {len(self.train_gdf)}, "
            f"test len: {test_len}"
        )
    else:
        self.val_gdf = test_gdf
        val_len = len(self.val_gdf) if self.val_gdf is not None else 0
        test_len = len(self.test_gdf) if self.test_gdf is not None else 0
        print(
            f"Created new train_gdf and val_gdf. Test split remains unchanged. "
            f"Train len: {len(self.train_gdf)}, val len: {val_len}, "
            f"test len: {test_len}"
        )
    return train_gdf, test_gdf

HouseSalesInKingCountyDataset

HouseSalesInKingCountyDataset()

Bases: PointDataset

House Sales in King County dataset.

This dataset contains house sale prices for King County, which includes Seattle. It includes homes sold between May 2014 and May 2015.

It's a great dataset for evaluating simple regression models.

Source code in srai/datasets/house_sales_in_king_county.py
def __init__(self) -> None:
    """Create the dataset."""
    numerical_columns = [
        "bathrooms",
        "sqft_living",
        "sqft_lot",
        "floors",
        "condition",
        "grade",
        "sqft_above",
        "sqft_basement",
        "sqft_living15",
        "sqft_lot15",
    ]
    categorical_columns = ["view", "yr_built", "yr_renovated", "waterfront"]
    type = "point"
    target = "price"
    super().__init__(
        "kraina/house_sales_in_king_county",
        type=type,
        numerical_columns=numerical_columns,
        categorical_columns=categorical_columns,
        target=target,
    )

get_h3_with_labels

get_h3_with_labels() -> (
    tuple[
        gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]
    ]
)

Returns h3 indexes with target labels from the dataset.

Points are aggregated to hexes and target column values are averaged or if target column is None, then the number of points is calculted within a hex and scaled to [0,1].

RETURNS DESCRIPTION
tuple[GeoDataFrame, Optional[GeoDataFrame], Optional[GeoDataFrame]]

tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]: Train, Val, Test hexes with target labels in GeoDataFrames

Source code in srai/datasets/_base.py
def get_h3_with_labels(
    self,
    # resolution: Optional[int] = None,
    # target_column: Optional[str] = None,
) -> tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]:
    """
    Returns h3 indexes with target labels from the dataset.

    Points are aggregated to hexes and target column values are averaged or if target column \
    is None, then the number of points is calculted within a hex and scaled to [0,1].

    Returns:
        tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]:\
            Train, Val, Test hexes with target labels in GeoDataFrames
    """
    # if target_column is None:
    #     target_column = "count"

    # resolution = resolution if resolution is not None else self.resolution

    assert self.train_gdf is not None
    # If resolution is still None, raise an error
    if self.resolution is None:
        raise ValueError(
            "No preset resolution for the dataset in self.resolution. Please"
            "provide a resolution."
        )
    # elif self.resolution is not None and resolution != self.resolution:
    #     raise ValueError(
    #         "Resolution provided is different from the preset resolution for the"
    #         "dataset. This may result in a data leak between splits."
    #     )

    _train_gdf = self._aggregate_hexes(self.train_gdf, self.resolution, self.target)

    if self.test_gdf is not None:
        _test_gdf = self._aggregate_hexes(self.test_gdf, self.resolution, self.target)
    else:
        _test_gdf = None

    if self.val_gdf is not None:
        _val_gdf = self._aggregate_hexes(self.val_gdf, self.resolution, self.target)
    else:
        _val_gdf = None

    # Scale the "count" column to [0, 1] if it is the target column
    if self.target == "count":
        scaler = MinMaxScaler()
        # Fit the scaler on the train dataset and transform
        _train_gdf["count"] = scaler.fit_transform(_train_gdf[["count"]])
        if _test_gdf is not None:
            _test_gdf["count"] = scaler.transform(_test_gdf[["count"]])
            _test_gdf["count"] = np.clip(_test_gdf["count"], 0, 1)
        if _val_gdf is not None:
            _val_gdf["count"] = scaler.transform(_val_gdf[["count"]])
            _val_gdf["count"] = np.clip(_val_gdf["count"], 0, 1)

    return _train_gdf, _val_gdf, _test_gdf

load

load(
    version: Optional[Union[int, str]] = 8, hf_token: Optional[str] = None
) -> dict[str, gpd.GeoDataFrame]

Method to load dataset.

PARAMETER DESCRIPTION
hf_token

If needed, a User Access Token needed to authenticate to the Hugging Face Hub. Environment variable HF_TOKEN can be also used. Defaults to None.

TYPE: str DEFAULT: None

version

version of a dataset. Available: '8', '9', '10', where number is a h3 resolution used in train-test split. Defaults to '8'. Raw, full data available as 'all'.

TYPE: str or int DEFAULT: 8

RETURNS DESCRIPTION
dict[str, GeoDataFrame]

dict[str, gpd.GeoDataFrame]: Dictionary with all splits loaded from the dataset. Will contain keys "train" and "test" if available.

Source code in srai/datasets/house_sales_in_king_county.py
def load(
    self, version: Optional[Union[int, str]] = 8, hf_token: Optional[str] = None
) -> dict[str, gpd.GeoDataFrame]:
    """
    Method to load dataset.

    Args:
        hf_token (str, optional): If needed, a User Access Token needed to authenticate to
            the Hugging Face Hub. Environment variable `HF_TOKEN` can be also used.
            Defaults to None.
        version (str or int, optional): version of a dataset.
            Available: '8', '9', '10', where number is a h3 resolution used in train-test \
                split. Defaults to '8'. Raw, full data available as 'all'.

    Returns:
        dict[str, gpd.GeoDataFrame]: Dictionary with all splits loaded from the dataset. Will
            contain keys "train" and "test" if available.
    """
    return super().load(hf_token=hf_token, version=version)

train_test_split

train_test_split(
    target_column: Optional[str] = None,
    resolution: Optional[int] = None,
    test_size: float = 0.2,
    n_bins: int = 7,
    random_state: Optional[int] = None,
    validation_split: bool = False,
    force_split: bool = False,
    task: Optional[str] = None,
) -> tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]

Method to generate splits from GeoDataFrame, based on the target_column values.

PARAMETER DESCRIPTION
target_column

Target column name. If None, split is generated based on number of points within a hex of a given resolution. Defaults to preset dataset target column.

TYPE: Optional[str] DEFAULT: None

resolution

h3 resolution to regionalize data. Defaults to default value from the dataset.

TYPE: int DEFAULT: None

test_size

Percentage of test set. Defaults to 0.2.

TYPE: float DEFAULT: 0.2

n_bins

Bucket number used to stratify target data. Defaults to 7.

TYPE: int DEFAULT: 7

random_state

Controls the shuffling applied to the data before applying the split. Pass an int for reproducible output across multiple function. Defaults to None.

TYPE: int DEFAULT: None

validation_split

If True, creates a validation split from existing train split and assigns it to self.val_gdf.

TYPE: bool DEFAULT: False

force_split

If True, forces a new split to be created, even if an existing train/test or validation split is already present. - With validation_split=False, regenerates and overwrites the test split. - With validation_split=True, regenerates and overwrites the validation split.

TYPE: bool DEFAULT: False

task

Currently not supported. Ignored in this subclass.

TYPE: Optional[str] DEFAULT: None

RETURNS DESCRIPTION
tuple

Train-test or train-val split made on previous train subset.

TYPE: (GeoDataFrame, GeoDataFrame)

Source code in srai/datasets/_base.py
def train_test_split(
    self,
    target_column: Optional[str] = None,
    resolution: Optional[int] = None,
    test_size: float = 0.2,
    n_bins: int = 7,
    random_state: Optional[int] = None,
    validation_split: bool = False,
    force_split: bool = False,
    task: Optional[str] = None,
) -> tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]:
    """
    Method to generate splits from GeoDataFrame, based on the target_column values.

    Args:
        target_column (Optional[str], optional): Target column name. If None, split is\
            generated based on number of points within a hex of a given resolution.\
            Defaults to preset dataset target column.
        resolution (int, optional): h3 resolution to regionalize data. Defaults to default\
            value from the dataset.
        test_size (float, optional): Percentage of test set. Defaults to 0.2.
        n_bins (int, optional): Bucket number used to stratify target data.\
            Defaults to 7.
        random_state (int, optional):  Controls the shuffling applied to the data before\
            applying the split. \
            Pass an int for reproducible output across multiple function. Defaults to None.
        validation_split (bool): If True, creates a validation split from existing train split\
            and assigns it to self.val_gdf.
        force_split: If True, forces a new split to be created, even if an existing train/test\
            or validation split is already present.
            - With `validation_split=False`, regenerates and overwrites the test split.
            - With `validation_split=True`, regenerates and overwrites the validation split.
        task (Optional[str], optional): Currently not supported. Ignored in this subclass.

    Returns:
        tuple(gpd.GeoDataFrame, gpd.GeoDataFrame): Train-test or train-val split made on\
            previous train subset.
    """
    assert self.train_gdf is not None

    if (self.val_gdf is not None and validation_split and not force_split) or (
        self.test_gdf is not None and not validation_split and not force_split
    ):
        raise ValueError(
            "A split already exists. Use `force_split=True` to overwrite the existing "
            f"{'validation' if validation_split else 'test'} split."
        )

    resolution = resolution or self.resolution

    if resolution is None:
        raise ValueError(
            "No preset resolution for the dataset in self.resolution. Please "
            "provide a resolution."
        )
    elif self.resolution is not None and resolution != self.resolution:
        raise ValueError(
            "Resolution provided is different from the preset resolution for the "
            "dataset. This may result in a data leak between splits."
        )

    if self.resolution is None:
        self.resolution = resolution
    target_column = target_column if target_column is not None else self.target
    if target_column is None:
        target_column = "count"

    gdf = self.train_gdf
    gdf_ = gdf.copy()

    train, test = train_test_spatial_split(
        gdf_,
        parent_h3_resolution=resolution,
        target_column=target_column,
        test_size=test_size,
        n_bins=n_bins,
        random_state=random_state,
    )

    self.train_gdf = train
    if not validation_split:
        self.test_gdf = test
        test_len = len(self.test_gdf) if self.test_gdf is not None else 0
        print(
            f"Created new train_gdf and test_gdf. Train len: {len(self.train_gdf)},"
            f"test len: {test_len}"
        )
    else:
        self.val_gdf = test
        val_len = len(self.val_gdf) if self.val_gdf is not None else 0
        test_len = len(self.test_gdf) if self.test_gdf is not None else 0
        print(
            f"Created new train_gdf and val_gdf. Test split remains unchanged."
            f"Train len: {len(self.train_gdf)}, val len: {val_len},"
            f"test len: {test_len}"
        )
    return train, test

HuggingFaceDataset

HuggingFaceDataset(
    path: str,
    version: Optional[str] = None,
    type: Optional[str] = None,
    numerical_columns: Optional[list[str]] = None,
    categorical_columns: Optional[list[str]] = None,
    target: Optional[str] = None,
    resolution: Optional[int] = None,
)

Bases: ABC

Abstract class for HuggingFace datasets.

Source code in srai/datasets/_base.py
def __init__(
    self,
    path: str,
    version: Optional[str] = None,
    type: Optional[str] = None,
    numerical_columns: Optional[list[str]] = None,
    categorical_columns: Optional[list[str]] = None,
    target: Optional[str] = None,
    resolution: Optional[int] = None,
) -> None:
    self.path = path
    self.version = version
    self.numerical_columns = numerical_columns
    self.categorical_columns = categorical_columns
    self.target = target
    self.type = type
    self.train_gdf = None
    self.test_gdf = None
    self.val_gdf = None
    self.resolution = resolution

get_h3_with_labels

abstractmethod
get_h3_with_labels() -> (
    tuple[
        gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]
    ]
)

Returns indexes with target labels from the dataset depending on dataset and task type.

RETURNS DESCRIPTION
tuple[GeoDataFrame, Optional[GeoDataFrame], Optional[GeoDataFrame]]

tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]: Train, Val, Test indexes with target labels in GeoDataFrames

Source code in srai/datasets/_base.py
@abc.abstractmethod
def get_h3_with_labels(
    self,
    # resolution: Optional[int] = None,
    # target_column: Optional[str] = None,
) -> tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]:
    """
    Returns indexes with target labels from the dataset depending on dataset and task type.

    Returns:
        tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]: \
            Train, Val, Test indexes with target labels in GeoDataFrames
    """
    raise NotImplementedError

load

load(
    version: Optional[Union[int, str]] = None, hf_token: Optional[str] = None
) -> dict[str, gpd.GeoDataFrame]

Method to load dataset.

PARAMETER DESCRIPTION
hf_token

If needed, a User Access Token needed to authenticate to the Hugging Face Hub. Environment variable HF_TOKEN can be also used. Defaults to None.

TYPE: str DEFAULT: None

version

version of a dataset

TYPE: str or int DEFAULT: None

RETURNS DESCRIPTION
dict[str, GeoDataFrame]

dict[str, gpd.GeoDataFrame]: Dictionary with all splits loaded from the dataset. Will contain keys "train" and "test" if available.

Source code in srai/datasets/_base.py
def load(
    self,
    version: Optional[Union[int, str]] = None,
    hf_token: Optional[str] = None,
) -> dict[str, gpd.GeoDataFrame]:
    """
    Method to load dataset.

    Args:
        hf_token (str, optional): If needed, a User Access Token needed to authenticate to
            the Hugging Face Hub. Environment variable `HF_TOKEN` can be also used.
            Defaults to None.
        version (str or int, optional): version of a dataset

    Returns:
        dict[str, gpd.GeoDataFrame]: Dictionary with all splits loaded from the dataset. Will
             contain keys "train" and "test" if available.
    """
    from datasets import load_dataset

    result = {}

    self.train_gdf, self.val_gdf, self.test_gdf = None, None, None
    dataset_name = self.path
    self.version = str(version)

    if (
        self.resolution is None
        and self.version in ("8", "9", "10")
        or (self.version in ("8", "9", "10") and str(self.resolution) != self.version)
    ):
        with suppress(ValueError):
            # Try to parse version as int (e.g. "8" or "9")
            self.resolution = int(self.version)

    data = load_dataset(dataset_name, str(version), token=hf_token, trust_remote_code=True)
    train = data["train"].to_pandas()
    processed_train = self._preprocessing(train)
    self.train_gdf = processed_train
    result["train"] = processed_train
    if "test" in data:
        test = data["test"].to_pandas()
        processed_test = self._preprocessing(test)
        self.test_gdf = processed_test
        result["test"] = processed_test

    return result

train_test_split

abstractmethod
train_test_split(
    target_column: Optional[str] = None,
    resolution: Optional[int] = None,
    test_size: float = 0.2,
    n_bins: int = 7,
    random_state: Optional[int] = None,
    validation_split: bool = False,
    force_split: bool = False,
    task: Optional[str] = None,
) -> tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]

Method to generate train/test or train/val split from GeoDataFrame.

PARAMETER DESCRIPTION
target_column

Target column name for Points, trajectories id column fortrajectory datasets. Defaults to preset dataset target column.

TYPE: Optional[str] DEFAULT: None

resolution

H3 resolution, subclasses mayb use this argument to regionalize data. Defaults to default value from the dataset.

TYPE: int DEFAULT: None

test_size

Percentage of test set. Defaults to 0.2.

TYPE: float DEFAULT: 0.2

n_bins

Bucket number used to stratify target data.

TYPE: int DEFAULT: 7

random_state

Controls the shuffling applied to the data before applying the split. Pass an int for reproducible output across multiple function. Defaults to None.

TYPE: int DEFAULT: None

validation_split

If True, creates a validation split from existing train split and assigns it to self.val_gdf.

TYPE: bool DEFAULT: False

force_split

If True, forces a new split to be created, even if an existing train/test or validation split is already present. - With validation_split=False, regenerates and overwrites the test split. - With validation_split=True, regenerates and overwrites the validation split.

TYPE: bool DEFAULT: False

task

Task identifier. Subclasses may use this argument to determine stratification logic (e.g., by duration or spatial pattern). Defaults to None.

TYPE: Optional[str] DEFAULT: None

RETURNS DESCRIPTION
tuple

Train-test or Train-val split made on previous train subset.

TYPE: (GeoDataFrame, GeoDataFrame)

Source code in srai/datasets/_base.py
@abc.abstractmethod
def train_test_split(
    self,
    target_column: Optional[str] = None,
    resolution: Optional[int] = None,
    test_size: float = 0.2,
    n_bins: int = 7,
    random_state: Optional[int] = None,
    validation_split: bool = False,
    force_split: bool = False,
    task: Optional[str] = None,
) -> tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]:
    """
    Method to generate train/test or train/val split from GeoDataFrame.

    Args:
        target_column (Optional[str], optional): Target column name for Points, trajectories id\
            column fortrajectory datasets. Defaults to preset dataset target column.
        resolution (int, optional): H3 resolution, subclasses mayb use this argument to\
            regionalize data. Defaults to default value from the dataset.
        test_size (float, optional): Percentage of test set. Defaults to 0.2.
        n_bins (int, optional): Bucket number used to stratify target data.
        random_state (int, optional):  Controls the shuffling applied to the data before \
            applying the split.
            Pass an int for reproducible output across multiple function. Defaults to None.
        validation_split (bool): If True, creates a validation split from existing train split\
            and assigns it to self.val_gdf.
        force_split: If True, forces a new split to be created, even if an existing train/test\
            or validation split is already present.
            - With `validation_split=False`, regenerates and overwrites the test split.
            - With `validation_split=True`, regenerates and overwrites the validation split.
        task (Optional[str], optional): Task identifier. Subclasses may use this argument to
            determine stratification logic (e.g., by duration or spatial pattern).\
                Defaults to None.

    Returns:
        tuple(gpd.GeoDataFrame, gpd.GeoDataFrame): Train-test or Train-val split made on\
            previous train subset.
    """
    raise NotImplementedError

PhiladelphiaCrimeDataset

PhiladelphiaCrimeDataset()

Bases: PointDataset

Philadelphia Crime dataset.

Crime incidents from the Philadelphia Police Department. Part I crimes include violent offenses such as aggravated assault, rape, arson, among others. Part II crimes include simple assault, prostitution, gambling, fraud, and other non-violent offenses.

Source code in srai/datasets/philadelphia_crime.py
def __init__(self) -> None:
    """Create the dataset."""
    numerical_columns = None
    categorical_columns = [
        "hour",
        "dispatch_date",
        "dispatch_time",
        "dc_dist",
        "psa",
    ]
    type = "point"
    # target = "text_general_code"
    target = "count"

    super().__init__(
        "kraina/philadelphia_crime",
        type=type,
        numerical_columns=numerical_columns,
        categorical_columns=categorical_columns,
        target=target,
    )

download_data

download_data(version: Optional[Union[int, str]] = 2023) -> None

Download and cache the Philadelphia crime dataset for a given year.

  • If the Parquet cache already exists, the download step is skipped.
  • Otherwise, the CSV is streamed from the API, converted in-memory to Parquet, and cached for future use.
PARAMETER DESCRIPTION
version

Dataset year to download (e.g., 2013-2023). If given as a short H3 resolution code ('8', '9', '10'), defaults to benchmark splits of the year 2023.

TYPE: int DEFAULT: 2023

Source code in srai/datasets/philadelphia_crime.py
def download_data(self, version: Optional[Union[int, str]] = 2023) -> None:
    """
    Download and cache the Philadelphia crime dataset for a given year.

    - If the Parquet cache already exists, the download step is skipped.
    - Otherwise, the CSV is streamed from the API, converted in-memory to Parquet,
    and cached for future use.

    Args:
        version (int): Dataset year to download (e.g., 2013-2023).
            If given as a short H3 resolution code ('8', '9', '10'),
            defaults to benchmark splits of the year 2023.
    """
    if version is None or len(str(version)) <= 3:
        version = 2023

    cache_file = self._get_cache_file(str(version))
    cache_file.parent.mkdir(parents=True, exist_ok=True)

    if not cache_file.exists():
        url = self._make_url(int(version))

        print(f"Downloading crime data for {version}...")
        duckdb.read_csv(url).to_parquet(str(cache_file), compression="zstd")

get_h3_with_labels

get_h3_with_labels() -> (
    tuple[
        gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]
    ]
)

Returns h3 indexes with target labels from the dataset.

Points are aggregated to hexes and target column values are averaged or if target column is None, then the number of points is calculted within a hex and scaled to [0,1].

RETURNS DESCRIPTION
tuple[GeoDataFrame, Optional[GeoDataFrame], Optional[GeoDataFrame]]

tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]: Train, Val, Test hexes with target labels in GeoDataFrames

Source code in srai/datasets/_base.py
def get_h3_with_labels(
    self,
    # resolution: Optional[int] = None,
    # target_column: Optional[str] = None,
) -> tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]:
    """
    Returns h3 indexes with target labels from the dataset.

    Points are aggregated to hexes and target column values are averaged or if target column \
    is None, then the number of points is calculted within a hex and scaled to [0,1].

    Returns:
        tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]:\
            Train, Val, Test hexes with target labels in GeoDataFrames
    """
    # if target_column is None:
    #     target_column = "count"

    # resolution = resolution if resolution is not None else self.resolution

    assert self.train_gdf is not None
    # If resolution is still None, raise an error
    if self.resolution is None:
        raise ValueError(
            "No preset resolution for the dataset in self.resolution. Please"
            "provide a resolution."
        )
    # elif self.resolution is not None and resolution != self.resolution:
    #     raise ValueError(
    #         "Resolution provided is different from the preset resolution for the"
    #         "dataset. This may result in a data leak between splits."
    #     )

    _train_gdf = self._aggregate_hexes(self.train_gdf, self.resolution, self.target)

    if self.test_gdf is not None:
        _test_gdf = self._aggregate_hexes(self.test_gdf, self.resolution, self.target)
    else:
        _test_gdf = None

    if self.val_gdf is not None:
        _val_gdf = self._aggregate_hexes(self.val_gdf, self.resolution, self.target)
    else:
        _val_gdf = None

    # Scale the "count" column to [0, 1] if it is the target column
    if self.target == "count":
        scaler = MinMaxScaler()
        # Fit the scaler on the train dataset and transform
        _train_gdf["count"] = scaler.fit_transform(_train_gdf[["count"]])
        if _test_gdf is not None:
            _test_gdf["count"] = scaler.transform(_test_gdf[["count"]])
            _test_gdf["count"] = np.clip(_test_gdf["count"], 0, 1)
        if _val_gdf is not None:
            _val_gdf["count"] = scaler.transform(_val_gdf[["count"]])
            _val_gdf["count"] = np.clip(_val_gdf["count"], 0, 1)

    return _train_gdf, _val_gdf, _test_gdf

load

load(
    version: Optional[Union[int, str]] = 8, hf_token: Optional[str] = None
) -> dict[str, gpd.GeoDataFrame]

Method to load dataset.

PARAMETER DESCRIPTION
hf_token

If needed, a User Access Token needed to authenticate to the Hugging Face Hub. Environment variable HF_TOKEN can be also used. Defaults to None.

TYPE: str DEFAULT: None

version

version of a dataset. Available: Official spatial train-test split from year 2023 in chosen h3 resolution: '8', '9, '10'. Defaults to '8'. Raw data from other years available as: '2013', '2014', '2015', '2016', '2017', '2018','2019', '2020', '2021', '2022', '2023'.

TYPE: str or int DEFAULT: 8

RETURNS DESCRIPTION
dict[str, GeoDataFrame]

dict[str, gpd.GeoDataFrame]: Dictionary with all splits loaded from the dataset. Will contain keys "train" and "test" if available.

Source code in srai/datasets/philadelphia_crime.py
def load(
    self, version: Optional[Union[int, str]] = 8, hf_token: Optional[str] = None
) -> dict[str, gpd.GeoDataFrame]:
    """
    Method to load dataset.

    Args:
        hf_token (str, optional): If needed, a User Access Token needed to authenticate to
            the Hugging Face Hub. Environment variable `HF_TOKEN` can be also used.
            Defaults to None.
        version (str or int, optional): version of a dataset.
            Available: Official spatial train-test split from year 2023 in chosen h3 resolution:
            '8', '9, '10'. Defaults to '8'. Raw data from other years available
            as: '2013', '2014', '2015', '2016', '2017', '2018','2019', '2020', '2021',
            '2022', '2023'.

    Returns:
        dict[str, gpd.GeoDataFrame]: Dictionary with all splits loaded from the dataset. Will
            contain keys "train" and "test" if available.
    """
    self.resolution = None
    self.download_data(version=version)

    from datasets import load_dataset

    result = {}

    self.train_gdf, self.val_gdf, self.test_gdf = None, None, None
    dataset_name = self.path
    self.version = str(version)

    if self.resolution is None and self.version in ("8", "9", "10"):
        with suppress(ValueError):
            # Try to parse version as int (e.g. "8" or "9")
            self.resolution = int(self.version)

    if len(str(version)) <= 3:
        data = load_dataset(dataset_name, str(version), token=hf_token, trust_remote_code=True)
    else:
        empty_dataset = Dataset.from_pandas(pd.DataFrame())
        data = {"train": empty_dataset}
    train = data["train"].to_pandas()
    processed_train = self._preprocessing(train)
    self.train_gdf = processed_train
    result["train"] = processed_train
    if "test" in data:
        test = data["test"].to_pandas()
        processed_test = self._preprocessing(test)
        self.test_gdf = processed_test
        result["test"] = processed_test

    return result

train_test_split

train_test_split(
    target_column: Optional[str] = None,
    resolution: Optional[int] = None,
    test_size: float = 0.2,
    n_bins: int = 7,
    random_state: Optional[int] = None,
    validation_split: bool = False,
    force_split: bool = False,
    task: Optional[str] = None,
) -> tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]

Method to generate splits from GeoDataFrame, based on the target_column values.

PARAMETER DESCRIPTION
target_column

Target column name. If None, split is generated based on number of points within a hex of a given resolution. Defaults to preset dataset target column.

TYPE: Optional[str] DEFAULT: None

resolution

h3 resolution to regionalize data. Defaults to default value from the dataset.

TYPE: int DEFAULT: None

test_size

Percentage of test set. Defaults to 0.2.

TYPE: float DEFAULT: 0.2

n_bins

Bucket number used to stratify target data. Defaults to 7.

TYPE: int DEFAULT: 7

random_state

Controls the shuffling applied to the data before applying the split. Pass an int for reproducible output across multiple function. Defaults to None.

TYPE: int DEFAULT: None

validation_split

If True, creates a validation split from existing train split and assigns it to self.val_gdf.

TYPE: bool DEFAULT: False

force_split

If True, forces a new split to be created, even if an existing train/test or validation split is already present. - With validation_split=False, regenerates and overwrites the test split. - With validation_split=True, regenerates and overwrites the validation split.

TYPE: bool DEFAULT: False

task

Currently not supported. Ignored in this subclass.

TYPE: Optional[str] DEFAULT: None

RETURNS DESCRIPTION
tuple

Train-test or train-val split made on previous train subset.

TYPE: (GeoDataFrame, GeoDataFrame)

Source code in srai/datasets/_base.py
def train_test_split(
    self,
    target_column: Optional[str] = None,
    resolution: Optional[int] = None,
    test_size: float = 0.2,
    n_bins: int = 7,
    random_state: Optional[int] = None,
    validation_split: bool = False,
    force_split: bool = False,
    task: Optional[str] = None,
) -> tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]:
    """
    Method to generate splits from GeoDataFrame, based on the target_column values.

    Args:
        target_column (Optional[str], optional): Target column name. If None, split is\
            generated based on number of points within a hex of a given resolution.\
            Defaults to preset dataset target column.
        resolution (int, optional): h3 resolution to regionalize data. Defaults to default\
            value from the dataset.
        test_size (float, optional): Percentage of test set. Defaults to 0.2.
        n_bins (int, optional): Bucket number used to stratify target data.\
            Defaults to 7.
        random_state (int, optional):  Controls the shuffling applied to the data before\
            applying the split. \
            Pass an int for reproducible output across multiple function. Defaults to None.
        validation_split (bool): If True, creates a validation split from existing train split\
            and assigns it to self.val_gdf.
        force_split: If True, forces a new split to be created, even if an existing train/test\
            or validation split is already present.
            - With `validation_split=False`, regenerates and overwrites the test split.
            - With `validation_split=True`, regenerates and overwrites the validation split.
        task (Optional[str], optional): Currently not supported. Ignored in this subclass.

    Returns:
        tuple(gpd.GeoDataFrame, gpd.GeoDataFrame): Train-test or train-val split made on\
            previous train subset.
    """
    assert self.train_gdf is not None

    if (self.val_gdf is not None and validation_split and not force_split) or (
        self.test_gdf is not None and not validation_split and not force_split
    ):
        raise ValueError(
            "A split already exists. Use `force_split=True` to overwrite the existing "
            f"{'validation' if validation_split else 'test'} split."
        )

    resolution = resolution or self.resolution

    if resolution is None:
        raise ValueError(
            "No preset resolution for the dataset in self.resolution. Please "
            "provide a resolution."
        )
    elif self.resolution is not None and resolution != self.resolution:
        raise ValueError(
            "Resolution provided is different from the preset resolution for the "
            "dataset. This may result in a data leak between splits."
        )

    if self.resolution is None:
        self.resolution = resolution
    target_column = target_column if target_column is not None else self.target
    if target_column is None:
        target_column = "count"

    gdf = self.train_gdf
    gdf_ = gdf.copy()

    train, test = train_test_spatial_split(
        gdf_,
        parent_h3_resolution=resolution,
        target_column=target_column,
        test_size=test_size,
        n_bins=n_bins,
        random_state=random_state,
    )

    self.train_gdf = train
    if not validation_split:
        self.test_gdf = test
        test_len = len(self.test_gdf) if self.test_gdf is not None else 0
        print(
            f"Created new train_gdf and test_gdf. Train len: {len(self.train_gdf)},"
            f"test len: {test_len}"
        )
    else:
        self.val_gdf = test
        val_len = len(self.val_gdf) if self.val_gdf is not None else 0
        test_len = len(self.test_gdf) if self.test_gdf is not None else 0
        print(
            f"Created new train_gdf and val_gdf. Test split remains unchanged."
            f"Train len: {len(self.train_gdf)}, val len: {val_len},"
            f"test len: {test_len}"
        )
    return train, test

PointDataset

PointDataset(
    path: str,
    version: Optional[str] = None,
    type: Optional[str] = None,
    numerical_columns: Optional[list[str]] = None,
    categorical_columns: Optional[list[str]] = None,
    target: Optional[str] = None,
    resolution: Optional[int] = None,
)

Bases: HuggingFaceDataset

Abstract class for HuggingFace datasets with Point Data.

Source code in srai/datasets/_base.py
def __init__(
    self,
    path: str,
    version: Optional[str] = None,
    type: Optional[str] = None,
    numerical_columns: Optional[list[str]] = None,
    categorical_columns: Optional[list[str]] = None,
    target: Optional[str] = None,
    resolution: Optional[int] = None,
) -> None:
    import_optional_dependencies(dependency_group="datasets", modules=["datasets"])
    self.path = path
    self.version = version
    self.numerical_columns = numerical_columns
    self.categorical_columns = categorical_columns
    self.target = target
    self.type = type
    self.train_gdf = None
    self.test_gdf = None
    self.val_gdf = None
    self.resolution = resolution

get_h3_with_labels

get_h3_with_labels() -> (
    tuple[
        gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]
    ]
)

Returns h3 indexes with target labels from the dataset.

Points are aggregated to hexes and target column values are averaged or if target column is None, then the number of points is calculted within a hex and scaled to [0,1].

RETURNS DESCRIPTION
tuple[GeoDataFrame, Optional[GeoDataFrame], Optional[GeoDataFrame]]

tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]: Train, Val, Test hexes with target labels in GeoDataFrames

Source code in srai/datasets/_base.py
def get_h3_with_labels(
    self,
    # resolution: Optional[int] = None,
    # target_column: Optional[str] = None,
) -> tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]:
    """
    Returns h3 indexes with target labels from the dataset.

    Points are aggregated to hexes and target column values are averaged or if target column \
    is None, then the number of points is calculted within a hex and scaled to [0,1].

    Returns:
        tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]:\
            Train, Val, Test hexes with target labels in GeoDataFrames
    """
    # if target_column is None:
    #     target_column = "count"

    # resolution = resolution if resolution is not None else self.resolution

    assert self.train_gdf is not None
    # If resolution is still None, raise an error
    if self.resolution is None:
        raise ValueError(
            "No preset resolution for the dataset in self.resolution. Please"
            "provide a resolution."
        )
    # elif self.resolution is not None and resolution != self.resolution:
    #     raise ValueError(
    #         "Resolution provided is different from the preset resolution for the"
    #         "dataset. This may result in a data leak between splits."
    #     )

    _train_gdf = self._aggregate_hexes(self.train_gdf, self.resolution, self.target)

    if self.test_gdf is not None:
        _test_gdf = self._aggregate_hexes(self.test_gdf, self.resolution, self.target)
    else:
        _test_gdf = None

    if self.val_gdf is not None:
        _val_gdf = self._aggregate_hexes(self.val_gdf, self.resolution, self.target)
    else:
        _val_gdf = None

    # Scale the "count" column to [0, 1] if it is the target column
    if self.target == "count":
        scaler = MinMaxScaler()
        # Fit the scaler on the train dataset and transform
        _train_gdf["count"] = scaler.fit_transform(_train_gdf[["count"]])
        if _test_gdf is not None:
            _test_gdf["count"] = scaler.transform(_test_gdf[["count"]])
            _test_gdf["count"] = np.clip(_test_gdf["count"], 0, 1)
        if _val_gdf is not None:
            _val_gdf["count"] = scaler.transform(_val_gdf[["count"]])
            _val_gdf["count"] = np.clip(_val_gdf["count"], 0, 1)

    return _train_gdf, _val_gdf, _test_gdf

load

load(
    version: Optional[Union[int, str]] = None, hf_token: Optional[str] = None
) -> dict[str, gpd.GeoDataFrame]

Method to load dataset.

PARAMETER DESCRIPTION
hf_token

If needed, a User Access Token needed to authenticate to the Hugging Face Hub. Environment variable HF_TOKEN can be also used. Defaults to None.

TYPE: str DEFAULT: None

version

version of a dataset

TYPE: str or int DEFAULT: None

RETURNS DESCRIPTION
dict[str, GeoDataFrame]

dict[str, gpd.GeoDataFrame]: Dictionary with all splits loaded from the dataset. Will contain keys "train" and "test" if available.

Source code in srai/datasets/_base.py
def load(
    self,
    version: Optional[Union[int, str]] = None,
    hf_token: Optional[str] = None,
) -> dict[str, gpd.GeoDataFrame]:
    """
    Method to load dataset.

    Args:
        hf_token (str, optional): If needed, a User Access Token needed to authenticate to
            the Hugging Face Hub. Environment variable `HF_TOKEN` can be also used.
            Defaults to None.
        version (str or int, optional): version of a dataset

    Returns:
        dict[str, gpd.GeoDataFrame]: Dictionary with all splits loaded from the dataset. Will
             contain keys "train" and "test" if available.
    """
    from datasets import load_dataset

    result = {}

    self.train_gdf, self.val_gdf, self.test_gdf = None, None, None
    dataset_name = self.path
    self.version = str(version)

    if (
        self.resolution is None
        and self.version in ("8", "9", "10")
        or (self.version in ("8", "9", "10") and str(self.resolution) != self.version)
    ):
        with suppress(ValueError):
            # Try to parse version as int (e.g. "8" or "9")
            self.resolution = int(self.version)

    data = load_dataset(dataset_name, str(version), token=hf_token, trust_remote_code=True)
    train = data["train"].to_pandas()
    processed_train = self._preprocessing(train)
    self.train_gdf = processed_train
    result["train"] = processed_train
    if "test" in data:
        test = data["test"].to_pandas()
        processed_test = self._preprocessing(test)
        self.test_gdf = processed_test
        result["test"] = processed_test

    return result

train_test_split

train_test_split(
    target_column: Optional[str] = None,
    resolution: Optional[int] = None,
    test_size: float = 0.2,
    n_bins: int = 7,
    random_state: Optional[int] = None,
    validation_split: bool = False,
    force_split: bool = False,
    task: Optional[str] = None,
) -> tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]

Method to generate splits from GeoDataFrame, based on the target_column values.

PARAMETER DESCRIPTION
target_column

Target column name. If None, split is generated based on number of points within a hex of a given resolution. Defaults to preset dataset target column.

TYPE: Optional[str] DEFAULT: None

resolution

h3 resolution to regionalize data. Defaults to default value from the dataset.

TYPE: int DEFAULT: None

test_size

Percentage of test set. Defaults to 0.2.

TYPE: float DEFAULT: 0.2

n_bins

Bucket number used to stratify target data. Defaults to 7.

TYPE: int DEFAULT: 7

random_state

Controls the shuffling applied to the data before applying the split. Pass an int for reproducible output across multiple function. Defaults to None.

TYPE: int DEFAULT: None

validation_split

If True, creates a validation split from existing train split and assigns it to self.val_gdf.

TYPE: bool DEFAULT: False

force_split

If True, forces a new split to be created, even if an existing train/test or validation split is already present. - With validation_split=False, regenerates and overwrites the test split. - With validation_split=True, regenerates and overwrites the validation split.

TYPE: bool DEFAULT: False

task

Currently not supported. Ignored in this subclass.

TYPE: Optional[str] DEFAULT: None

RETURNS DESCRIPTION
tuple

Train-test or train-val split made on previous train subset.

TYPE: (GeoDataFrame, GeoDataFrame)

Source code in srai/datasets/_base.py
def train_test_split(
    self,
    target_column: Optional[str] = None,
    resolution: Optional[int] = None,
    test_size: float = 0.2,
    n_bins: int = 7,
    random_state: Optional[int] = None,
    validation_split: bool = False,
    force_split: bool = False,
    task: Optional[str] = None,
) -> tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]:
    """
    Method to generate splits from GeoDataFrame, based on the target_column values.

    Args:
        target_column (Optional[str], optional): Target column name. If None, split is\
            generated based on number of points within a hex of a given resolution.\
            Defaults to preset dataset target column.
        resolution (int, optional): h3 resolution to regionalize data. Defaults to default\
            value from the dataset.
        test_size (float, optional): Percentage of test set. Defaults to 0.2.
        n_bins (int, optional): Bucket number used to stratify target data.\
            Defaults to 7.
        random_state (int, optional):  Controls the shuffling applied to the data before\
            applying the split. \
            Pass an int for reproducible output across multiple function. Defaults to None.
        validation_split (bool): If True, creates a validation split from existing train split\
            and assigns it to self.val_gdf.
        force_split: If True, forces a new split to be created, even if an existing train/test\
            or validation split is already present.
            - With `validation_split=False`, regenerates and overwrites the test split.
            - With `validation_split=True`, regenerates and overwrites the validation split.
        task (Optional[str], optional): Currently not supported. Ignored in this subclass.

    Returns:
        tuple(gpd.GeoDataFrame, gpd.GeoDataFrame): Train-test or train-val split made on\
            previous train subset.
    """
    assert self.train_gdf is not None

    if (self.val_gdf is not None and validation_split and not force_split) or (
        self.test_gdf is not None and not validation_split and not force_split
    ):
        raise ValueError(
            "A split already exists. Use `force_split=True` to overwrite the existing "
            f"{'validation' if validation_split else 'test'} split."
        )

    resolution = resolution or self.resolution

    if resolution is None:
        raise ValueError(
            "No preset resolution for the dataset in self.resolution. Please "
            "provide a resolution."
        )
    elif self.resolution is not None and resolution != self.resolution:
        raise ValueError(
            "Resolution provided is different from the preset resolution for the "
            "dataset. This may result in a data leak between splits."
        )

    if self.resolution is None:
        self.resolution = resolution
    target_column = target_column if target_column is not None else self.target
    if target_column is None:
        target_column = "count"

    gdf = self.train_gdf
    gdf_ = gdf.copy()

    train, test = train_test_spatial_split(
        gdf_,
        parent_h3_resolution=resolution,
        target_column=target_column,
        test_size=test_size,
        n_bins=n_bins,
        random_state=random_state,
    )

    self.train_gdf = train
    if not validation_split:
        self.test_gdf = test
        test_len = len(self.test_gdf) if self.test_gdf is not None else 0
        print(
            f"Created new train_gdf and test_gdf. Train len: {len(self.train_gdf)},"
            f"test len: {test_len}"
        )
    else:
        self.val_gdf = test
        val_len = len(self.val_gdf) if self.val_gdf is not None else 0
        test_len = len(self.test_gdf) if self.test_gdf is not None else 0
        print(
            f"Created new train_gdf and val_gdf. Test split remains unchanged."
            f"Train len: {len(self.train_gdf)}, val len: {val_len},"
            f"test len: {test_len}"
        )
    return train, test

PoliceDepartmentIncidentsDataset

PoliceDepartmentIncidentsDataset()

Bases: PointDataset

The San Francisco Police Department's (SFPD) Incident Report Datatset.

This dataset includes incident reports that have been filed as of January 1, 2018 till March, 2024. These reports are filed by officers or self-reported by members of the public using SFPD’s online reporting system.

Source code in srai/datasets/police_department_incidents.py
def __init__(self) -> None:
    """Create the dataset."""
    numerical_columns = None
    categorical_columns = [
        "Incdident Year",
        "Incident Day of Week",
        "Police District",
        "Analysis Neighborhood",
        "Incident Description",
        "Incident Time",
        "Incident Code",
        "Report Type Code",
        "Police District",
        "Analysis Neighborhood",
    ]
    type = "point"
    # target = "Incident Category"
    target = "count"
    super().__init__(
        "kraina/police_department_incidents",
        type=type,
        numerical_columns=numerical_columns,
        categorical_columns=categorical_columns,
        target=target,
    )

get_h3_with_labels

get_h3_with_labels() -> (
    tuple[
        gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]
    ]
)

Returns h3 indexes with target labels from the dataset.

Points are aggregated to hexes and target column values are averaged or if target column is None, then the number of points is calculted within a hex and scaled to [0,1].

RETURNS DESCRIPTION
tuple[GeoDataFrame, Optional[GeoDataFrame], Optional[GeoDataFrame]]

tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]: Train, Val, Test hexes with target labels in GeoDataFrames

Source code in srai/datasets/_base.py
def get_h3_with_labels(
    self,
    # resolution: Optional[int] = None,
    # target_column: Optional[str] = None,
) -> tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]:
    """
    Returns h3 indexes with target labels from the dataset.

    Points are aggregated to hexes and target column values are averaged or if target column \
    is None, then the number of points is calculted within a hex and scaled to [0,1].

    Returns:
        tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]:\
            Train, Val, Test hexes with target labels in GeoDataFrames
    """
    # if target_column is None:
    #     target_column = "count"

    # resolution = resolution if resolution is not None else self.resolution

    assert self.train_gdf is not None
    # If resolution is still None, raise an error
    if self.resolution is None:
        raise ValueError(
            "No preset resolution for the dataset in self.resolution. Please"
            "provide a resolution."
        )
    # elif self.resolution is not None and resolution != self.resolution:
    #     raise ValueError(
    #         "Resolution provided is different from the preset resolution for the"
    #         "dataset. This may result in a data leak between splits."
    #     )

    _train_gdf = self._aggregate_hexes(self.train_gdf, self.resolution, self.target)

    if self.test_gdf is not None:
        _test_gdf = self._aggregate_hexes(self.test_gdf, self.resolution, self.target)
    else:
        _test_gdf = None

    if self.val_gdf is not None:
        _val_gdf = self._aggregate_hexes(self.val_gdf, self.resolution, self.target)
    else:
        _val_gdf = None

    # Scale the "count" column to [0, 1] if it is the target column
    if self.target == "count":
        scaler = MinMaxScaler()
        # Fit the scaler on the train dataset and transform
        _train_gdf["count"] = scaler.fit_transform(_train_gdf[["count"]])
        if _test_gdf is not None:
            _test_gdf["count"] = scaler.transform(_test_gdf[["count"]])
            _test_gdf["count"] = np.clip(_test_gdf["count"], 0, 1)
        if _val_gdf is not None:
            _val_gdf["count"] = scaler.transform(_val_gdf[["count"]])
            _val_gdf["count"] = np.clip(_val_gdf["count"], 0, 1)

    return _train_gdf, _val_gdf, _test_gdf

load

load(
    version: Optional[Union[int, str]] = 9, hf_token: Optional[str] = None
) -> dict[str, gpd.GeoDataFrame]

Method to load dataset.

PARAMETER DESCRIPTION
hf_token

If needed, a User Access Token needed to authenticate to the Hugging Face Hub. Environment variable HF_TOKEN can be also used. Defaults to None.

TYPE: str DEFAULT: None

version

version of a dataset. Available: '8', '9', '10', where number is a h3 resolution used in train-test split. Defaults to '9'. Raw, full data available as 'all'.

TYPE: str or int DEFAULT: 9

RETURNS DESCRIPTION
dict[str, GeoDataFrame]

dict[str, gpd.GeoDataFrame]: Dictionary with all splits loaded from the dataset. Will contain keys "train" and "test" if available.

Source code in srai/datasets/police_department_incidents.py
def load(
    self, version: Optional[Union[int, str]] = 9, hf_token: Optional[str] = None
) -> dict[str, gpd.GeoDataFrame]:
    """
    Method to load dataset.

    Args:
        hf_token (str, optional): If needed, a User Access Token needed to authenticate to
            the Hugging Face Hub. Environment variable `HF_TOKEN` can be also used.
            Defaults to None.
        version (str or int, optional): version of a dataset.
            Available: '8', '9', '10', where number is a h3 resolution used in train-test \
                split. Defaults to '9'. Raw, full data available as 'all'.

    Returns:
        dict[str, gpd.GeoDataFrame]: Dictionary with all splits loaded from the dataset. Will
            contain keys "train" and "test" if available.
    """
    return super().load(hf_token=hf_token, version=version)

train_test_split

train_test_split(
    target_column: Optional[str] = None,
    resolution: Optional[int] = None,
    test_size: float = 0.2,
    n_bins: int = 7,
    random_state: Optional[int] = None,
    validation_split: bool = False,
    force_split: bool = False,
    task: Optional[str] = None,
) -> tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]

Method to generate splits from GeoDataFrame, based on the target_column values.

PARAMETER DESCRIPTION
target_column

Target column name. If None, split is generated based on number of points within a hex of a given resolution. Defaults to preset dataset target column.

TYPE: Optional[str] DEFAULT: None

resolution

h3 resolution to regionalize data. Defaults to default value from the dataset.

TYPE: int DEFAULT: None

test_size

Percentage of test set. Defaults to 0.2.

TYPE: float DEFAULT: 0.2

n_bins

Bucket number used to stratify target data. Defaults to 7.

TYPE: int DEFAULT: 7

random_state

Controls the shuffling applied to the data before applying the split. Pass an int for reproducible output across multiple function. Defaults to None.

TYPE: int DEFAULT: None

validation_split

If True, creates a validation split from existing train split and assigns it to self.val_gdf.

TYPE: bool DEFAULT: False

force_split

If True, forces a new split to be created, even if an existing train/test or validation split is already present. - With validation_split=False, regenerates and overwrites the test split. - With validation_split=True, regenerates and overwrites the validation split.

TYPE: bool DEFAULT: False

task

Currently not supported. Ignored in this subclass.

TYPE: Optional[str] DEFAULT: None

RETURNS DESCRIPTION
tuple

Train-test or train-val split made on previous train subset.

TYPE: (GeoDataFrame, GeoDataFrame)

Source code in srai/datasets/_base.py
def train_test_split(
    self,
    target_column: Optional[str] = None,
    resolution: Optional[int] = None,
    test_size: float = 0.2,
    n_bins: int = 7,
    random_state: Optional[int] = None,
    validation_split: bool = False,
    force_split: bool = False,
    task: Optional[str] = None,
) -> tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]:
    """
    Method to generate splits from GeoDataFrame, based on the target_column values.

    Args:
        target_column (Optional[str], optional): Target column name. If None, split is\
            generated based on number of points within a hex of a given resolution.\
            Defaults to preset dataset target column.
        resolution (int, optional): h3 resolution to regionalize data. Defaults to default\
            value from the dataset.
        test_size (float, optional): Percentage of test set. Defaults to 0.2.
        n_bins (int, optional): Bucket number used to stratify target data.\
            Defaults to 7.
        random_state (int, optional):  Controls the shuffling applied to the data before\
            applying the split. \
            Pass an int for reproducible output across multiple function. Defaults to None.
        validation_split (bool): If True, creates a validation split from existing train split\
            and assigns it to self.val_gdf.
        force_split: If True, forces a new split to be created, even if an existing train/test\
            or validation split is already present.
            - With `validation_split=False`, regenerates and overwrites the test split.
            - With `validation_split=True`, regenerates and overwrites the validation split.
        task (Optional[str], optional): Currently not supported. Ignored in this subclass.

    Returns:
        tuple(gpd.GeoDataFrame, gpd.GeoDataFrame): Train-test or train-val split made on\
            previous train subset.
    """
    assert self.train_gdf is not None

    if (self.val_gdf is not None and validation_split and not force_split) or (
        self.test_gdf is not None and not validation_split and not force_split
    ):
        raise ValueError(
            "A split already exists. Use `force_split=True` to overwrite the existing "
            f"{'validation' if validation_split else 'test'} split."
        )

    resolution = resolution or self.resolution

    if resolution is None:
        raise ValueError(
            "No preset resolution for the dataset in self.resolution. Please "
            "provide a resolution."
        )
    elif self.resolution is not None and resolution != self.resolution:
        raise ValueError(
            "Resolution provided is different from the preset resolution for the "
            "dataset. This may result in a data leak between splits."
        )

    if self.resolution is None:
        self.resolution = resolution
    target_column = target_column if target_column is not None else self.target
    if target_column is None:
        target_column = "count"

    gdf = self.train_gdf
    gdf_ = gdf.copy()

    train, test = train_test_spatial_split(
        gdf_,
        parent_h3_resolution=resolution,
        target_column=target_column,
        test_size=test_size,
        n_bins=n_bins,
        random_state=random_state,
    )

    self.train_gdf = train
    if not validation_split:
        self.test_gdf = test
        test_len = len(self.test_gdf) if self.test_gdf is not None else 0
        print(
            f"Created new train_gdf and test_gdf. Train len: {len(self.train_gdf)},"
            f"test len: {test_len}"
        )
    else:
        self.val_gdf = test
        val_len = len(self.val_gdf) if self.val_gdf is not None else 0
        test_len = len(self.test_gdf) if self.test_gdf is not None else 0
        print(
            f"Created new train_gdf and val_gdf. Test split remains unchanged."
            f"Train len: {len(self.train_gdf)}, val len: {val_len},"
            f"test len: {test_len}"
        )
    return train, test

PortoTaxiDataset

PortoTaxiDataset()

Bases: TrajectoryDataset

Porto Taxi dataset.

The dataset covers a year of trajectory data for taxis in Porto, Portugal Each ride is categorized as: A) taxi central based, B) stand-based or C) non-taxi central based. Each data point represents a completed trip initiated through the dispatch central, a taxi stand, or a random street.

Source code in srai/datasets/porto_taxi.py
def __init__(self) -> None:
    """Create the dataset."""
    numerical_columns = ["speed"]
    categorical_columns = ["call_type", "origin_call", "origin_stand", "day_type"]
    type = "trajectory"
    target = "trip_id"
    # target = None
    super().__init__(
        "kraina/porto_taxi",
        type=type,
        numerical_columns=numerical_columns,
        categorical_columns=categorical_columns,
        target=target,
    )

get_h3_with_labels

get_h3_with_labels() -> (
    tuple[
        gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]
    ]
)

Returns ids, h3 indexes sequences, with target labels from the dataset.

Points are aggregated to hex trajectories and target column values are calculated for each trajectory (time duration for TTE task, future movement sequence for HMP task).

RETURNS DESCRIPTION
tuple[GeoDataFrame, Optional[GeoDataFrame], Optional[GeoDataFrame]]

tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]: Train, Val, Test hexes sequences with target labels in GeoDataFrames

Source code in srai/datasets/_base.py
def get_h3_with_labels(
    self,
    # resolution: Optional[int] = None,
    # target_column: Optional[str] = None,
) -> tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]:
    """
    Returns ids, h3 indexes sequences, with target labels from the dataset.

    Points are aggregated to hex trajectories and target column values are calculated \
        for each trajectory (time duration for TTE task, future movement sequence for HMP task).

    Returns:
        tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]: Train,\
            Val, Test hexes sequences with target labels in GeoDataFrames
    """
    # resolution = resolution if resolution is not None else self.resolution

    assert self.train_gdf is not None
    # If resolution is still None, raise an error
    if self.resolution is None:
        raise ValueError(
            "No preset resolution for the dataset in self.resolution. Please"
            "provide a resolution."
        )
    # elif self.resolution is not None and resolution != self.resolution:
    #     raise ValueError(
    #         "Resolution provided is different from the preset resolution for the"
    #         "dataset. This may result in a data leak between splits."
    #     )

    if self.version == "TTE":
        _train_gdf = self.train_gdf[[self.target, "h3_sequence", "duration"]]

        if self.test_gdf is not None:
            _test_gdf = self.test_gdf[[self.target, "h3_sequence", "duration"]]
        else:
            _test_gdf = None

        if self.val_gdf is not None:
            _val_gdf = self.val_gdf[[self.target, "h3_sequence", "duration"]]
        else:
            _val_gdf = None

    elif self.version == "HMP":
        _train_gdf = self.train_gdf[[self.target, "h3_sequence_x", "h3_sequence_y"]]

        if self.test_gdf is not None:
            _test_gdf = self.test_gdf[[self.target, "h3_sequence_x", "h3_sequence_y"]]
        else:
            _test_gdf = None

        if self.val_gdf is not None:
            _val_gdf = self.val_gdf[[self.target, "h3_sequence_x", "h3_sequence_y"]]
        else:
            _val_gdf = None

    elif self.version == "all":
        raise TypeError(
            "Could not provide target labels, as version 'all'\
        of dataset does not provide one."
        )

    return _train_gdf, _val_gdf, _test_gdf

load

load(
    version: Optional[Union[int, str]] = "TTE",
    hf_token: Optional[str] = None,
    resolution: Optional[int] = None,
) -> dict[str, gpd.GeoDataFrame]

Method to load dataset.

PARAMETER DESCRIPTION
hf_token

If needed, a User Access Token needed to authenticate to the Hugging Face Hub. Environment variable HF_TOKEN can be also used. Defaults to None.

TYPE: Optional[str] DEFAULT: None

version

version of a dataset. Available: Official train-test split for Travel Time Estimation task (TTE) and Human Mobility Prediction task (HMP). Raw data from available as: 'all'.

TYPE: Optional[str, int] DEFAULT: 'TTE'

resolution

H3 resolution for hex trajectories. Neccessary if using 'all' split.

TYPE: Optional[int] DEFAULT: None

RETURNS DESCRIPTION
dict[str, GeoDataFrame]

dict[str, gpd.GeoDataFrame]: Dictionary with all splits loaded from the dataset. Will contain keys "train" and "test" if available.

Source code in srai/datasets/porto_taxi.py
def load(
    self,
    version: Optional[Union[int, str]] = "TTE",
    hf_token: Optional[str] = None,
    resolution: Optional[int] = None,
) -> dict[str, gpd.GeoDataFrame]:
    """
    Method to load dataset.

    Args:
        hf_token (Optional[str]): If needed, a User Access Token needed to authenticate to
            the Hugging Face Hub. Environment variable `HF_TOKEN` can be also used.
            Defaults to None.
        version (Optional[str, int]): version of a dataset.
            Available: Official train-test split for Travel Time Estimation task (TTE) and
            Human Mobility Prediction task (HMP). Raw data from available as: 'all'.
        resolution (Optional[int]): H3 resolution for hex trajectories.
            Neccessary if using 'all' split.

    Returns:
        dict[str, gpd.GeoDataFrame]: Dictionary with all splits loaded from the dataset. Will
            contain keys "train" and "test" if available.
    """
    if version in ("TTE", "HMP"):
        self.resolution = 9
    elif version == "all":
        self.resolution = resolution if resolution is not None else None
    else:
        raise NotImplementedError("Version not implemented")
    return super().load(hf_token=hf_token, version=version)

train_test_split

train_test_split(
    target_column: Optional[str] = None,
    resolution: Optional[int] = None,
    test_size: float = 0.2,
    n_bins: int = 4,
    random_state: Optional[int] = None,
    validation_split: bool = False,
    force_split: bool = False,
    task: Optional[str] = "TTE",
) -> tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]

Generate train/test split or train/val split from trajectory GeoDataFrame.

Train-test/train-val split is generated by splitting train_gdf.

PARAMETER DESCRIPTION
target_column

Column identifying each trajectory (contains trajectory ids).

TYPE: str DEFAULT: None

test_size

Fraction of data to be used as test set.

TYPE: float DEFAULT: 0.2

n_bins

Number of stratification bins.

TYPE: int DEFAULT: 4

random_state

Controls the shuffling applied to the data before applying the split. Pass an int for reproducible output across multiple function. Defaults to None.

TYPE: int DEFAULT: None

validation_split

If True, creates a validation split from existing train split and assigns it to self.val_gdf.

TYPE: bool DEFAULT: False

force_split

If True, forces a new split to be created, even if an existing train/test or validation split is already present. - With validation_split=False, regenerates and overwrites the test split. - With validation_split=True, regenerates and overwrites the validation split.

TYPE: bool DEFAULT: False

resolution

H3 resolution to regionalize data. Currently ignored in this subclass, different resolutions splits not supported yet. Defaults to default value from the dataset.

TYPE: int DEFAULT: None

task

Task type. Stratifies by duration (TTE) or hex length (HMP).

TYPE: Literal[TTE, HMP] DEFAULT: 'TTE'

RETURNS DESCRIPTION
tuple[GeoDataFrame, GeoDataFrame]

Tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]: Train/test or train/val GeoDataFrames.

Source code in srai/datasets/_base.py
def train_test_split(
    self,
    target_column: Optional[str] = None,
    resolution: Optional[int] = None,
    test_size: float = 0.2,
    n_bins: int = 4,
    random_state: Optional[int] = None,
    validation_split: bool = False,
    force_split: bool = False,
    task: Optional[str] = "TTE",
) -> tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]:
    """
    Generate train/test split or train/val split from trajectory GeoDataFrame.

    Train-test/train-val split is generated by splitting train_gdf.

    Args:
        target_column (str): Column identifying each trajectory (contains trajectory ids).
        test_size (float): Fraction of data to be used as test set.
        n_bins (int): Number of stratification bins.
        random_state (int, optional):  Controls the shuffling applied to the data before\
            applying the split. Pass an int for reproducible output across multiple function.\
                Defaults to None.
        validation_split (bool): If True, creates a validation split from existing train split\
            and assigns it to self.val_gdf.
        force_split: If True, forces a new split to be created, even if an existing train/test\
            or validation split is already present.
            - With `validation_split=False`, regenerates and overwrites the test split.
            - With `validation_split=True`, regenerates and overwrites the validation split.
        resolution (int, optional): H3 resolution to regionalize data. Currently ignored in\
            this subclass, different resolutions splits not supported yet.\
                Defaults to default value from the dataset.
        task (Literal["TTE", "HMP"]): Task type. Stratifies by duration
            (TTE) or hex length (HMP).


    Returns:
        Tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]: Train/test or train/val GeoDataFrames.
    """
    if (self.val_gdf is not None and validation_split and not force_split) or (
        self.test_gdf is not None and not validation_split and not force_split
    ):
        raise ValueError(
            "A split already exists. Use `force_split=True` to overwrite the existing "
            f"{'validation' if validation_split else 'test'} split."
        )
    assert self.train_gdf is not None
    trajectory_id_column = target_column or self.target
    gdf_copy = self.train_gdf.copy()

    if task not in {"TTE", "HMP"}:
        raise ValueError(f"Unsupported task: {task}")

    if task == "TTE":
        self.version = "TTE"
        # Calculate duration in seconds from timestamps list

        if "duration" in gdf_copy.columns:
            gdf_copy["stratify_col"] = gdf_copy["duration"]
        elif "duration" not in gdf_copy.columns and "timestamp" in gdf_copy.columns:
            gdf_copy["stratify_col"] = gdf_copy["timestamp"].apply(
                #     lambda ts: (0.0 if len(ts) < 2 else (ts[-1] - ts[0]).total_seconds())
                # )
                lambda ts: (
                    0.0 if len(ts) < 2 else pd.Timedelta(ts[-1] - ts[0]).total_seconds()
                )
            )
        else:
            raise ValueError(
                "Duration column and timestamp column does not exist.\
                              Can't stratify it."
            )

    elif task == "HMP":
        self.version = "HMP"

        def split_sequence(seq):
            split_idx = int(len(seq) * 0.85)
            if split_idx == len(seq):
                split_idx = len(seq) - 1
            return seq[:split_idx], seq[split_idx:]

        if "h3_sequence_x" not in gdf_copy.columns:
            split_result = gdf_copy["h3_sequence"].apply(split_sequence)
            gdf_copy["h3_sequence_x"] = split_result.apply(operator.itemgetter(0))
            gdf_copy["h3_sequence_y"] = split_result.apply(operator.itemgetter(1))

        # Calculate trajectory length in unique hexagons
        gdf_copy["x_len"] = gdf_copy["h3_sequence_x"].apply(lambda seq: len(set(seq)))
        gdf_copy["y_len"] = gdf_copy["h3_sequence_y"].apply(lambda seq: len(set(seq)))
        gdf_copy["stratify_col"] = gdf_copy.apply(
            lambda row: row["x_len"] + row["y_len"], axis=1
        )
    else:
        raise ValueError(f"Unsupported task type: {task}")

    gdf_copy["stratification_bin"] = pd.cut(gdf_copy["stratify_col"], bins=n_bins, labels=False)

    trajectory_indices = gdf_copy[trajectory_id_column].unique()
    duration_bins = (
        gdf_copy[[trajectory_id_column, "stratification_bin"]]
        .drop_duplicates()
        .set_index(trajectory_id_column)["stratification_bin"]
    )

    train_indices, test_indices = train_test_split(
        trajectory_indices,
        test_size=test_size,
        stratify=duration_bins.loc[trajectory_indices],
        random_state=random_state,
    )

    train_gdf = gdf_copy[gdf_copy[trajectory_id_column].isin(train_indices)]
    test_gdf = gdf_copy[gdf_copy[trajectory_id_column].isin(test_indices)]

    test_gdf = test_gdf.drop(
        columns=[
            col
            for col in (
                "x_len",
                "y_len",
                "stratification_bin",
                "stratify_col",
            )
            if col in test_gdf.columns
        ],
    )
    train_gdf = train_gdf.drop(
        columns=[
            col
            for col in (
                "x_len",
                "y_len",
                "stratification_bin",
                "stratify_col",
            )
            if col in test_gdf.columns
        ],
    )

    self.train_gdf = train_gdf
    if not validation_split:
        self.test_gdf = test_gdf
        test_len = len(self.test_gdf) if self.test_gdf is not None else 0
        print(
            f"Created new train_gdf and test_gdf. Train len: {len(self.train_gdf)}, "
            f"test len: {test_len}"
        )
    else:
        self.val_gdf = test_gdf
        val_len = len(self.val_gdf) if self.val_gdf is not None else 0
        test_len = len(self.test_gdf) if self.test_gdf is not None else 0
        print(
            f"Created new train_gdf and val_gdf. Test split remains unchanged. "
            f"Train len: {len(self.train_gdf)}, val len: {val_len}, "
            f"test len: {test_len}"
        )
    return train_gdf, test_gdf

TrajectoryDataset

TrajectoryDataset(
    path: str,
    version: Optional[str] = None,
    type: Optional[str] = None,
    numerical_columns: Optional[list[str]] = None,
    categorical_columns: Optional[list[str]] = None,
    target: Optional[str] = None,
    resolution: Optional[int] = None,
)

Bases: HuggingFaceDataset

Abstract class for HuggingFace datasets with Trajectory data.

Source code in srai/datasets/_base.py
def __init__(
    self,
    path: str,
    version: Optional[str] = None,
    type: Optional[str] = None,
    numerical_columns: Optional[list[str]] = None,
    categorical_columns: Optional[list[str]] = None,
    target: Optional[str] = None,
    resolution: Optional[int] = None,
) -> None:
    import_optional_dependencies(dependency_group="datasets", modules=["datasets"])
    self.path = path
    self.version = version
    self.numerical_columns = numerical_columns
    self.categorical_columns = categorical_columns
    self.target = target
    self.type = type
    self.train_gdf = None
    self.test_gdf = None
    self.val_gdf = None
    self.resolution = resolution

get_h3_with_labels

get_h3_with_labels() -> (
    tuple[
        gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]
    ]
)

Returns ids, h3 indexes sequences, with target labels from the dataset.

Points are aggregated to hex trajectories and target column values are calculated for each trajectory (time duration for TTE task, future movement sequence for HMP task).

RETURNS DESCRIPTION
tuple[GeoDataFrame, Optional[GeoDataFrame], Optional[GeoDataFrame]]

tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]: Train, Val, Test hexes sequences with target labels in GeoDataFrames

Source code in srai/datasets/_base.py
def get_h3_with_labels(
    self,
    # resolution: Optional[int] = None,
    # target_column: Optional[str] = None,
) -> tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]:
    """
    Returns ids, h3 indexes sequences, with target labels from the dataset.

    Points are aggregated to hex trajectories and target column values are calculated \
        for each trajectory (time duration for TTE task, future movement sequence for HMP task).

    Returns:
        tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]: Train,\
            Val, Test hexes sequences with target labels in GeoDataFrames
    """
    # resolution = resolution if resolution is not None else self.resolution

    assert self.train_gdf is not None
    # If resolution is still None, raise an error
    if self.resolution is None:
        raise ValueError(
            "No preset resolution for the dataset in self.resolution. Please"
            "provide a resolution."
        )
    # elif self.resolution is not None and resolution != self.resolution:
    #     raise ValueError(
    #         "Resolution provided is different from the preset resolution for the"
    #         "dataset. This may result in a data leak between splits."
    #     )

    if self.version == "TTE":
        _train_gdf = self.train_gdf[[self.target, "h3_sequence", "duration"]]

        if self.test_gdf is not None:
            _test_gdf = self.test_gdf[[self.target, "h3_sequence", "duration"]]
        else:
            _test_gdf = None

        if self.val_gdf is not None:
            _val_gdf = self.val_gdf[[self.target, "h3_sequence", "duration"]]
        else:
            _val_gdf = None

    elif self.version == "HMP":
        _train_gdf = self.train_gdf[[self.target, "h3_sequence_x", "h3_sequence_y"]]

        if self.test_gdf is not None:
            _test_gdf = self.test_gdf[[self.target, "h3_sequence_x", "h3_sequence_y"]]
        else:
            _test_gdf = None

        if self.val_gdf is not None:
            _val_gdf = self.val_gdf[[self.target, "h3_sequence_x", "h3_sequence_y"]]
        else:
            _val_gdf = None

    elif self.version == "all":
        raise TypeError(
            "Could not provide target labels, as version 'all'\
        of dataset does not provide one."
        )

    return _train_gdf, _val_gdf, _test_gdf

load

load(
    version: Optional[Union[int, str]] = None, hf_token: Optional[str] = None
) -> dict[str, gpd.GeoDataFrame]

Method to load dataset.

PARAMETER DESCRIPTION
hf_token

If needed, a User Access Token needed to authenticate to the Hugging Face Hub. Environment variable HF_TOKEN can be also used. Defaults to None.

TYPE: str DEFAULT: None

version

version of a dataset

TYPE: str or int DEFAULT: None

RETURNS DESCRIPTION
dict[str, GeoDataFrame]

dict[str, gpd.GeoDataFrame]: Dictionary with all splits loaded from the dataset. Will contain keys "train" and "test" if available.

Source code in srai/datasets/_base.py
def load(
    self,
    version: Optional[Union[int, str]] = None,
    hf_token: Optional[str] = None,
) -> dict[str, gpd.GeoDataFrame]:
    """
    Method to load dataset.

    Args:
        hf_token (str, optional): If needed, a User Access Token needed to authenticate to
            the Hugging Face Hub. Environment variable `HF_TOKEN` can be also used.
            Defaults to None.
        version (str or int, optional): version of a dataset

    Returns:
        dict[str, gpd.GeoDataFrame]: Dictionary with all splits loaded from the dataset. Will
             contain keys "train" and "test" if available.
    """
    from datasets import load_dataset

    result = {}

    self.train_gdf, self.val_gdf, self.test_gdf = None, None, None
    dataset_name = self.path
    self.version = str(version)

    if (
        self.resolution is None
        and self.version in ("8", "9", "10")
        or (self.version in ("8", "9", "10") and str(self.resolution) != self.version)
    ):
        with suppress(ValueError):
            # Try to parse version as int (e.g. "8" or "9")
            self.resolution = int(self.version)

    data = load_dataset(dataset_name, str(version), token=hf_token, trust_remote_code=True)
    train = data["train"].to_pandas()
    processed_train = self._preprocessing(train)
    self.train_gdf = processed_train
    result["train"] = processed_train
    if "test" in data:
        test = data["test"].to_pandas()
        processed_test = self._preprocessing(test)
        self.test_gdf = processed_test
        result["test"] = processed_test

    return result

train_test_split

train_test_split(
    target_column: Optional[str] = None,
    resolution: Optional[int] = None,
    test_size: float = 0.2,
    n_bins: int = 4,
    random_state: Optional[int] = None,
    validation_split: bool = False,
    force_split: bool = False,
    task: Optional[str] = "TTE",
) -> tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]

Generate train/test split or train/val split from trajectory GeoDataFrame.

Train-test/train-val split is generated by splitting train_gdf.

PARAMETER DESCRIPTION
target_column

Column identifying each trajectory (contains trajectory ids).

TYPE: str DEFAULT: None

test_size

Fraction of data to be used as test set.

TYPE: float DEFAULT: 0.2

n_bins

Number of stratification bins.

TYPE: int DEFAULT: 4

random_state

Controls the shuffling applied to the data before applying the split. Pass an int for reproducible output across multiple function. Defaults to None.

TYPE: int DEFAULT: None

validation_split

If True, creates a validation split from existing train split and assigns it to self.val_gdf.

TYPE: bool DEFAULT: False

force_split

If True, forces a new split to be created, even if an existing train/test or validation split is already present. - With validation_split=False, regenerates and overwrites the test split. - With validation_split=True, regenerates and overwrites the validation split.

TYPE: bool DEFAULT: False

resolution

H3 resolution to regionalize data. Currently ignored in this subclass, different resolutions splits not supported yet. Defaults to default value from the dataset.

TYPE: int DEFAULT: None

task

Task type. Stratifies by duration (TTE) or hex length (HMP).

TYPE: Literal[TTE, HMP] DEFAULT: 'TTE'

RETURNS DESCRIPTION
tuple[GeoDataFrame, GeoDataFrame]

Tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]: Train/test or train/val GeoDataFrames.

Source code in srai/datasets/_base.py
def train_test_split(
    self,
    target_column: Optional[str] = None,
    resolution: Optional[int] = None,
    test_size: float = 0.2,
    n_bins: int = 4,
    random_state: Optional[int] = None,
    validation_split: bool = False,
    force_split: bool = False,
    task: Optional[str] = "TTE",
) -> tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]:
    """
    Generate train/test split or train/val split from trajectory GeoDataFrame.

    Train-test/train-val split is generated by splitting train_gdf.

    Args:
        target_column (str): Column identifying each trajectory (contains trajectory ids).
        test_size (float): Fraction of data to be used as test set.
        n_bins (int): Number of stratification bins.
        random_state (int, optional):  Controls the shuffling applied to the data before\
            applying the split. Pass an int for reproducible output across multiple function.\
                Defaults to None.
        validation_split (bool): If True, creates a validation split from existing train split\
            and assigns it to self.val_gdf.
        force_split: If True, forces a new split to be created, even if an existing train/test\
            or validation split is already present.
            - With `validation_split=False`, regenerates and overwrites the test split.
            - With `validation_split=True`, regenerates and overwrites the validation split.
        resolution (int, optional): H3 resolution to regionalize data. Currently ignored in\
            this subclass, different resolutions splits not supported yet.\
                Defaults to default value from the dataset.
        task (Literal["TTE", "HMP"]): Task type. Stratifies by duration
            (TTE) or hex length (HMP).


    Returns:
        Tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]: Train/test or train/val GeoDataFrames.
    """
    if (self.val_gdf is not None and validation_split and not force_split) or (
        self.test_gdf is not None and not validation_split and not force_split
    ):
        raise ValueError(
            "A split already exists. Use `force_split=True` to overwrite the existing "
            f"{'validation' if validation_split else 'test'} split."
        )
    assert self.train_gdf is not None
    trajectory_id_column = target_column or self.target
    gdf_copy = self.train_gdf.copy()

    if task not in {"TTE", "HMP"}:
        raise ValueError(f"Unsupported task: {task}")

    if task == "TTE":
        self.version = "TTE"
        # Calculate duration in seconds from timestamps list

        if "duration" in gdf_copy.columns:
            gdf_copy["stratify_col"] = gdf_copy["duration"]
        elif "duration" not in gdf_copy.columns and "timestamp" in gdf_copy.columns:
            gdf_copy["stratify_col"] = gdf_copy["timestamp"].apply(
                #     lambda ts: (0.0 if len(ts) < 2 else (ts[-1] - ts[0]).total_seconds())
                # )
                lambda ts: (
                    0.0 if len(ts) < 2 else pd.Timedelta(ts[-1] - ts[0]).total_seconds()
                )
            )
        else:
            raise ValueError(
                "Duration column and timestamp column does not exist.\
                              Can't stratify it."
            )

    elif task == "HMP":
        self.version = "HMP"

        def split_sequence(seq):
            split_idx = int(len(seq) * 0.85)
            if split_idx == len(seq):
                split_idx = len(seq) - 1
            return seq[:split_idx], seq[split_idx:]

        if "h3_sequence_x" not in gdf_copy.columns:
            split_result = gdf_copy["h3_sequence"].apply(split_sequence)
            gdf_copy["h3_sequence_x"] = split_result.apply(operator.itemgetter(0))
            gdf_copy["h3_sequence_y"] = split_result.apply(operator.itemgetter(1))

        # Calculate trajectory length in unique hexagons
        gdf_copy["x_len"] = gdf_copy["h3_sequence_x"].apply(lambda seq: len(set(seq)))
        gdf_copy["y_len"] = gdf_copy["h3_sequence_y"].apply(lambda seq: len(set(seq)))
        gdf_copy["stratify_col"] = gdf_copy.apply(
            lambda row: row["x_len"] + row["y_len"], axis=1
        )
    else:
        raise ValueError(f"Unsupported task type: {task}")

    gdf_copy["stratification_bin"] = pd.cut(gdf_copy["stratify_col"], bins=n_bins, labels=False)

    trajectory_indices = gdf_copy[trajectory_id_column].unique()
    duration_bins = (
        gdf_copy[[trajectory_id_column, "stratification_bin"]]
        .drop_duplicates()
        .set_index(trajectory_id_column)["stratification_bin"]
    )

    train_indices, test_indices = train_test_split(
        trajectory_indices,
        test_size=test_size,
        stratify=duration_bins.loc[trajectory_indices],
        random_state=random_state,
    )

    train_gdf = gdf_copy[gdf_copy[trajectory_id_column].isin(train_indices)]
    test_gdf = gdf_copy[gdf_copy[trajectory_id_column].isin(test_indices)]

    test_gdf = test_gdf.drop(
        columns=[
            col
            for col in (
                "x_len",
                "y_len",
                "stratification_bin",
                "stratify_col",
            )
            if col in test_gdf.columns
        ],
    )
    train_gdf = train_gdf.drop(
        columns=[
            col
            for col in (
                "x_len",
                "y_len",
                "stratification_bin",
                "stratify_col",
            )
            if col in test_gdf.columns
        ],
    )

    self.train_gdf = train_gdf
    if not validation_split:
        self.test_gdf = test_gdf
        test_len = len(self.test_gdf) if self.test_gdf is not None else 0
        print(
            f"Created new train_gdf and test_gdf. Train len: {len(self.train_gdf)}, "
            f"test len: {test_len}"
        )
    else:
        self.val_gdf = test_gdf
        val_len = len(self.val_gdf) if self.val_gdf is not None else 0
        test_len = len(self.test_gdf) if self.test_gdf is not None else 0
        print(
            f"Created new train_gdf and val_gdf. Test split remains unchanged. "
            f"Train len: {len(self.train_gdf)}, val len: {val_len}, "
            f"test len: {test_len}"
        )
    return train_gdf, test_gdf