Highway2Vec Embedder¶
In [1]:
Copied!
from IPython.display import display
from srai.plotting import plot_numeric_data, plot_regions
from IPython.display import display
from srai.plotting import plot_numeric_data, plot_regions
Get an area to embed¶
In [2]:
Copied!
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Wrocław, PL")
plot_regions(area_gdf, tiles_style="CartoDB positron")
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Wrocław, PL")
plot_regions(area_gdf, tiles_style="CartoDB positron")
Out[2]:
Make this Notebook Trusted to load map: File -> Trust Notebook
Regionalize the area with a regionalizer¶
In [3]:
Copied!
from srai.regionalizers import H3Regionalizer
regionalizer = H3Regionalizer(9)
regions_gdf = regionalizer.transform(area_gdf)
print(len(regions_gdf))
display(regions_gdf.head(3))
plot_regions(regions_gdf, tiles_style="CartoDB positron")
from srai.regionalizers import H3Regionalizer
regionalizer = H3Regionalizer(9)
regions_gdf = regionalizer.transform(area_gdf)
print(len(regions_gdf))
display(regions_gdf.head(3))
plot_regions(regions_gdf, tiles_style="CartoDB positron")
3168
| geometry | |
|---|---|
| region_id | |
| 891e204e56fffff | POLYGON ((17.06835 51.0574, 17.06839 51.05572,... |
| 891e2042c6bffff | POLYGON ((16.99086 51.16114, 16.9909 51.15946,... |
| 891e2047367ffff | POLYGON ((17.12237 51.11454, 17.12241 51.11286... |
Out[3]:
Make this Notebook Trusted to load map: File -> Trust Notebook
Download a road infrastructure for the area¶
In [4]:
Copied!
from srai.loaders import OSMNetworkType, OSMWayLoader
loader = OSMWayLoader(OSMNetworkType.DRIVE)
nodes_gdf, edges_gdf = loader.load(area_gdf)
display(nodes_gdf.head(3))
display(edges_gdf.head(3))
ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))
nodes_gdf.plot(ax=ax, markersize=3, color="red")
from srai.loaders import OSMNetworkType, OSMWayLoader
loader = OSMWayLoader(OSMNetworkType.DRIVE)
nodes_gdf, edges_gdf = loader.load(area_gdf)
display(nodes_gdf.head(3))
display(edges_gdf.head(3))
ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))
nodes_gdf.plot(ax=ax, markersize=3, color="red")
| y | x | street_count | highway | railway | ref | geometry | |
|---|---|---|---|---|---|---|---|
| osmid | |||||||
| 95584835 | 51.083111 | 17.049513 | 4 | NaN | NaN | NaN | POINT (17.04951 51.08311) |
| 95584841 | 51.084686 | 17.064329 | 3 | NaN | NaN | NaN | POINT (17.06433 51.08469) |
| 95584850 | 51.083328 | 17.035057 | 4 | NaN | NaN | NaN | POINT (17.03506 51.08333) |
| oneway | lanes-1 | lanes-2 | lanes-3 | lanes-4 | lanes-5 | lanes-6 | lanes-7 | lanes-8 | lanes-9 | ... | bicycle-official | lit-yes | lit-no | lit-sunset-sunrise | lit-24/7 | lit-automatic | lit-disused | lit-limited | lit-interval | geometry | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| feature_id | |||||||||||||||||||||
| 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.04951 51.08311, 17.04932 51.083... |
| 1 | 0 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.04951 51.08311, 17.04947 51.083... |
| 2 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.05357 51.08301, 17.05335 51.082... |
3 rows × 219 columns
Out[4]:
<Axes: >
Find out which edges correspond to which regions¶
In [5]:
Copied!
from srai.joiners import IntersectionJoiner
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, edges_gdf)
joint_gdf
from srai.joiners import IntersectionJoiner
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, edges_gdf)
joint_gdf
Out[5]:
| region_id | feature_id |
|---|---|
| 891e2040b03ffff | 0 |
| 1 | |
| 2 | |
| 891e2040b07ffff | 2 |
| 891e2040b03ffff | 3 |
| ... | ... |
| 891e20403dbffff | 10227 |
| 891e2047307ffff | 10228 |
| 10229 | |
| 10230 | |
| 891e2040c33ffff | 10231 |
15602 rows × 0 columns
Get the embeddings for regions based on the road infrastructure¶
In [6]:
Copied!
from pytorch_lightning import seed_everything
from srai.embedders import Highway2VecEmbedder
seed_everything(42)
embedder = Highway2VecEmbedder()
embeddings = embedder.fit_transform(regions_gdf, edges_gdf, joint_gdf)
embeddings
from pytorch_lightning import seed_everything
from srai.embedders import Highway2VecEmbedder
seed_everything(42)
embedder = Highway2VecEmbedder()
embeddings = embedder.fit_transform(regions_gdf, edges_gdf, joint_gdf)
embeddings
Seed set to 42
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default | Name | Type | Params | Mode ----------------------------------------------- 0 | encoder | Sequential | 16.0 K | train 1 | decoder | Sequential | 16.2 K | train ----------------------------------------------- 32.1 K Trainable params 0 Non-trainable params 32.1 K Total params 0.128 Total estimated model params size (MB) 8 Modules in train mode 0 Modules in eval mode
/opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
`Trainer.fit` stopped: `max_epochs=10` reached.
/opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/srai/embedders/highway2vec/embedder.py:75: FutureWarning: The behavior of array concatenation with empty entries is deprecated. In a future version, this will no longer exclude empty items when determining the result dtype. To retain the old behavior, exclude the empty entries before the concat operation. embeddings_joint = joint_gdf.join(embeddings_df)
Out[6]:
| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| region_id | |||||||||||||||||||||
| 891e2040007ffff | -0.508066 | -0.177025 | 0.156437 | -0.071603 | 0.339832 | -0.141631 | 0.045181 | -0.668599 | -0.218437 | 0.771740 | ... | 0.358396 | -0.577295 | 0.193811 | 0.018114 | 0.010343 | 0.039192 | 0.218984 | 0.359508 | 0.053295 | 0.372048 |
| 891e2040013ffff | -0.069767 | -0.139462 | 0.192447 | -0.226635 | 0.568134 | -0.027965 | 0.009412 | -0.727394 | -0.412899 | 0.820051 | ... | 0.386026 | -0.746125 | 0.277594 | -0.383546 | 0.019188 | 0.056929 | 0.543255 | 0.415436 | 0.072133 | 0.399711 |
| 891e2040017ffff | -0.215969 | -0.260626 | 0.151140 | -0.136405 | 0.411986 | -0.143113 | 0.164531 | -0.870192 | -0.148687 | 0.674008 | ... | 0.271089 | -0.629138 | 0.251839 | -0.092176 | 0.036656 | 0.067176 | 0.535852 | 0.292186 | 0.022294 | 0.353723 |
| 891e2040023ffff | -0.245717 | -0.395154 | 0.296367 | 0.021404 | 0.516553 | -0.154588 | 0.134588 | -0.454115 | -0.046521 | 0.624609 | ... | 0.179349 | -0.549597 | 0.229970 | 0.003512 | -0.050325 | 0.210689 | 0.444287 | 0.102506 | -0.114143 | 0.460099 |
| 891e2040027ffff | -0.563019 | -0.319663 | 0.126843 | -0.231352 | 0.642825 | 0.053629 | 0.041703 | -0.745535 | -0.258027 | 0.829416 | ... | 0.024639 | -0.644661 | 0.204146 | -0.141410 | -0.008804 | 0.296371 | 0.276304 | 0.229202 | -0.024074 | 0.580097 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 891e2055bc3ffff | -0.467639 | -0.133415 | 0.033876 | -0.060866 | 0.539893 | -0.377114 | 0.055854 | -0.758585 | -0.324896 | 0.826015 | ... | 0.587118 | -0.568958 | 0.063796 | -0.317409 | -0.306000 | 0.238316 | 0.210633 | 0.292501 | 0.091417 | 0.794386 |
| 891e2055bc7ffff | -0.312607 | -0.051650 | 0.021861 | -0.129742 | 0.763106 | -0.515685 | 0.125849 | -0.520134 | -0.327023 | 0.858875 | ... | 0.449067 | -0.594045 | 0.074953 | -0.334261 | -0.357999 | 0.194893 | 0.071183 | 0.259940 | -0.017499 | 0.485172 |
| 891e2055bcbffff | -0.629814 | -0.184191 | 0.048094 | -0.030607 | 0.549817 | -0.285326 | 0.076246 | -0.962848 | -0.322616 | 0.769946 | ... | 0.714414 | -0.497399 | -0.042653 | -0.413841 | -0.372848 | 0.353870 | 0.291589 | 0.313341 | 0.206297 | 1.100539 |
| 891e205a967ffff | -0.652834 | 0.109684 | 0.208043 | -0.235541 | 0.351340 | 0.185104 | -0.044626 | -0.337786 | -0.378774 | 0.569485 | ... | -0.213431 | -0.363390 | 0.359284 | -0.045358 | 0.209417 | 0.036405 | 0.260836 | 0.291675 | 0.056622 | 0.543634 |
| 891e205a9a7ffff | -1.329819 | -1.523456 | 0.467204 | -0.870951 | 0.884991 | -0.907677 | -0.044607 | -0.809217 | -0.356116 | 1.873452 | ... | 0.690767 | -0.678770 | -0.026932 | 0.007893 | -0.481888 | 0.104526 | 0.684110 | 0.902423 | -0.061302 | 1.024393 |
2038 rows × 30 columns
In [7]:
Copied!
from sklearn.cluster import KMeans
clusterizer = KMeans(n_clusters=5, random_state=42)
clusterizer.fit(embeddings)
embeddings["cluster"] = clusterizer.labels_
embeddings
from sklearn.cluster import KMeans
clusterizer = KMeans(n_clusters=5, random_state=42)
clusterizer.fit(embeddings)
embeddings["cluster"] = clusterizer.labels_
embeddings
Out[7]:
| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | cluster | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| region_id | |||||||||||||||||||||
| 891e2040007ffff | -0.508066 | -0.177025 | 0.156437 | -0.071603 | 0.339832 | -0.141631 | 0.045181 | -0.668599 | -0.218437 | 0.771740 | ... | -0.577295 | 0.193811 | 0.018114 | 0.010343 | 0.039192 | 0.218984 | 0.359508 | 0.053295 | 0.372048 | 3 |
| 891e2040013ffff | -0.069767 | -0.139462 | 0.192447 | -0.226635 | 0.568134 | -0.027965 | 0.009412 | -0.727394 | -0.412899 | 0.820051 | ... | -0.746125 | 0.277594 | -0.383546 | 0.019188 | 0.056929 | 0.543255 | 0.415436 | 0.072133 | 0.399711 | 1 |
| 891e2040017ffff | -0.215969 | -0.260626 | 0.151140 | -0.136405 | 0.411986 | -0.143113 | 0.164531 | -0.870192 | -0.148687 | 0.674008 | ... | -0.629138 | 0.251839 | -0.092176 | 0.036656 | 0.067176 | 0.535852 | 0.292186 | 0.022294 | 0.353723 | 1 |
| 891e2040023ffff | -0.245717 | -0.395154 | 0.296367 | 0.021404 | 0.516553 | -0.154588 | 0.134588 | -0.454115 | -0.046521 | 0.624609 | ... | -0.549597 | 0.229970 | 0.003512 | -0.050325 | 0.210689 | 0.444287 | 0.102506 | -0.114143 | 0.460099 | 4 |
| 891e2040027ffff | -0.563019 | -0.319663 | 0.126843 | -0.231352 | 0.642825 | 0.053629 | 0.041703 | -0.745535 | -0.258027 | 0.829416 | ... | -0.644661 | 0.204146 | -0.141410 | -0.008804 | 0.296371 | 0.276304 | 0.229202 | -0.024074 | 0.580097 | 3 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 891e2055bc3ffff | -0.467639 | -0.133415 | 0.033876 | -0.060866 | 0.539893 | -0.377114 | 0.055854 | -0.758585 | -0.324896 | 0.826015 | ... | -0.568958 | 0.063796 | -0.317409 | -0.306000 | 0.238316 | 0.210633 | 0.292501 | 0.091417 | 0.794386 | 2 |
| 891e2055bc7ffff | -0.312607 | -0.051650 | 0.021861 | -0.129742 | 0.763106 | -0.515685 | 0.125849 | -0.520134 | -0.327023 | 0.858875 | ... | -0.594045 | 0.074953 | -0.334261 | -0.357999 | 0.194893 | 0.071183 | 0.259940 | -0.017499 | 0.485172 | 2 |
| 891e2055bcbffff | -0.629814 | -0.184191 | 0.048094 | -0.030607 | 0.549817 | -0.285326 | 0.076246 | -0.962848 | -0.322616 | 0.769946 | ... | -0.497399 | -0.042653 | -0.413841 | -0.372848 | 0.353870 | 0.291589 | 0.313341 | 0.206297 | 1.100539 | 0 |
| 891e205a967ffff | -0.652834 | 0.109684 | 0.208043 | -0.235541 | 0.351340 | 0.185104 | -0.044626 | -0.337786 | -0.378774 | 0.569485 | ... | -0.363390 | 0.359284 | -0.045358 | 0.209417 | 0.036405 | 0.260836 | 0.291675 | 0.056622 | 0.543634 | 3 |
| 891e205a9a7ffff | -1.329819 | -1.523456 | 0.467204 | -0.870951 | 0.884991 | -0.907677 | -0.044607 | -0.809217 | -0.356116 | 1.873452 | ... | -0.678770 | -0.026932 | 0.007893 | -0.481888 | 0.104526 | 0.684110 | 0.902423 | -0.061302 | 1.024393 | 0 |
2038 rows × 31 columns
In [8]:
Copied!
plot_numeric_data(regions_gdf, "cluster", embeddings)
plot_numeric_data(regions_gdf, "cluster", embeddings)
Out[8]:
Make this Notebook Trusted to load map: File -> Trust Notebook