Skip to content

BaseEvaluator

srai.benchmark.BaseEvaluator

BaseEvaluator(
    task: Literal[
        "trajectory_regression",
        "regression",
        "poi_prediction",
        "mobility_prediction",
    ],
)

Bases: ABC

Abstract class for benchmark evaluators.

Source code in srai/benchmark/_base.py
def __init__(
    self,
    task: Literal[
        "trajectory_regression", "regression", "poi_prediction", "mobility_prediction"
    ],
) -> None:
    self.task = task

evaluate

abstractmethod
evaluate(
    dataset: sds.PointDataset | sds.TrajectoryDataset,
    predictions: np.ndarray,
    log_metrics: bool = True,
    hf_token: Optional[str] = None,
    **kwargs: Any
) -> dict[str, float]

Evaluate predictions againts test set.

PARAMETER DESCRIPTION
dataset

Dataset to evaluate on.

TYPE: HuggingFaceDataset

predictions

Predictions returned by your model.

TYPE: ndarray

log_metrics

If True, logs metrics to the console. Defaults to True.

TYPE: bool DEFAULT: True

hf_token

If needed, a User Access Token needed to authenticate to HF Defaults to None.

TYPE: str DEFAULT: None

**kwargs

Additional keyword arguments depending on the task.

TYPE: Any DEFAULT: {}

PARAMETER DESCRIPTION
region_ids

List of region IDs. Required for region-based evaluators.

TYPE: list[str]

point_of_interests

Points of interest. Required for point-based evaluators.

TYPE: ndarray

RETURNS DESCRIPTION
dict[str, float]

dict[str, float]: Dictionary with metrics values for the task.

Note

Specific subclasses may require different sets of keyword arguments.

Source code in srai/benchmark/_base.py
@abc.abstractmethod
def evaluate(
    self,
    dataset: sds.PointDataset | sds.TrajectoryDataset,
    predictions: np.ndarray,
    log_metrics: bool = True,
    hf_token: Optional[str] = None,
    **kwargs: Any,
) -> dict[str, float]:
    """
    Evaluate predictions againts test set.

    Args:
        dataset (sds.HuggingFaceDataset): Dataset to evaluate on.
        predictions (np.ndarray): Predictions returned by your model.
        log_metrics (bool, optional): If True, logs metrics to the console. Defaults to True.
        hf_token (str, optional): If needed, a User Access Token needed to authenticate to HF
            Defaults to None.
        **kwargs: Additional keyword arguments depending on the task.

    Keyword Args:
        region_ids (list[str], optional): List of region IDs. Required for region-based\
              evaluators.
        point_of_interests (np.ndarray, optional): Points of interest. Required for point-based\
            evaluators.

    Returns:
        dict[str, float]: Dictionary with metrics values for the task.

    Note:
        Specific subclasses may require different sets of keyword arguments.
    """
    # if self.task == "regression":
    #     train_gdf, test_gdf = dataset.load(version=f"res_{resolution}", hf_token=hf_token)
    #     target_column = dataset.target if dataset.target is not None else "count"
    #     # h3_indexes, labels = self._get_labels(test, resolution, target_column)
    #     _, h3_test = dataset.get_h3_with_labels(
    #         train_gdf=train_gdf, test_gdf=test_gdf, resolution=resolution
    #     )

    #     if h3_test is None:
    #         raise ValueError("The function 'get_h3_with_labels' returned None for h3_test.")
    #     else:
    #         h3_indexes = h3_test["region_id"].to_list()
    #         labels = h3_test[target_column].to_numpy()

    #     region_to_prediction = {
    #         region_id: prediction for region_id, prediction in zip(region_ids, predictions)
    #     }

    #     # order predictions according to the order of region_ids
    #     try:
    #         ordered_predictions = [region_to_prediction[h3] for h3 in h3_indexes]
    #     except KeyError as err:
    #         raise ValueError(
    #             "Region id for H3 index {err.args[0]} not found in region_ids."
    #         ) from err

    #     region_ids[:] = h3_indexes
    #     predictions = np.array(ordered_predictions)
    #     metrics = self._compute_metrics(predictions, labels)
    #     if log_metrics:
    #         self._log_metrics(metrics)
    #     return metrics
    # else:
    #     raise NotImplementedError
    raise NotImplementedError