Hex2vec embedder
In [1]:
Copied!
import warnings
import matplotlib.pyplot as plt
import pandas as pd
from pytorch_lightning import seed_everything
from pytorch_lightning.loggers import CSVLogger
from srai.embedders import Hex2VecEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders import OSMOnlineLoader
from srai.neighbourhoods import H3Neighbourhood
from srai.plotting import plot_numeric_data, plot_regions
from srai.regionalizers import H3Regionalizer, geocode_to_region_gdf
import warnings
import matplotlib.pyplot as plt
import pandas as pd
from pytorch_lightning import seed_everything
from pytorch_lightning.loggers import CSVLogger
from srai.embedders import Hex2VecEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders import OSMOnlineLoader
from srai.neighbourhoods import H3Neighbourhood
from srai.plotting import plot_numeric_data, plot_regions
from srai.regionalizers import H3Regionalizer, geocode_to_region_gdf
In [2]:
Copied!
SEED = 71
seed_everything(SEED)
SEED = 71
seed_everything(SEED)
Seed set to 71
Out[2]:
71
Load data from OSM¶
First use geocoding to get the area
In [3]:
Copied!
area_gdf = geocode_to_region_gdf("Wrocław, Poland")
plot_regions(area_gdf, tiles_style="CartoDB positron")
area_gdf = geocode_to_region_gdf("Wrocław, Poland")
plot_regions(area_gdf, tiles_style="CartoDB positron")
Out[3]:
Make this Notebook Trusted to load map: File -> Trust Notebook
Next, download the data for the selected region and the specified tags. We're using OSMOnlineLoader here, as it's faster for low numbers of tags. In a real life scenario with more tags, you would likely want to use the OSMPbfLoader.
In [4]:
Copied!
tags = {
"leisure": "park",
"landuse": "forest",
"amenity": ["bar", "restaurant", "cafe"],
"water": "river",
"sport": "soccer",
}
loader = OSMOnlineLoader()
features_gdf = loader.load(area_gdf, tags)
folium_map = plot_regions(area_gdf, colormap=["rgba(0,0,0,0)"], tiles_style="CartoDB positron")
features_gdf.explore(m=folium_map)
tags = {
"leisure": "park",
"landuse": "forest",
"amenity": ["bar", "restaurant", "cafe"],
"water": "river",
"sport": "soccer",
}
loader = OSMOnlineLoader()
features_gdf = loader.load(area_gdf, tags)
folium_map = plot_regions(area_gdf, colormap=["rgba(0,0,0,0)"], tiles_style="CartoDB positron")
features_gdf.explore(m=folium_map)
Out[4]:
Make this Notebook Trusted to load map: File -> Trust Notebook
Prepare the data for embedding¶
After downloading the data, we need to prepare it for embedding. Namely - we need to regionalize the selected area, and join the features with regions.
In [5]:
Copied!
regionalizer = H3Regionalizer(resolution=9)
regions_gdf = regionalizer.transform(area_gdf)
plot_regions(regions_gdf, tiles_style="CartoDB positron")
regionalizer = H3Regionalizer(resolution=9)
regions_gdf = regionalizer.transform(area_gdf)
plot_regions(regions_gdf, tiles_style="CartoDB positron")
Out[5]:
Make this Notebook Trusted to load map: File -> Trust Notebook
In [6]:
Copied!
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, features_gdf)
joint_gdf
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, features_gdf)
joint_gdf
Out[6]:
| region_id | feature_id |
|---|---|
| 891e2040897ffff | node/280727473 |
| 891e2040d4bffff | node/300461026 |
| node/300461036 | |
| 891e2040d5bffff | node/300461042 |
| 891e2040887ffff | node/300461045 |
| ... | ... |
| 891e2042e73ffff | way/1427496434 |
| 891e2040a8fffff | way/1428809179 |
| 891e2045203ffff | way/1429016156 |
| 891e2045217ffff | way/1429016156 |
| 891e2040e43ffff | way/1429586876 |
4189 rows × 0 columns
Embedding¶
After preparing the data we can proceed with generating embeddings for the regions.
In [7]:
Copied!
neighbourhood = H3Neighbourhood(regions_gdf)
embedder = Hex2VecEmbedder([15, 10])
csv_logger = CSVLogger(save_dir="hex2vec_logs")
with warnings.catch_warnings():
warnings.simplefilter("ignore")
embeddings = embedder.fit_transform(
regions_gdf,
features_gdf,
joint_gdf,
neighbourhood,
trainer_kwargs={"max_epochs": 5, "accelerator": "cpu", "logger": csv_logger},
batch_size=100,
)
embeddings
neighbourhood = H3Neighbourhood(regions_gdf)
embedder = Hex2VecEmbedder([15, 10])
csv_logger = CSVLogger(save_dir="hex2vec_logs")
with warnings.catch_warnings():
warnings.simplefilter("ignore")
embeddings = embedder.fit_transform(
regions_gdf,
features_gdf,
joint_gdf,
neighbourhood,
trainer_kwargs={"max_epochs": 5, "accelerator": "cpu", "logger": csv_logger},
batch_size=100,
)
embeddings
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
| Name | Type | Params | Mode ----------------------------------------------- 0 | encoder | Sequential | 280 | train ----------------------------------------------- 280 Trainable params 0 Non-trainable params 280 Total params 0.001 Total estimated model params size (MB) 4 Modules in train mode 0 Modules in eval mode
`Trainer.fit` stopped: `max_epochs=5` reached.
Out[7]:
| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | |
|---|---|---|---|---|---|---|---|---|---|---|
| region_id | ||||||||||
| 891e204246bffff | 0.308667 | -0.221366 | -0.016485 | 0.273814 | 0.223453 | -0.108205 | -0.269239 | -0.074259 | -0.185711 | 0.355687 |
| 891e2040d4bffff | 2.530500 | -0.389116 | -1.033928 | -0.391546 | 0.494700 | -3.941719 | 6.012843 | 0.225300 | 1.163661 | -3.452672 |
| 891e20436afffff | -0.453659 | 0.153939 | -0.385682 | -0.609465 | -0.041406 | 0.245630 | -0.310095 | 0.349441 | -0.131722 | -0.183063 |
| 891e2040653ffff | -0.312010 | -0.001030 | -0.280687 | -0.446141 | -0.001030 | 0.241800 | -0.387260 | 0.318737 | -0.096331 | -0.023286 |
| 891e20434bbffff | -0.312010 | -0.001030 | -0.280687 | -0.446141 | -0.001030 | 0.241800 | -0.387260 | 0.318737 | -0.096331 | -0.023286 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 891e204e403ffff | 0.308667 | -0.221366 | -0.016485 | 0.273814 | 0.223453 | -0.108205 | -0.269239 | -0.074259 | -0.185711 | 0.355687 |
| 891e20462cbffff | 0.338925 | -0.246208 | 0.198810 | 0.398731 | 0.243021 | -0.179302 | 0.049243 | -0.137178 | -0.013175 | 0.235332 |
| 891e2043547ffff | -0.573060 | 0.721390 | 0.434355 | 0.376330 | -0.440009 | 0.239848 | 0.343041 | -0.266044 | 0.662304 | -0.068802 |
| 891e2045647ffff | 0.308667 | -0.221366 | -0.016485 | 0.273814 | 0.223453 | -0.108205 | -0.269239 | -0.074259 | -0.185711 | 0.355687 |
| 891e2042537ffff | 0.308667 | -0.221366 | -0.016485 | 0.273814 | 0.223453 | -0.108205 | -0.269239 | -0.074259 | -0.185711 | 0.355687 |
3168 rows × 10 columns
In [8]:
Copied!
metrics_df = pd.read_csv(csv_logger.log_dir + "/metrics.csv").dropna(subset="train_f1_epoch")
fig, ax1 = plt.subplots(1, 1, figsize=(10, 5))
ax2 = ax1.twinx()
line1 = ax1.plot(metrics_df["epoch"], metrics_df["train_f1_epoch"])
line2 = ax2.plot(metrics_df["epoch"], metrics_df["train_loss_epoch"], color="orange")
ax1.legend(line1 + line2, ["F1", "Loss"], loc=7)
ax1.set_title("Training metrics")
ax1.set_ylabel("F1")
ax2.set_ylabel("Loss")
ax1.set_xlabel("Training epoch")
plt.show()
metrics_df = pd.read_csv(csv_logger.log_dir + "/metrics.csv").dropna(subset="train_f1_epoch")
fig, ax1 = plt.subplots(1, 1, figsize=(10, 5))
ax2 = ax1.twinx()
line1 = ax1.plot(metrics_df["epoch"], metrics_df["train_f1_epoch"])
line2 = ax2.plot(metrics_df["epoch"], metrics_df["train_loss_epoch"], color="orange")
ax1.legend(line1 + line2, ["F1", "Loss"], loc=7)
ax1.set_title("Training metrics")
ax1.set_ylabel("F1")
ax2.set_ylabel("Loss")
ax1.set_xlabel("Training epoch")
plt.show()
Visualizing the embeddings' similarity¶
In [9]:
Copied!
from sklearn.cluster import KMeans
clusterizer = KMeans(n_clusters=5, random_state=SEED)
clusterizer.fit(embeddings)
embeddings["cluster"] = clusterizer.labels_
embeddings
from sklearn.cluster import KMeans
clusterizer = KMeans(n_clusters=5, random_state=SEED)
clusterizer.fit(embeddings)
embeddings["cluster"] = clusterizer.labels_
embeddings
Out[9]:
| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | cluster | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| region_id | |||||||||||
| 891e204246bffff | 0.308667 | -0.221366 | -0.016485 | 0.273814 | 0.223453 | -0.108205 | -0.269239 | -0.074259 | -0.185711 | 0.355687 | 0 |
| 891e2040d4bffff | 2.530500 | -0.389116 | -1.033928 | -0.391546 | 0.494700 | -3.941719 | 6.012843 | 0.225300 | 1.163661 | -3.452672 | 4 |
| 891e20436afffff | -0.453659 | 0.153939 | -0.385682 | -0.609465 | -0.041406 | 0.245630 | -0.310095 | 0.349441 | -0.131722 | -0.183063 | 2 |
| 891e2040653ffff | -0.312010 | -0.001030 | -0.280687 | -0.446141 | -0.001030 | 0.241800 | -0.387260 | 0.318737 | -0.096331 | -0.023286 | 2 |
| 891e20434bbffff | -0.312010 | -0.001030 | -0.280687 | -0.446141 | -0.001030 | 0.241800 | -0.387260 | 0.318737 | -0.096331 | -0.023286 | 2 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 891e204e403ffff | 0.308667 | -0.221366 | -0.016485 | 0.273814 | 0.223453 | -0.108205 | -0.269239 | -0.074259 | -0.185711 | 0.355687 | 0 |
| 891e20462cbffff | 0.338925 | -0.246208 | 0.198810 | 0.398731 | 0.243021 | -0.179302 | 0.049243 | -0.137178 | -0.013175 | 0.235332 | 0 |
| 891e2043547ffff | -0.573060 | 0.721390 | 0.434355 | 0.376330 | -0.440009 | 0.239848 | 0.343041 | -0.266044 | 0.662304 | -0.068802 | 1 |
| 891e2045647ffff | 0.308667 | -0.221366 | -0.016485 | 0.273814 | 0.223453 | -0.108205 | -0.269239 | -0.074259 | -0.185711 | 0.355687 | 0 |
| 891e2042537ffff | 0.308667 | -0.221366 | -0.016485 | 0.273814 | 0.223453 | -0.108205 | -0.269239 | -0.074259 | -0.185711 | 0.355687 | 0 |
3168 rows × 11 columns
In [10]:
Copied!
plot_numeric_data(regions_gdf, "cluster", embeddings)
plot_numeric_data(regions_gdf, "cluster", embeddings)
Out[10]:
Make this Notebook Trusted to load map: File -> Trust Notebook