In [1]:
Copied!
import pandas as pd
import geopandas as gpd
from shapely.geometry import Polygon
from pytorch_lightning import seed_everything
from srai.embedders import GTFS2VecEmbedder
from srai.constants import REGIONS_INDEX
import pandas as pd
import geopandas as gpd
from shapely.geometry import Polygon
from pytorch_lightning import seed_everything
from srai.embedders import GTFS2VecEmbedder
from srai.constants import REGIONS_INDEX
Example on artificial data¶
Define features and regions¶
In [2]:
Copied!
features_gdf = gpd.GeoDataFrame(
{
"trip_count_at_6": [1, 0, 0],
"trip_count_at_7": [1, 1, 0],
"trip_count_at_8": [0, 0, 1],
"directions_at_6": [
{"A", "A1"},
{"B", "B1"},
{"C"},
],
},
geometry=gpd.points_from_xy([1, 2, 5], [1, 2, 2]),
index=pd.Index(name="stop_id", data=[1, 2, 3]),
)
features_gdf
features_gdf = gpd.GeoDataFrame(
{
"trip_count_at_6": [1, 0, 0],
"trip_count_at_7": [1, 1, 0],
"trip_count_at_8": [0, 0, 1],
"directions_at_6": [
{"A", "A1"},
{"B", "B1"},
{"C"},
],
},
geometry=gpd.points_from_xy([1, 2, 5], [1, 2, 2]),
index=pd.Index(name="stop_id", data=[1, 2, 3]),
)
features_gdf
Out[2]:
trip_count_at_6 | trip_count_at_7 | trip_count_at_8 | directions_at_6 | geometry | |
---|---|---|---|---|---|
stop_id | |||||
1 | 1 | 1 | 0 | {A, A1} | POINT (1.00000 1.00000) |
2 | 0 | 1 | 0 | {B1, B} | POINT (2.00000 2.00000) |
3 | 0 | 0 | 1 | {C} | POINT (5.00000 2.00000) |
In [3]:
Copied!
regions_gdf = gpd.GeoDataFrame(
geometry=[
Polygon([(0, 0), (0, 3), (3, 3), (3, 0)]),
Polygon([(4, 0), (4, 3), (7, 3), (7, 0)]),
Polygon([(8, 0), (8, 3), (11, 3), (11, 0)]),
],
index=pd.Index(name=REGIONS_INDEX, data=["ff1", "ff2", "ff3"]),
)
regions_gdf
regions_gdf = gpd.GeoDataFrame(
geometry=[
Polygon([(0, 0), (0, 3), (3, 3), (3, 0)]),
Polygon([(4, 0), (4, 3), (7, 3), (7, 0)]),
Polygon([(8, 0), (8, 3), (11, 3), (11, 0)]),
],
index=pd.Index(name=REGIONS_INDEX, data=["ff1", "ff2", "ff3"]),
)
regions_gdf
Out[3]:
geometry | |
---|---|
region_id | |
ff1 | POLYGON ((0.00000 0.00000, 0.00000 3.00000, 3.... |
ff2 | POLYGON ((4.00000 0.00000, 4.00000 3.00000, 7.... |
ff3 | POLYGON ((8.00000 0.00000, 8.00000 3.00000, 11... |
In [4]:
Copied!
ax = regions_gdf.plot()
features_gdf.plot(ax=ax, color="red")
ax = regions_gdf.plot()
features_gdf.plot(ax=ax, color="red")
Out[4]:
<Axes: >
In [5]:
Copied!
joint_gdf = gpd.GeoDataFrame()
joint_gdf.index = pd.MultiIndex.from_tuples(
[("ff1", 1), ("ff1", 2), ("ff2", 3)],
names=[REGIONS_INDEX, "stop_id"],
)
joint_gdf
joint_gdf = gpd.GeoDataFrame()
joint_gdf.index = pd.MultiIndex.from_tuples(
[("ff1", 1), ("ff1", 2), ("ff2", 3)],
names=[REGIONS_INDEX, "stop_id"],
)
joint_gdf
Out[5]:
region_id | stop_id |
---|---|
ff1 | 1 |
2 | |
ff2 | 3 |
Get features without embedding them¶
In [6]:
Copied!
embedder = GTFS2VecEmbedder(skip_autoencoder=True)
res = embedder.transform(regions_gdf, features_gdf, joint_gdf)
res
embedder = GTFS2VecEmbedder(skip_autoencoder=True)
res = embedder.transform(regions_gdf, features_gdf, joint_gdf)
res
Out[6]:
directions_at_6 | |
---|---|
region_id | |
ff1 | 1.00 |
ff2 | 0.25 |
ff3 | 0.00 |
Fit and train the embedder¶
In [7]:
Copied!
seed_everything(42)
embedder = GTFS2VecEmbedder(hidden_size=2, embedding_size=4)
embedder.fit(regions_gdf, features_gdf, joint_gdf)
res = embedder.transform(regions_gdf, features_gdf, joint_gdf)
res
seed_everything(42)
embedder = GTFS2VecEmbedder(hidden_size=2, embedding_size=4)
embedder.fit(regions_gdf, features_gdf, joint_gdf)
res = embedder.transform(regions_gdf, features_gdf, joint_gdf)
res
Global seed set to 42 /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. warnings.warn(_create_warning_msg( GPU available: False, used: False TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: UserWarning: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default warning_cache.warn( | Name | Type | Params --------------------------------------- 0 | encoder | Sequential | 16 1 | decoder | Sequential | 13 --------------------------------------- 29 Trainable params 0 Non-trainable params 29 Total params 0.000 Total estimated model params size (MB) /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. warnings.warn(_create_warning_msg( /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:280: PossibleUserWarning: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch. rank_zero_warn(
`Trainer.fit` stopped: `max_epochs=10` reached.
Out[7]:
0 | 1 | 2 | 3 | |
---|---|---|---|---|
region_id | ||||
ff1 | 0.657301 | 0.599207 | -0.188990 | 0.438122 |
ff2 | 0.663876 | 0.541362 | -0.220063 | 0.030094 |
ff3 | 0.636288 | 0.457780 | -0.115227 | 0.004630 |