Highway2Vec Embedder¶
In [1]:
Copied!
from IPython.display import display
from srai.plotting import plot_regions, plot_numeric_data
from IPython.display import display
from srai.plotting import plot_regions, plot_numeric_data
Get an area to embed¶
In [2]:
Copied!
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Wrocław, PL")
plot_regions(area_gdf, tiles_style="CartoDB positron")
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Wrocław, PL")
plot_regions(area_gdf, tiles_style="CartoDB positron")
Out[2]:
Make this Notebook Trusted to load map: File -> Trust Notebook
Regionalize the area with a regionalizer¶
In [3]:
Copied!
from srai.regionalizers import H3Regionalizer
regionalizer = H3Regionalizer(9)
regions_gdf = regionalizer.transform(area_gdf)
print(len(regions_gdf))
display(regions_gdf.head(3))
plot_regions(regions_gdf, tiles_style="CartoDB positron")
from srai.regionalizers import H3Regionalizer
regionalizer = H3Regionalizer(9)
regions_gdf = regionalizer.transform(area_gdf)
print(len(regions_gdf))
display(regions_gdf.head(3))
plot_regions(regions_gdf, tiles_style="CartoDB positron")
3168
geometry | |
---|---|
region_id | |
891e20415b7ffff | POLYGON ((16.89927 51.10535, 16.89931 51.10367... |
891e204335bffff | POLYGON ((16.85404 51.10899, 16.85408 51.10731... |
891e2047043ffff | POLYGON ((17.08685 51.13371, 17.08688 51.13202... |
Out[3]:
Make this Notebook Trusted to load map: File -> Trust Notebook
Download a road infrastructure for the area¶
In [4]:
Copied!
from srai.loaders import OSMNetworkType, OSMWayLoader
loader = OSMWayLoader(OSMNetworkType.DRIVE)
nodes_gdf, edges_gdf = loader.load(area_gdf)
display(nodes_gdf.head(3))
display(edges_gdf.head(3))
ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))
nodes_gdf.plot(ax=ax, markersize=3, color="red")
from srai.loaders import OSMNetworkType, OSMWayLoader
loader = OSMWayLoader(OSMNetworkType.DRIVE)
nodes_gdf, edges_gdf = loader.load(area_gdf)
display(nodes_gdf.head(3))
display(edges_gdf.head(3))
ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))
nodes_gdf.plot(ax=ax, markersize=3, color="red")
/opt/hostedtoolcache/Python/3.10.13/x64/lib/python3.10/site-packages/srai/loaders/osm_way_loader/osm_way_loader.py:219: UserWarning: The clean_periphery argument has been deprecated and will be removed in a future release. Future behavior will be as though clean_periphery=True. G_directed = ox.graph_from_polygon(
y | x | street_count | highway | ref | geometry | |
---|---|---|---|---|---|---|
osmid | ||||||
95584835 | 51.083111 | 17.049513 | 4 | NaN | NaN | POINT (17.04951 51.08311) |
95584841 | 51.084699 | 17.064367 | 3 | NaN | NaN | POINT (17.06437 51.08470) |
95584850 | 51.083328 | 17.035057 | 4 | NaN | NaN | POINT (17.03506 51.08333) |
oneway | lanes-1 | lanes-2 | lanes-3 | lanes-4 | lanes-5 | lanes-6 | lanes-7 | lanes-8 | lanes-9 | ... | bicycle-official | lit-yes | lit-no | lit-sunset-sunrise | lit-24/7 | lit-automatic | lit-disused | lit-limited | lit-interval | geometry | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
feature_id | |||||||||||||||||||||
0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.04951 51.08311, 17.04947 51.083... |
1 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.04951 51.08311, 17.04933 51.083... |
2 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.05357 51.08301, 17.05335 51.082... |
3 rows × 219 columns
Out[4]:
<Axes: >
Find out which edges correspond to which regions¶
In [5]:
Copied!
from srai.joiners import IntersectionJoiner
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, edges_gdf)
joint_gdf
from srai.joiners import IntersectionJoiner
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, edges_gdf)
joint_gdf
Out[5]:
region_id | feature_id |
---|---|
891e20415b7ffff | 7390 |
891e204064fffff | 7390 |
891e2040647ffff | 7390 |
891e20415b3ffff | 7390 |
891e204067bffff | 7390 |
... | ... |
891e2040c67ffff | 1921 |
486 | |
488 | |
485 | |
5140 |
15448 rows × 0 columns
Get the embeddings for regions based on the road infrastructure¶
In [6]:
Copied!
from srai.embedders import Highway2VecEmbedder
from pytorch_lightning import seed_everything
seed_everything(42)
embedder = Highway2VecEmbedder()
embeddings = embedder.fit_transform(regions_gdf, edges_gdf, joint_gdf)
embeddings
from srai.embedders import Highway2VecEmbedder
from pytorch_lightning import seed_everything
seed_everything(42)
embedder = Highway2VecEmbedder()
embeddings = embedder.fit_transform(regions_gdf, edges_gdf, joint_gdf)
embeddings
Seed set to 42 GPU available: False, used: False TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs /opt/hostedtoolcache/Python/3.10.13/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default | Name | Type | Params --------------------------------------- 0 | encoder | Sequential | 16.0 K 1 | decoder | Sequential | 16.2 K --------------------------------------- 32.1 K Trainable params 0 Non-trainable params 32.1 K Total params 0.128 Total estimated model params size (MB) /opt/hostedtoolcache/Python/3.10.13/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
`Trainer.fit` stopped: `max_epochs=10` reached.
Out[6]:
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
891e2040007ffff | -0.545477 | -0.106568 | 0.251093 | -0.290328 | 0.322375 | -0.065176 | -0.093790 | -0.635599 | -0.317438 | 0.616847 | ... | 0.158255 | -0.494278 | 0.092005 | 0.066845 | 0.070154 | 0.115346 | 0.121379 | 0.341841 | 0.169828 | 0.360102 |
891e2040013ffff | -0.132893 | -0.064373 | 0.382712 | -0.441649 | 0.575632 | -0.147125 | -0.088596 | -0.648478 | -0.466033 | 0.743535 | ... | 0.387957 | -0.809826 | 0.171693 | -0.282940 | 0.032247 | 0.145852 | 0.485494 | 0.479922 | 0.199453 | 0.306769 |
891e2040017ffff | -0.232802 | -0.189104 | 0.209663 | -0.269215 | 0.277312 | -0.140585 | 0.085083 | -0.542930 | -0.312318 | 0.444084 | ... | 0.143012 | -0.473011 | 0.140142 | -0.017009 | -0.073534 | 0.089504 | 0.375983 | 0.357249 | 0.131582 | 0.274925 |
891e2040023ffff | -0.249668 | -0.340604 | 0.431228 | -0.125087 | 0.523017 | -0.160585 | 0.063273 | -0.398711 | -0.132480 | 0.590603 | ... | 0.153102 | -0.562343 | 0.144597 | -0.013573 | -0.050746 | 0.290633 | 0.417697 | 0.113516 | -0.024347 | 0.440837 |
891e2040027ffff | -0.584167 | -0.274672 | 0.215587 | -0.438272 | 0.676389 | 0.097052 | -0.007293 | -0.651651 | -0.333542 | 0.709880 | ... | -0.017163 | -0.645789 | 0.144641 | -0.157250 | -0.023485 | 0.399435 | 0.233828 | 0.202338 | -0.025808 | 0.544677 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
891e2055bc3ffff | -0.592934 | -0.021718 | 0.237075 | -0.314544 | 0.527947 | -0.259927 | 0.097363 | -0.761237 | -0.337783 | 0.694507 | ... | 0.436733 | -0.623624 | -0.004663 | -0.253114 | -0.313103 | 0.430642 | 0.185390 | 0.396637 | 0.286581 | 0.720450 |
891e2055bc7ffff | -0.386026 | 0.125749 | 0.265474 | -0.369776 | 0.830519 | -0.435244 | 0.186564 | -0.583105 | -0.435166 | 0.672402 | ... | 0.320550 | -0.651989 | 0.038463 | -0.251759 | -0.405288 | 0.452297 | 0.082378 | 0.379454 | 0.207687 | 0.457752 |
891e2055bcbffff | -0.805282 | -0.088701 | 0.301848 | -0.287935 | 0.498316 | -0.055782 | 0.130608 | -0.990803 | -0.297827 | 0.663824 | ... | 0.575841 | -0.548880 | -0.175065 | -0.379298 | -0.291409 | 0.592270 | 0.294457 | 0.434415 | 0.443008 | 0.955718 |
891e205a967ffff | -0.655053 | 0.103695 | 0.218895 | -0.360505 | 0.352006 | 0.209300 | -0.107797 | -0.283862 | -0.413920 | 0.509821 | ... | -0.275728 | -0.342822 | 0.353349 | -0.009116 | 0.190810 | 0.057366 | 0.206863 | 0.235443 | 0.010426 | 0.534805 |
891e205a9a7ffff | -1.380033 | -1.405728 | 0.762203 | -1.193572 | 1.043738 | -0.831230 | 0.051409 | -0.830489 | -0.477262 | 1.660393 | ... | 0.691987 | -0.638992 | -0.095699 | -0.061690 | -0.319811 | 0.465980 | 0.842168 | 1.080342 | 0.085706 | 0.870983 |
2033 rows × 30 columns
In [7]:
Copied!
from sklearn.cluster import KMeans
clusterizer = KMeans(n_clusters=5, random_state=42)
clusterizer.fit(embeddings)
embeddings["cluster"] = clusterizer.labels_
embeddings
from sklearn.cluster import KMeans
clusterizer = KMeans(n_clusters=5, random_state=42)
clusterizer.fit(embeddings)
embeddings["cluster"] = clusterizer.labels_
embeddings
/opt/hostedtoolcache/Python/3.10.13/x64/lib/python3.10/site-packages/sklearn/cluster/_kmeans.py:1412: FutureWarning: The default value of `n_init` will change from 10 to 'auto' in 1.4. Set the value of `n_init` explicitly to suppress the warning super()._check_params_vs_input(X, default_n_init=10)
Out[7]:
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | cluster | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
891e2040007ffff | -0.545477 | -0.106568 | 0.251093 | -0.290328 | 0.322375 | -0.065176 | -0.093790 | -0.635599 | -0.317438 | 0.616847 | ... | -0.494278 | 0.092005 | 0.066845 | 0.070154 | 0.115346 | 0.121379 | 0.341841 | 0.169828 | 0.360102 | 0 |
891e2040013ffff | -0.132893 | -0.064373 | 0.382712 | -0.441649 | 0.575632 | -0.147125 | -0.088596 | -0.648478 | -0.466033 | 0.743535 | ... | -0.809826 | 0.171693 | -0.282940 | 0.032247 | 0.145852 | 0.485494 | 0.479922 | 0.199453 | 0.306769 | 1 |
891e2040017ffff | -0.232802 | -0.189104 | 0.209663 | -0.269215 | 0.277312 | -0.140585 | 0.085083 | -0.542930 | -0.312318 | 0.444084 | ... | -0.473011 | 0.140142 | -0.017009 | -0.073534 | 0.089504 | 0.375983 | 0.357249 | 0.131582 | 0.274925 | 0 |
891e2040023ffff | -0.249668 | -0.340604 | 0.431228 | -0.125087 | 0.523017 | -0.160585 | 0.063273 | -0.398711 | -0.132480 | 0.590603 | ... | -0.562343 | 0.144597 | -0.013573 | -0.050746 | 0.290633 | 0.417697 | 0.113516 | -0.024347 | 0.440837 | 3 |
891e2040027ffff | -0.584167 | -0.274672 | 0.215587 | -0.438272 | 0.676389 | 0.097052 | -0.007293 | -0.651651 | -0.333542 | 0.709880 | ... | -0.645789 | 0.144641 | -0.157250 | -0.023485 | 0.399435 | 0.233828 | 0.202338 | -0.025808 | 0.544677 | 0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
891e2055bc3ffff | -0.592934 | -0.021718 | 0.237075 | -0.314544 | 0.527947 | -0.259927 | 0.097363 | -0.761237 | -0.337783 | 0.694507 | ... | -0.623624 | -0.004663 | -0.253114 | -0.313103 | 0.430642 | 0.185390 | 0.396637 | 0.286581 | 0.720450 | 2 |
891e2055bc7ffff | -0.386026 | 0.125749 | 0.265474 | -0.369776 | 0.830519 | -0.435244 | 0.186564 | -0.583105 | -0.435166 | 0.672402 | ... | -0.651989 | 0.038463 | -0.251759 | -0.405288 | 0.452297 | 0.082378 | 0.379454 | 0.207687 | 0.457752 | 2 |
891e2055bcbffff | -0.805282 | -0.088701 | 0.301848 | -0.287935 | 0.498316 | -0.055782 | 0.130608 | -0.990803 | -0.297827 | 0.663824 | ... | -0.548880 | -0.175065 | -0.379298 | -0.291409 | 0.592270 | 0.294457 | 0.434415 | 0.443008 | 0.955718 | 4 |
891e205a967ffff | -0.655053 | 0.103695 | 0.218895 | -0.360505 | 0.352006 | 0.209300 | -0.107797 | -0.283862 | -0.413920 | 0.509821 | ... | -0.342822 | 0.353349 | -0.009116 | 0.190810 | 0.057366 | 0.206863 | 0.235443 | 0.010426 | 0.534805 | 0 |
891e205a9a7ffff | -1.380033 | -1.405728 | 0.762203 | -1.193572 | 1.043738 | -0.831230 | 0.051409 | -0.830489 | -0.477262 | 1.660393 | ... | -0.638992 | -0.095699 | -0.061690 | -0.319811 | 0.465980 | 0.842168 | 1.080342 | 0.085706 | 0.870983 | 4 |
2033 rows × 31 columns
In [8]:
Copied!
plot_numeric_data(regions_gdf, "cluster", embeddings)
plot_numeric_data(regions_gdf, "cluster", embeddings)
Out[8]:
Make this Notebook Trusted to load map: File -> Trust Notebook