Highway2Vec Embedder¶
In [1]:
Copied!
from IPython.display import display
from srai.plotting import plot_numeric_data, plot_regions
from IPython.display import display
from srai.plotting import plot_numeric_data, plot_regions
Get an area to embed¶
In [2]:
Copied!
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Wrocław, PL")
plot_regions(area_gdf, tiles_style="CartoDB positron")
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Wrocław, PL")
plot_regions(area_gdf, tiles_style="CartoDB positron")
Out[2]:
Make this Notebook Trusted to load map: File -> Trust Notebook
Regionalize the area with a regionalizer¶
In [3]:
Copied!
from srai.regionalizers import H3Regionalizer
regionalizer = H3Regionalizer(9)
regions_gdf = regionalizer.transform(area_gdf)
print(len(regions_gdf))
display(regions_gdf.head(3))
plot_regions(regions_gdf, tiles_style="CartoDB positron")
from srai.regionalizers import H3Regionalizer
regionalizer = H3Regionalizer(9)
regions_gdf = regionalizer.transform(area_gdf)
print(len(regions_gdf))
display(regions_gdf.head(3))
plot_regions(regions_gdf, tiles_style="CartoDB positron")
3168
geometry | |
---|---|
region_id | |
891e20403bbffff | POLYGON ((16.962 51.10474, 16.96203 51.10306, ... |
891e204e58bffff | POLYGON ((17.05801 51.07226, 17.05805 51.07057... |
891e2050a63ffff | POLYGON ((16.88712 51.1958, 16.88716 51.19412,... |
Out[3]:
Make this Notebook Trusted to load map: File -> Trust Notebook
Download a road infrastructure for the area¶
In [4]:
Copied!
from srai.loaders import OSMNetworkType, OSMWayLoader
loader = OSMWayLoader(OSMNetworkType.DRIVE)
nodes_gdf, edges_gdf = loader.load(area_gdf)
display(nodes_gdf.head(3))
display(edges_gdf.head(3))
ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))
nodes_gdf.plot(ax=ax, markersize=3, color="red")
from srai.loaders import OSMNetworkType, OSMWayLoader
loader = OSMWayLoader(OSMNetworkType.DRIVE)
nodes_gdf, edges_gdf = loader.load(area_gdf)
display(nodes_gdf.head(3))
display(edges_gdf.head(3))
ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))
nodes_gdf.plot(ax=ax, markersize=3, color="red")
y | x | street_count | highway | railway | ref | geometry | |
---|---|---|---|---|---|---|---|
osmid | |||||||
95584835 | 51.083111 | 17.049513 | 4 | NaN | NaN | NaN | POINT (17.04951 51.08311) |
95584841 | 51.084699 | 17.064367 | 3 | NaN | NaN | NaN | POINT (17.06437 51.0847) |
95584850 | 51.083328 | 17.035057 | 4 | NaN | NaN | NaN | POINT (17.03506 51.08333) |
oneway | lanes-1 | lanes-2 | lanes-3 | lanes-4 | lanes-5 | lanes-6 | lanes-7 | lanes-8 | lanes-9 | ... | bicycle-official | lit-yes | lit-no | lit-sunset-sunrise | lit-24/7 | lit-automatic | lit-disused | lit-limited | lit-interval | geometry | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
feature_id | |||||||||||||||||||||
0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.04951 51.08311, 17.04947 51.083... |
1 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.04951 51.08311, 17.04933 51.083... |
2 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.05357 51.08301, 17.05335 51.082... |
3 rows × 219 columns
Out[4]:
<Axes: >
Find out which edges correspond to which regions¶
In [5]:
Copied!
from srai.joiners import IntersectionJoiner
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, edges_gdf)
joint_gdf
from srai.joiners import IntersectionJoiner
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, edges_gdf)
joint_gdf
Out[5]:
region_id | feature_id |
---|---|
891e2040b03ffff | 0 |
1 | |
2 | |
891e2040b07ffff | 2 |
891e2040b03ffff | 3 |
... | ... |
891e2047267ffff | 10312 |
891e20403dbffff | 10313 |
891e2047307ffff | 10314 |
10315 | |
10316 |
15688 rows × 0 columns
Get the embeddings for regions based on the road infrastructure¶
In [6]:
Copied!
from pytorch_lightning import seed_everything
from srai.embedders import Highway2VecEmbedder
seed_everything(42)
embedder = Highway2VecEmbedder()
embeddings = embedder.fit_transform(regions_gdf, edges_gdf, joint_gdf)
embeddings
from pytorch_lightning import seed_everything
from srai.embedders import Highway2VecEmbedder
seed_everything(42)
embedder = Highway2VecEmbedder()
embeddings = embedder.fit_transform(regions_gdf, edges_gdf, joint_gdf)
embeddings
Seed set to 42
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/opt/hostedtoolcache/Python/3.10.16/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default | Name | Type | Params | Mode ----------------------------------------------- 0 | encoder | Sequential | 16.0 K | train 1 | decoder | Sequential | 16.2 K | train ----------------------------------------------- 32.1 K Trainable params 0 Non-trainable params 32.1 K Total params 0.128 Total estimated model params size (MB) 8 Modules in train mode 0 Modules in eval mode
/opt/hostedtoolcache/Python/3.10.16/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
`Trainer.fit` stopped: `max_epochs=10` reached.
/opt/hostedtoolcache/Python/3.10.16/x64/lib/python3.10/site-packages/srai/embedders/highway2vec/embedder.py:75: FutureWarning: The behavior of array concatenation with empty entries is deprecated. In a future version, this will no longer exclude empty items when determining the result dtype. To retain the old behavior, exclude the empty entries before the concat operation. embeddings_joint = joint_gdf.join(embeddings_df)
Out[6]:
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
891e2040007ffff | -0.518636 | -0.150762 | 0.099129 | -0.128340 | 0.374203 | -0.171351 | 0.041406 | -0.597257 | -0.195334 | 0.766336 | ... | 0.356429 | -0.536595 | 0.197296 | 0.073133 | -0.009141 | 0.007996 | 0.163186 | 0.358240 | 0.068875 | 0.389990 |
891e2040013ffff | -0.092134 | -0.089167 | 0.161766 | -0.252995 | 0.625243 | -0.048758 | -0.005048 | -0.604901 | -0.377803 | 0.827473 | ... | 0.408355 | -0.701997 | 0.276326 | -0.338328 | -0.026943 | 0.038455 | 0.445025 | 0.350000 | 0.059989 | 0.359978 |
891e2040017ffff | -0.269235 | -0.247315 | 0.095021 | -0.159002 | 0.452461 | -0.194560 | 0.155457 | -0.782245 | -0.163456 | 0.642140 | ... | 0.242925 | -0.570346 | 0.270923 | -0.005590 | 0.012166 | 0.059096 | 0.450231 | 0.240332 | 0.039077 | 0.411482 |
891e2040023ffff | -0.269447 | -0.389121 | 0.234750 | -0.012543 | 0.499266 | -0.159563 | 0.116168 | -0.389050 | -0.060858 | 0.640575 | ... | 0.141512 | -0.507868 | 0.229550 | 0.035368 | -0.087487 | 0.180899 | 0.390609 | 0.074013 | -0.112671 | 0.482980 |
891e2040027ffff | -0.608874 | -0.281542 | 0.008080 | -0.270795 | 0.636253 | 0.038253 | 0.047164 | -0.680029 | -0.296917 | 0.791156 | ... | 0.014026 | -0.608172 | 0.238070 | -0.086131 | -0.071340 | 0.264048 | 0.172930 | 0.203559 | -0.080134 | 0.629472 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
891e2055bc3ffff | -0.529938 | -0.073519 | 0.055865 | -0.067914 | 0.599557 | -0.306125 | 0.104335 | -0.591797 | -0.283700 | 0.820726 | ... | 0.595174 | -0.536751 | 0.052352 | -0.299551 | -0.352925 | 0.222843 | 0.112538 | 0.336185 | 0.152496 | 0.819504 |
891e2055bc7ffff | -0.375381 | 0.038806 | 0.021660 | -0.117423 | 0.906208 | -0.511946 | 0.177024 | -0.290537 | -0.285046 | 0.844595 | ... | 0.403720 | -0.540731 | 0.123075 | -0.316030 | -0.472408 | 0.234737 | -0.037275 | 0.260438 | 0.039564 | 0.532431 |
891e2055bcbffff | -0.704404 | -0.131264 | 0.110740 | -0.045771 | 0.578663 | -0.148514 | 0.125499 | -0.802698 | -0.266073 | 0.794319 | ... | 0.782213 | -0.468432 | -0.109421 | -0.393788 | -0.378235 | 0.311282 | 0.179118 | 0.400895 | 0.281861 | 1.125080 |
891e205a967ffff | -0.661377 | 0.122641 | 0.132611 | -0.243137 | 0.318846 | 0.192291 | -0.041654 | -0.365727 | -0.418835 | 0.537724 | ... | -0.260569 | -0.359826 | 0.372024 | 0.006376 | 0.215719 | -0.017789 | 0.240693 | 0.270300 | 0.051628 | 0.586003 |
891e205a9a7ffff | -1.442428 | -1.533561 | 0.356087 | -1.009694 | 0.991603 | -1.077378 | -0.044753 | -0.590862 | -0.393576 | 1.966672 | ... | 0.720633 | -0.636218 | 0.009046 | 0.005114 | -0.503919 | 0.161550 | 0.564600 | 0.998339 | -0.031163 | 1.029532 |
2037 rows × 30 columns
In [7]:
Copied!
from sklearn.cluster import KMeans
clusterizer = KMeans(n_clusters=5, random_state=42)
clusterizer.fit(embeddings)
embeddings["cluster"] = clusterizer.labels_
embeddings
from sklearn.cluster import KMeans
clusterizer = KMeans(n_clusters=5, random_state=42)
clusterizer.fit(embeddings)
embeddings["cluster"] = clusterizer.labels_
embeddings
Out[7]:
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | cluster | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
891e2040007ffff | -0.518636 | -0.150762 | 0.099129 | -0.128340 | 0.374203 | -0.171351 | 0.041406 | -0.597257 | -0.195334 | 0.766336 | ... | -0.536595 | 0.197296 | 0.073133 | -0.009141 | 0.007996 | 0.163186 | 0.358240 | 0.068875 | 0.389990 | 1 |
891e2040013ffff | -0.092134 | -0.089167 | 0.161766 | -0.252995 | 0.625243 | -0.048758 | -0.005048 | -0.604901 | -0.377803 | 0.827473 | ... | -0.701997 | 0.276326 | -0.338328 | -0.026943 | 0.038455 | 0.445025 | 0.350000 | 0.059989 | 0.359978 | 1 |
891e2040017ffff | -0.269235 | -0.247315 | 0.095021 | -0.159002 | 0.452461 | -0.194560 | 0.155457 | -0.782245 | -0.163456 | 0.642140 | ... | -0.570346 | 0.270923 | -0.005590 | 0.012166 | 0.059096 | 0.450231 | 0.240332 | 0.039077 | 0.411482 | 1 |
891e2040023ffff | -0.269447 | -0.389121 | 0.234750 | -0.012543 | 0.499266 | -0.159563 | 0.116168 | -0.389050 | -0.060858 | 0.640575 | ... | -0.507868 | 0.229550 | 0.035368 | -0.087487 | 0.180899 | 0.390609 | 0.074013 | -0.112671 | 0.482980 | 3 |
891e2040027ffff | -0.608874 | -0.281542 | 0.008080 | -0.270795 | 0.636253 | 0.038253 | 0.047164 | -0.680029 | -0.296917 | 0.791156 | ... | -0.608172 | 0.238070 | -0.086131 | -0.071340 | 0.264048 | 0.172930 | 0.203559 | -0.080134 | 0.629472 | 1 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
891e2055bc3ffff | -0.529938 | -0.073519 | 0.055865 | -0.067914 | 0.599557 | -0.306125 | 0.104335 | -0.591797 | -0.283700 | 0.820726 | ... | -0.536751 | 0.052352 | -0.299551 | -0.352925 | 0.222843 | 0.112538 | 0.336185 | 0.152496 | 0.819504 | 2 |
891e2055bc7ffff | -0.375381 | 0.038806 | 0.021660 | -0.117423 | 0.906208 | -0.511946 | 0.177024 | -0.290537 | -0.285046 | 0.844595 | ... | -0.540731 | 0.123075 | -0.316030 | -0.472408 | 0.234737 | -0.037275 | 0.260438 | 0.039564 | 0.532431 | 2 |
891e2055bcbffff | -0.704404 | -0.131264 | 0.110740 | -0.045771 | 0.578663 | -0.148514 | 0.125499 | -0.802698 | -0.266073 | 0.794319 | ... | -0.468432 | -0.109421 | -0.393788 | -0.378235 | 0.311282 | 0.179118 | 0.400895 | 0.281861 | 1.125080 | 4 |
891e205a967ffff | -0.661377 | 0.122641 | 0.132611 | -0.243137 | 0.318846 | 0.192291 | -0.041654 | -0.365727 | -0.418835 | 0.537724 | ... | -0.359826 | 0.372024 | 0.006376 | 0.215719 | -0.017789 | 0.240693 | 0.270300 | 0.051628 | 0.586003 | 1 |
891e205a9a7ffff | -1.442428 | -1.533561 | 0.356087 | -1.009694 | 0.991603 | -1.077378 | -0.044753 | -0.590862 | -0.393576 | 1.966672 | ... | -0.636218 | 0.009046 | 0.005114 | -0.503919 | 0.161550 | 0.564600 | 0.998339 | -0.031163 | 1.029532 | 4 |
2037 rows × 31 columns