base
datasets._base ¶
Base classes for Datasets.
HuggingFaceDataset ¶
HuggingFaceDataset(
path: str,
version: Optional[str] = None,
type: Optional[str] = None,
numerical_columns: Optional[list[str]] = None,
categorical_columns: Optional[list[str]] = None,
target: Optional[str] = None,
resolution: Optional[int] = None,
)
Bases: ABC
Abstract class for HuggingFace datasets.
Source code in srai/datasets/_base.py
get_h3_with_labels ¶
abstractmethod
get_h3_with_labels() -> (
tuple[
gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]
]
)
Returns indexes with target labels from the dataset depending on dataset and task type.
RETURNS | DESCRIPTION |
---|---|
tuple[GeoDataFrame, Optional[GeoDataFrame], Optional[GeoDataFrame]]
|
tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]: Train, Val, Test indexes with target labels in GeoDataFrames |
Source code in srai/datasets/_base.py
load ¶
load(
version: Optional[Union[int, str]] = None, hf_token: Optional[str] = None
) -> dict[str, gpd.GeoDataFrame]
Method to load dataset.
PARAMETER | DESCRIPTION |
---|---|
hf_token
|
If needed, a User Access Token needed to authenticate to
the Hugging Face Hub. Environment variable
TYPE:
|
version
|
version of a dataset
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
dict[str, GeoDataFrame]
|
dict[str, gpd.GeoDataFrame]: Dictionary with all splits loaded from the dataset. Will contain keys "train" and "test" if available. |
Source code in srai/datasets/_base.py
train_test_split ¶
abstractmethod
train_test_split(
target_column: Optional[str] = None,
resolution: Optional[int] = None,
test_size: float = 0.2,
n_bins: int = 7,
random_state: Optional[int] = None,
validation_split: bool = False,
force_split: bool = False,
task: Optional[str] = None,
) -> tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]
Method to generate train/test or train/val split from GeoDataFrame.
PARAMETER | DESCRIPTION |
---|---|
target_column
|
Target column name for Points, trajectories id column fortrajectory datasets. Defaults to preset dataset target column.
TYPE:
|
resolution
|
H3 resolution, subclasses mayb use this argument to regionalize data. Defaults to default value from the dataset.
TYPE:
|
test_size
|
Percentage of test set. Defaults to 0.2.
TYPE:
|
n_bins
|
Bucket number used to stratify target data.
TYPE:
|
random_state
|
Controls the shuffling applied to the data before applying the split. Pass an int for reproducible output across multiple function. Defaults to None.
TYPE:
|
validation_split
|
If True, creates a validation split from existing train split and assigns it to self.val_gdf.
TYPE:
|
force_split
|
If True, forces a new split to be created, even if an existing train/test or validation split is already present.
- With
TYPE:
|
task
|
Task identifier. Subclasses may use this argument to determine stratification logic (e.g., by duration or spatial pattern). Defaults to None.
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
tuple
|
Train-test or Train-val split made on previous train subset.
TYPE:
|
Source code in srai/datasets/_base.py
PointDataset ¶
PointDataset(
path: str,
version: Optional[str] = None,
type: Optional[str] = None,
numerical_columns: Optional[list[str]] = None,
categorical_columns: Optional[list[str]] = None,
target: Optional[str] = None,
resolution: Optional[int] = None,
)
Bases: HuggingFaceDataset
Abstract class for HuggingFace datasets with Point Data.
Source code in srai/datasets/_base.py
get_h3_with_labels ¶
get_h3_with_labels() -> (
tuple[
gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]
]
)
Returns h3 indexes with target labels from the dataset.
Points are aggregated to hexes and target column values are averaged or if target column is None, then the number of points is calculted within a hex and scaled to [0,1].
RETURNS | DESCRIPTION |
---|---|
tuple[GeoDataFrame, Optional[GeoDataFrame], Optional[GeoDataFrame]]
|
tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]: Train, Val, Test hexes with target labels in GeoDataFrames |
Source code in srai/datasets/_base.py
load ¶
load(
version: Optional[Union[int, str]] = None, hf_token: Optional[str] = None
) -> dict[str, gpd.GeoDataFrame]
Method to load dataset.
PARAMETER | DESCRIPTION |
---|---|
hf_token
|
If needed, a User Access Token needed to authenticate to
the Hugging Face Hub. Environment variable
TYPE:
|
version
|
version of a dataset
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
dict[str, GeoDataFrame]
|
dict[str, gpd.GeoDataFrame]: Dictionary with all splits loaded from the dataset. Will contain keys "train" and "test" if available. |
Source code in srai/datasets/_base.py
train_test_split ¶
train_test_split(
target_column: Optional[str] = None,
resolution: Optional[int] = None,
test_size: float = 0.2,
n_bins: int = 7,
random_state: Optional[int] = None,
validation_split: bool = False,
force_split: bool = False,
task: Optional[str] = None,
) -> tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]
Method to generate splits from GeoDataFrame, based on the target_column values.
PARAMETER | DESCRIPTION |
---|---|
target_column
|
Target column name. If None, split is generated based on number of points within a hex of a given resolution. Defaults to preset dataset target column.
TYPE:
|
resolution
|
h3 resolution to regionalize data. Defaults to default value from the dataset.
TYPE:
|
test_size
|
Percentage of test set. Defaults to 0.2.
TYPE:
|
n_bins
|
Bucket number used to stratify target data. Defaults to 7.
TYPE:
|
random_state
|
Controls the shuffling applied to the data before applying the split. Pass an int for reproducible output across multiple function. Defaults to None.
TYPE:
|
validation_split
|
If True, creates a validation split from existing train split and assigns it to self.val_gdf.
TYPE:
|
force_split
|
If True, forces a new split to be created, even if an existing train/test or validation split is already present.
- With
TYPE:
|
task
|
Currently not supported. Ignored in this subclass.
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
tuple
|
Train-test or train-val split made on previous train subset.
TYPE:
|
Source code in srai/datasets/_base.py
202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 |
|
TrajectoryDataset ¶
TrajectoryDataset(
path: str,
version: Optional[str] = None,
type: Optional[str] = None,
numerical_columns: Optional[list[str]] = None,
categorical_columns: Optional[list[str]] = None,
target: Optional[str] = None,
resolution: Optional[int] = None,
)
Bases: HuggingFaceDataset
Abstract class for HuggingFace datasets with Trajectory data.
Source code in srai/datasets/_base.py
get_h3_with_labels ¶
get_h3_with_labels() -> (
tuple[
gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]
]
)
Returns ids, h3 indexes sequences, with target labels from the dataset.
Points are aggregated to hex trajectories and target column values are calculated for each trajectory (time duration for TTE task, future movement sequence for HMP task).
RETURNS | DESCRIPTION |
---|---|
tuple[GeoDataFrame, Optional[GeoDataFrame], Optional[GeoDataFrame]]
|
tuple[gpd.GeoDataFrame, Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]: Train, Val, Test hexes sequences with target labels in GeoDataFrames |
Source code in srai/datasets/_base.py
load ¶
load(
version: Optional[Union[int, str]] = None, hf_token: Optional[str] = None
) -> dict[str, gpd.GeoDataFrame]
Method to load dataset.
PARAMETER | DESCRIPTION |
---|---|
hf_token
|
If needed, a User Access Token needed to authenticate to
the Hugging Face Hub. Environment variable
TYPE:
|
version
|
version of a dataset
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
dict[str, GeoDataFrame]
|
dict[str, gpd.GeoDataFrame]: Dictionary with all splits loaded from the dataset. Will contain keys "train" and "test" if available. |
Source code in srai/datasets/_base.py
train_test_split ¶
train_test_split(
target_column: Optional[str] = None,
resolution: Optional[int] = None,
test_size: float = 0.2,
n_bins: int = 4,
random_state: Optional[int] = None,
validation_split: bool = False,
force_split: bool = False,
task: Optional[str] = "TTE",
) -> tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]
Generate train/test split or train/val split from trajectory GeoDataFrame.
Train-test/train-val split is generated by splitting train_gdf.
PARAMETER | DESCRIPTION |
---|---|
target_column
|
Column identifying each trajectory (contains trajectory ids).
TYPE:
|
test_size
|
Fraction of data to be used as test set.
TYPE:
|
n_bins
|
Number of stratification bins.
TYPE:
|
random_state
|
Controls the shuffling applied to the data before applying the split. Pass an int for reproducible output across multiple function. Defaults to None.
TYPE:
|
validation_split
|
If True, creates a validation split from existing train split and assigns it to self.val_gdf.
TYPE:
|
force_split
|
If True, forces a new split to be created, even if an existing train/test or validation split is already present.
- With
TYPE:
|
resolution
|
H3 resolution to regionalize data. Currently ignored in this subclass, different resolutions splits not supported yet. Defaults to default value from the dataset.
TYPE:
|
task
|
Task type. Stratifies by duration (TTE) or hex length (HMP).
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
tuple[GeoDataFrame, GeoDataFrame]
|
Tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]: Train/test or train/val GeoDataFrames. |
Source code in srai/datasets/_base.py
429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 |
|