Highway2Vec Embedder¶
In [1]:
Copied!
from IPython.display import display
from srai.plotting import plot_regions, plot_numeric_data
from IPython.display import display
from srai.plotting import plot_regions, plot_numeric_data
Get an area to embed¶
In [2]:
Copied!
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Wrocław, PL")
plot_regions(area_gdf, tiles_style="CartoDB positron")
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Wrocław, PL")
plot_regions(area_gdf, tiles_style="CartoDB positron")
Out[2]:
Make this Notebook Trusted to load map: File -> Trust Notebook
Regionalize the area with a regionalizer¶
In [3]:
Copied!
from srai.regionalizers import H3Regionalizer
regionalizer = H3Regionalizer(9)
regions_gdf = regionalizer.transform(area_gdf)
print(len(regions_gdf))
display(regions_gdf.head(3))
plot_regions(regions_gdf, tiles_style="CartoDB positron")
from srai.regionalizers import H3Regionalizer
regionalizer = H3Regionalizer(9)
regions_gdf = regionalizer.transform(area_gdf)
print(len(regions_gdf))
display(regions_gdf.head(3))
plot_regions(regions_gdf, tiles_style="CartoDB positron")
3168
geometry | |
---|---|
region_id | |
891e2042237ffff | POLYGON ((16.91828 51.15136, 16.91832 51.14968... |
891e2040573ffff | POLYGON ((16.95642 51.12982, 16.95645 51.12814... |
891e204210bffff | POLYGON ((16.95306 51.16756, 16.95310 51.16588... |
Out[3]:
Make this Notebook Trusted to load map: File -> Trust Notebook
Download a road infrastructure for the area¶
In [4]:
Copied!
from srai.loaders import OSMNetworkType, OSMWayLoader
loader = OSMWayLoader(OSMNetworkType.DRIVE)
nodes_gdf, edges_gdf = loader.load(area_gdf)
display(nodes_gdf.head(3))
display(edges_gdf.head(3))
ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))
nodes_gdf.plot(ax=ax, markersize=3, color="red")
from srai.loaders import OSMNetworkType, OSMWayLoader
loader = OSMWayLoader(OSMNetworkType.DRIVE)
nodes_gdf, edges_gdf = loader.load(area_gdf)
display(nodes_gdf.head(3))
display(edges_gdf.head(3))
ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))
nodes_gdf.plot(ax=ax, markersize=3, color="red")
/opt/hostedtoolcache/Python/3.10.13/x64/lib/python3.10/site-packages/srai/loaders/osm_way_loader/osm_way_loader.py:219: UserWarning: The clean_periphery argument has been deprecated and will be removed in a future release. Future behavior will be as though clean_periphery=True. G_directed = ox.graph_from_polygon(
y | x | street_count | highway | ref | geometry | |
---|---|---|---|---|---|---|
osmid | ||||||
95584835 | 51.083111 | 17.049513 | 4 | NaN | NaN | POINT (17.04951 51.08311) |
95584841 | 51.084699 | 17.064367 | 3 | NaN | NaN | POINT (17.06437 51.08470) |
95584850 | 51.083328 | 17.035057 | 4 | NaN | NaN | POINT (17.03506 51.08333) |
oneway | lanes-1 | lanes-2 | lanes-3 | lanes-4 | lanes-5 | lanes-6 | lanes-7 | lanes-8 | lanes-9 | ... | bicycle-official | lit-yes | lit-no | lit-sunset-sunrise | lit-24/7 | lit-automatic | lit-disused | lit-limited | lit-interval | geometry | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
feature_id | |||||||||||||||||||||
0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.04951 51.08311, 17.04947 51.083... |
1 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.04951 51.08311, 17.04933 51.083... |
2 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.05357 51.08301, 17.05335 51.082... |
3 rows × 219 columns
Out[4]:
<Axes: >
Find out which edges correspond to which regions¶
In [5]:
Copied!
from srai.joiners import IntersectionJoiner
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, edges_gdf)
joint_gdf
from srai.joiners import IntersectionJoiner
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, edges_gdf)
joint_gdf
Out[5]:
region_id | feature_id |
---|---|
891e2040573ffff | 7176 |
891e204057bffff | 7176 |
891e2040563ffff | 7176 |
891e2040573ffff | 7175 |
891e204057bffff | 7175 |
... | ... |
891e2040563ffff | 7908 |
7907 | |
7909 | |
7913 | |
7915 |
15349 rows × 0 columns
Get the embeddings for regions based on the road infrastructure¶
In [6]:
Copied!
from srai.embedders import Highway2VecEmbedder
from pytorch_lightning import seed_everything
seed_everything(42)
embedder = Highway2VecEmbedder()
embeddings = embedder.fit_transform(regions_gdf, edges_gdf, joint_gdf)
embeddings
from srai.embedders import Highway2VecEmbedder
from pytorch_lightning import seed_everything
seed_everything(42)
embedder = Highway2VecEmbedder()
embeddings = embedder.fit_transform(regions_gdf, edges_gdf, joint_gdf)
embeddings
Seed set to 42 GPU available: False, used: False TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs /opt/hostedtoolcache/Python/3.10.13/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default | Name | Type | Params --------------------------------------- 0 | encoder | Sequential | 16.0 K 1 | decoder | Sequential | 16.2 K --------------------------------------- 32.1 K Trainable params 0 Non-trainable params 32.1 K Total params 0.128 Total estimated model params size (MB) /opt/hostedtoolcache/Python/3.10.13/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.
`Trainer.fit` stopped: `max_epochs=10` reached.
Out[6]:
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
891e2040007ffff | -0.619877 | -0.139891 | 0.311713 | -0.235325 | 0.286864 | 0.021996 | -0.010311 | -0.577181 | -0.241507 | 0.677906 | ... | 0.149497 | -0.451310 | 0.064326 | 0.089845 | 0.146459 | 0.120934 | 0.116879 | 0.332750 | 0.190347 | 0.360663 |
891e2040013ffff | -0.224642 | -0.143098 | 0.427563 | -0.393280 | 0.489236 | -0.054072 | 0.014470 | -0.577511 | -0.354027 | 0.803737 | ... | 0.422881 | -0.773623 | 0.164668 | -0.253249 | 0.161901 | 0.145316 | 0.491370 | 0.500055 | 0.250518 | 0.370982 |
891e2040017ffff | -0.304012 | -0.203452 | 0.250050 | -0.206345 | 0.197903 | -0.175079 | 0.128693 | -0.519562 | -0.226868 | 0.507805 | ... | 0.107964 | -0.423496 | 0.161032 | 0.033017 | 0.014773 | 0.085093 | 0.373949 | 0.346180 | 0.179789 | 0.314990 |
891e2040023ffff | -0.305751 | -0.365035 | 0.463593 | -0.072082 | 0.496328 | -0.111837 | 0.149594 | -0.378091 | -0.074685 | 0.616050 | ... | 0.143213 | -0.518410 | 0.155986 | 0.032004 | 0.015227 | 0.275388 | 0.427325 | 0.098358 | -0.020611 | 0.460642 |
891e2040027ffff | -0.632700 | -0.309423 | 0.261198 | -0.390320 | 0.631320 | 0.146557 | 0.070962 | -0.655877 | -0.287360 | 0.751151 | ... | -0.009895 | -0.607455 | 0.140857 | -0.122701 | 0.075262 | 0.367290 | 0.249054 | 0.215070 | 0.003655 | 0.550938 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
891e2055bc3ffff | -0.627002 | -0.064194 | 0.325893 | -0.337385 | 0.434200 | -0.243677 | 0.203394 | -0.900395 | -0.356680 | 0.735271 | ... | 0.501860 | -0.627793 | -0.014291 | -0.178347 | -0.109331 | 0.382347 | 0.258068 | 0.454882 | 0.313613 | 0.702337 |
891e2055bc7ffff | -0.438971 | 0.040592 | 0.333510 | -0.393801 | 0.695110 | -0.353364 | 0.282117 | -0.758274 | -0.453044 | 0.714118 | ... | 0.395762 | -0.649027 | -0.006232 | -0.197612 | -0.145356 | 0.363767 | 0.171695 | 0.460346 | 0.201980 | 0.432349 |
891e2055bcbffff | -0.809552 | -0.099434 | 0.409700 | -0.315831 | 0.409537 | -0.118376 | 0.250347 | -1.149436 | -0.332271 | 0.701659 | ... | 0.670045 | -0.557726 | -0.138531 | -0.270370 | -0.093538 | 0.537909 | 0.376044 | 0.499072 | 0.514148 | 0.947484 |
891e205a967ffff | -0.714737 | 0.083641 | 0.243519 | -0.337393 | 0.355515 | 0.228682 | -0.058184 | -0.216769 | -0.366524 | 0.573776 | ... | -0.303209 | -0.328912 | 0.354445 | 0.018837 | 0.218598 | 0.100005 | 0.164957 | 0.216653 | 0.033482 | 0.571845 |
891e205a9a7ffff | -1.458610 | -1.494464 | 0.893583 | -1.073221 | 0.934651 | -0.626731 | 0.280122 | -0.734369 | -0.280892 | 1.656427 | ... | 0.542885 | -0.568228 | -0.104253 | -0.027485 | -0.098446 | 0.345150 | 0.906987 | 1.147529 | 0.131658 | 0.887252 |
2029 rows × 30 columns
In [7]:
Copied!
from sklearn.cluster import KMeans
clusterizer = KMeans(n_clusters=5, random_state=42)
clusterizer.fit(embeddings)
embeddings["cluster"] = clusterizer.labels_
embeddings
from sklearn.cluster import KMeans
clusterizer = KMeans(n_clusters=5, random_state=42)
clusterizer.fit(embeddings)
embeddings["cluster"] = clusterizer.labels_
embeddings
/opt/hostedtoolcache/Python/3.10.13/x64/lib/python3.10/site-packages/sklearn/cluster/_kmeans.py:1412: FutureWarning: The default value of `n_init` will change from 10 to 'auto' in 1.4. Set the value of `n_init` explicitly to suppress the warning super()._check_params_vs_input(X, default_n_init=10)
Out[7]:
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | cluster | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
891e2040007ffff | -0.619877 | -0.139891 | 0.311713 | -0.235325 | 0.286864 | 0.021996 | -0.010311 | -0.577181 | -0.241507 | 0.677906 | ... | -0.451310 | 0.064326 | 0.089845 | 0.146459 | 0.120934 | 0.116879 | 0.332750 | 0.190347 | 0.360663 | 4 |
891e2040013ffff | -0.224642 | -0.143098 | 0.427563 | -0.393280 | 0.489236 | -0.054072 | 0.014470 | -0.577511 | -0.354027 | 0.803737 | ... | -0.773623 | 0.164668 | -0.253249 | 0.161901 | 0.145316 | 0.491370 | 0.500055 | 0.250518 | 0.370982 | 2 |
891e2040017ffff | -0.304012 | -0.203452 | 0.250050 | -0.206345 | 0.197903 | -0.175079 | 0.128693 | -0.519562 | -0.226868 | 0.507805 | ... | -0.423496 | 0.161032 | 0.033017 | 0.014773 | 0.085093 | 0.373949 | 0.346180 | 0.179789 | 0.314990 | 4 |
891e2040023ffff | -0.305751 | -0.365035 | 0.463593 | -0.072082 | 0.496328 | -0.111837 | 0.149594 | -0.378091 | -0.074685 | 0.616050 | ... | -0.518410 | 0.155986 | 0.032004 | 0.015227 | 0.275388 | 0.427325 | 0.098358 | -0.020611 | 0.460642 | 0 |
891e2040027ffff | -0.632700 | -0.309423 | 0.261198 | -0.390320 | 0.631320 | 0.146557 | 0.070962 | -0.655877 | -0.287360 | 0.751151 | ... | -0.607455 | 0.140857 | -0.122701 | 0.075262 | 0.367290 | 0.249054 | 0.215070 | 0.003655 | 0.550938 | 4 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
891e2055bc3ffff | -0.627002 | -0.064194 | 0.325893 | -0.337385 | 0.434200 | -0.243677 | 0.203394 | -0.900395 | -0.356680 | 0.735271 | ... | -0.627793 | -0.014291 | -0.178347 | -0.109331 | 0.382347 | 0.258068 | 0.454882 | 0.313613 | 0.702337 | 1 |
891e2055bc7ffff | -0.438971 | 0.040592 | 0.333510 | -0.393801 | 0.695110 | -0.353364 | 0.282117 | -0.758274 | -0.453044 | 0.714118 | ... | -0.649027 | -0.006232 | -0.197612 | -0.145356 | 0.363767 | 0.171695 | 0.460346 | 0.201980 | 0.432349 | 1 |
891e2055bcbffff | -0.809552 | -0.099434 | 0.409700 | -0.315831 | 0.409537 | -0.118376 | 0.250347 | -1.149436 | -0.332271 | 0.701659 | ... | -0.557726 | -0.138531 | -0.270370 | -0.093538 | 0.537909 | 0.376044 | 0.499072 | 0.514148 | 0.947484 | 3 |
891e205a967ffff | -0.714737 | 0.083641 | 0.243519 | -0.337393 | 0.355515 | 0.228682 | -0.058184 | -0.216769 | -0.366524 | 0.573776 | ... | -0.328912 | 0.354445 | 0.018837 | 0.218598 | 0.100005 | 0.164957 | 0.216653 | 0.033482 | 0.571845 | 4 |
891e205a9a7ffff | -1.458610 | -1.494464 | 0.893583 | -1.073221 | 0.934651 | -0.626731 | 0.280122 | -0.734369 | -0.280892 | 1.656427 | ... | -0.568228 | -0.104253 | -0.027485 | -0.098446 | 0.345150 | 0.906987 | 1.147529 | 0.131658 | 0.887252 | 3 |
2029 rows × 31 columns
In [8]:
Copied!
plot_numeric_data(regions_gdf, "cluster", embeddings)
plot_numeric_data(regions_gdf, "cluster", embeddings)
Out[8]:
Make this Notebook Trusted to load map: File -> Trust Notebook