Skip to content

Dataset

HexagonalDataset.

This dataset is used to train a hexagonal encoder model. As defined in GeoVex paper[1].

References

[1] https://openreview.net/forum?id=7bvWopYY1H

HexagonalDataset(data, neighbourhood, neighbor_k_ring=6)

Bases: Dataset['torch.Tensor'], Generic[T]

Dataset for the hexagonal encoder model.

It works by returning a 3d tensor of hexagonal regions. The tensor is a cube with the target hexagonal region in the center, and the rings of neighbors around surrounding it.

PARAMETER DESCRIPTION
data

Data to use for training. Raw counts of features in regions.

TYPE: DataFrame

neighbourhood

H3Neighbourhood to use for training. It has to be initialized with the same data as the data argument.

TYPE: H3Neighbourhood

neighbor_k_ring

The hexagonal rings of neighbors to include in the tensor. Defaults to 6.

TYPE: int DEFAULT: 6

Source code in srai/embedders/geovex/dataset.py
def __init__(
    self,
    data: pd.DataFrame,
    neighbourhood: H3Neighbourhood,
    neighbor_k_ring: int = 6,
):
    """
    Initialize the HexagonalDataset.

    Args:
        data (pd.DataFrame): Data to use for training. Raw counts of features in regions.
        neighbourhood (H3Neighbourhood): H3Neighbourhood to use for training.
            It has to be initialized with the same data as the data argument.
        neighbor_k_ring (int, optional): The hexagonal rings of neighbors to include
            in the tensor. Defaults to 6.
    """
    import_optional_dependencies(dependency_group="torch", modules=["torch"])
    import torch

    self._assert_k_ring_correct(neighbor_k_ring)
    self._assert_h3_neighbourhood(neighbourhood)
    # store the desired k
    self._k: int = neighbor_k_ring
    # number of columns in the dataset
    self._N: int = data.shape[1]
    # store the list of valid h3 indices (have all the neighbors in the dataset)
    self._valid_cells: list[CellInfo] = []
    # store the data as a torch tensor
    self._data_torch = torch.Tensor(data.to_numpy(dtype=np.float32))
    # iterate over the data and build the valid h3 indices
    self._invalid_cells, self._valid_cells = self._seperate_valid_invalid_cells(
        data, neighbourhood, neighbor_k_ring, set(data.index)
    )

__len__()

Returns the number of valid h3 indices in the dataset.

RETURNS DESCRIPTION
int

Number of valid h3 indices in the dataset.

TYPE: int

Source code in srai/embedders/geovex/dataset.py
def __len__(self) -> int:
    """
    Returns the number of valid h3 indices in the dataset.

    Returns:
        int: Number of valid h3 indices in the dataset.
    """
    return len(self._valid_cells)

__getitem__(index)

Return a single item from the dataset.

PARAMETER DESCRIPTION
index

The index of dataset item to return

TYPE: Any

RETURNS DESCRIPTION
HexagonalDatasetItem

The dataset item

TYPE: Tensor

Source code in srai/embedders/geovex/dataset.py
def __getitem__(self, index: Any) -> "torch.Tensor":
    """
    Return a single item from the dataset.

    Args:
        index (Any): The index of dataset item to return

    Returns:
        HexagonalDatasetItem: The dataset item
    """
    _, target_idx, neighbors_idxs = self._valid_cells[index]
    return self._build_tensor(target_idx, neighbors_idxs)

get_valid_cells()

Returns the list of valid h3 indices in the dataset.

RETURNS DESCRIPTION
list[str]

List[str]: List of valid h3 indices in the dataset.

Source code in srai/embedders/geovex/dataset.py
def get_valid_cells(self) -> list[str]:
    """
    Returns the list of valid h3 indices in the dataset.

    Returns:
        List[str]: List of valid h3 indices in the dataset.
    """
    return [h3_index for h3_index, _, _ in self._valid_cells]

get_invalid_cells()

Returns the list of invalid h3 indices in the dataset.

RETURNS DESCRIPTION
list[str]

List[str]: List of invalid h3 indices in the dataset.

Source code in srai/embedders/geovex/dataset.py
def get_invalid_cells(self) -> list[str]:
    """
    Returns the list of invalid h3 indices in the dataset.

    Returns:
        List[str]: List of invalid h3 indices in the dataset.
    """
    return list(self._invalid_cells)