Highway2Vec Embedder¶
In [1]:
Copied!
from IPython.display import display
from srai.plotting import plot_numeric_data, plot_regions
from IPython.display import display
from srai.plotting import plot_numeric_data, plot_regions
Get an area to embed¶
In [2]:
Copied!
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Wrocław, PL")
plot_regions(area_gdf, tiles_style="CartoDB positron")
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Wrocław, PL")
plot_regions(area_gdf, tiles_style="CartoDB positron")
Out[2]:
Make this Notebook Trusted to load map: File -> Trust Notebook
Regionalize the area with a regionalizer¶
In [3]:
Copied!
from srai.regionalizers import H3Regionalizer
regionalizer = H3Regionalizer(9)
regions_gdf = regionalizer.transform(area_gdf)
print(len(regions_gdf))
display(regions_gdf.head(3))
plot_regions(regions_gdf, tiles_style="CartoDB positron")
from srai.regionalizers import H3Regionalizer
regionalizer = H3Regionalizer(9)
regions_gdf = regionalizer.transform(area_gdf)
print(len(regions_gdf))
display(regions_gdf.head(3))
plot_regions(regions_gdf, tiles_style="CartoDB positron")
3168
geometry | |
---|---|
region_id | |
891e204553bffff | POLYGON ((17.14018 51.10243, 17.14022 51.10075... |
891e20455c3ffff | POLYGON ((17.12002 51.1069, 17.12005 51.10522,... |
891e20473d3ffff | POLYGON ((17.09954 51.1265, 17.09958 51.12482,... |
Out[3]:
Make this Notebook Trusted to load map: File -> Trust Notebook
Download a road infrastructure for the area¶
In [4]:
Copied!
from srai.loaders import OSMNetworkType, OSMWayLoader
loader = OSMWayLoader(OSMNetworkType.DRIVE)
nodes_gdf, edges_gdf = loader.load(area_gdf)
display(nodes_gdf.head(3))
display(edges_gdf.head(3))
ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))
nodes_gdf.plot(ax=ax, markersize=3, color="red")
from srai.loaders import OSMNetworkType, OSMWayLoader
loader = OSMWayLoader(OSMNetworkType.DRIVE)
nodes_gdf, edges_gdf = loader.load(area_gdf)
display(nodes_gdf.head(3))
display(edges_gdf.head(3))
ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))
nodes_gdf.plot(ax=ax, markersize=3, color="red")
y | x | street_count | highway | railway | ref | geometry | |
---|---|---|---|---|---|---|---|
osmid | |||||||
95584835 | 51.083111 | 17.049513 | 4 | NaN | NaN | NaN | POINT (17.04951 51.08311) |
95584841 | 51.084686 | 17.064329 | 3 | NaN | NaN | NaN | POINT (17.06433 51.08469) |
95584850 | 51.083328 | 17.035057 | 4 | NaN | NaN | NaN | POINT (17.03506 51.08333) |
oneway | lanes-1 | lanes-2 | lanes-3 | lanes-4 | lanes-5 | lanes-6 | lanes-7 | lanes-8 | lanes-9 | ... | bicycle-official | lit-yes | lit-no | lit-sunset-sunrise | lit-24/7 | lit-automatic | lit-disused | lit-limited | lit-interval | geometry | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
feature_id | |||||||||||||||||||||
0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.04951 51.08311, 17.04932 51.083... |
1 | 0 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.04951 51.08311, 17.04947 51.083... |
2 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.05357 51.08301, 17.05335 51.082... |
3 rows × 219 columns
Out[4]:
<Axes: >
Find out which edges correspond to which regions¶
In [5]:
Copied!
from srai.joiners import IntersectionJoiner
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, edges_gdf)
joint_gdf
from srai.joiners import IntersectionJoiner
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, edges_gdf)
joint_gdf
Out[5]:
region_id | feature_id |
---|---|
891e2040b03ffff | 0 |
1 | |
2 | |
891e2040b07ffff | 2 |
891e2040b03ffff | 3 |
... | ... |
891e20403dbffff | 10212 |
891e2047307ffff | 10213 |
10214 | |
10215 | |
891e2040c33ffff | 10216 |
15588 rows × 0 columns
Get the embeddings for regions based on the road infrastructure¶
In [6]:
Copied!
from pytorch_lightning import seed_everything
from srai.embedders import Highway2VecEmbedder
seed_everything(42)
embedder = Highway2VecEmbedder()
embeddings = embedder.fit_transform(regions_gdf, edges_gdf, joint_gdf)
embeddings
from pytorch_lightning import seed_everything
from srai.embedders import Highway2VecEmbedder
seed_everything(42)
embedder = Highway2VecEmbedder()
embeddings = embedder.fit_transform(regions_gdf, edges_gdf, joint_gdf)
embeddings
Seed set to 42
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default | Name | Type | Params | Mode ----------------------------------------------- 0 | encoder | Sequential | 16.0 K | train 1 | decoder | Sequential | 16.2 K | train ----------------------------------------------- 32.1 K Trainable params 0 Non-trainable params 32.1 K Total params 0.128 Total estimated model params size (MB) 8 Modules in train mode 0 Modules in eval mode
/opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
`Trainer.fit` stopped: `max_epochs=10` reached.
/opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/srai/embedders/highway2vec/embedder.py:75: FutureWarning: The behavior of array concatenation with empty entries is deprecated. In a future version, this will no longer exclude empty items when determining the result dtype. To retain the old behavior, exclude the empty entries before the concat operation. embeddings_joint = joint_gdf.join(embeddings_df)
Out[6]:
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
891e2040007ffff | -0.534016 | -0.201778 | 0.144246 | -0.095623 | 0.316268 | -0.132133 | 0.061305 | -0.663294 | -0.218176 | 0.744132 | ... | 0.336672 | -0.553832 | 0.194347 | 0.014645 | 0.034789 | 0.047576 | 0.226890 | 0.342445 | 0.089569 | 0.309732 |
891e2040013ffff | -0.094293 | -0.178858 | 0.210738 | -0.277024 | 0.533318 | -0.011199 | 0.021215 | -0.716154 | -0.419597 | 0.807884 | ... | 0.386895 | -0.712610 | 0.276299 | -0.381424 | 0.024393 | 0.050257 | 0.553129 | 0.448246 | 0.115272 | 0.343464 |
891e2040017ffff | -0.248451 | -0.298605 | 0.154628 | -0.180253 | 0.368104 | -0.143678 | 0.174051 | -0.855640 | -0.154088 | 0.674387 | ... | 0.261616 | -0.591597 | 0.225092 | -0.064742 | 0.050834 | 0.053356 | 0.545746 | 0.303658 | 0.069200 | 0.291704 |
891e2040023ffff | -0.271283 | -0.418891 | 0.294150 | -0.009502 | 0.482758 | -0.145368 | 0.144936 | -0.439638 | -0.059179 | 0.615518 | ... | 0.157684 | -0.525371 | 0.215768 | 0.010485 | -0.032696 | 0.216907 | 0.449858 | 0.104989 | -0.076041 | 0.403865 |
891e2040027ffff | -0.591952 | -0.344464 | 0.115241 | -0.264010 | 0.599564 | 0.069988 | 0.062022 | -0.739905 | -0.286227 | 0.800987 | ... | -0.002508 | -0.625947 | 0.185755 | -0.133184 | 0.021147 | 0.307178 | 0.285219 | 0.221658 | 0.011187 | 0.523993 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
891e2055bc3ffff | -0.537135 | -0.172080 | 0.018940 | -0.085241 | 0.520160 | -0.361842 | 0.088952 | -0.699181 | -0.286357 | 0.772064 | ... | 0.566357 | -0.540186 | 0.084428 | -0.284199 | -0.272115 | 0.266706 | 0.226986 | 0.242860 | 0.123624 | 0.667049 |
891e2055bc7ffff | -0.381154 | -0.083696 | 0.002206 | -0.167084 | 0.759510 | -0.499913 | 0.150527 | -0.472900 | -0.300341 | 0.805560 | ... | 0.426389 | -0.572051 | 0.089742 | -0.308517 | -0.320987 | 0.255431 | 0.112616 | 0.206560 | 0.010866 | 0.385196 |
891e2055bcbffff | -0.700467 | -0.227193 | 0.040580 | -0.041293 | 0.528812 | -0.271484 | 0.136393 | -0.906035 | -0.271587 | 0.706812 | ... | 0.692370 | -0.453415 | -0.014989 | -0.361717 | -0.328462 | 0.361966 | 0.304717 | 0.267644 | 0.257206 | 0.923234 |
891e205a967ffff | -0.647052 | 0.084924 | 0.199249 | -0.252316 | 0.303917 | 0.191526 | -0.024409 | -0.357718 | -0.420558 | 0.547459 | ... | -0.249330 | -0.361599 | 0.347880 | -0.042167 | 0.249846 | 0.028283 | 0.268046 | 0.302642 | 0.086434 | 0.516717 |
891e205a9a7ffff | -1.520746 | -1.560468 | 0.324574 | -1.016496 | 0.838082 | -0.815146 | -0.037224 | -0.676902 | -0.251171 | 1.622622 | ... | 0.579083 | -0.607954 | -0.086732 | -0.015792 | -0.361900 | 0.285142 | 0.708317 | 0.643517 | -0.109267 | 0.890957 |
2038 rows × 30 columns
In [7]:
Copied!
from sklearn.cluster import KMeans
clusterizer = KMeans(n_clusters=5, random_state=42)
clusterizer.fit(embeddings)
embeddings["cluster"] = clusterizer.labels_
embeddings
from sklearn.cluster import KMeans
clusterizer = KMeans(n_clusters=5, random_state=42)
clusterizer.fit(embeddings)
embeddings["cluster"] = clusterizer.labels_
embeddings
Out[7]:
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | cluster | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
891e2040007ffff | -0.534016 | -0.201778 | 0.144246 | -0.095623 | 0.316268 | -0.132133 | 0.061305 | -0.663294 | -0.218176 | 0.744132 | ... | -0.553832 | 0.194347 | 0.014645 | 0.034789 | 0.047576 | 0.226890 | 0.342445 | 0.089569 | 0.309732 | 2 |
891e2040013ffff | -0.094293 | -0.178858 | 0.210738 | -0.277024 | 0.533318 | -0.011199 | 0.021215 | -0.716154 | -0.419597 | 0.807884 | ... | -0.712610 | 0.276299 | -0.381424 | 0.024393 | 0.050257 | 0.553129 | 0.448246 | 0.115272 | 0.343464 | 4 |
891e2040017ffff | -0.248451 | -0.298605 | 0.154628 | -0.180253 | 0.368104 | -0.143678 | 0.174051 | -0.855640 | -0.154088 | 0.674387 | ... | -0.591597 | 0.225092 | -0.064742 | 0.050834 | 0.053356 | 0.545746 | 0.303658 | 0.069200 | 0.291704 | 4 |
891e2040023ffff | -0.271283 | -0.418891 | 0.294150 | -0.009502 | 0.482758 | -0.145368 | 0.144936 | -0.439638 | -0.059179 | 0.615518 | ... | -0.525371 | 0.215768 | 0.010485 | -0.032696 | 0.216907 | 0.449858 | 0.104989 | -0.076041 | 0.403865 | 1 |
891e2040027ffff | -0.591952 | -0.344464 | 0.115241 | -0.264010 | 0.599564 | 0.069988 | 0.062022 | -0.739905 | -0.286227 | 0.800987 | ... | -0.625947 | 0.185755 | -0.133184 | 0.021147 | 0.307178 | 0.285219 | 0.221658 | 0.011187 | 0.523993 | 2 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
891e2055bc3ffff | -0.537135 | -0.172080 | 0.018940 | -0.085241 | 0.520160 | -0.361842 | 0.088952 | -0.699181 | -0.286357 | 0.772064 | ... | -0.540186 | 0.084428 | -0.284199 | -0.272115 | 0.266706 | 0.226986 | 0.242860 | 0.123624 | 0.667049 | 3 |
891e2055bc7ffff | -0.381154 | -0.083696 | 0.002206 | -0.167084 | 0.759510 | -0.499913 | 0.150527 | -0.472900 | -0.300341 | 0.805560 | ... | -0.572051 | 0.089742 | -0.308517 | -0.320987 | 0.255431 | 0.112616 | 0.206560 | 0.010866 | 0.385196 | 3 |
891e2055bcbffff | -0.700467 | -0.227193 | 0.040580 | -0.041293 | 0.528812 | -0.271484 | 0.136393 | -0.906035 | -0.271587 | 0.706812 | ... | -0.453415 | -0.014989 | -0.361717 | -0.328462 | 0.361966 | 0.304717 | 0.267644 | 0.257206 | 0.923234 | 0 |
891e205a967ffff | -0.647052 | 0.084924 | 0.199249 | -0.252316 | 0.303917 | 0.191526 | -0.024409 | -0.357718 | -0.420558 | 0.547459 | ... | -0.361599 | 0.347880 | -0.042167 | 0.249846 | 0.028283 | 0.268046 | 0.302642 | 0.086434 | 0.516717 | 2 |
891e205a9a7ffff | -1.520746 | -1.560468 | 0.324574 | -1.016496 | 0.838082 | -0.815146 | -0.037224 | -0.676902 | -0.251171 | 1.622622 | ... | -0.607954 | -0.086732 | -0.015792 | -0.361900 | 0.285142 | 0.708317 | 0.643517 | -0.109267 | 0.890957 | 0 |
2038 rows × 31 columns
In [8]:
Copied!
plot_numeric_data(regions_gdf, "cluster", embeddings)
plot_numeric_data(regions_gdf, "cluster", embeddings)
Out[8]:
Make this Notebook Trusted to load map: File -> Trust Notebook