Highway2Vec Embedder¶
In [1]:
Copied!
from IPython.display import display
from IPython.display import display
Get an area to embed¶
In [2]:
Copied!
from srai.utils import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Wrocław, PL")
area_gdf.plot()
from srai.utils import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Wrocław, PL")
area_gdf.plot()
Out[2]:
<Axes: >
Regionalize the area with a regionalizer¶
In [3]:
Copied!
from srai.regionalizers import H3Regionalizer
regionalizer = H3Regionalizer(9)
regions_gdf = regionalizer.transform(area_gdf)
print(len(regions_gdf))
display(regions_gdf.head(3))
regions_gdf.plot()
from srai.regionalizers import H3Regionalizer
regionalizer = H3Regionalizer(9)
regions_gdf = regionalizer.transform(area_gdf)
print(len(regions_gdf))
display(regions_gdf.head(3))
regions_gdf.plot()
3168
geometry | |
---|---|
region_id | |
891e20435c7ffff | POLYGON ((16.86311 51.14966, 16.86559 51.15058... |
891e2042557ffff | POLYGON ((16.93008 51.18452, 16.93012 51.18284... |
891e2041c27ffff | POLYGON ((16.95755 51.07936, 16.95758 51.07767... |
Out[3]:
<Axes: >
Download a road infrastructure for the area¶
In [4]:
Copied!
from srai.loaders import OSMWayLoader
from srai.loaders.osm_way_loader import NetworkType
loader = OSMWayLoader(NetworkType.DRIVE)
nodes_gdf, edges_gdf = loader.load(area_gdf)
display(nodes_gdf.head(3))
display(edges_gdf.head(3))
ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))
nodes_gdf.plot(ax=ax, markersize=3, color="red")
from srai.loaders import OSMWayLoader
from srai.loaders.osm_way_loader import NetworkType
loader = OSMWayLoader(NetworkType.DRIVE)
nodes_gdf, edges_gdf = loader.load(area_gdf)
display(nodes_gdf.head(3))
display(edges_gdf.head(3))
ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))
nodes_gdf.plot(ax=ax, markersize=3, color="red")
y | x | street_count | highway | ref | geometry | |
---|---|---|---|---|---|---|
osmid | ||||||
95584835 | 51.083111 | 17.049513 | 4 | NaN | NaN | POINT (17.04951 51.08311) |
95584841 | 51.084699 | 17.064367 | 3 | NaN | NaN | POINT (17.06437 51.08470) |
95584850 | 51.083328 | 17.035057 | 4 | NaN | NaN | POINT (17.03506 51.08333) |
oneway | lanes-1 | lanes-2 | lanes-3 | lanes-4 | lanes-5 | lanes-6 | lanes-7 | lanes-8 | lanes-9 | ... | bicycle-official | lit-yes | lit-no | lit-sunset-sunrise | lit-24/7 | lit-automatic | lit-disused | lit-limited | lit-interval | geometry | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
feature_id | |||||||||||||||||||||
0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.04951 51.08311, 17.04947 51.083... |
1 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.04951 51.08311, 17.04933 51.083... |
2 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | LINESTRING (17.05357 51.08301, 17.05335 51.082... |
3 rows × 219 columns
Out[4]:
<Axes: >
Find out which edges correspond to which regions¶
In [5]:
Copied!
from srai.joiners import IntersectionJoiner
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, edges_gdf)
joint_gdf
from srai.joiners import IntersectionJoiner
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, edges_gdf)
joint_gdf
Out[5]:
region_id | feature_id |
---|---|
891e20435c7ffff | 6980 |
891e204351bffff | 6980 |
891e20435cfffff | 6980 |
891e20435c7ffff | 8305 |
891e204351bffff | 8305 |
... | ... |
891e2042ba7ffff | 6370 |
8125 | |
6372 | |
6371 | |
891e204759bffff | 5381 |
15339 rows × 0 columns
Get the embeddings for regions based on the road infrastructure¶
In [6]:
Copied!
from srai.embedders import Highway2VecEmbedder
from pytorch_lightning import seed_everything
seed_everything(42)
embedder = Highway2VecEmbedder()
embeddings = embedder.fit_transform(regions_gdf, edges_gdf, joint_gdf)
embeddings
from srai.embedders import Highway2VecEmbedder
from pytorch_lightning import seed_everything
seed_everything(42)
embedder = Highway2VecEmbedder()
embeddings = embedder.fit_transform(regions_gdf, edges_gdf, joint_gdf)
embeddings
Global seed set to 42 GPU available: False, used: False TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: UserWarning: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default warning_cache.warn( | Name | Type | Params --------------------------------------- 0 | encoder | Sequential | 16.0 K 1 | decoder | Sequential | 16.2 K --------------------------------------- 32.1 K Trainable params 0 Non-trainable params 32.1 K Total params 0.128 Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=10` reached.
Out[6]:
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
891e2040007ffff | -0.588394 | -0.145310 | 0.329390 | -0.199922 | 0.333583 | -0.053468 | -0.029244 | -0.731217 | -0.276920 | 0.579728 | ... | 0.222044 | -0.383879 | 0.100069 | 0.175837 | 0.086110 | 0.089029 | 0.182174 | 0.340044 | 0.301474 | 0.296530 |
891e2040013ffff | -0.136220 | -0.105376 | 0.310452 | -0.394829 | 0.550002 | -0.056917 | -0.016684 | -0.762693 | -0.447748 | 0.844249 | ... | 0.426936 | -0.819544 | 0.241111 | -0.330227 | 0.048865 | 0.168811 | 0.544063 | 0.559132 | 0.266924 | 0.328178 |
891e2040017ffff | -0.257393 | -0.171729 | 0.178272 | -0.149288 | 0.176805 | -0.130743 | 0.100364 | -0.717141 | -0.172072 | 0.529592 | ... | 0.103032 | -0.533517 | 0.181119 | 0.002176 | -0.082720 | 0.124841 | 0.460044 | 0.334705 | 0.140837 | 0.298622 |
891e2040023ffff | -0.273172 | -0.386001 | 0.398755 | -0.082176 | 0.507767 | -0.113934 | 0.082887 | -0.447940 | -0.095078 | 0.584399 | ... | 0.136490 | -0.520931 | 0.175278 | 0.002595 | -0.069401 | 0.304857 | 0.444537 | 0.130316 | -0.027166 | 0.442138 |
891e2040027ffff | -0.623637 | -0.279568 | 0.180895 | -0.373369 | 0.629895 | 0.111698 | -0.021915 | -0.717526 | -0.295012 | 0.726249 | ... | -0.017105 | -0.623780 | 0.175905 | -0.149013 | -0.041487 | 0.410061 | 0.236078 | 0.222020 | 0.010708 | 0.578343 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
891e2055bc3ffff | -0.577809 | -0.089938 | 0.250177 | -0.174679 | 0.512738 | -0.190829 | 0.059606 | -0.842934 | -0.292998 | 0.720076 | ... | 0.416265 | -0.579532 | 0.033998 | -0.237962 | -0.288905 | 0.472219 | 0.131982 | 0.490269 | 0.309203 | 0.742934 |
891e2055bc7ffff | -0.409222 | 0.095479 | 0.269943 | -0.261908 | 0.821063 | -0.368225 | 0.143521 | -0.581799 | -0.377443 | 0.738584 | ... | 0.310113 | -0.636391 | 0.050789 | -0.241962 | -0.347940 | 0.489631 | 0.012203 | 0.435431 | 0.263317 | 0.519553 |
891e2055bcbffff | -0.771704 | -0.195970 | 0.328046 | -0.108660 | 0.493283 | 0.000060 | 0.057297 | -1.090019 | -0.258585 | 0.696927 | ... | 0.554331 | -0.470081 | -0.088866 | -0.341404 | -0.304259 | 0.633277 | 0.199738 | 0.570181 | 0.447835 | 0.974461 |
891e205a967ffff | -0.671492 | 0.112009 | 0.211116 | -0.277808 | 0.285113 | 0.217281 | -0.094777 | -0.439988 | -0.376565 | 0.489636 | ... | -0.294597 | -0.336413 | 0.384200 | -0.012502 | 0.182029 | 0.069917 | 0.270887 | 0.259210 | 0.021868 | 0.548017 |
891e205a9a7ffff | -1.641847 | -1.523713 | 0.689654 | -1.028500 | 1.017919 | -0.716355 | -0.096624 | -0.636827 | -0.365205 | 1.873264 | ... | 0.592826 | -0.650813 | -0.206562 | -0.080894 | -0.482702 | 0.543556 | 0.592136 | 1.087363 | 0.147414 | 1.048715 |
2027 rows × 30 columns