Highway2Vec Embedder¶
In [1]:
Copied!
from IPython.display import display
from srai.plotting import plot_regions, plot_numeric_data
from IPython.display import display
from srai.plotting import plot_regions, plot_numeric_data
Get an area to embed¶
In [2]:
Copied!
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Wrocław, PL")
plot_regions(area_gdf, tiles_style="CartoDB positron")
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Wrocław, PL")
plot_regions(area_gdf, tiles_style="CartoDB positron")
Out[2]:
Make this Notebook Trusted to load map: File -> Trust Notebook
Regionalize the area with a regionalizer¶
In [3]:
Copied!
from srai.regionalizers import H3Regionalizer
regionalizer = H3Regionalizer(9)
regions_gdf = regionalizer.transform(area_gdf)
print(len(regions_gdf))
display(regions_gdf.head(3))
plot_regions(regions_gdf, tiles_style="CartoDB positron")
from srai.regionalizers import H3Regionalizer
regionalizer = H3Regionalizer(9)
regions_gdf = regionalizer.transform(area_gdf)
print(len(regions_gdf))
display(regions_gdf.head(3))
plot_regions(regions_gdf, tiles_style="CartoDB positron")
3168
geometry | |
---|---|
region_id | |
891e2040d8fffff | POLYGON ((17.04436 51.12488, 17.04440 51.12320... |
891e20433b7ffff | POLYGON ((16.88131 51.12246, 16.88135 51.12078... |
891e204709bffff | POLYGON ((17.08910 51.14639, 17.08913 51.14471... |
Out[3]:
Make this Notebook Trusted to load map: File -> Trust Notebook
Download a road infrastructure for the area¶
In [4]:
Copied!
from srai.loaders import OSMNetworkType, OSMWayLoader
loader = OSMWayLoader(OSMNetworkType.DRIVE)
nodes_gdf, edges_gdf = loader.load(area_gdf)
display(nodes_gdf.head(3))
display(edges_gdf.head(3))
ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))
nodes_gdf.plot(ax=ax, markersize=3, color="red")
from srai.loaders import OSMNetworkType, OSMWayLoader
loader = OSMWayLoader(OSMNetworkType.DRIVE)
nodes_gdf, edges_gdf = loader.load(area_gdf)
display(nodes_gdf.head(3))
display(edges_gdf.head(3))
ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))
nodes_gdf.plot(ax=ax, markersize=3, color="red")
/opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/srai/loaders/osm_way_loader/osm_way_loader.py:219: UserWarning: The clean_periphery argument has been deprecated and will be removed in a future release. Future behavior will be as though clean_periphery=True. G_directed = ox.graph_from_polygon(
y | x | street_count | highway | ref | geometry | |
---|---|---|---|---|---|---|
osmid | ||||||
95584835 | 51.083111 | 17.049513 | 4 | NaN | NaN | POINT (17.04951 51.08311) |
95584841 | 51.084699 | 17.064367 | 3 | NaN | NaN | POINT (17.06437 51.08470) |
95584850 | 51.083328 | 17.035057 | 4 | NaN | NaN | POINT (17.03506 51.08333) |
oneway | lanes-1 | lanes-2 | lanes-3 | lanes-4 | lanes-5 | lanes-6 | lanes-7 | lanes-8 | lanes-9 | ... | bicycle-official | lit-yes | lit-no | lit-sunset-sunrise | lit-24/7 | lit-automatic | lit-disused | lit-limited | lit-interval | geometry | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
feature_id | |||||||||||||||||||||
0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.04951 51.08311, 17.04947 51.083... |
1 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.04951 51.08311, 17.04933 51.083... |
2 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.05357 51.08301, 17.05335 51.082... |
3 rows × 219 columns
Out[4]:
<Axes: >
Find out which edges correspond to which regions¶
In [5]:
Copied!
from srai.joiners import IntersectionJoiner
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, edges_gdf)
joint_gdf
from srai.joiners import IntersectionJoiner
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, edges_gdf)
joint_gdf
Out[5]:
region_id | feature_id |
---|---|
891e2040d8fffff | 2769 |
891e2040d13ffff | 2769 |
891e2040d8fffff | 2770 |
891e2040d13ffff | 2770 |
891e2040d17ffff | 2770 |
... | ... |
891e2041d1bffff | 1086 |
32 | |
1088 | |
34 | |
1087 |
15303 rows × 0 columns
Get the embeddings for regions based on the road infrastructure¶
In [6]:
Copied!
from srai.embedders import Highway2VecEmbedder
from pytorch_lightning import seed_everything
seed_everything(42)
embedder = Highway2VecEmbedder()
embeddings = embedder.fit_transform(regions_gdf, edges_gdf, joint_gdf)
embeddings
from srai.embedders import Highway2VecEmbedder
from pytorch_lightning import seed_everything
seed_everything(42)
embedder = Highway2VecEmbedder()
embeddings = embedder.fit_transform(regions_gdf, edges_gdf, joint_gdf)
embeddings
Global seed set to 42 GPU available: False, used: False TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: UserWarning: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default warning_cache.warn( | Name | Type | Params --------------------------------------- 0 | encoder | Sequential | 16.0 K 1 | decoder | Sequential | 16.2 K --------------------------------------- 32.1 K Trainable params 0 Non-trainable params 32.1 K Total params 0.128 Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=10` reached.
Out[6]:
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
891e2040007ffff | -0.461249 | -0.097246 | 0.283835 | -0.340094 | 0.444099 | -0.089585 | -0.036850 | -0.676215 | -0.339848 | 0.617979 | ... | 0.267708 | -0.472880 | 0.040998 | 0.146030 | 0.080322 | 0.103697 | 0.048034 | 0.457679 | 0.141208 | 0.213029 |
891e2040013ffff | -0.136390 | -0.058529 | 0.308364 | -0.441377 | 0.603708 | -0.148072 | -0.064378 | -0.678277 | -0.419989 | 0.777352 | ... | 0.518513 | -0.850709 | 0.245799 | -0.237865 | 0.032768 | 0.120408 | 0.505136 | 0.490513 | 0.179292 | 0.350829 |
891e2040017ffff | -0.161810 | -0.132121 | 0.154428 | -0.186884 | 0.185939 | -0.203705 | 0.167834 | -0.643824 | -0.098391 | 0.548911 | ... | 0.102457 | -0.569440 | 0.176136 | 0.053619 | -0.099321 | 0.050696 | 0.488355 | 0.408395 | 0.100037 | 0.196078 |
891e2040023ffff | -0.299133 | -0.351207 | 0.411926 | -0.141444 | 0.555849 | -0.153168 | 0.094474 | -0.406186 | -0.103325 | 0.615908 | ... | 0.160088 | -0.568392 | 0.190292 | 0.047913 | -0.104510 | 0.281145 | 0.438018 | 0.126148 | -0.067245 | 0.484021 |
891e2040027ffff | -0.624578 | -0.248959 | 0.205029 | -0.460974 | 0.703157 | 0.067102 | -0.005280 | -0.656151 | -0.317916 | 0.774102 | ... | -0.002013 | -0.684031 | 0.175323 | -0.084356 | -0.076293 | 0.390546 | 0.218503 | 0.236500 | -0.043623 | 0.596306 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
891e2055bc3ffff | -0.631238 | 0.018650 | 0.290267 | -0.258436 | 0.607869 | -0.239184 | 0.035131 | -0.685573 | -0.315166 | 0.837931 | ... | 0.530711 | -0.700306 | 0.005977 | -0.153235 | -0.293272 | 0.423207 | 0.065331 | 0.427945 | 0.296351 | 0.814009 |
891e2055bc7ffff | -0.418185 | 0.170102 | 0.332182 | -0.339462 | 0.924004 | -0.417339 | 0.137030 | -0.464724 | -0.412662 | 0.826780 | ... | 0.392677 | -0.732027 | 0.036106 | -0.174536 | -0.405658 | 0.467268 | -0.043914 | 0.392487 | 0.225896 | 0.541943 |
891e2055bcbffff | -0.859970 | -0.024042 | 0.361322 | -0.218212 | 0.573248 | -0.055005 | 0.000301 | -0.895226 | -0.284975 | 0.834218 | ... | 0.734197 | -0.644883 | -0.161828 | -0.246370 | -0.261343 | 0.563054 | 0.122927 | 0.484361 | 0.453442 | 1.103095 |
891e205a967ffff | -0.707361 | 0.126996 | 0.202224 | -0.380303 | 0.344584 | 0.201205 | -0.084687 | -0.352298 | -0.393757 | 0.553689 | ... | -0.310467 | -0.377117 | 0.368715 | 0.055262 | 0.161857 | 0.045700 | 0.242622 | 0.261642 | -0.000286 | 0.563209 |
891e205a9a7ffff | -1.467914 | -1.278422 | 0.889408 | -1.234111 | 1.200635 | -0.929764 | -0.115005 | -0.793235 | -0.522718 | 1.885711 | ... | 0.685181 | -0.895595 | -0.182822 | 0.017098 | -0.367274 | 0.573289 | 0.699673 | 1.190632 | 0.144614 | 1.027629 |
2027 rows × 30 columns
In [7]:
Copied!
from sklearn.cluster import KMeans
clusterizer = KMeans(n_clusters=5, random_state=42)
clusterizer.fit(embeddings)
embeddings["cluster"] = clusterizer.labels_
embeddings
from sklearn.cluster import KMeans
clusterizer = KMeans(n_clusters=5, random_state=42)
clusterizer.fit(embeddings)
embeddings["cluster"] = clusterizer.labels_
embeddings
/opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/sklearn/cluster/_kmeans.py:1412: FutureWarning: The default value of `n_init` will change from 10 to 'auto' in 1.4. Set the value of `n_init` explicitly to suppress the warning super()._check_params_vs_input(X, default_n_init=10)
Out[7]:
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | cluster | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
891e2040007ffff | -0.461249 | -0.097246 | 0.283835 | -0.340094 | 0.444099 | -0.089585 | -0.036850 | -0.676215 | -0.339848 | 0.617979 | ... | -0.472880 | 0.040998 | 0.146030 | 0.080322 | 0.103697 | 0.048034 | 0.457679 | 0.141208 | 0.213029 | 1 |
891e2040013ffff | -0.136390 | -0.058529 | 0.308364 | -0.441377 | 0.603708 | -0.148072 | -0.064378 | -0.678277 | -0.419989 | 0.777352 | ... | -0.850709 | 0.245799 | -0.237865 | 0.032768 | 0.120408 | 0.505136 | 0.490513 | 0.179292 | 0.350829 | 0 |
891e2040017ffff | -0.161810 | -0.132121 | 0.154428 | -0.186884 | 0.185939 | -0.203705 | 0.167834 | -0.643824 | -0.098391 | 0.548911 | ... | -0.569440 | 0.176136 | 0.053619 | -0.099321 | 0.050696 | 0.488355 | 0.408395 | 0.100037 | 0.196078 | 3 |
891e2040023ffff | -0.299133 | -0.351207 | 0.411926 | -0.141444 | 0.555849 | -0.153168 | 0.094474 | -0.406186 | -0.103325 | 0.615908 | ... | -0.568392 | 0.190292 | 0.047913 | -0.104510 | 0.281145 | 0.438018 | 0.126148 | -0.067245 | 0.484021 | 3 |
891e2040027ffff | -0.624578 | -0.248959 | 0.205029 | -0.460974 | 0.703157 | 0.067102 | -0.005280 | -0.656151 | -0.317916 | 0.774102 | ... | -0.684031 | 0.175323 | -0.084356 | -0.076293 | 0.390546 | 0.218503 | 0.236500 | -0.043623 | 0.596306 | 1 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
891e2055bc3ffff | -0.631238 | 0.018650 | 0.290267 | -0.258436 | 0.607869 | -0.239184 | 0.035131 | -0.685573 | -0.315166 | 0.837931 | ... | -0.700306 | 0.005977 | -0.153235 | -0.293272 | 0.423207 | 0.065331 | 0.427945 | 0.296351 | 0.814009 | 4 |
891e2055bc7ffff | -0.418185 | 0.170102 | 0.332182 | -0.339462 | 0.924004 | -0.417339 | 0.137030 | -0.464724 | -0.412662 | 0.826780 | ... | -0.732027 | 0.036106 | -0.174536 | -0.405658 | 0.467268 | -0.043914 | 0.392487 | 0.225896 | 0.541943 | 4 |
891e2055bcbffff | -0.859970 | -0.024042 | 0.361322 | -0.218212 | 0.573248 | -0.055005 | 0.000301 | -0.895226 | -0.284975 | 0.834218 | ... | -0.644883 | -0.161828 | -0.246370 | -0.261343 | 0.563054 | 0.122927 | 0.484361 | 0.453442 | 1.103095 | 2 |
891e205a967ffff | -0.707361 | 0.126996 | 0.202224 | -0.380303 | 0.344584 | 0.201205 | -0.084687 | -0.352298 | -0.393757 | 0.553689 | ... | -0.377117 | 0.368715 | 0.055262 | 0.161857 | 0.045700 | 0.242622 | 0.261642 | -0.000286 | 0.563209 | 1 |
891e205a9a7ffff | -1.467914 | -1.278422 | 0.889408 | -1.234111 | 1.200635 | -0.929764 | -0.115005 | -0.793235 | -0.522718 | 1.885711 | ... | -0.895595 | -0.182822 | 0.017098 | -0.367274 | 0.573289 | 0.699673 | 1.190632 | 0.144614 | 1.027629 | 2 |
2027 rows × 31 columns
In [8]:
Copied!
plot_numeric_data(regions_gdf, "cluster", embeddings)
plot_numeric_data(regions_gdf, "cluster", embeddings)
Out[8]:
Make this Notebook Trusted to load map: File -> Trust Notebook