Skip to content

GTFSLoader

Bases: Loader

GTFSLoader.

This loader is capable of reading GTFS feed and calculates time aggregations in 1H slots.

Source code in srai/loaders/gtfs_loader.py
class GTFSLoader(Loader):
    """
    GTFSLoader.

    This loader is capable of reading GTFS feed and calculates time aggregations in 1H slots.
    """

    def __init__(self) -> None:
        """Initialize GTFS loader."""
        import_optional_dependencies(dependency_group="gtfs", modules=["gtfs_kit"])

        self.time_resolution = "1H"

    def load(
        self,
        gtfs_file: Path,
        fail_on_validation_errors: bool = True,
        skip_validation: bool = False,
    ) -> gpd.GeoDataFrame:
        """
        Load GTFS feed and calculate time aggregations for stops.

        Args:
            gtfs_file (Path): Path to the GTFS feed.
            fail_on_validation_errors (bool): Fail if GTFS feed is invalid. Ignored when
                skip_validation is True.
            skip_validation (bool): Skip GTFS feed validation.

        Returns:
            gpd.GeoDataFrame: GeoDataFrame with trip counts and list of directions for stops.
        """
        import gtfs_kit as gk

        feed = gk.read_feed(gtfs_file, dist_units="km")

        if not skip_validation:
            self._validate_feed(feed, fail=fail_on_validation_errors)

        trips_df = self._load_trips(feed)
        directions_df = self._load_directions(feed)

        stops_df = feed.stops[["stop_id", "stop_lat", "stop_lon"]].set_index("stop_id")
        stops_df[GEOMETRY_COLUMN] = stops_df.apply(
            lambda row: Point([row["stop_lon"], row["stop_lat"]]), axis=1
        )

        result_gdf = gpd.GeoDataFrame(
            trips_df.merge(stops_df[GEOMETRY_COLUMN], how="inner", on="stop_id"),
            geometry=GEOMETRY_COLUMN,
            crs=WGS84_CRS,
        )

        result_gdf = result_gdf.merge(directions_df, how="left", on="stop_id")

        result_gdf.index.name = None

        return result_gdf

    def _load_trips(self, feed: "Feed") -> pd.DataFrame:
        """
        Load trips from GTFS feed.

        Calculate sum of trips from stop in each time slot.

        Args:
            feed (gk.Feed): GTFS feed.

        Returns:
            gpd.GeoDataFrame: GeoDataFrame with trips.
        """
        # FIXME: this takes first wednesday from the feed, may not be the best,
        # but that is what I did in gtfs2vec
        date = feed.get_first_week()[2]
        ts = feed.compute_stop_time_series([date], freq=self.time_resolution)

        records = []

        for idx, row in ts.iterrows():
            h = idx.hour
            for s, n in row["num_trips"].items():
                records.append((s, h, n))

        df = pd.DataFrame(records, columns=["stop_id", "hour", "num_trips"])
        df = df.pivot_table(index="stop_id", columns="hour", values="num_trips", fill_value=0)
        df = df.add_prefix(GTFS2VEC_TRIPS_PREFIX)

        return df

    def _load_directions(self, feed: "Feed") -> gpd.GeoDataFrame:
        """
        Load directions from GTFS feed.

        Create a list of unique directions for each stop and time slot.

        Args:
            feed (gk.Feed): GTFS feed.

        Returns:
            gpd.GeoDataFrame: GeoDataFrame with directions.
        """
        df = feed.stop_times.merge(feed.trips, on="trip_id")
        df = df.merge(feed.stops, on="stop_id")

        df = df[df["departure_time"].notna()]

        df["hour"] = df["departure_time"].apply(self._parse_departure_time)

        pivoted = df.pivot_table(
            values="trip_headsign", index="stop_id", columns="hour", aggfunc=set
        )
        pivoted = pivoted.add_prefix(GTFS2VEC_DIRECTIONS_PREFIX)

        return pivoted

    def _validate_feed(self, feed: "Feed", fail: bool = True) -> None:
        """
        Validate GTFS feed.

        Args:
            feed (gk.Feed): GTFS feed.
            fail (bool): Fail if feed is invalid.
        """
        validation_result = feed.validate()

        if (validation_result["type"] == "error").sum() > 0:
            import warnings

            warnings.warn(f"Invalid GTFS feed: \n{validation_result}", RuntimeWarning, stacklevel=2)
            if fail:
                raise ValueError("Invalid GTFS feed.")

    def _parse_departure_time(self, departure_time: str) -> int:
        """
        Parse departure time and extract hour from it.

        In GTFS feed, departure time is in format HH:MM:SS. HH can be greater than 24, so
        we need to parse it to 0-23 range.

        Args:
            departure_time (str): Departure time in format HH:MM:SS.

        Returns:
            int: Departure time in hours.
        """
        return int(departure_time[:2].replace(":", "")) % 24

__init__

__init__() -> None

Initialize GTFS loader.

Source code in srai/loaders/gtfs_loader.py
def __init__(self) -> None:
    """Initialize GTFS loader."""
    import_optional_dependencies(dependency_group="gtfs", modules=["gtfs_kit"])

    self.time_resolution = "1H"

load

load(gtfs_file: Path, fail_on_validation_errors: bool = True, skip_validation: bool = False) -> gpd.GeoDataFrame

Load GTFS feed and calculate time aggregations for stops.

PARAMETER DESCRIPTION
gtfs_file

Path to the GTFS feed.

TYPE: Path

fail_on_validation_errors

Fail if GTFS feed is invalid. Ignored when skip_validation is True.

TYPE: bool DEFAULT: True

skip_validation

Skip GTFS feed validation.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
gpd.GeoDataFrame

gpd.GeoDataFrame: GeoDataFrame with trip counts and list of directions for stops.

Source code in srai/loaders/gtfs_loader.py
def load(
    self,
    gtfs_file: Path,
    fail_on_validation_errors: bool = True,
    skip_validation: bool = False,
) -> gpd.GeoDataFrame:
    """
    Load GTFS feed and calculate time aggregations for stops.

    Args:
        gtfs_file (Path): Path to the GTFS feed.
        fail_on_validation_errors (bool): Fail if GTFS feed is invalid. Ignored when
            skip_validation is True.
        skip_validation (bool): Skip GTFS feed validation.

    Returns:
        gpd.GeoDataFrame: GeoDataFrame with trip counts and list of directions for stops.
    """
    import gtfs_kit as gk

    feed = gk.read_feed(gtfs_file, dist_units="km")

    if not skip_validation:
        self._validate_feed(feed, fail=fail_on_validation_errors)

    trips_df = self._load_trips(feed)
    directions_df = self._load_directions(feed)

    stops_df = feed.stops[["stop_id", "stop_lat", "stop_lon"]].set_index("stop_id")
    stops_df[GEOMETRY_COLUMN] = stops_df.apply(
        lambda row: Point([row["stop_lon"], row["stop_lat"]]), axis=1
    )

    result_gdf = gpd.GeoDataFrame(
        trips_df.merge(stops_df[GEOMETRY_COLUMN], how="inner", on="stop_id"),
        geometry=GEOMETRY_COLUMN,
        crs=WGS84_CRS,
    )

    result_gdf = result_gdf.merge(directions_df, how="left", on="stop_id")

    result_gdf.index.name = None

    return result_gdf