Highway2Vec Embedder¶
In [1]:
Copied!
from IPython.display import display
from srai.plotting import plot_regions, plot_numeric_data
from IPython.display import display
from srai.plotting import plot_regions, plot_numeric_data
Get an area to embed¶
In [2]:
Copied!
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Wrocław, PL")
plot_regions(area_gdf, tiles_style="CartoDB positron")
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Wrocław, PL")
plot_regions(area_gdf, tiles_style="CartoDB positron")
Out[2]:
Make this Notebook Trusted to load map: File -> Trust Notebook
Regionalize the area with a regionalizer¶
In [3]:
Copied!
from srai.regionalizers import H3Regionalizer
regionalizer = H3Regionalizer(9)
regions_gdf = regionalizer.transform(area_gdf)
print(len(regions_gdf))
display(regions_gdf.head(3))
plot_regions(regions_gdf, tiles_style="CartoDB positron")
from srai.regionalizers import H3Regionalizer
regionalizer = H3Regionalizer(9)
regions_gdf = regionalizer.transform(area_gdf)
print(len(regions_gdf))
display(regions_gdf.head(3))
plot_regions(regions_gdf, tiles_style="CartoDB positron")
3168
geometry | |
---|---|
region_id | |
891e2040077ffff | POLYGON ((16.94428 51.11178, 16.94432 51.11010... |
891e2043593ffff | POLYGON ((16.86287 51.15974, 16.86291 51.15806... |
891e2047513ffff | POLYGON ((17.11143 51.15966, 17.11146 51.15798... |
Out[3]:
Make this Notebook Trusted to load map: File -> Trust Notebook
Download a road infrastructure for the area¶
In [4]:
Copied!
from srai.loaders import OSMNetworkType, OSMWayLoader
loader = OSMWayLoader(OSMNetworkType.DRIVE)
nodes_gdf, edges_gdf = loader.load(area_gdf)
display(nodes_gdf.head(3))
display(edges_gdf.head(3))
ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))
nodes_gdf.plot(ax=ax, markersize=3, color="red")
from srai.loaders import OSMNetworkType, OSMWayLoader
loader = OSMWayLoader(OSMNetworkType.DRIVE)
nodes_gdf, edges_gdf = loader.load(area_gdf)
display(nodes_gdf.head(3))
display(edges_gdf.head(3))
ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))
nodes_gdf.plot(ax=ax, markersize=3, color="red")
/opt/hostedtoolcache/Python/3.10.13/x64/lib/python3.10/site-packages/srai/loaders/osm_way_loader/osm_way_loader.py:219: UserWarning: The clean_periphery argument has been deprecated and will be removed in a future release. Future behavior will be as though clean_periphery=True. G_directed = ox.graph_from_polygon(
y | x | street_count | highway | ref | geometry | |
---|---|---|---|---|---|---|
osmid | ||||||
95584835 | 51.083111 | 17.049513 | 4 | NaN | NaN | POINT (17.04951 51.08311) |
95584841 | 51.084699 | 17.064367 | 3 | NaN | NaN | POINT (17.06437 51.08470) |
95584850 | 51.083328 | 17.035057 | 4 | NaN | NaN | POINT (17.03506 51.08333) |
oneway | lanes-1 | lanes-2 | lanes-3 | lanes-4 | lanes-5 | lanes-6 | lanes-7 | lanes-8 | lanes-9 | ... | bicycle-official | lit-yes | lit-no | lit-sunset-sunrise | lit-24/7 | lit-automatic | lit-disused | lit-limited | lit-interval | geometry | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
feature_id | |||||||||||||||||||||
0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.04951 51.08311, 17.04947 51.083... |
1 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.04951 51.08311, 17.04933 51.083... |
2 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.05357 51.08301, 17.05335 51.082... |
3 rows × 219 columns
Out[4]:
<Axes: >
Find out which edges correspond to which regions¶
In [5]:
Copied!
from srai.joiners import IntersectionJoiner
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, edges_gdf)
joint_gdf
from srai.joiners import IntersectionJoiner
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, edges_gdf)
joint_gdf
Out[5]:
region_id | feature_id |
---|---|
891e2040a03ffff | 3238 |
891e2040a0bffff | 3238 |
891e2040a03ffff | 3239 |
891e2040a0fffff | 3239 |
891e2040a03ffff | 3236 |
... | ... |
891e2042b8bffff | 5267 |
891e20442dbffff | 5203 |
7819 | |
7820 | |
7821 |
15349 rows × 0 columns
Get the embeddings for regions based on the road infrastructure¶
In [6]:
Copied!
from srai.embedders import Highway2VecEmbedder
from pytorch_lightning import seed_everything
seed_everything(42)
embedder = Highway2VecEmbedder()
embeddings = embedder.fit_transform(regions_gdf, edges_gdf, joint_gdf)
embeddings
from srai.embedders import Highway2VecEmbedder
from pytorch_lightning import seed_everything
seed_everything(42)
embedder = Highway2VecEmbedder()
embeddings = embedder.fit_transform(regions_gdf, edges_gdf, joint_gdf)
embeddings
Seed set to 42 GPU available: False, used: False TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs /opt/hostedtoolcache/Python/3.10.13/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default | Name | Type | Params --------------------------------------- 0 | encoder | Sequential | 16.0 K 1 | decoder | Sequential | 16.2 K --------------------------------------- 32.1 K Trainable params 0 Non-trainable params 32.1 K Total params 0.128 Total estimated model params size (MB) /opt/hostedtoolcache/Python/3.10.13/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.
`Trainer.fit` stopped: `max_epochs=10` reached.
Out[6]:
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
891e2040007ffff | -0.606281 | -0.138842 | 0.304199 | -0.228422 | 0.266903 | 0.025754 | -0.010462 | -0.585527 | -0.240833 | 0.673681 | ... | 0.158427 | -0.457539 | 0.058718 | 0.076031 | 0.136089 | 0.124990 | 0.127766 | 0.333599 | 0.192963 | 0.366278 |
891e2040013ffff | -0.216977 | -0.145342 | 0.417459 | -0.402054 | 0.465120 | -0.064714 | 0.024610 | -0.601585 | -0.367011 | 0.811321 | ... | 0.419746 | -0.792594 | 0.156041 | -0.276614 | 0.132031 | 0.146711 | 0.516092 | 0.504026 | 0.246584 | 0.374882 |
891e2040017ffff | -0.286712 | -0.206757 | 0.235690 | -0.200215 | 0.183379 | -0.184813 | 0.128704 | -0.523017 | -0.226942 | 0.511229 | ... | 0.114081 | -0.428691 | 0.161389 | 0.024387 | -0.005048 | 0.086457 | 0.377714 | 0.341217 | 0.175955 | 0.311676 |
891e2040023ffff | -0.296102 | -0.365457 | 0.450568 | -0.069854 | 0.481113 | -0.110100 | 0.147706 | -0.377631 | -0.078125 | 0.608483 | ... | 0.147528 | -0.520215 | 0.149968 | 0.016940 | -0.002120 | 0.281002 | 0.427855 | 0.097800 | -0.023010 | 0.464972 |
891e2040027ffff | -0.628833 | -0.310601 | 0.249207 | -0.388397 | 0.610593 | 0.147230 | 0.073520 | -0.650836 | -0.291396 | 0.744639 | ... | -0.012181 | -0.615358 | 0.138278 | -0.140941 | 0.055909 | 0.372398 | 0.248022 | 0.207902 | -0.003125 | 0.563227 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
891e2055bc3ffff | -0.613999 | -0.063764 | 0.303529 | -0.326326 | 0.399839 | -0.242278 | 0.183222 | -0.863839 | -0.353749 | 0.733669 | ... | 0.494852 | -0.633912 | -0.008650 | -0.191272 | -0.138710 | 0.378147 | 0.232807 | 0.430400 | 0.298963 | 0.708956 |
891e2055bc7ffff | -0.430927 | 0.047062 | 0.310531 | -0.392069 | 0.658037 | -0.354934 | 0.261160 | -0.710268 | -0.452049 | 0.709795 | ... | 0.378803 | -0.657929 | -0.006701 | -0.223662 | -0.176097 | 0.364315 | 0.136698 | 0.437906 | 0.181847 | 0.443744 |
891e2055bcbffff | -0.795925 | -0.103043 | 0.384312 | -0.301696 | 0.376047 | -0.126913 | 0.228671 | -1.115448 | -0.334385 | 0.701097 | ... | 0.662348 | -0.566517 | -0.120830 | -0.271869 | -0.129187 | 0.527656 | 0.350461 | 0.461671 | 0.493865 | 0.948593 |
891e205a967ffff | -0.711801 | 0.091555 | 0.238810 | -0.326909 | 0.345014 | 0.221190 | -0.058576 | -0.217269 | -0.361725 | 0.583299 | ... | -0.298507 | -0.335969 | 0.354557 | 0.015543 | 0.214512 | 0.098273 | 0.171889 | 0.208658 | 0.039592 | 0.592684 |
891e205a9a7ffff | -1.465427 | -1.521233 | 0.865493 | -1.106137 | 0.900861 | -0.669065 | 0.262976 | -0.682952 | -0.289703 | 1.661705 | ... | 0.543084 | -0.590690 | -0.117270 | -0.062198 | -0.164327 | 0.384806 | 0.903105 | 1.117275 | 0.086381 | 0.883071 |
2029 rows × 30 columns
In [7]:
Copied!
from sklearn.cluster import KMeans
clusterizer = KMeans(n_clusters=5, random_state=42)
clusterizer.fit(embeddings)
embeddings["cluster"] = clusterizer.labels_
embeddings
from sklearn.cluster import KMeans
clusterizer = KMeans(n_clusters=5, random_state=42)
clusterizer.fit(embeddings)
embeddings["cluster"] = clusterizer.labels_
embeddings
/opt/hostedtoolcache/Python/3.10.13/x64/lib/python3.10/site-packages/sklearn/cluster/_kmeans.py:1412: FutureWarning: The default value of `n_init` will change from 10 to 'auto' in 1.4. Set the value of `n_init` explicitly to suppress the warning super()._check_params_vs_input(X, default_n_init=10)
Out[7]:
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | cluster | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
891e2040007ffff | -0.606281 | -0.138842 | 0.304199 | -0.228422 | 0.266903 | 0.025754 | -0.010462 | -0.585527 | -0.240833 | 0.673681 | ... | -0.457539 | 0.058718 | 0.076031 | 0.136089 | 0.124990 | 0.127766 | 0.333599 | 0.192963 | 0.366278 | 4 |
891e2040013ffff | -0.216977 | -0.145342 | 0.417459 | -0.402054 | 0.465120 | -0.064714 | 0.024610 | -0.601585 | -0.367011 | 0.811321 | ... | -0.792594 | 0.156041 | -0.276614 | 0.132031 | 0.146711 | 0.516092 | 0.504026 | 0.246584 | 0.374882 | 0 |
891e2040017ffff | -0.286712 | -0.206757 | 0.235690 | -0.200215 | 0.183379 | -0.184813 | 0.128704 | -0.523017 | -0.226942 | 0.511229 | ... | -0.428691 | 0.161389 | 0.024387 | -0.005048 | 0.086457 | 0.377714 | 0.341217 | 0.175955 | 0.311676 | 3 |
891e2040023ffff | -0.296102 | -0.365457 | 0.450568 | -0.069854 | 0.481113 | -0.110100 | 0.147706 | -0.377631 | -0.078125 | 0.608483 | ... | -0.520215 | 0.149968 | 0.016940 | -0.002120 | 0.281002 | 0.427855 | 0.097800 | -0.023010 | 0.464972 | 3 |
891e2040027ffff | -0.628833 | -0.310601 | 0.249207 | -0.388397 | 0.610593 | 0.147230 | 0.073520 | -0.650836 | -0.291396 | 0.744639 | ... | -0.615358 | 0.138278 | -0.140941 | 0.055909 | 0.372398 | 0.248022 | 0.207902 | -0.003125 | 0.563227 | 4 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
891e2055bc3ffff | -0.613999 | -0.063764 | 0.303529 | -0.326326 | 0.399839 | -0.242278 | 0.183222 | -0.863839 | -0.353749 | 0.733669 | ... | -0.633912 | -0.008650 | -0.191272 | -0.138710 | 0.378147 | 0.232807 | 0.430400 | 0.298963 | 0.708956 | 1 |
891e2055bc7ffff | -0.430927 | 0.047062 | 0.310531 | -0.392069 | 0.658037 | -0.354934 | 0.261160 | -0.710268 | -0.452049 | 0.709795 | ... | -0.657929 | -0.006701 | -0.223662 | -0.176097 | 0.364315 | 0.136698 | 0.437906 | 0.181847 | 0.443744 | 1 |
891e2055bcbffff | -0.795925 | -0.103043 | 0.384312 | -0.301696 | 0.376047 | -0.126913 | 0.228671 | -1.115448 | -0.334385 | 0.701097 | ... | -0.566517 | -0.120830 | -0.271869 | -0.129187 | 0.527656 | 0.350461 | 0.461671 | 0.493865 | 0.948593 | 2 |
891e205a967ffff | -0.711801 | 0.091555 | 0.238810 | -0.326909 | 0.345014 | 0.221190 | -0.058576 | -0.217269 | -0.361725 | 0.583299 | ... | -0.335969 | 0.354557 | 0.015543 | 0.214512 | 0.098273 | 0.171889 | 0.208658 | 0.039592 | 0.592684 | 4 |
891e205a9a7ffff | -1.465427 | -1.521233 | 0.865493 | -1.106137 | 0.900861 | -0.669065 | 0.262976 | -0.682952 | -0.289703 | 1.661705 | ... | -0.590690 | -0.117270 | -0.062198 | -0.164327 | 0.384806 | 0.903105 | 1.117275 | 0.086381 | 0.883071 | 2 |
2029 rows × 31 columns
In [8]:
Copied!
plot_numeric_data(regions_gdf, "cluster", embeddings)
plot_numeric_data(regions_gdf, "cluster", embeddings)
Out[8]:
Make this Notebook Trusted to load map: File -> Trust Notebook