Highway2Vec Embedder¶
In [1]:
Copied!
from IPython.display import display
from srai.plotting import plot_numeric_data, plot_regions
from IPython.display import display
from srai.plotting import plot_numeric_data, plot_regions
Get an area to embed¶
In [2]:
Copied!
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Wrocław, PL")
plot_regions(area_gdf, tiles_style="CartoDB positron")
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Wrocław, PL")
plot_regions(area_gdf, tiles_style="CartoDB positron")
Out[2]:
Make this Notebook Trusted to load map: File -> Trust Notebook
Regionalize the area with a regionalizer¶
In [3]:
Copied!
from srai.regionalizers import H3Regionalizer
regionalizer = H3Regionalizer(9)
regions_gdf = regionalizer.transform(area_gdf)
print(len(regions_gdf))
display(regions_gdf.head(3))
plot_regions(regions_gdf, tiles_style="CartoDB positron")
from srai.regionalizers import H3Regionalizer
regionalizer = H3Regionalizer(9)
regions_gdf = regionalizer.transform(area_gdf)
print(len(regions_gdf))
display(regions_gdf.head(3))
plot_regions(regions_gdf, tiles_style="CartoDB positron")
3168
geometry | |
---|---|
region_id | |
891e2040e43ffff | POLYGON ((16.98227 51.09526, 16.98231 51.09357... |
891e204e413ffff | POLYGON ((17.02548 51.06877, 17.02552 51.06708... |
891e2040d7bffff | POLYGON ((17.04212 51.11219, 17.04215 51.11051... |
Out[3]:
Make this Notebook Trusted to load map: File -> Trust Notebook
Download a road infrastructure for the area¶
In [4]:
Copied!
from srai.loaders import OSMNetworkType, OSMWayLoader
loader = OSMWayLoader(OSMNetworkType.DRIVE)
nodes_gdf, edges_gdf = loader.load(area_gdf)
display(nodes_gdf.head(3))
display(edges_gdf.head(3))
ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))
nodes_gdf.plot(ax=ax, markersize=3, color="red")
from srai.loaders import OSMNetworkType, OSMWayLoader
loader = OSMWayLoader(OSMNetworkType.DRIVE)
nodes_gdf, edges_gdf = loader.load(area_gdf)
display(nodes_gdf.head(3))
display(edges_gdf.head(3))
ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))
nodes_gdf.plot(ax=ax, markersize=3, color="red")
/root/development/srai/srai/loaders/osm_way_loader/osm_way_loader.py:229: FutureWarning: The clean_periphery argument has been deprecated and will be removed in the v2.0.0 release. Future behavior will be as though clean_periphery=True. See the OSMnx v2 migration guide: https://github.com/gboeing/osmnx/issues/1123 G_directed = ox.graph_from_polygon(
/root/development/srai/.venv/lib/python3.10/site-packages/osmnx/_overpass.py:350: FutureWarning: `settings.timeout` is deprecated and will be removed in the v2.0.0 release: use `settings.requests_timeout` instead. See the OSMnx v2 migration guide: https://github.com/gboeing/osmnx/issues/1123 overpass_settings = _make_overpass_settings() /root/development/srai/.venv/lib/python3.10/site-packages/osmnx/_overpass.py:360: FutureWarning: `settings.timeout` is deprecated and will be removed in the v2.0.0 release: use `settings.requests_timeout` instead. See the OSMnx v2 migration guide: https://github.com/gboeing/osmnx/issues/1123 yield _overpass_request(data={"data": query_str}) /root/development/srai/.venv/lib/python3.10/site-packages/osmnx/_overpass.py:442: FutureWarning: `settings.timeout` is deprecated and will be removed in the v2.0.0 release: use `settings.requests_timeout` instead. See the OSMnx v2 migration guide: https://github.com/gboeing/osmnx/issues/1123 this_pause = _get_overpass_pause(overpass_endpoint)
/root/development/srai/srai/loaders/osm_way_loader/osm_way_loader.py:237: FutureWarning: The `get_undirected` function is deprecated and will be removed in the v2.0.0 release. Replace it with `convert.to_undirected` instead. See the OSMnx v2 migration guide: https://github.com/gboeing/osmnx/issues/1123 G_undirected = ox.utils_graph.get_undirected(G_directed)
y | x | street_count | highway | ref | geometry | |
---|---|---|---|---|---|---|
osmid | ||||||
95584835 | 51.083111 | 17.049513 | 4 | NaN | NaN | POINT (17.04951 51.08311) |
95584841 | 51.084699 | 17.064367 | 3 | NaN | NaN | POINT (17.06437 51.08470) |
95584850 | 51.083328 | 17.035057 | 4 | NaN | NaN | POINT (17.03506 51.08333) |
oneway | lanes-1 | lanes-2 | lanes-3 | lanes-4 | lanes-5 | lanes-6 | lanes-7 | lanes-8 | lanes-9 | ... | bicycle-official | lit-yes | lit-no | lit-sunset-sunrise | lit-24/7 | lit-automatic | lit-disused | lit-limited | lit-interval | geometry | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
feature_id | |||||||||||||||||||||
0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.04951 51.08311, 17.04947 51.083... |
1 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.04951 51.08311, 17.04933 51.083... |
2 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.05357 51.08301, 17.05335 51.082... |
3 rows × 219 columns
Out[4]:
<Axes: >
Find out which edges correspond to which regions¶
In [5]:
Copied!
from srai.joiners import IntersectionJoiner
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, edges_gdf)
joint_gdf
from srai.joiners import IntersectionJoiner
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, edges_gdf)
joint_gdf
Out[5]:
region_id | feature_id |
---|---|
891e2040e43ffff | 1865 |
1012 | |
1013 | |
8145 | |
8144 | |
... | ... |
891e204087bffff | 1261 |
1262 | |
4294 | |
1299 | |
9014 |
15647 rows × 0 columns
Get the embeddings for regions based on the road infrastructure¶
In [6]:
Copied!
from pytorch_lightning import seed_everything
from srai.embedders import Highway2VecEmbedder
seed_everything(42)
embedder = Highway2VecEmbedder()
embeddings = embedder.fit_transform(regions_gdf, edges_gdf, joint_gdf)
embeddings
from pytorch_lightning import seed_everything
from srai.embedders import Highway2VecEmbedder
seed_everything(42)
embedder = Highway2VecEmbedder()
embeddings = embedder.fit_transform(regions_gdf, edges_gdf, joint_gdf)
embeddings
Seed set to 42
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/root/development/srai/.venv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
| Name | Type | Params --------------------------------------- 0 | encoder | Sequential | 16.0 K 1 | decoder | Sequential | 16.2 K --------------------------------------- 32.1 K Trainable params 0 Non-trainable params 32.1 K Total params 0.128 Total estimated model params size (MB)
/root/development/srai/.venv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
`Trainer.fit` stopped: `max_epochs=10` reached.
/root/development/srai/srai/embedders/highway2vec/embedder.py:75: FutureWarning: The behavior of array concatenation with empty entries is deprecated. In a future version, this will no longer exclude empty items when determining the result dtype. To retain the old behavior, exclude the empty entries before the concat operation. embeddings_joint = joint_gdf.join(embeddings_df)
Out[6]:
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
891e2040007ffff | -0.504556 | -0.155268 | 0.248344 | -0.168650 | 0.372300 | -0.108547 | 0.060736 | -0.567865 | -0.246932 | 0.704372 | ... | 0.351554 | -0.545886 | 0.100016 | 0.071949 | 0.036495 | 0.044149 | 0.138407 | 0.349743 | 0.143677 | 0.377119 |
891e2040013ffff | -0.063886 | -0.108321 | 0.347798 | -0.290935 | 0.602296 | -0.059722 | -0.044114 | -0.559339 | -0.445059 | 0.776166 | ... | 0.484868 | -0.795263 | 0.164593 | -0.346911 | 0.031975 | 0.071917 | 0.416641 | 0.402075 | 0.191143 | 0.371276 |
891e2040017ffff | -0.228718 | -0.275933 | 0.255260 | -0.190627 | 0.432342 | -0.215445 | 0.133701 | -0.700056 | -0.202570 | 0.596845 | ... | 0.287095 | -0.573274 | 0.138941 | -0.006875 | 0.066737 | 0.056983 | 0.448730 | 0.247664 | 0.155226 | 0.399925 |
891e2040023ffff | -0.234212 | -0.375628 | 0.367217 | -0.050228 | 0.514135 | -0.154354 | 0.117761 | -0.352391 | -0.092301 | 0.592975 | ... | 0.164884 | -0.505890 | 0.135076 | 0.015516 | -0.047556 | 0.226682 | 0.383343 | 0.077571 | -0.039865 | 0.448460 |
891e2040027ffff | -0.575319 | -0.284612 | 0.151406 | -0.326954 | 0.650756 | 0.059048 | 0.051050 | -0.612110 | -0.325108 | 0.762787 | ... | 0.046838 | -0.616633 | 0.154819 | -0.128314 | -0.016028 | 0.298513 | 0.157719 | 0.189196 | 0.009537 | 0.601834 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
891e2055bc3ffff | -0.454497 | -0.025613 | 0.235932 | -0.094423 | 0.557578 | -0.266869 | 0.049226 | -0.635413 | -0.269236 | 0.767842 | ... | 0.591636 | -0.581389 | -0.040448 | -0.234895 | -0.278276 | 0.281510 | 0.030581 | 0.357593 | 0.298373 | 0.764828 |
891e2055bc7ffff | -0.257726 | 0.119923 | 0.215199 | -0.174138 | 0.857351 | -0.446207 | 0.197454 | -0.412605 | -0.361281 | 0.772351 | ... | 0.420934 | -0.605371 | 0.024094 | -0.226583 | -0.369977 | 0.297929 | -0.087533 | 0.319609 | 0.256416 | 0.522280 |
891e2055bcbffff | -0.635451 | -0.077627 | 0.333176 | -0.028319 | 0.522318 | -0.090063 | 0.017219 | -0.862788 | -0.221086 | 0.758205 | ... | 0.803460 | -0.498680 | -0.225858 | -0.347445 | -0.263835 | 0.403450 | 0.074003 | 0.436040 | 0.425636 | 1.040164 |
891e205a967ffff | -0.648195 | 0.104950 | 0.235534 | -0.266269 | 0.352077 | 0.191819 | -0.106896 | -0.260897 | -0.384404 | 0.531822 | ... | -0.268065 | -0.331405 | 0.331181 | -0.005222 | 0.203001 | 0.011407 | 0.175702 | 0.221664 | 0.043823 | 0.566290 |
891e205a9a7ffff | -1.458649 | -1.413952 | 0.511791 | -1.238288 | 0.946567 | -0.871188 | -0.053677 | -0.542523 | -0.548428 | 1.703901 | ... | 0.655956 | -0.768179 | -0.194526 | -0.075900 | -0.438048 | 0.374782 | 0.549085 | 0.899736 | 0.075891 | 1.142536 |
2035 rows × 30 columns
In [7]:
Copied!
from sklearn.cluster import KMeans
clusterizer = KMeans(n_clusters=5, random_state=42)
clusterizer.fit(embeddings)
embeddings["cluster"] = clusterizer.labels_
embeddings
from sklearn.cluster import KMeans
clusterizer = KMeans(n_clusters=5, random_state=42)
clusterizer.fit(embeddings)
embeddings["cluster"] = clusterizer.labels_
embeddings
Out[7]:
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | cluster | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
891e2040007ffff | -0.504556 | -0.155268 | 0.248344 | -0.168650 | 0.372300 | -0.108547 | 0.060736 | -0.567865 | -0.246932 | 0.704372 | ... | -0.545886 | 0.100016 | 0.071949 | 0.036495 | 0.044149 | 0.138407 | 0.349743 | 0.143677 | 0.377119 | 3 |
891e2040013ffff | -0.063886 | -0.108321 | 0.347798 | -0.290935 | 0.602296 | -0.059722 | -0.044114 | -0.559339 | -0.445059 | 0.776166 | ... | -0.795263 | 0.164593 | -0.346911 | 0.031975 | 0.071917 | 0.416641 | 0.402075 | 0.191143 | 0.371276 | 0 |
891e2040017ffff | -0.228718 | -0.275933 | 0.255260 | -0.190627 | 0.432342 | -0.215445 | 0.133701 | -0.700056 | -0.202570 | 0.596845 | ... | -0.573274 | 0.138941 | -0.006875 | 0.066737 | 0.056983 | 0.448730 | 0.247664 | 0.155226 | 0.399925 | 0 |
891e2040023ffff | -0.234212 | -0.375628 | 0.367217 | -0.050228 | 0.514135 | -0.154354 | 0.117761 | -0.352391 | -0.092301 | 0.592975 | ... | -0.505890 | 0.135076 | 0.015516 | -0.047556 | 0.226682 | 0.383343 | 0.077571 | -0.039865 | 0.448460 | 4 |
891e2040027ffff | -0.575319 | -0.284612 | 0.151406 | -0.326954 | 0.650756 | 0.059048 | 0.051050 | -0.612110 | -0.325108 | 0.762787 | ... | -0.616633 | 0.154819 | -0.128314 | -0.016028 | 0.298513 | 0.157719 | 0.189196 | 0.009537 | 0.601834 | 3 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
891e2055bc3ffff | -0.454497 | -0.025613 | 0.235932 | -0.094423 | 0.557578 | -0.266869 | 0.049226 | -0.635413 | -0.269236 | 0.767842 | ... | -0.581389 | -0.040448 | -0.234895 | -0.278276 | 0.281510 | 0.030581 | 0.357593 | 0.298373 | 0.764828 | 2 |
891e2055bc7ffff | -0.257726 | 0.119923 | 0.215199 | -0.174138 | 0.857351 | -0.446207 | 0.197454 | -0.412605 | -0.361281 | 0.772351 | ... | -0.605371 | 0.024094 | -0.226583 | -0.369977 | 0.297929 | -0.087533 | 0.319609 | 0.256416 | 0.522280 | 2 |
891e2055bcbffff | -0.635451 | -0.077627 | 0.333176 | -0.028319 | 0.522318 | -0.090063 | 0.017219 | -0.862788 | -0.221086 | 0.758205 | ... | -0.498680 | -0.225858 | -0.347445 | -0.263835 | 0.403450 | 0.074003 | 0.436040 | 0.425636 | 1.040164 | 1 |
891e205a967ffff | -0.648195 | 0.104950 | 0.235534 | -0.266269 | 0.352077 | 0.191819 | -0.106896 | -0.260897 | -0.384404 | 0.531822 | ... | -0.331405 | 0.331181 | -0.005222 | 0.203001 | 0.011407 | 0.175702 | 0.221664 | 0.043823 | 0.566290 | 3 |
891e205a9a7ffff | -1.458649 | -1.413952 | 0.511791 | -1.238288 | 0.946567 | -0.871188 | -0.053677 | -0.542523 | -0.548428 | 1.703901 | ... | -0.768179 | -0.194526 | -0.075900 | -0.438048 | 0.374782 | 0.549085 | 0.899736 | 0.075891 | 1.142536 | 1 |
2035 rows × 31 columns
In [8]:
Copied!
plot_numeric_data(regions_gdf, "cluster", embeddings)
plot_numeric_data(regions_gdf, "cluster", embeddings)
Out[8]:
Make this Notebook Trusted to load map: File -> Trust Notebook