Highway2Vec Embedder¶
In [1]:
Copied!
from IPython.display import display
from srai.plotting import plot_numeric_data, plot_regions
from IPython.display import display
from srai.plotting import plot_numeric_data, plot_regions
Get an area to embed¶
In [2]:
Copied!
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Wrocław, PL")
plot_regions(area_gdf, tiles_style="CartoDB positron")
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Wrocław, PL")
plot_regions(area_gdf, tiles_style="CartoDB positron")
Out[2]:
Make this Notebook Trusted to load map: File -> Trust Notebook
Regionalize the area with a regionalizer¶
In [3]:
Copied!
from srai.regionalizers import H3Regionalizer
regionalizer = H3Regionalizer(9)
regions_gdf = regionalizer.transform(area_gdf)
print(len(regions_gdf))
display(regions_gdf.head(3))
plot_regions(regions_gdf, tiles_style="CartoDB positron")
from srai.regionalizers import H3Regionalizer
regionalizer = H3Regionalizer(9)
regions_gdf = regionalizer.transform(area_gdf)
print(len(regions_gdf))
display(regions_gdf.head(3))
plot_regions(regions_gdf, tiles_style="CartoDB positron")
3168
| geometry | |
|---|---|
| region_id | |
| 891e2044647ffff | POLYGON ((17.1449 51.11772, 17.14493 51.11604,... |
| 891e2042d07ffff | POLYGON ((17.0308 51.17242, 17.03083 51.17074,... |
| 891e2042913ffff | POLYGON ((17.04116 51.15759, 17.04119 51.15591... |
Out[3]:
Make this Notebook Trusted to load map: File -> Trust Notebook
Download a road infrastructure for the area¶
In [4]:
Copied!
from srai.loaders import OSMNetworkType, OSMWayLoader
loader = OSMWayLoader(OSMNetworkType.DRIVE)
nodes_gdf, edges_gdf = loader.load(area_gdf)
display(nodes_gdf.head(3))
display(edges_gdf.head(3))
ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))
nodes_gdf.plot(ax=ax, markersize=3, color="red")
from srai.loaders import OSMNetworkType, OSMWayLoader
loader = OSMWayLoader(OSMNetworkType.DRIVE)
nodes_gdf, edges_gdf = loader.load(area_gdf)
display(nodes_gdf.head(3))
display(edges_gdf.head(3))
ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))
nodes_gdf.plot(ax=ax, markersize=3, color="red")
| y | x | street_count | highway | railway | ref | geometry | |
|---|---|---|---|---|---|---|---|
| osmid | |||||||
| 95584835 | 51.083111 | 17.049513 | 4 | NaN | NaN | NaN | POINT (17.04951 51.08311) |
| 95584841 | 51.084686 | 17.064329 | 3 | NaN | NaN | NaN | POINT (17.06433 51.08469) |
| 95584850 | 51.083328 | 17.035057 | 4 | NaN | NaN | NaN | POINT (17.03506 51.08333) |
| oneway | lanes-1 | lanes-2 | lanes-3 | lanes-4 | lanes-5 | lanes-6 | lanes-7 | lanes-8 | lanes-9 | ... | bicycle-official | lit-yes | lit-no | lit-sunset-sunrise | lit-24/7 | lit-automatic | lit-disused | lit-limited | lit-interval | geometry | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| feature_id | |||||||||||||||||||||
| 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.04951 51.08311, 17.04932 51.083... |
| 1 | 0 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.04951 51.08311, 17.04947 51.083... |
| 2 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.05357 51.08301, 17.05335 51.082... |
3 rows × 219 columns
Out[4]:
<Axes: >
Find out which edges correspond to which regions¶
In [5]:
Copied!
from srai.joiners import IntersectionJoiner
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, edges_gdf)
joint_gdf
from srai.joiners import IntersectionJoiner
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, edges_gdf)
joint_gdf
Out[5]:
| region_id | feature_id |
|---|---|
| 891e2040b03ffff | 0 |
| 1 | |
| 2 | |
| 891e2040b07ffff | 2 |
| 891e2040b03ffff | 3 |
| ... | ... |
| 891e20403dbffff | 10250 |
| 891e2047307ffff | 10251 |
| 10252 | |
| 10253 | |
| 891e2040c33ffff | 10254 |
15634 rows × 0 columns
Get the embeddings for regions based on the road infrastructure¶
In [6]:
Copied!
from pytorch_lightning import seed_everything
from srai.embedders import Highway2VecEmbedder
seed_everything(42)
embedder = Highway2VecEmbedder()
embeddings = embedder.fit_transform(regions_gdf, edges_gdf, joint_gdf)
embeddings
from pytorch_lightning import seed_everything
from srai.embedders import Highway2VecEmbedder
seed_everything(42)
embedder = Highway2VecEmbedder()
embeddings = embedder.fit_transform(regions_gdf, edges_gdf, joint_gdf)
embeddings
Seed set to 42
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default | Name | Type | Params | Mode ----------------------------------------------- 0 | encoder | Sequential | 16.0 K | train 1 | decoder | Sequential | 16.2 K | train ----------------------------------------------- 32.1 K Trainable params 0 Non-trainable params 32.1 K Total params 0.128 Total estimated model params size (MB) 8 Modules in train mode 0 Modules in eval mode
/opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
`Trainer.fit` stopped: `max_epochs=10` reached.
/opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/srai/embedders/highway2vec/embedder.py:75: FutureWarning: The behavior of array concatenation with empty entries is deprecated. In a future version, this will no longer exclude empty items when determining the result dtype. To retain the old behavior, exclude the empty entries before the concat operation. embeddings_joint = joint_gdf.join(embeddings_df)
Out[6]:
| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| region_id | |||||||||||||||||||||
| 891e2040007ffff | -0.528359 | -0.164042 | 0.128171 | -0.068827 | 0.312713 | -0.157757 | 0.031731 | -0.669343 | -0.238158 | 0.750430 | ... | 0.353437 | -0.568471 | 0.221183 | 0.012523 | 0.010102 | 0.020913 | 0.244809 | 0.330183 | 0.001661 | 0.329018 |
| 891e2040013ffff | -0.082209 | -0.105255 | 0.162667 | -0.194649 | 0.545473 | -0.065436 | 0.016032 | -0.743190 | -0.393954 | 0.798285 | ... | 0.399343 | -0.717276 | 0.309425 | -0.335921 | 0.002151 | 0.001565 | 0.582529 | 0.362140 | 0.022971 | 0.330358 |
| 891e2040017ffff | -0.234723 | -0.249322 | 0.105677 | -0.137942 | 0.403094 | -0.177673 | 0.141853 | -0.875002 | -0.168599 | 0.653679 | ... | 0.264878 | -0.601499 | 0.290364 | -0.078554 | 0.028421 | 0.022644 | 0.550806 | 0.272872 | -0.020917 | 0.334649 |
| 891e2040023ffff | -0.255435 | -0.375575 | 0.265957 | 0.018996 | 0.507044 | -0.169918 | 0.122459 | -0.475371 | -0.070760 | 0.602137 | ... | 0.163511 | -0.537209 | 0.262486 | -0.001070 | -0.049220 | 0.188635 | 0.471438 | 0.078272 | -0.143121 | 0.452784 |
| 891e2040027ffff | -0.582498 | -0.313046 | 0.062118 | -0.241743 | 0.652053 | 0.030318 | 0.007733 | -0.747071 | -0.302134 | 0.797427 | ... | 0.001563 | -0.629347 | 0.259644 | -0.162560 | -0.026864 | 0.273631 | 0.282109 | 0.175811 | -0.096263 | 0.589438 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 891e2055bc3ffff | -0.501167 | -0.147116 | -0.027027 | -0.050435 | 0.547825 | -0.394404 | 0.091431 | -0.742640 | -0.334403 | 0.799427 | ... | 0.608394 | -0.546142 | 0.094322 | -0.300585 | -0.345335 | 0.199467 | 0.247516 | 0.225200 | -0.013494 | 0.671639 |
| 891e2055bc7ffff | -0.341917 | -0.075936 | -0.023385 | -0.116489 | 0.759822 | -0.530153 | 0.149626 | -0.507811 | -0.332619 | 0.826880 | ... | 0.471789 | -0.572400 | 0.113307 | -0.335588 | -0.392958 | 0.153375 | 0.124247 | 0.203245 | -0.140573 | 0.369904 |
| 891e2055bcbffff | -0.665629 | -0.197789 | -0.026268 | -0.029814 | 0.576673 | -0.314537 | 0.141973 | -0.951958 | -0.324014 | 0.756586 | ... | 0.734397 | -0.471919 | -0.027606 | -0.369756 | -0.414597 | 0.307323 | 0.329392 | 0.233454 | 0.112539 | 0.950219 |
| 891e205a967ffff | -0.659383 | 0.108068 | 0.133118 | -0.234619 | 0.363473 | 0.182678 | -0.092003 | -0.333566 | -0.425726 | 0.542888 | ... | -0.262608 | -0.348275 | 0.414538 | -0.081376 | 0.187144 | 0.036518 | 0.232909 | 0.235828 | 0.006770 | 0.599703 |
| 891e205a9a7ffff | -1.363809 | -1.546547 | 0.523880 | -0.880781 | 0.888798 | -0.883979 | 0.029907 | -0.834964 | -0.430355 | 1.778972 | ... | 0.634769 | -0.639539 | 0.048703 | -0.004560 | -0.337775 | 0.050854 | 0.833152 | 0.957786 | -0.121980 | 0.872835 |
2039 rows × 30 columns
In [7]:
Copied!
from sklearn.cluster import KMeans
clusterizer = KMeans(n_clusters=5, random_state=42)
clusterizer.fit(embeddings)
embeddings["cluster"] = clusterizer.labels_
embeddings
from sklearn.cluster import KMeans
clusterizer = KMeans(n_clusters=5, random_state=42)
clusterizer.fit(embeddings)
embeddings["cluster"] = clusterizer.labels_
embeddings
Out[7]:
| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | cluster | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| region_id | |||||||||||||||||||||
| 891e2040007ffff | -0.528359 | -0.164042 | 0.128171 | -0.068827 | 0.312713 | -0.157757 | 0.031731 | -0.669343 | -0.238158 | 0.750430 | ... | -0.568471 | 0.221183 | 0.012523 | 0.010102 | 0.020913 | 0.244809 | 0.330183 | 0.001661 | 0.329018 | 3 |
| 891e2040013ffff | -0.082209 | -0.105255 | 0.162667 | -0.194649 | 0.545473 | -0.065436 | 0.016032 | -0.743190 | -0.393954 | 0.798285 | ... | -0.717276 | 0.309425 | -0.335921 | 0.002151 | 0.001565 | 0.582529 | 0.362140 | 0.022971 | 0.330358 | 2 |
| 891e2040017ffff | -0.234723 | -0.249322 | 0.105677 | -0.137942 | 0.403094 | -0.177673 | 0.141853 | -0.875002 | -0.168599 | 0.653679 | ... | -0.601499 | 0.290364 | -0.078554 | 0.028421 | 0.022644 | 0.550806 | 0.272872 | -0.020917 | 0.334649 | 2 |
| 891e2040023ffff | -0.255435 | -0.375575 | 0.265957 | 0.018996 | 0.507044 | -0.169918 | 0.122459 | -0.475371 | -0.070760 | 0.602137 | ... | -0.537209 | 0.262486 | -0.001070 | -0.049220 | 0.188635 | 0.471438 | 0.078272 | -0.143121 | 0.452784 | 1 |
| 891e2040027ffff | -0.582498 | -0.313046 | 0.062118 | -0.241743 | 0.652053 | 0.030318 | 0.007733 | -0.747071 | -0.302134 | 0.797427 | ... | -0.629347 | 0.259644 | -0.162560 | -0.026864 | 0.273631 | 0.282109 | 0.175811 | -0.096263 | 0.589438 | 3 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 891e2055bc3ffff | -0.501167 | -0.147116 | -0.027027 | -0.050435 | 0.547825 | -0.394404 | 0.091431 | -0.742640 | -0.334403 | 0.799427 | ... | -0.546142 | 0.094322 | -0.300585 | -0.345335 | 0.199467 | 0.247516 | 0.225200 | -0.013494 | 0.671639 | 4 |
| 891e2055bc7ffff | -0.341917 | -0.075936 | -0.023385 | -0.116489 | 0.759822 | -0.530153 | 0.149626 | -0.507811 | -0.332619 | 0.826880 | ... | -0.572400 | 0.113307 | -0.335588 | -0.392958 | 0.153375 | 0.124247 | 0.203245 | -0.140573 | 0.369904 | 4 |
| 891e2055bcbffff | -0.665629 | -0.197789 | -0.026268 | -0.029814 | 0.576673 | -0.314537 | 0.141973 | -0.951958 | -0.324014 | 0.756586 | ... | -0.471919 | -0.027606 | -0.369756 | -0.414597 | 0.307323 | 0.329392 | 0.233454 | 0.112539 | 0.950219 | 0 |
| 891e205a967ffff | -0.659383 | 0.108068 | 0.133118 | -0.234619 | 0.363473 | 0.182678 | -0.092003 | -0.333566 | -0.425726 | 0.542888 | ... | -0.348275 | 0.414538 | -0.081376 | 0.187144 | 0.036518 | 0.232909 | 0.235828 | 0.006770 | 0.599703 | 3 |
| 891e205a9a7ffff | -1.363809 | -1.546547 | 0.523880 | -0.880781 | 0.888798 | -0.883979 | 0.029907 | -0.834964 | -0.430355 | 1.778972 | ... | -0.639539 | 0.048703 | -0.004560 | -0.337775 | 0.050854 | 0.833152 | 0.957786 | -0.121980 | 0.872835 | 0 |
2039 rows × 31 columns
In [8]:
Copied!
plot_numeric_data(regions_gdf, "cluster", embeddings)
plot_numeric_data(regions_gdf, "cluster", embeddings)
Out[8]:
Make this Notebook Trusted to load map: File -> Trust Notebook