Highway2Vec Embedder¶
In [1]:
Copied!
from IPython.display import display
from srai.plotting import plot_numeric_data, plot_regions
from IPython.display import display
from srai.plotting import plot_numeric_data, plot_regions
Get an area to embed¶
In [2]:
Copied!
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Wrocław, PL")
plot_regions(area_gdf, tiles_style="CartoDB positron")
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Wrocław, PL")
plot_regions(area_gdf, tiles_style="CartoDB positron")
Out[2]:
Make this Notebook Trusted to load map: File -> Trust Notebook
Regionalize the area with a regionalizer¶
In [3]:
Copied!
from srai.regionalizers import H3Regionalizer
regionalizer = H3Regionalizer(9)
regions_gdf = regionalizer.transform(area_gdf)
print(len(regions_gdf))
display(regions_gdf.head(3))
plot_regions(regions_gdf, tiles_style="CartoDB positron")
from srai.regionalizers import H3Regionalizer
regionalizer = H3Regionalizer(9)
regions_gdf = regionalizer.transform(area_gdf)
print(len(regions_gdf))
display(regions_gdf.head(3))
plot_regions(regions_gdf, tiles_style="CartoDB positron")
3168
geometry | |
---|---|
region_id | |
891e2042d77ffff | POLYGON ((17.03091 51.16738, 17.03094 51.1657,... |
891e20421abffff | POLYGON ((16.96298 51.17291, 16.96302 51.17123... |
891e2044273ffff | POLYGON ((17.16782 51.1007, 17.16785 51.09902,... |
Out[3]:
Make this Notebook Trusted to load map: File -> Trust Notebook
Download a road infrastructure for the area¶
In [4]:
Copied!
from srai.loaders import OSMNetworkType, OSMWayLoader
loader = OSMWayLoader(OSMNetworkType.DRIVE)
nodes_gdf, edges_gdf = loader.load(area_gdf)
display(nodes_gdf.head(3))
display(edges_gdf.head(3))
ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))
nodes_gdf.plot(ax=ax, markersize=3, color="red")
from srai.loaders import OSMNetworkType, OSMWayLoader
loader = OSMWayLoader(OSMNetworkType.DRIVE)
nodes_gdf, edges_gdf = loader.load(area_gdf)
display(nodes_gdf.head(3))
display(edges_gdf.head(3))
ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))
nodes_gdf.plot(ax=ax, markersize=3, color="red")
y | x | street_count | highway | railway | ref | geometry | |
---|---|---|---|---|---|---|---|
osmid | |||||||
95584835 | 51.083111 | 17.049513 | 4 | NaN | NaN | NaN | POINT (17.04951 51.08311) |
95584841 | 51.084686 | 17.064329 | 3 | NaN | NaN | NaN | POINT (17.06433 51.08469) |
95584850 | 51.083328 | 17.035057 | 4 | NaN | NaN | NaN | POINT (17.03506 51.08333) |
oneway | lanes-1 | lanes-2 | lanes-3 | lanes-4 | lanes-5 | lanes-6 | lanes-7 | lanes-8 | lanes-9 | ... | bicycle-official | lit-yes | lit-no | lit-sunset-sunrise | lit-24/7 | lit-automatic | lit-disused | lit-limited | lit-interval | geometry | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
feature_id | |||||||||||||||||||||
0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.04951 51.08311, 17.04932 51.083... |
1 | 0 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.04951 51.08311, 17.04947 51.083... |
2 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.05357 51.08301, 17.05335 51.082... |
3 rows × 219 columns
Out[4]:
<Axes: >
Find out which edges correspond to which regions¶
In [5]:
Copied!
from srai.joiners import IntersectionJoiner
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, edges_gdf)
joint_gdf
from srai.joiners import IntersectionJoiner
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, edges_gdf)
joint_gdf
Out[5]:
region_id | feature_id |
---|---|
891e2040b03ffff | 0 |
1 | |
2 | |
891e2040b07ffff | 2 |
891e2040b03ffff | 3 |
... | ... |
891e20403dbffff | 10239 |
891e2047307ffff | 10240 |
10241 | |
10242 | |
891e2040c33ffff | 10243 |
15619 rows × 0 columns
Get the embeddings for regions based on the road infrastructure¶
In [6]:
Copied!
from pytorch_lightning import seed_everything
from srai.embedders import Highway2VecEmbedder
seed_everything(42)
embedder = Highway2VecEmbedder()
embeddings = embedder.fit_transform(regions_gdf, edges_gdf, joint_gdf)
embeddings
from pytorch_lightning import seed_everything
from srai.embedders import Highway2VecEmbedder
seed_everything(42)
embedder = Highway2VecEmbedder()
embeddings = embedder.fit_transform(regions_gdf, edges_gdf, joint_gdf)
embeddings
Seed set to 42
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default | Name | Type | Params | Mode ----------------------------------------------- 0 | encoder | Sequential | 16.0 K | train 1 | decoder | Sequential | 16.2 K | train ----------------------------------------------- 32.1 K Trainable params 0 Non-trainable params 32.1 K Total params 0.128 Total estimated model params size (MB) 8 Modules in train mode 0 Modules in eval mode
/opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
`Trainer.fit` stopped: `max_epochs=10` reached.
/opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/srai/embedders/highway2vec/embedder.py:75: FutureWarning: The behavior of array concatenation with empty entries is deprecated. In a future version, this will no longer exclude empty items when determining the result dtype. To retain the old behavior, exclude the empty entries before the concat operation. embeddings_joint = joint_gdf.join(embeddings_df)
Out[6]:
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
891e2040007ffff | -0.505571 | -0.161977 | 0.244276 | -0.040582 | 0.316826 | -0.141222 | 0.060456 | -0.733364 | -0.256160 | 0.747521 | ... | 0.378488 | -0.574541 | 0.182099 | -0.021319 | 0.006439 | 0.010994 | 0.292025 | 0.367131 | 0.066439 | 0.294101 |
891e2040013ffff | -0.062365 | -0.162727 | 0.234881 | -0.219666 | 0.544893 | -0.031368 | 0.010354 | -0.783258 | -0.460515 | 0.784668 | ... | 0.377643 | -0.697939 | 0.288901 | -0.383642 | 0.000252 | 0.022039 | 0.580721 | 0.420004 | 0.119929 | 0.324456 |
891e2040017ffff | -0.217881 | -0.271846 | 0.197869 | -0.126759 | 0.380466 | -0.153410 | 0.169173 | -0.919598 | -0.171831 | 0.644836 | ... | 0.285975 | -0.596563 | 0.235816 | -0.089400 | 0.044356 | 0.009748 | 0.581687 | 0.313082 | 0.071090 | 0.291771 |
891e2040023ffff | -0.234771 | -0.392038 | 0.343199 | 0.043906 | 0.498590 | -0.143046 | 0.142766 | -0.508038 | -0.076214 | 0.591567 | ... | 0.187048 | -0.535494 | 0.214821 | -0.023328 | -0.035013 | 0.177445 | 0.493670 | 0.109602 | -0.068035 | 0.413717 |
891e2040027ffff | -0.561839 | -0.317224 | 0.183030 | -0.214215 | 0.623574 | 0.065527 | 0.062856 | -0.791319 | -0.286160 | 0.808577 | ... | 0.040294 | -0.644298 | 0.181041 | -0.170795 | -0.000483 | 0.263557 | 0.332198 | 0.234718 | 0.019180 | 0.533377 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
891e2055bc3ffff | -0.477219 | -0.160535 | 0.115676 | -0.062552 | 0.592524 | -0.366858 | 0.102093 | -0.737968 | -0.318296 | 0.853119 | ... | 0.638078 | -0.561715 | 0.043542 | -0.354135 | -0.324639 | 0.209370 | 0.298326 | 0.288536 | 0.137298 | 0.690278 |
891e2055bc7ffff | -0.316054 | -0.082379 | 0.110414 | -0.124763 | 0.833848 | -0.484215 | 0.180403 | -0.480617 | -0.313713 | 0.882932 | ... | 0.492551 | -0.592583 | 0.046099 | -0.387436 | -0.362791 | 0.167632 | 0.169509 | 0.259948 | 0.013781 | 0.382859 |
891e2055bcbffff | -0.645476 | -0.227068 | 0.124664 | -0.052287 | 0.634229 | -0.275640 | 0.140212 | -0.940571 | -0.310617 | 0.811473 | ... | 0.771928 | -0.470617 | -0.071422 | -0.436041 | -0.396020 | 0.324479 | 0.393317 | 0.296664 | 0.278454 | 0.975055 |
891e205a967ffff | -0.630802 | 0.111766 | 0.248090 | -0.198120 | 0.304351 | 0.190321 | -0.030354 | -0.419449 | -0.437004 | 0.538626 | ... | -0.196105 | -0.364051 | 0.353933 | -0.070906 | 0.210677 | -0.007400 | 0.294595 | 0.323297 | 0.097985 | 0.518808 |
891e205a9a7ffff | -1.353162 | -1.533194 | 0.632859 | -0.862067 | 0.974528 | -0.927505 | 0.150118 | -0.794047 | -0.281776 | 1.955861 | ... | 0.736396 | -0.696542 | -0.066819 | -0.099729 | -0.457464 | 0.022553 | 0.900467 | 0.899027 | -0.041229 | 0.738852 |
2038 rows × 30 columns
In [7]:
Copied!
from sklearn.cluster import KMeans
clusterizer = KMeans(n_clusters=5, random_state=42)
clusterizer.fit(embeddings)
embeddings["cluster"] = clusterizer.labels_
embeddings
from sklearn.cluster import KMeans
clusterizer = KMeans(n_clusters=5, random_state=42)
clusterizer.fit(embeddings)
embeddings["cluster"] = clusterizer.labels_
embeddings
Out[7]:
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | cluster | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
891e2040007ffff | -0.505571 | -0.161977 | 0.244276 | -0.040582 | 0.316826 | -0.141222 | 0.060456 | -0.733364 | -0.256160 | 0.747521 | ... | -0.574541 | 0.182099 | -0.021319 | 0.006439 | 0.010994 | 0.292025 | 0.367131 | 0.066439 | 0.294101 | 3 |
891e2040013ffff | -0.062365 | -0.162727 | 0.234881 | -0.219666 | 0.544893 | -0.031368 | 0.010354 | -0.783258 | -0.460515 | 0.784668 | ... | -0.697939 | 0.288901 | -0.383642 | 0.000252 | 0.022039 | 0.580721 | 0.420004 | 0.119929 | 0.324456 | 1 |
891e2040017ffff | -0.217881 | -0.271846 | 0.197869 | -0.126759 | 0.380466 | -0.153410 | 0.169173 | -0.919598 | -0.171831 | 0.644836 | ... | -0.596563 | 0.235816 | -0.089400 | 0.044356 | 0.009748 | 0.581687 | 0.313082 | 0.071090 | 0.291771 | 1 |
891e2040023ffff | -0.234771 | -0.392038 | 0.343199 | 0.043906 | 0.498590 | -0.143046 | 0.142766 | -0.508038 | -0.076214 | 0.591567 | ... | -0.535494 | 0.214821 | -0.023328 | -0.035013 | 0.177445 | 0.493670 | 0.109602 | -0.068035 | 0.413717 | 4 |
891e2040027ffff | -0.561839 | -0.317224 | 0.183030 | -0.214215 | 0.623574 | 0.065527 | 0.062856 | -0.791319 | -0.286160 | 0.808577 | ... | -0.644298 | 0.181041 | -0.170795 | -0.000483 | 0.263557 | 0.332198 | 0.234718 | 0.019180 | 0.533377 | 3 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
891e2055bc3ffff | -0.477219 | -0.160535 | 0.115676 | -0.062552 | 0.592524 | -0.366858 | 0.102093 | -0.737968 | -0.318296 | 0.853119 | ... | -0.561715 | 0.043542 | -0.354135 | -0.324639 | 0.209370 | 0.298326 | 0.288536 | 0.137298 | 0.690278 | 2 |
891e2055bc7ffff | -0.316054 | -0.082379 | 0.110414 | -0.124763 | 0.833848 | -0.484215 | 0.180403 | -0.480617 | -0.313713 | 0.882932 | ... | -0.592583 | 0.046099 | -0.387436 | -0.362791 | 0.167632 | 0.169509 | 0.259948 | 0.013781 | 0.382859 | 2 |
891e2055bcbffff | -0.645476 | -0.227068 | 0.124664 | -0.052287 | 0.634229 | -0.275640 | 0.140212 | -0.940571 | -0.310617 | 0.811473 | ... | -0.470617 | -0.071422 | -0.436041 | -0.396020 | 0.324479 | 0.393317 | 0.296664 | 0.278454 | 0.975055 | 0 |
891e205a967ffff | -0.630802 | 0.111766 | 0.248090 | -0.198120 | 0.304351 | 0.190321 | -0.030354 | -0.419449 | -0.437004 | 0.538626 | ... | -0.364051 | 0.353933 | -0.070906 | 0.210677 | -0.007400 | 0.294595 | 0.323297 | 0.097985 | 0.518808 | 3 |
891e205a9a7ffff | -1.353162 | -1.533194 | 0.632859 | -0.862067 | 0.974528 | -0.927505 | 0.150118 | -0.794047 | -0.281776 | 1.955861 | ... | -0.696542 | -0.066819 | -0.099729 | -0.457464 | 0.022553 | 0.900467 | 0.899027 | -0.041229 | 0.738852 | 0 |
2038 rows × 31 columns
In [8]:
Copied!
plot_numeric_data(regions_gdf, "cluster", embeddings)
plot_numeric_data(regions_gdf, "cluster", embeddings)
Out[8]:
Make this Notebook Trusted to load map: File -> Trust Notebook