Highway2Vec Embedder¶
In [1]:
Copied!
from IPython.display import display
from srai.plotting import plot_numeric_data, plot_regions
from IPython.display import display
from srai.plotting import plot_numeric_data, plot_regions
Get an area to embed¶
In [2]:
Copied!
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Wrocław, PL")
plot_regions(area_gdf, tiles_style="CartoDB positron")
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Wrocław, PL")
plot_regions(area_gdf, tiles_style="CartoDB positron")
Out[2]:
Make this Notebook Trusted to load map: File -> Trust Notebook
Regionalize the area with a regionalizer¶
In [3]:
Copied!
from srai.regionalizers import H3Regionalizer
regionalizer = H3Regionalizer(9)
regions_gdf = regionalizer.transform(area_gdf)
print(len(regions_gdf))
display(regions_gdf.head(3))
plot_regions(regions_gdf, tiles_style="CartoDB positron")
from srai.regionalizers import H3Regionalizer
regionalizer = H3Regionalizer(9)
regions_gdf = regionalizer.transform(area_gdf)
print(len(regions_gdf))
display(regions_gdf.head(3))
plot_regions(regions_gdf, tiles_style="CartoDB positron")
3168
geometry | |
---|---|
region_id | |
891e204e113ffff | POLYGON ((17.07848 51.05265, 17.07851 51.05096... |
891e204556fffff | POLYGON ((17.13782 51.09479, 17.13786 51.09311... |
891e204046bffff | POLYGON ((16.92143 51.1237, 16.92146 51.12202,... |
Out[3]:
Make this Notebook Trusted to load map: File -> Trust Notebook
Download a road infrastructure for the area¶
In [4]:
Copied!
from srai.loaders import OSMNetworkType, OSMWayLoader
loader = OSMWayLoader(OSMNetworkType.DRIVE)
nodes_gdf, edges_gdf = loader.load(area_gdf)
display(nodes_gdf.head(3))
display(edges_gdf.head(3))
ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))
nodes_gdf.plot(ax=ax, markersize=3, color="red")
from srai.loaders import OSMNetworkType, OSMWayLoader
loader = OSMWayLoader(OSMNetworkType.DRIVE)
nodes_gdf, edges_gdf = loader.load(area_gdf)
display(nodes_gdf.head(3))
display(edges_gdf.head(3))
ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))
nodes_gdf.plot(ax=ax, markersize=3, color="red")
y | x | street_count | highway | railway | ref | geometry | |
---|---|---|---|---|---|---|---|
osmid | |||||||
95584835 | 51.083111 | 17.049513 | 4 | NaN | NaN | NaN | POINT (17.04951 51.08311) |
95584841 | 51.084686 | 17.064329 | 3 | NaN | NaN | NaN | POINT (17.06433 51.08469) |
95584850 | 51.083328 | 17.035057 | 4 | NaN | NaN | NaN | POINT (17.03506 51.08333) |
oneway | lanes-1 | lanes-2 | lanes-3 | lanes-4 | lanes-5 | lanes-6 | lanes-7 | lanes-8 | lanes-9 | ... | bicycle-official | lit-yes | lit-no | lit-sunset-sunrise | lit-24/7 | lit-automatic | lit-disused | lit-limited | lit-interval | geometry | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
feature_id | |||||||||||||||||||||
0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.04951 51.08311, 17.04932 51.083... |
1 | 0 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.04951 51.08311, 17.04947 51.083... |
2 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.05357 51.08301, 17.05335 51.082... |
3 rows × 219 columns
Out[4]:
<Axes: >
Find out which edges correspond to which regions¶
In [5]:
Copied!
from srai.joiners import IntersectionJoiner
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, edges_gdf)
joint_gdf
from srai.joiners import IntersectionJoiner
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, edges_gdf)
joint_gdf
Out[5]:
region_id | feature_id |
---|---|
891e2040b03ffff | 0 |
1 | |
2 | |
891e2040b07ffff | 2 |
891e2040b03ffff | 3 |
... | ... |
891e2047267ffff | 10243 |
891e20403dbffff | 10244 |
891e2047307ffff | 10245 |
10246 | |
10247 |
15620 rows × 0 columns
Get the embeddings for regions based on the road infrastructure¶
In [6]:
Copied!
from pytorch_lightning import seed_everything
from srai.embedders import Highway2VecEmbedder
seed_everything(42)
embedder = Highway2VecEmbedder()
embeddings = embedder.fit_transform(regions_gdf, edges_gdf, joint_gdf)
embeddings
from pytorch_lightning import seed_everything
from srai.embedders import Highway2VecEmbedder
seed_everything(42)
embedder = Highway2VecEmbedder()
embeddings = embedder.fit_transform(regions_gdf, edges_gdf, joint_gdf)
embeddings
Seed set to 42
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/opt/hostedtoolcache/Python/3.10.17/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default | Name | Type | Params | Mode ----------------------------------------------- 0 | encoder | Sequential | 16.0 K | train 1 | decoder | Sequential | 16.2 K | train ----------------------------------------------- 32.1 K Trainable params 0 Non-trainable params 32.1 K Total params 0.128 Total estimated model params size (MB) 8 Modules in train mode 0 Modules in eval mode
/opt/hostedtoolcache/Python/3.10.17/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
`Trainer.fit` stopped: `max_epochs=10` reached.
/opt/hostedtoolcache/Python/3.10.17/x64/lib/python3.10/site-packages/srai/embedders/highway2vec/embedder.py:75: FutureWarning: The behavior of array concatenation with empty entries is deprecated. In a future version, this will no longer exclude empty items when determining the result dtype. To retain the old behavior, exclude the empty entries before the concat operation. embeddings_joint = joint_gdf.join(embeddings_df)
Out[6]:
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
891e2040007ffff | -0.529524 | -0.166252 | 0.188208 | -0.078263 | 0.312652 | -0.122885 | 0.046166 | -0.666709 | -0.191351 | 0.789335 | ... | 0.381418 | -0.484698 | 0.180864 | 0.030174 | 0.040711 | 0.001478 | 0.239903 | 0.383556 | 0.070571 | 0.369365 |
891e2040013ffff | -0.079371 | -0.105000 | 0.201465 | -0.178798 | 0.538683 | -0.047302 | 0.018895 | -0.736121 | -0.354629 | 0.885385 | ... | 0.417807 | -0.658675 | 0.296876 | -0.331781 | -0.013512 | -0.030582 | 0.527361 | 0.372395 | 0.081116 | 0.366028 |
891e2040017ffff | -0.230446 | -0.220944 | 0.174758 | -0.139742 | 0.411353 | -0.155695 | 0.177206 | -0.892093 | -0.174453 | 0.694497 | ... | 0.274119 | -0.550626 | 0.262852 | -0.070176 | 0.041392 | 0.015679 | 0.542191 | 0.277701 | 0.043621 | 0.353432 |
891e2040023ffff | -0.243348 | -0.375326 | 0.297044 | 0.023123 | 0.508149 | -0.180508 | 0.132730 | -0.472526 | -0.068183 | 0.655002 | ... | 0.183288 | -0.491395 | 0.257521 | 0.005157 | -0.075450 | 0.169884 | 0.447809 | 0.085794 | -0.092300 | 0.471468 |
891e2040027ffff | -0.583803 | -0.299034 | 0.110338 | -0.244688 | 0.647475 | 0.013781 | 0.042422 | -0.779286 | -0.310125 | 0.832543 | ... | 0.025534 | -0.563438 | 0.251398 | -0.125340 | -0.056853 | 0.261820 | 0.265913 | 0.188490 | -0.035973 | 0.601671 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
891e2055bc3ffff | -0.521368 | -0.087746 | 0.041659 | -0.055106 | 0.573150 | -0.408106 | 0.081084 | -0.751152 | -0.312421 | 0.879668 | ... | 0.673746 | -0.499221 | 0.077826 | -0.263520 | -0.374796 | 0.197104 | 0.158045 | 0.276370 | 0.078446 | 0.737015 |
891e2055bc7ffff | -0.368377 | -0.028404 | 0.016222 | -0.157243 | 0.826003 | -0.549717 | 0.116973 | -0.485026 | -0.315559 | 0.915919 | ... | 0.533159 | -0.525376 | 0.102559 | -0.306852 | -0.431274 | 0.183856 | 0.044066 | 0.233206 | -0.063242 | 0.440407 |
891e2055bcbffff | -0.691901 | -0.121992 | 0.071182 | -0.007955 | 0.560298 | -0.308511 | 0.134867 | -0.967332 | -0.290703 | 0.846483 | ... | 0.808390 | -0.417870 | -0.059367 | -0.334908 | -0.436776 | 0.292623 | 0.216498 | 0.300700 | 0.223636 | 1.021725 |
891e205a967ffff | -0.649509 | 0.129532 | 0.186205 | -0.216507 | 0.339390 | 0.147296 | -0.056306 | -0.392828 | -0.437906 | 0.534310 | ... | -0.239489 | -0.320744 | 0.414384 | -0.025356 | 0.174043 | -0.002633 | 0.241043 | 0.256610 | 0.046542 | 0.573696 |
891e205a9a7ffff | -1.455485 | -1.602951 | 0.534991 | -0.923690 | 0.955260 | -0.951231 | -0.138807 | -0.604343 | -0.305117 | 2.103476 | ... | 0.749453 | -0.498591 | -0.030532 | -0.020145 | -0.474182 | 0.079278 | 0.649379 | 0.987396 | 0.018225 | 1.095909 |
2038 rows × 30 columns
In [7]:
Copied!
from sklearn.cluster import KMeans
clusterizer = KMeans(n_clusters=5, random_state=42)
clusterizer.fit(embeddings)
embeddings["cluster"] = clusterizer.labels_
embeddings
from sklearn.cluster import KMeans
clusterizer = KMeans(n_clusters=5, random_state=42)
clusterizer.fit(embeddings)
embeddings["cluster"] = clusterizer.labels_
embeddings
Out[7]:
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | cluster | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
891e2040007ffff | -0.529524 | -0.166252 | 0.188208 | -0.078263 | 0.312652 | -0.122885 | 0.046166 | -0.666709 | -0.191351 | 0.789335 | ... | -0.484698 | 0.180864 | 0.030174 | 0.040711 | 0.001478 | 0.239903 | 0.383556 | 0.070571 | 0.369365 | 3 |
891e2040013ffff | -0.079371 | -0.105000 | 0.201465 | -0.178798 | 0.538683 | -0.047302 | 0.018895 | -0.736121 | -0.354629 | 0.885385 | ... | -0.658675 | 0.296876 | -0.331781 | -0.013512 | -0.030582 | 0.527361 | 0.372395 | 0.081116 | 0.366028 | 0 |
891e2040017ffff | -0.230446 | -0.220944 | 0.174758 | -0.139742 | 0.411353 | -0.155695 | 0.177206 | -0.892093 | -0.174453 | 0.694497 | ... | -0.550626 | 0.262852 | -0.070176 | 0.041392 | 0.015679 | 0.542191 | 0.277701 | 0.043621 | 0.353432 | 0 |
891e2040023ffff | -0.243348 | -0.375326 | 0.297044 | 0.023123 | 0.508149 | -0.180508 | 0.132730 | -0.472526 | -0.068183 | 0.655002 | ... | -0.491395 | 0.257521 | 0.005157 | -0.075450 | 0.169884 | 0.447809 | 0.085794 | -0.092300 | 0.471468 | 4 |
891e2040027ffff | -0.583803 | -0.299034 | 0.110338 | -0.244688 | 0.647475 | 0.013781 | 0.042422 | -0.779286 | -0.310125 | 0.832543 | ... | -0.563438 | 0.251398 | -0.125340 | -0.056853 | 0.261820 | 0.265913 | 0.188490 | -0.035973 | 0.601671 | 3 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
891e2055bc3ffff | -0.521368 | -0.087746 | 0.041659 | -0.055106 | 0.573150 | -0.408106 | 0.081084 | -0.751152 | -0.312421 | 0.879668 | ... | -0.499221 | 0.077826 | -0.263520 | -0.374796 | 0.197104 | 0.158045 | 0.276370 | 0.078446 | 0.737015 | 2 |
891e2055bc7ffff | -0.368377 | -0.028404 | 0.016222 | -0.157243 | 0.826003 | -0.549717 | 0.116973 | -0.485026 | -0.315559 | 0.915919 | ... | -0.525376 | 0.102559 | -0.306852 | -0.431274 | 0.183856 | 0.044066 | 0.233206 | -0.063242 | 0.440407 | 2 |
891e2055bcbffff | -0.691901 | -0.121992 | 0.071182 | -0.007955 | 0.560298 | -0.308511 | 0.134867 | -0.967332 | -0.290703 | 0.846483 | ... | -0.417870 | -0.059367 | -0.334908 | -0.436776 | 0.292623 | 0.216498 | 0.300700 | 0.223636 | 1.021725 | 1 |
891e205a967ffff | -0.649509 | 0.129532 | 0.186205 | -0.216507 | 0.339390 | 0.147296 | -0.056306 | -0.392828 | -0.437906 | 0.534310 | ... | -0.320744 | 0.414384 | -0.025356 | 0.174043 | -0.002633 | 0.241043 | 0.256610 | 0.046542 | 0.573696 | 4 |
891e205a9a7ffff | -1.455485 | -1.602951 | 0.534991 | -0.923690 | 0.955260 | -0.951231 | -0.138807 | -0.604343 | -0.305117 | 2.103476 | ... | -0.498591 | -0.030532 | -0.020145 | -0.474182 | 0.079278 | 0.649379 | 0.987396 | 0.018225 | 1.095909 | 1 |
2038 rows × 31 columns
In [8]:
Copied!
plot_numeric_data(regions_gdf, "cluster", embeddings)
plot_numeric_data(regions_gdf, "cluster", embeddings)
Out[8]:
Make this Notebook Trusted to load map: File -> Trust Notebook