Skip to content

Neighbour dataset

NeighbourDataset.

This dataset is used to train a model to predict whether regions are neighbours or not. As defined in Hex2Vec paper[1].

References

[1] https://dl.acm.org/doi/10.1145/3486635.3491076

NeighbourDatasetItem

Bases: NamedTuple

Neighbour dataset item.

ATTRIBUTE DESCRIPTION
X_anchor

Anchor regions.

TYPE: Tensor

X_positive

Positive regions. Data for the regions that are neighbours of regions in X_anchor.

TYPE: Tensor

X_negative

Negative regions. Data for the regions that are NOT neighbours of the regions in X_anchor.

TYPE: Tensor

NeighbourDataset(
    data, neighbourhood, negative_sample_k_distance=2
)

Bases: Dataset[NeighbourDatasetItem], Generic[T]

Dataset for training a model to predict neighbours.

It works by returning triplets of regions: anchor, positive and negative. A model can be trained to predict that the anchor region is a neighbour of the positive region, and that it is not a neighbour of the negative region.

PARAMETER DESCRIPTION
data

Data to use for training. Raw counts of features in regions.

TYPE: DataFrame

neighbourhood

Neighbourhood to use for training. It has to be initialized with the same data as the data argument.

TYPE: Neighbourhood[T]

negative_sample_k_distance

How many neighbours away to sample negative regions. For example, if k=2, then the negative regions will be sampled from regions that are at least 3 hops away from the anchor region. Has to be >= 2.

TYPE: int DEFAULT: 2

RAISES DESCRIPTION
ValueError

If negative_sample_k_distance < 2.

Source code in srai/embedders/hex2vec/neighbour_dataset.py
def __init__(
    self,
    data: pd.DataFrame,
    neighbourhood: Neighbourhood[T],
    negative_sample_k_distance: int = 2,
):
    """
    Initialize NeighbourDataset.

    Args:
        data (pd.DataFrame): Data to use for training. Raw counts of features in regions.
        neighbourhood (Neighbourhood[T]): Neighbourhood to use for training.
            It has to be initialized with the same data as the data argument.
        negative_sample_k_distance (int): How many neighbours away to sample negative regions.
            For example, if k=2, then the negative regions will be sampled from regions that are
            at least 3 hops away from the anchor region. Has to be >= 2.

    Raises:
        ValueError: If negative_sample_k_distance < 2.
    """
    import_optional_dependencies(dependency_group="torch", modules=["torch"])
    import torch

    self._data = torch.Tensor(data.to_numpy())
    self._assert_negative_sample_k_distance_correct(negative_sample_k_distance)
    self._negative_sample_k_distance = negative_sample_k_distance

    self._anchor_df_locs_lookup: np.ndarray
    self._positive_df_locs_lookup: np.ndarray
    self._excluded_from_negatives: dict[int, set[int]] = {}

    self._region_index_to_df_loc: dict[T, int] = {
        region_index: i for i, region_index in enumerate(data.index)
    }
    self._df_loc_to_region_index: dict[int, T] = {
        i: region_index for region_index, i in self._region_index_to_df_loc.items()
    }

    self._build_lookup_tables(data, neighbourhood)

__len__()

Return the number of anchor-positive pairs available in the dataset.

RETURNS DESCRIPTION
int

The number of pairs.

TYPE: int

Source code in srai/embedders/hex2vec/neighbour_dataset.py
def __len__(self) -> int:
    """
    Return the number of anchor-positive pairs available in the dataset.

    Returns:
        int: The number of pairs.
    """
    return len(self._anchor_df_locs_lookup)

__getitem__(data_row_index)

Return a single dataset item (anchor, positive, negative).

PARAMETER DESCRIPTION
data_row_index

The index of the dataset item to return.

TYPE: Any

RETURNS DESCRIPTION
NeighbourDatasetItem

The dataset item. This includes the anchor region, positive region and arandomly sampled negative region.

TYPE: NeighbourDatasetItem

Source code in srai/embedders/hex2vec/neighbour_dataset.py
def __getitem__(self, data_row_index: Any) -> NeighbourDatasetItem:
    """
    Return a single dataset item (anchor, positive, negative).

    Args:
        data_row_index (Any): The index of the dataset item to return.

    Returns:
        NeighbourDatasetItem: The dataset item.
            This includes the anchor region, positive region
            and arandomly sampled negative region.
    """
    anchor_df_loc = self._anchor_df_locs_lookup[data_row_index]
    positive_df_loc = self._positive_df_locs_lookup[data_row_index]
    negative_df_loc = self._get_random_negative_df_loc(anchor_df_loc)

    anchor_region = self._data[anchor_df_loc]
    positive_region = self._data[positive_df_loc]
    negative_region = self._data[negative_df_loc]

    return NeighbourDatasetItem(anchor_region, positive_region, negative_region)