Highway2Vec Embedder¶
In [1]:
Copied!
from IPython.display import display
from srai.plotting import plot_numeric_data, plot_regions
from IPython.display import display
from srai.plotting import plot_numeric_data, plot_regions
Get an area to embed¶
In [2]:
Copied!
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Wrocław, PL")
plot_regions(area_gdf, tiles_style="CartoDB positron")
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Wrocław, PL")
plot_regions(area_gdf, tiles_style="CartoDB positron")
Out[2]:
Make this Notebook Trusted to load map: File -> Trust Notebook
Regionalize the area with a regionalizer¶
In [3]:
Copied!
from srai.regionalizers import H3Regionalizer
regionalizer = H3Regionalizer(9)
regions_gdf = regionalizer.transform(area_gdf)
print(len(regions_gdf))
display(regions_gdf.head(3))
plot_regions(regions_gdf, tiles_style="CartoDB positron")
from srai.regionalizers import H3Regionalizer
regionalizer = H3Regionalizer(9)
regions_gdf = regionalizer.transform(area_gdf)
print(len(regions_gdf))
display(regions_gdf.head(3))
plot_regions(regions_gdf, tiles_style="CartoDB positron")
3168
geometry | |
---|---|
region_id | |
891e2042b57ffff | POLYGON ((17.01906 51.13423, 17.0191 51.13255,... |
891e2047073ffff | POLYGON ((17.09688 51.134, 17.09692 51.13232, ... |
891e204220fffff | POLYGON ((16.9058 51.14846, 16.90584 51.14678,... |
Out[3]:
Make this Notebook Trusted to load map: File -> Trust Notebook
Download a road infrastructure for the area¶
In [4]:
Copied!
from srai.loaders import OSMNetworkType, OSMWayLoader
loader = OSMWayLoader(OSMNetworkType.DRIVE)
nodes_gdf, edges_gdf = loader.load(area_gdf)
display(nodes_gdf.head(3))
display(edges_gdf.head(3))
ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))
nodes_gdf.plot(ax=ax, markersize=3, color="red")
from srai.loaders import OSMNetworkType, OSMWayLoader
loader = OSMWayLoader(OSMNetworkType.DRIVE)
nodes_gdf, edges_gdf = loader.load(area_gdf)
display(nodes_gdf.head(3))
display(edges_gdf.head(3))
ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))
nodes_gdf.plot(ax=ax, markersize=3, color="red")
y | x | street_count | highway | railway | ref | geometry | |
---|---|---|---|---|---|---|---|
osmid | |||||||
95584835 | 51.083111 | 17.049513 | 4 | NaN | NaN | NaN | POINT (17.04951 51.08311) |
95584841 | 51.084699 | 17.064367 | 3 | NaN | NaN | NaN | POINT (17.06437 51.0847) |
95584850 | 51.083328 | 17.035057 | 4 | NaN | NaN | NaN | POINT (17.03506 51.08333) |
oneway | lanes-1 | lanes-2 | lanes-3 | lanes-4 | lanes-5 | lanes-6 | lanes-7 | lanes-8 | lanes-9 | ... | bicycle-official | lit-yes | lit-no | lit-sunset-sunrise | lit-24/7 | lit-automatic | lit-disused | lit-limited | lit-interval | geometry | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
feature_id | |||||||||||||||||||||
0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.04951 51.08311, 17.04933 51.083... |
1 | 0 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.04951 51.08311, 17.04947 51.083... |
2 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.05357 51.08301, 17.05335 51.082... |
3 rows × 219 columns
Out[4]:
<Axes: >
Find out which edges correspond to which regions¶
In [5]:
Copied!
from srai.joiners import IntersectionJoiner
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, edges_gdf)
joint_gdf
from srai.joiners import IntersectionJoiner
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, edges_gdf)
joint_gdf
Out[5]:
region_id | feature_id |
---|---|
891e2040b03ffff | 0 |
1 | |
2 | |
891e2040b07ffff | 2 |
891e2040b03ffff | 3 |
... | ... |
891e2047267ffff | 10285 |
891e20403dbffff | 10286 |
891e2047307ffff | 10287 |
10288 | |
10289 |
15659 rows × 0 columns
Get the embeddings for regions based on the road infrastructure¶
In [6]:
Copied!
from pytorch_lightning import seed_everything
from srai.embedders import Highway2VecEmbedder
seed_everything(42)
embedder = Highway2VecEmbedder()
embeddings = embedder.fit_transform(regions_gdf, edges_gdf, joint_gdf)
embeddings
from pytorch_lightning import seed_everything
from srai.embedders import Highway2VecEmbedder
seed_everything(42)
embedder = Highway2VecEmbedder()
embeddings = embedder.fit_transform(regions_gdf, edges_gdf, joint_gdf)
embeddings
Seed set to 42
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/opt/hostedtoolcache/Python/3.10.16/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default | Name | Type | Params | Mode ----------------------------------------------- 0 | encoder | Sequential | 16.0 K | train 1 | decoder | Sequential | 16.2 K | train ----------------------------------------------- 32.1 K Trainable params 0 Non-trainable params 32.1 K Total params 0.128 Total estimated model params size (MB) 8 Modules in train mode 0 Modules in eval mode
/opt/hostedtoolcache/Python/3.10.16/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
`Trainer.fit` stopped: `max_epochs=10` reached.
/opt/hostedtoolcache/Python/3.10.16/x64/lib/python3.10/site-packages/srai/embedders/highway2vec/embedder.py:75: FutureWarning: The behavior of array concatenation with empty entries is deprecated. In a future version, this will no longer exclude empty items when determining the result dtype. To retain the old behavior, exclude the empty entries before the concat operation. embeddings_joint = joint_gdf.join(embeddings_df)
Out[6]:
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
891e2040007ffff | -0.478026 | -0.113127 | 0.126128 | -0.128388 | 0.446125 | -0.156285 | 0.088457 | -0.625201 | -0.214197 | 0.783338 | ... | 0.316884 | -0.509442 | 0.166196 | 0.075310 | -0.034963 | 0.021978 | 0.192447 | 0.390332 | 0.139028 | 0.362403 |
891e2040013ffff | -0.063860 | -0.076059 | 0.150680 | -0.266818 | 0.727081 | -0.072893 | 0.043731 | -0.607804 | -0.398895 | 0.860228 | ... | 0.340022 | -0.667463 | 0.252368 | -0.325111 | -0.086180 | 0.026638 | 0.462096 | 0.367568 | 0.112972 | 0.333759 |
891e2040017ffff | -0.245016 | -0.210927 | 0.118392 | -0.154740 | 0.534677 | -0.223711 | 0.186089 | -0.802900 | -0.177486 | 0.658790 | ... | 0.196308 | -0.544187 | 0.254516 | 0.002243 | -0.053540 | 0.073068 | 0.466818 | 0.246125 | 0.095184 | 0.374472 |
891e2040023ffff | -0.249175 | -0.345110 | 0.255491 | -0.007640 | 0.596295 | -0.187191 | 0.162686 | -0.405697 | -0.078213 | 0.651963 | ... | 0.113870 | -0.501503 | 0.224948 | 0.037005 | -0.136227 | 0.205126 | 0.390937 | 0.081955 | -0.044320 | 0.455348 |
891e2040027ffff | -0.595245 | -0.246802 | 0.041521 | -0.256869 | 0.719737 | 0.033925 | 0.080786 | -0.695064 | -0.301172 | 0.815772 | ... | -0.006877 | -0.591811 | 0.230054 | -0.092612 | -0.113522 | 0.283893 | 0.191670 | 0.223959 | -0.015963 | 0.590438 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
891e2055bc3ffff | -0.495092 | -0.004143 | 0.046024 | -0.071930 | 0.736199 | -0.383215 | 0.133914 | -0.643929 | -0.311601 | 0.836389 | ... | 0.560938 | -0.523382 | 0.065871 | -0.296830 | -0.439479 | 0.277858 | 0.120116 | 0.336617 | 0.200686 | 0.774149 |
891e2055bc7ffff | -0.371683 | 0.075762 | 0.052570 | -0.125076 | 0.988465 | -0.519015 | 0.206818 | -0.360221 | -0.324198 | 0.882784 | ... | 0.416953 | -0.544981 | 0.092606 | -0.298336 | -0.516742 | 0.293110 | 0.001199 | 0.301829 | 0.086672 | 0.475212 |
891e2055bcbffff | -0.648635 | -0.020019 | 0.090373 | -0.047570 | 0.759681 | -0.273785 | 0.171760 | -0.861461 | -0.302369 | 0.779831 | ... | 0.725614 | -0.438793 | -0.070551 | -0.398328 | -0.475414 | 0.384725 | 0.172158 | 0.395729 | 0.355014 | 1.064769 |
891e205a967ffff | -0.652508 | 0.142373 | 0.134960 | -0.233271 | 0.383352 | 0.164160 | -0.037181 | -0.356311 | -0.416033 | 0.549010 | ... | -0.275853 | -0.330454 | 0.397011 | 0.017968 | 0.163322 | -0.020132 | 0.220784 | 0.250572 | 0.063603 | 0.557158 |
891e205a9a7ffff | -1.331521 | -1.388956 | 0.461944 | -0.931037 | 1.022121 | -1.029820 | 0.045572 | -0.621947 | -0.418444 | 2.011283 | ... | 0.594184 | -0.638899 | -0.112642 | -0.005163 | -0.585824 | 0.221812 | 0.611300 | 1.047100 | 0.111776 | 0.974985 |
2037 rows × 30 columns
In [7]:
Copied!
from sklearn.cluster import KMeans
clusterizer = KMeans(n_clusters=5, random_state=42)
clusterizer.fit(embeddings)
embeddings["cluster"] = clusterizer.labels_
embeddings
from sklearn.cluster import KMeans
clusterizer = KMeans(n_clusters=5, random_state=42)
clusterizer.fit(embeddings)
embeddings["cluster"] = clusterizer.labels_
embeddings
Out[7]:
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | cluster | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
891e2040007ffff | -0.478026 | -0.113127 | 0.126128 | -0.128388 | 0.446125 | -0.156285 | 0.088457 | -0.625201 | -0.214197 | 0.783338 | ... | -0.509442 | 0.166196 | 0.075310 | -0.034963 | 0.021978 | 0.192447 | 0.390332 | 0.139028 | 0.362403 | 4 |
891e2040013ffff | -0.063860 | -0.076059 | 0.150680 | -0.266818 | 0.727081 | -0.072893 | 0.043731 | -0.607804 | -0.398895 | 0.860228 | ... | -0.667463 | 0.252368 | -0.325111 | -0.086180 | 0.026638 | 0.462096 | 0.367568 | 0.112972 | 0.333759 | 3 |
891e2040017ffff | -0.245016 | -0.210927 | 0.118392 | -0.154740 | 0.534677 | -0.223711 | 0.186089 | -0.802900 | -0.177486 | 0.658790 | ... | -0.544187 | 0.254516 | 0.002243 | -0.053540 | 0.073068 | 0.466818 | 0.246125 | 0.095184 | 0.374472 | 3 |
891e2040023ffff | -0.249175 | -0.345110 | 0.255491 | -0.007640 | 0.596295 | -0.187191 | 0.162686 | -0.405697 | -0.078213 | 0.651963 | ... | -0.501503 | 0.224948 | 0.037005 | -0.136227 | 0.205126 | 0.390937 | 0.081955 | -0.044320 | 0.455348 | 1 |
891e2040027ffff | -0.595245 | -0.246802 | 0.041521 | -0.256869 | 0.719737 | 0.033925 | 0.080786 | -0.695064 | -0.301172 | 0.815772 | ... | -0.591811 | 0.230054 | -0.092612 | -0.113522 | 0.283893 | 0.191670 | 0.223959 | -0.015963 | 0.590438 | 4 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
891e2055bc3ffff | -0.495092 | -0.004143 | 0.046024 | -0.071930 | 0.736199 | -0.383215 | 0.133914 | -0.643929 | -0.311601 | 0.836389 | ... | -0.523382 | 0.065871 | -0.296830 | -0.439479 | 0.277858 | 0.120116 | 0.336617 | 0.200686 | 0.774149 | 2 |
891e2055bc7ffff | -0.371683 | 0.075762 | 0.052570 | -0.125076 | 0.988465 | -0.519015 | 0.206818 | -0.360221 | -0.324198 | 0.882784 | ... | -0.544981 | 0.092606 | -0.298336 | -0.516742 | 0.293110 | 0.001199 | 0.301829 | 0.086672 | 0.475212 | 2 |
891e2055bcbffff | -0.648635 | -0.020019 | 0.090373 | -0.047570 | 0.759681 | -0.273785 | 0.171760 | -0.861461 | -0.302369 | 0.779831 | ... | -0.438793 | -0.070551 | -0.398328 | -0.475414 | 0.384725 | 0.172158 | 0.395729 | 0.355014 | 1.064769 | 2 |
891e205a967ffff | -0.652508 | 0.142373 | 0.134960 | -0.233271 | 0.383352 | 0.164160 | -0.037181 | -0.356311 | -0.416033 | 0.549010 | ... | -0.330454 | 0.397011 | 0.017968 | 0.163322 | -0.020132 | 0.220784 | 0.250572 | 0.063603 | 0.557158 | 4 |
891e205a9a7ffff | -1.331521 | -1.388956 | 0.461944 | -0.931037 | 1.022121 | -1.029820 | 0.045572 | -0.621947 | -0.418444 | 2.011283 | ... | -0.638899 | -0.112642 | -0.005163 | -0.585824 | 0.221812 | 0.611300 | 1.047100 | 0.111776 | 0.974985 | 2 |
2037 rows × 31 columns