Skip to content

GTFSLoader

srai.loaders.GTFSLoader()

Bases: Loader

GTFSLoader.

This loader is capable of reading GTFS feed and calculates time aggregations in 1H slots.

Source code in srai/loaders/gtfs_loader.py
def __init__(self) -> None:
    """Initialize GTFS loader."""
    import_optional_dependencies(dependency_group="gtfs", modules=["gtfs_kit"])

    self.time_resolution = "1H"

load(
    gtfs_file,
    fail_on_validation_errors=True,
    skip_validation=False,
)

Load GTFS feed and calculate time aggregations for stops.

PARAMETER DESCRIPTION
gtfs_file

Path to the GTFS feed.

TYPE: Path

fail_on_validation_errors

Fail if GTFS feed is invalid. Ignored when skip_validation is True.

TYPE: bool DEFAULT: True

skip_validation

Skip GTFS feed validation.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
GeoDataFrame

gpd.GeoDataFrame: GeoDataFrame with trip counts and list of directions for stops.

Source code in srai/loaders/gtfs_loader.py
def load(
    self,
    gtfs_file: Path,
    fail_on_validation_errors: bool = True,
    skip_validation: bool = False,
) -> gpd.GeoDataFrame:
    """
    Load GTFS feed and calculate time aggregations for stops.

    Args:
        gtfs_file (Path): Path to the GTFS feed.
        fail_on_validation_errors (bool): Fail if GTFS feed is invalid. Ignored when
            skip_validation is True.
        skip_validation (bool): Skip GTFS feed validation.

    Returns:
        gpd.GeoDataFrame: GeoDataFrame with trip counts and list of directions for stops.
    """
    import gtfs_kit as gk

    feed = gk.read_feed(gtfs_file, dist_units="km")

    if not skip_validation:
        self._validate_feed(feed, fail=fail_on_validation_errors)

    trips_df = self._load_trips(feed)
    directions_df = self._load_directions(feed)

    stops_df = feed.stops[["stop_id", "stop_lat", "stop_lon"]].set_index("stop_id")
    stops_df[GEOMETRY_COLUMN] = stops_df.apply(
        lambda row: Point([row["stop_lon"], row["stop_lat"]]), axis=1
    )

    result_gdf = gpd.GeoDataFrame(
        trips_df.merge(stops_df[GEOMETRY_COLUMN], how="inner", on="stop_id"),
        geometry=GEOMETRY_COLUMN,
        crs=WGS84_CRS,
    )

    result_gdf = result_gdf.merge(directions_df, how="left", on="stop_id")

    result_gdf.index.name = FEATURES_INDEX

    return result_gdf