Highway2Vec Embedder¶
In [1]:
Copied!
from IPython.display import display
from srai.plotting import plot_numeric_data, plot_regions
from IPython.display import display
from srai.plotting import plot_numeric_data, plot_regions
Get an area to embed¶
In [2]:
Copied!
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Wrocław, PL")
plot_regions(area_gdf, tiles_style="CartoDB positron")
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Wrocław, PL")
plot_regions(area_gdf, tiles_style="CartoDB positron")
Out[2]:
Make this Notebook Trusted to load map: File -> Trust Notebook
Regionalize the area with a regionalizer¶
In [3]:
Copied!
from srai.regionalizers import H3Regionalizer
regionalizer = H3Regionalizer(9)
regions_gdf = regionalizer.transform(area_gdf)
print(len(regions_gdf))
display(regions_gdf.head(3))
plot_regions(regions_gdf, tiles_style="CartoDB positron")
from srai.regionalizers import H3Regionalizer
regionalizer = H3Regionalizer(9)
regions_gdf = regionalizer.transform(area_gdf)
print(len(regions_gdf))
display(regions_gdf.head(3))
plot_regions(regions_gdf, tiles_style="CartoDB positron")
3168
geometry | |
---|---|
region_id | |
891e204296bffff | POLYGON ((17.04393 51.14506, 17.04397 51.14338... |
891e204022bffff | POLYGON ((16.93214 51.09373, 16.93218 51.09205... |
891e20402b7ffff | POLYGON ((16.93436 51.10643, 16.9344 51.10475,... |
Out[3]:
Make this Notebook Trusted to load map: File -> Trust Notebook
Download a road infrastructure for the area¶
In [4]:
Copied!
from srai.loaders import OSMNetworkType, OSMWayLoader
loader = OSMWayLoader(OSMNetworkType.DRIVE)
nodes_gdf, edges_gdf = loader.load(area_gdf)
display(nodes_gdf.head(3))
display(edges_gdf.head(3))
ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))
nodes_gdf.plot(ax=ax, markersize=3, color="red")
from srai.loaders import OSMNetworkType, OSMWayLoader
loader = OSMWayLoader(OSMNetworkType.DRIVE)
nodes_gdf, edges_gdf = loader.load(area_gdf)
display(nodes_gdf.head(3))
display(edges_gdf.head(3))
ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))
nodes_gdf.plot(ax=ax, markersize=3, color="red")
y | x | street_count | highway | railway | ref | geometry | |
---|---|---|---|---|---|---|---|
osmid | |||||||
95584835 | 51.083111 | 17.049513 | 4 | NaN | NaN | NaN | POINT (17.04951 51.08311) |
95584841 | 51.084699 | 17.064367 | 3 | NaN | NaN | NaN | POINT (17.06437 51.0847) |
95584850 | 51.083328 | 17.035057 | 4 | NaN | NaN | NaN | POINT (17.03506 51.08333) |
oneway | lanes-1 | lanes-2 | lanes-3 | lanes-4 | lanes-5 | lanes-6 | lanes-7 | lanes-8 | lanes-9 | ... | bicycle-official | lit-yes | lit-no | lit-sunset-sunrise | lit-24/7 | lit-automatic | lit-disused | lit-limited | lit-interval | geometry | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
feature_id | |||||||||||||||||||||
0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.04951 51.08311, 17.04947 51.083... |
1 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.04951 51.08311, 17.04933 51.083... |
2 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.05357 51.08301, 17.05335 51.082... |
3 rows × 219 columns
Out[4]:
<Axes: >
Find out which edges correspond to which regions¶
In [5]:
Copied!
from srai.joiners import IntersectionJoiner
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, edges_gdf)
joint_gdf
from srai.joiners import IntersectionJoiner
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, edges_gdf)
joint_gdf
Out[5]:
region_id | feature_id |
---|---|
891e204296bffff | 5792 |
8189 | |
8188 | |
8203 | |
8204 | |
... | ... |
891e204227bffff | 7663 |
7666 | |
7665 | |
7667 | |
9476 |
15722 rows × 0 columns
Get the embeddings for regions based on the road infrastructure¶
In [6]:
Copied!
from pytorch_lightning import seed_everything
from srai.embedders import Highway2VecEmbedder
seed_everything(42)
embedder = Highway2VecEmbedder()
embeddings = embedder.fit_transform(regions_gdf, edges_gdf, joint_gdf)
embeddings
from pytorch_lightning import seed_everything
from srai.embedders import Highway2VecEmbedder
seed_everything(42)
embedder = Highway2VecEmbedder()
embeddings = embedder.fit_transform(regions_gdf, edges_gdf, joint_gdf)
embeddings
Seed set to 42
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default | Name | Type | Params | Mode ----------------------------------------------- 0 | encoder | Sequential | 16.0 K | train 1 | decoder | Sequential | 16.2 K | train ----------------------------------------------- 32.1 K Trainable params 0 Non-trainable params 32.1 K Total params 0.128 Total estimated model params size (MB) 8 Modules in train mode 0 Modules in eval mode
/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
`Trainer.fit` stopped: `max_epochs=10` reached.
/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/srai/embedders/highway2vec/embedder.py:75: FutureWarning: The behavior of array concatenation with empty entries is deprecated. In a future version, this will no longer exclude empty items when determining the result dtype. To retain the old behavior, exclude the empty entries before the concat operation. embeddings_joint = joint_gdf.join(embeddings_df)
Out[6]:
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
891e2040007ffff | -0.479510 | -0.170161 | 0.141499 | -0.131385 | 0.356130 | -0.170253 | 0.082514 | -0.594185 | -0.235572 | 0.742436 | ... | 0.336453 | -0.519993 | 0.133136 | 0.112709 | 0.013159 | -0.010626 | 0.178799 | 0.334447 | 0.163934 | 0.368286 |
891e2040013ffff | -0.053364 | -0.095388 | 0.259534 | -0.273419 | 0.601486 | -0.127698 | 0.008694 | -0.672795 | -0.470622 | 0.827055 | ... | 0.405657 | -0.760753 | 0.165349 | -0.286263 | 0.014819 | 0.039674 | 0.456278 | 0.353700 | 0.205730 | 0.390251 |
891e2040017ffff | -0.223151 | -0.260382 | 0.203482 | -0.188966 | 0.428696 | -0.247260 | 0.226724 | -0.783058 | -0.208041 | 0.601338 | ... | 0.224511 | -0.559403 | 0.128883 | 0.047662 | 0.077394 | 0.010018 | 0.484214 | 0.256647 | 0.183641 | 0.395048 |
891e2040023ffff | -0.230657 | -0.363989 | 0.309760 | -0.017237 | 0.507161 | -0.186252 | 0.154596 | -0.425472 | -0.097725 | 0.611651 | ... | 0.151504 | -0.503910 | 0.142945 | 0.062385 | -0.037322 | 0.164119 | 0.409733 | 0.088270 | -0.011714 | 0.466407 |
891e2040027ffff | -0.540541 | -0.267364 | 0.084434 | -0.310432 | 0.636206 | 0.008829 | 0.107817 | -0.712202 | -0.339696 | 0.760957 | ... | 0.013064 | -0.608031 | 0.147565 | -0.054307 | 0.001965 | 0.230121 | 0.214177 | 0.215097 | 0.073425 | 0.621710 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
891e2055bc3ffff | -0.458865 | -0.024982 | 0.189483 | -0.148232 | 0.570965 | -0.339826 | 0.255847 | -0.743790 | -0.316874 | 0.777996 | ... | 0.528354 | -0.560291 | -0.137184 | -0.185775 | -0.229802 | 0.218840 | 0.220230 | 0.379667 | 0.356566 | 0.749199 |
891e2055bc7ffff | -0.301984 | 0.067865 | 0.129250 | -0.229516 | 0.841595 | -0.542149 | 0.344944 | -0.466103 | -0.370153 | 0.782236 | ... | 0.305966 | -0.580967 | -0.050131 | -0.214895 | -0.340888 | 0.217396 | 0.066982 | 0.312978 | 0.256146 | 0.464525 |
891e2055bcbffff | -0.630770 | -0.058533 | 0.289752 | -0.102037 | 0.559412 | -0.160441 | 0.306311 | -0.977856 | -0.258928 | 0.759188 | ... | 0.734534 | -0.482689 | -0.345045 | -0.264224 | -0.224084 | 0.330036 | 0.322070 | 0.444812 | 0.515427 | 1.013915 |
891e205a967ffff | -0.616801 | 0.123673 | 0.193281 | -0.259642 | 0.319446 | 0.160594 | -0.038908 | -0.363509 | -0.401363 | 0.544136 | ... | -0.279612 | -0.326650 | 0.313171 | 0.047732 | 0.233037 | -0.036650 | 0.232588 | 0.292018 | 0.118049 | 0.574624 |
891e205a9a7ffff | -1.412213 | -1.500952 | 0.435123 | -1.049657 | 1.131710 | -1.103062 | 0.149565 | -0.713480 | -0.489774 | 1.805687 | ... | 0.713913 | -0.570885 | -0.016611 | 0.096510 | -0.397899 | 0.092768 | 0.679317 | 0.989048 | 0.128019 | 0.926884 |
2034 rows × 30 columns
In [7]:
Copied!
from sklearn.cluster import KMeans
clusterizer = KMeans(n_clusters=5, random_state=42)
clusterizer.fit(embeddings)
embeddings["cluster"] = clusterizer.labels_
embeddings
from sklearn.cluster import KMeans
clusterizer = KMeans(n_clusters=5, random_state=42)
clusterizer.fit(embeddings)
embeddings["cluster"] = clusterizer.labels_
embeddings
Out[7]:
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | cluster | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
891e2040007ffff | -0.479510 | -0.170161 | 0.141499 | -0.131385 | 0.356130 | -0.170253 | 0.082514 | -0.594185 | -0.235572 | 0.742436 | ... | -0.519993 | 0.133136 | 0.112709 | 0.013159 | -0.010626 | 0.178799 | 0.334447 | 0.163934 | 0.368286 | 1 |
891e2040013ffff | -0.053364 | -0.095388 | 0.259534 | -0.273419 | 0.601486 | -0.127698 | 0.008694 | -0.672795 | -0.470622 | 0.827055 | ... | -0.760753 | 0.165349 | -0.286263 | 0.014819 | 0.039674 | 0.456278 | 0.353700 | 0.205730 | 0.390251 | 0 |
891e2040017ffff | -0.223151 | -0.260382 | 0.203482 | -0.188966 | 0.428696 | -0.247260 | 0.226724 | -0.783058 | -0.208041 | 0.601338 | ... | -0.559403 | 0.128883 | 0.047662 | 0.077394 | 0.010018 | 0.484214 | 0.256647 | 0.183641 | 0.395048 | 4 |
891e2040023ffff | -0.230657 | -0.363989 | 0.309760 | -0.017237 | 0.507161 | -0.186252 | 0.154596 | -0.425472 | -0.097725 | 0.611651 | ... | -0.503910 | 0.142945 | 0.062385 | -0.037322 | 0.164119 | 0.409733 | 0.088270 | -0.011714 | 0.466407 | 3 |
891e2040027ffff | -0.540541 | -0.267364 | 0.084434 | -0.310432 | 0.636206 | 0.008829 | 0.107817 | -0.712202 | -0.339696 | 0.760957 | ... | -0.608031 | 0.147565 | -0.054307 | 0.001965 | 0.230121 | 0.214177 | 0.215097 | 0.073425 | 0.621710 | 1 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
891e2055bc3ffff | -0.458865 | -0.024982 | 0.189483 | -0.148232 | 0.570965 | -0.339826 | 0.255847 | -0.743790 | -0.316874 | 0.777996 | ... | -0.560291 | -0.137184 | -0.185775 | -0.229802 | 0.218840 | 0.220230 | 0.379667 | 0.356566 | 0.749199 | 2 |
891e2055bc7ffff | -0.301984 | 0.067865 | 0.129250 | -0.229516 | 0.841595 | -0.542149 | 0.344944 | -0.466103 | -0.370153 | 0.782236 | ... | -0.580967 | -0.050131 | -0.214895 | -0.340888 | 0.217396 | 0.066982 | 0.312978 | 0.256146 | 0.464525 | 1 |
891e2055bcbffff | -0.630770 | -0.058533 | 0.289752 | -0.102037 | 0.559412 | -0.160441 | 0.306311 | -0.977856 | -0.258928 | 0.759188 | ... | -0.482689 | -0.345045 | -0.264224 | -0.224084 | 0.330036 | 0.322070 | 0.444812 | 0.515427 | 1.013915 | 2 |
891e205a967ffff | -0.616801 | 0.123673 | 0.193281 | -0.259642 | 0.319446 | 0.160594 | -0.038908 | -0.363509 | -0.401363 | 0.544136 | ... | -0.326650 | 0.313171 | 0.047732 | 0.233037 | -0.036650 | 0.232588 | 0.292018 | 0.118049 | 0.574624 | 1 |
891e205a9a7ffff | -1.412213 | -1.500952 | 0.435123 | -1.049657 | 1.131710 | -1.103062 | 0.149565 | -0.713480 | -0.489774 | 1.805687 | ... | -0.570885 | -0.016611 | 0.096510 | -0.397899 | 0.092768 | 0.679317 | 0.989048 | 0.128019 | 0.926884 | 2 |
2034 rows × 31 columns
In [8]:
Copied!
plot_numeric_data(regions_gdf, "cluster", embeddings)
plot_numeric_data(regions_gdf, "cluster", embeddings)
Out[8]:
Make this Notebook Trusted to load map: File -> Trust Notebook