Highway2Vec Embedder¶
In [1]:
Copied!
from IPython.display import display
from srai.plotting import plot_numeric_data, plot_regions
from IPython.display import display
from srai.plotting import plot_numeric_data, plot_regions
Get an area to embed¶
In [2]:
Copied!
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Wrocław, PL")
plot_regions(area_gdf, tiles_style="CartoDB positron")
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Wrocław, PL")
plot_regions(area_gdf, tiles_style="CartoDB positron")
Out[2]:
Make this Notebook Trusted to load map: File -> Trust Notebook
Regionalize the area with a regionalizer¶
In [3]:
Copied!
from srai.regionalizers import H3Regionalizer
regionalizer = H3Regionalizer(9)
regions_gdf = regionalizer.transform(area_gdf)
print(len(regions_gdf))
display(regions_gdf.head(3))
plot_regions(regions_gdf, tiles_style="CartoDB positron")
from srai.regionalizers import H3Regionalizer
regionalizer = H3Regionalizer(9)
regions_gdf = regionalizer.transform(area_gdf)
print(len(regions_gdf))
display(regions_gdf.head(3))
plot_regions(regions_gdf, tiles_style="CartoDB positron")
3168
geometry | |
---|---|
region_id | |
891e2045663ffff | POLYGON ((17.08297 51.07804, 17.08301 51.07636... |
891e20476a3ffff | POLYGON ((17.07158 51.14335, 17.07162 51.14167... |
891e20433c7ffff | POLYGON ((16.86138 51.1168, 16.86142 51.11512,... |
Out[3]:
Make this Notebook Trusted to load map: File -> Trust Notebook
Download a road infrastructure for the area¶
In [4]:
Copied!
from srai.loaders import OSMNetworkType, OSMWayLoader
loader = OSMWayLoader(OSMNetworkType.DRIVE)
nodes_gdf, edges_gdf = loader.load(area_gdf)
display(nodes_gdf.head(3))
display(edges_gdf.head(3))
ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))
nodes_gdf.plot(ax=ax, markersize=3, color="red")
from srai.loaders import OSMNetworkType, OSMWayLoader
loader = OSMWayLoader(OSMNetworkType.DRIVE)
nodes_gdf, edges_gdf = loader.load(area_gdf)
display(nodes_gdf.head(3))
display(edges_gdf.head(3))
ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))
nodes_gdf.plot(ax=ax, markersize=3, color="red")
y | x | street_count | highway | railway | ref | geometry | |
---|---|---|---|---|---|---|---|
osmid | |||||||
95584835 | 51.083111 | 17.049513 | 4 | NaN | NaN | NaN | POINT (17.04951 51.08311) |
95584841 | 51.084699 | 17.064367 | 3 | NaN | NaN | NaN | POINT (17.06437 51.0847) |
95584850 | 51.083328 | 17.035057 | 4 | NaN | NaN | NaN | POINT (17.03506 51.08333) |
oneway | lanes-1 | lanes-2 | lanes-3 | lanes-4 | lanes-5 | lanes-6 | lanes-7 | lanes-8 | lanes-9 | ... | bicycle-official | lit-yes | lit-no | lit-sunset-sunrise | lit-24/7 | lit-automatic | lit-disused | lit-limited | lit-interval | geometry | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
feature_id | |||||||||||||||||||||
0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.04951 51.08311, 17.04947 51.083... |
1 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.04951 51.08311, 17.04933 51.083... |
2 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.05357 51.08301, 17.05335 51.082... |
3 rows × 219 columns
Out[4]:
<Axes: >
Find out which edges correspond to which regions¶
In [5]:
Copied!
from srai.joiners import IntersectionJoiner
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, edges_gdf)
joint_gdf
from srai.joiners import IntersectionJoiner
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, edges_gdf)
joint_gdf
Out[5]:
region_id | feature_id |
---|---|
891e2040b03ffff | 0 |
1 | |
2 | |
891e2040b07ffff | 2 |
891e2040b03ffff | 3 |
... | ... |
891e2047267ffff | 10332 |
891e20403dbffff | 10333 |
891e2047307ffff | 10334 |
10335 | |
10336 |
15704 rows × 0 columns
Get the embeddings for regions based on the road infrastructure¶
In [6]:
Copied!
from pytorch_lightning import seed_everything
from srai.embedders import Highway2VecEmbedder
seed_everything(42)
embedder = Highway2VecEmbedder()
embeddings = embedder.fit_transform(regions_gdf, edges_gdf, joint_gdf)
embeddings
from pytorch_lightning import seed_everything
from srai.embedders import Highway2VecEmbedder
seed_everything(42)
embedder = Highway2VecEmbedder()
embeddings = embedder.fit_transform(regions_gdf, edges_gdf, joint_gdf)
embeddings
Seed set to 42
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/opt/hostedtoolcache/Python/3.10.16/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default | Name | Type | Params | Mode ----------------------------------------------- 0 | encoder | Sequential | 16.0 K | train 1 | decoder | Sequential | 16.2 K | train ----------------------------------------------- 32.1 K Trainable params 0 Non-trainable params 32.1 K Total params 0.128 Total estimated model params size (MB) 8 Modules in train mode 0 Modules in eval mode
/opt/hostedtoolcache/Python/3.10.16/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
`Trainer.fit` stopped: `max_epochs=10` reached.
/opt/hostedtoolcache/Python/3.10.16/x64/lib/python3.10/site-packages/srai/embedders/highway2vec/embedder.py:75: FutureWarning: The behavior of array concatenation with empty entries is deprecated. In a future version, this will no longer exclude empty items when determining the result dtype. To retain the old behavior, exclude the empty entries before the concat operation. embeddings_joint = joint_gdf.join(embeddings_df)
Out[6]:
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
891e2040007ffff | -0.513893 | -0.145070 | 0.080739 | -0.122525 | 0.356257 | -0.161801 | 0.013215 | -0.604649 | -0.221475 | 0.751927 | ... | 0.357176 | -0.524709 | 0.196952 | 0.077800 | -0.007118 | 0.007151 | 0.142733 | 0.345968 | 0.065985 | 0.394128 |
891e2040013ffff | -0.081005 | -0.095910 | 0.152072 | -0.233143 | 0.587691 | -0.045076 | -0.037996 | -0.640235 | -0.386453 | 0.818154 | ... | 0.417790 | -0.691335 | 0.286598 | -0.330986 | -0.027983 | 0.018665 | 0.410494 | 0.344340 | 0.039094 | 0.352501 |
891e2040017ffff | -0.268360 | -0.248471 | 0.073042 | -0.152062 | 0.436663 | -0.191000 | 0.138760 | -0.784994 | -0.180976 | 0.629028 | ... | 0.247412 | -0.577053 | 0.279971 | -0.015680 | 0.011789 | 0.058033 | 0.425232 | 0.229334 | 0.034790 | 0.415501 |
891e2040023ffff | -0.262698 | -0.385015 | 0.226048 | -0.007230 | 0.479784 | -0.158262 | 0.095209 | -0.398210 | -0.082673 | 0.621222 | ... | 0.155291 | -0.504012 | 0.233585 | 0.027048 | -0.091137 | 0.180363 | 0.374557 | 0.069816 | -0.121347 | 0.484671 |
891e2040027ffff | -0.596365 | -0.279333 | -0.018903 | -0.264261 | 0.613303 | 0.053933 | 0.022079 | -0.683633 | -0.317850 | 0.775937 | ... | 0.022755 | -0.617145 | 0.251375 | -0.108733 | -0.072900 | 0.261571 | 0.145573 | 0.196485 | -0.082726 | 0.639100 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
891e2055bc3ffff | -0.533033 | -0.081049 | 0.031047 | -0.068882 | 0.564926 | -0.296095 | 0.064596 | -0.613725 | -0.302779 | 0.821448 | ... | 0.591887 | -0.549273 | 0.051332 | -0.291371 | -0.344958 | 0.220053 | 0.086800 | 0.332113 | 0.168488 | 0.820158 |
891e2055bc7ffff | -0.373421 | 0.038110 | 0.004133 | -0.112134 | 0.858200 | -0.509740 | 0.139249 | -0.316913 | -0.305972 | 0.841957 | ... | 0.408013 | -0.552353 | 0.124785 | -0.301819 | -0.465245 | 0.226447 | -0.068748 | 0.254168 | 0.050024 | 0.530532 |
891e2055bcbffff | -0.714216 | -0.147771 | 0.074821 | -0.051116 | 0.533600 | -0.135746 | 0.084340 | -0.830490 | -0.282055 | 0.798509 | ... | 0.767942 | -0.487200 | -0.111118 | -0.384844 | -0.371428 | 0.305901 | 0.163340 | 0.400254 | 0.307522 | 1.120561 |
891e205a967ffff | -0.647208 | 0.116226 | 0.121393 | -0.243000 | 0.310114 | 0.216284 | -0.068945 | -0.350502 | -0.429043 | 0.526084 | ... | -0.257188 | -0.357095 | 0.382777 | -0.027045 | 0.215643 | -0.010033 | 0.221585 | 0.264183 | 0.042259 | 0.597591 |
891e205a9a7ffff | -1.469863 | -1.499591 | 0.310515 | -1.012978 | 0.908214 | -1.053621 | -0.170075 | -0.589309 | -0.485389 | 1.947888 | ... | 0.703161 | -0.632688 | -0.038029 | 0.061385 | -0.530599 | 0.153408 | 0.469294 | 0.964734 | -0.039563 | 1.023270 |
2036 rows × 30 columns
In [7]:
Copied!
from sklearn.cluster import KMeans
clusterizer = KMeans(n_clusters=5, random_state=42)
clusterizer.fit(embeddings)
embeddings["cluster"] = clusterizer.labels_
embeddings
from sklearn.cluster import KMeans
clusterizer = KMeans(n_clusters=5, random_state=42)
clusterizer.fit(embeddings)
embeddings["cluster"] = clusterizer.labels_
embeddings
Out[7]:
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | cluster | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
891e2040007ffff | -0.513893 | -0.145070 | 0.080739 | -0.122525 | 0.356257 | -0.161801 | 0.013215 | -0.604649 | -0.221475 | 0.751927 | ... | -0.524709 | 0.196952 | 0.077800 | -0.007118 | 0.007151 | 0.142733 | 0.345968 | 0.065985 | 0.394128 | 3 |
891e2040013ffff | -0.081005 | -0.095910 | 0.152072 | -0.233143 | 0.587691 | -0.045076 | -0.037996 | -0.640235 | -0.386453 | 0.818154 | ... | -0.691335 | 0.286598 | -0.330986 | -0.027983 | 0.018665 | 0.410494 | 0.344340 | 0.039094 | 0.352501 | 3 |
891e2040017ffff | -0.268360 | -0.248471 | 0.073042 | -0.152062 | 0.436663 | -0.191000 | 0.138760 | -0.784994 | -0.180976 | 0.629028 | ... | -0.577053 | 0.279971 | -0.015680 | 0.011789 | 0.058033 | 0.425232 | 0.229334 | 0.034790 | 0.415501 | 3 |
891e2040023ffff | -0.262698 | -0.385015 | 0.226048 | -0.007230 | 0.479784 | -0.158262 | 0.095209 | -0.398210 | -0.082673 | 0.621222 | ... | -0.504012 | 0.233585 | 0.027048 | -0.091137 | 0.180363 | 0.374557 | 0.069816 | -0.121347 | 0.484671 | 4 |
891e2040027ffff | -0.596365 | -0.279333 | -0.018903 | -0.264261 | 0.613303 | 0.053933 | 0.022079 | -0.683633 | -0.317850 | 0.775937 | ... | -0.617145 | 0.251375 | -0.108733 | -0.072900 | 0.261571 | 0.145573 | 0.196485 | -0.082726 | 0.639100 | 1 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
891e2055bc3ffff | -0.533033 | -0.081049 | 0.031047 | -0.068882 | 0.564926 | -0.296095 | 0.064596 | -0.613725 | -0.302779 | 0.821448 | ... | -0.549273 | 0.051332 | -0.291371 | -0.344958 | 0.220053 | 0.086800 | 0.332113 | 0.168488 | 0.820158 | 1 |
891e2055bc7ffff | -0.373421 | 0.038110 | 0.004133 | -0.112134 | 0.858200 | -0.509740 | 0.139249 | -0.316913 | -0.305972 | 0.841957 | ... | -0.552353 | 0.124785 | -0.301819 | -0.465245 | 0.226447 | -0.068748 | 0.254168 | 0.050024 | 0.530532 | 1 |
891e2055bcbffff | -0.714216 | -0.147771 | 0.074821 | -0.051116 | 0.533600 | -0.135746 | 0.084340 | -0.830490 | -0.282055 | 0.798509 | ... | -0.487200 | -0.111118 | -0.384844 | -0.371428 | 0.305901 | 0.163340 | 0.400254 | 0.307522 | 1.120561 | 2 |
891e205a967ffff | -0.647208 | 0.116226 | 0.121393 | -0.243000 | 0.310114 | 0.216284 | -0.068945 | -0.350502 | -0.429043 | 0.526084 | ... | -0.357095 | 0.382777 | -0.027045 | 0.215643 | -0.010033 | 0.221585 | 0.264183 | 0.042259 | 0.597591 | 3 |
891e205a9a7ffff | -1.469863 | -1.499591 | 0.310515 | -1.012978 | 0.908214 | -1.053621 | -0.170075 | -0.589309 | -0.485389 | 1.947888 | ... | -0.632688 | -0.038029 | 0.061385 | -0.530599 | 0.153408 | 0.469294 | 0.964734 | -0.039563 | 1.023270 | 2 |
2036 rows × 31 columns