S2vec embedder
In [1]:
Copied!
import warnings
import matplotlib.pyplot as plt
import pandas as pd
import torch
from pytorch_lightning import seed_everything
from pytorch_lightning.loggers import CSVLogger
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from srai.embedders import S2VecEmbedder
from srai.embedders.s2vec.s2_utils import get_patches_from_img_gdf
from srai.loaders import OSMPbfLoader
from srai.loaders.osm_loaders.filters import GEOFABRIK_LAYERS
from srai.plotting import plot_numeric_data, plot_regions
from srai.regionalizers import S2Regionalizer, geocode_to_region_gdf
import warnings
import matplotlib.pyplot as plt
import pandas as pd
import torch
from pytorch_lightning import seed_everything
from pytorch_lightning.loggers import CSVLogger
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from srai.embedders import S2VecEmbedder
from srai.embedders.s2vec.s2_utils import get_patches_from_img_gdf
from srai.loaders import OSMPbfLoader
from srai.loaders.osm_loaders.filters import GEOFABRIK_LAYERS
from srai.plotting import plot_numeric_data, plot_regions
from srai.regionalizers import S2Regionalizer, geocode_to_region_gdf
In [2]:
Copied!
SEED = 71
seed_everything(SEED)
SEED = 71
seed_everything(SEED)
Seed set to 71
Out[2]:
71
Load data from OSM¶
First use geocoding to get the area
In [3]:
Copied!
area_gdf = geocode_to_region_gdf("Wrocław, Poland")
plot_regions(area_gdf, tiles_style="CartoDB positron")
area_gdf = geocode_to_region_gdf("Wrocław, Poland")
plot_regions(area_gdf, tiles_style="CartoDB positron")
Out[3]:
Make this Notebook Trusted to load map: File -> Trust Notebook
In [4]:
Copied!
img_resolution = 12
patch_resolution = 16
img_regionalizer = S2Regionalizer(resolution=img_resolution, buffer=True)
img_s2_regions = img_regionalizer.transform(area_gdf.reset_index(drop=True))
img_s2_geometry = img_s2_regions.union_all()
print("Image regions:", len(img_s2_regions))
img_resolution = 12
patch_resolution = 16
img_regionalizer = S2Regionalizer(resolution=img_resolution, buffer=True)
img_s2_regions = img_regionalizer.transform(area_gdf.reset_index(drop=True))
img_s2_geometry = img_s2_regions.union_all()
print("Image regions:", len(img_s2_regions))
Image regions: 85
Download the Data¶
Next, download the data for the selected region and the specified tags.
In [5]:
Copied!
tags = GEOFABRIK_LAYERS
loader = OSMPbfLoader()
features_gdf = loader.load(img_s2_regions, tags)
tags = GEOFABRIK_LAYERS
loader = OSMPbfLoader()
features_gdf = loader.load(img_s2_regions, tags)
/opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/quackosm/osm_extracts/__init__.py:602: GeometryNotCoveredWarning: Skipping extract because of low IoU value (geofabrik_europe_poland_dolnoslaskie, 0.000265). warnings.warn(
Finished operation in 0:00:20
Prepare the data for embedding¶
After downloading the data, we need to prepare it for embedding. In the previous step we have regionalized the selected area and buffered it, now we have to join the features with prepared regions.
In [6]:
Copied!
plot_regions(img_s2_regions, tiles_style="CartoDB positron")
plot_regions(img_s2_regions, tiles_style="CartoDB positron")
Out[6]:
Make this Notebook Trusted to load map: File -> Trust Notebook
S2Vec Embedding¶
After preparing the data we can proceed with generating embeddings for the regions.
In [7]:
Copied!
embedder = S2VecEmbedder(
target_features=GEOFABRIK_LAYERS,
batch_size=8,
img_res=img_resolution,
patch_res=patch_resolution,
embedding_dim=64,
decoder_dim=32,
)
embedder = S2VecEmbedder(
target_features=GEOFABRIK_LAYERS,
batch_size=8,
img_res=img_resolution,
patch_res=patch_resolution,
embedding_dim=64,
decoder_dim=32,
)
In [8]:
Copied!
with warnings.catch_warnings():
warnings.simplefilter("ignore")
csv_logger = CSVLogger(save_dir="s2vec_logs")
embeddings = embedder.fit_transform(
regions_gdf=img_s2_regions,
features_gdf=features_gdf,
trainer_kwargs={
# "max_epochs": 20, # uncomment for a longer training
"max_epochs": 5,
"accelerator": ("cpu" if torch.backends.mps.is_available() else "auto"),
"logger": csv_logger,
},
learning_rate=0.001,
)
embeddings.head()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
csv_logger = CSVLogger(save_dir="s2vec_logs")
embeddings = embedder.fit_transform(
regions_gdf=img_s2_regions,
features_gdf=features_gdf,
trainer_kwargs={
# "max_epochs": 20, # uncomment for a longer training
"max_epochs": 5,
"accelerator": ("cpu" if torch.backends.mps.is_available() else "auto"),
"logger": csv_logger,
},
learning_rate=0.001,
)
embeddings.head()
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
| Name | Type | Params | Mode ----------------------------------------------------- 0 | patch_embed | Linear | 22.3 K | train 1 | encoder | MAEEncoder | 300 K | train 2 | decoder_embed | Linear | 2.1 K | train 3 | decoder | MAEDecoder | 36.9 K | train | other params | n/a | 24.8 K | n/a ----------------------------------------------------- 361 K Trainable params 24.7 K Non-trainable params 386 K Total params 1.544 Total estimated model params size (MB) 185 Modules in train mode 0 Modules in eval mode
`Trainer.fit` stopped: `max_epochs=5` reached.
Out[8]:
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
470f94801 | -1.479785 | -1.105935 | 0.842880 | 0.212038 | 0.863173 | -0.183333 | 0.286089 | 2.009435 | -0.689480 | 0.860651 | ... | -1.556884 | 0.128759 | 0.289503 | 1.585362 | 0.233475 | -0.086120 | 1.894688 | 0.476262 | 0.154510 | -0.829343 |
470f94803 | -1.504945 | -1.107852 | 0.825963 | 0.174862 | 0.865863 | -0.216160 | 0.282480 | 1.979087 | -0.722283 | 0.862320 | ... | -1.526596 | 0.081813 | 0.279964 | 1.606718 | 0.228851 | -0.032660 | 1.890381 | 0.487481 | 0.157385 | -0.892933 |
470f9481d | -1.510059 | -1.113408 | 0.839492 | 0.171124 | 0.869098 | -0.210906 | 0.285950 | 1.978307 | -0.723305 | 0.866281 | ... | -1.535362 | 0.060504 | 0.300072 | 1.620742 | 0.226636 | -0.032244 | 1.886249 | 0.493242 | 0.162692 | -0.885299 |
470f9481f | -1.509028 | -1.109931 | 0.861396 | 0.171525 | 0.867936 | -0.202461 | 0.295846 | 1.983008 | -0.715796 | 0.864069 | ... | -1.544552 | 0.064595 | 0.309477 | 1.620151 | 0.221593 | -0.038922 | 1.874383 | 0.502787 | 0.171163 | -0.853185 |
470f94821 | -1.500708 | -1.105572 | 0.873479 | 0.179596 | 0.862860 | -0.202229 | 0.307782 | 1.982434 | -0.705466 | 0.858610 | ... | -1.544924 | 0.089422 | 0.299647 | 1.609755 | 0.216551 | -0.046512 | 1.862943 | 0.506624 | 0.171887 | -0.823495 |
5 rows × 64 columns
In [9]:
Copied!
metrics_df = pd.read_csv(csv_logger.log_dir + "/metrics.csv").dropna(subset="train_loss_epoch")
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
line1 = ax.plot(metrics_df["epoch"], metrics_df["train_loss_epoch"])
ax.set_title("Training metrics")
ax.set_ylabel("Loss")
ax.set_xlabel("Epoch")
plt.show()
metrics_df = pd.read_csv(csv_logger.log_dir + "/metrics.csv").dropna(subset="train_loss_epoch")
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
line1 = ax.plot(metrics_df["epoch"], metrics_df["train_loss_epoch"])
ax.set_title("Training metrics")
ax.set_ylabel("Loss")
ax.set_xlabel("Epoch")
plt.show()
In [10]:
Copied!
patch_s2_regions, _ = get_patches_from_img_gdf(img_s2_regions, target_level=patch_resolution)
# do pca with three components and then cast to RGB
pca = PCA(n_components=3)
pca_embeddings = pca.fit_transform(embeddings)
# make the embeddings into a dataframe
pca_embeddings = pd.DataFrame(pca_embeddings, index=embeddings.index)
# convert to RGB
pca_embeddings = (
(pca_embeddings - pca_embeddings.min()) / (pca_embeddings.max() - pca_embeddings.min()) * 255
).astype(int)
# make the rgb array into a string
pca_embeddings["rgb"] = pca_embeddings.apply(
lambda row: f"rgb({row[0]}, {row[1]}, {row[2]})", axis=1
)
color_dict = dict(enumerate(patch_s2_regions.index.map(pca_embeddings["rgb"].to_dict()).to_list()))
patch_s2_regions.reset_index().reset_index().explore(
column="index",
tooltip="region_id",
tiles="CartoDB positron",
legend=False,
cmap=lambda x: color_dict[x],
style_kwds=dict(color="#444", opacity=0.0, fillOpacity=0.5),
)
patch_s2_regions, _ = get_patches_from_img_gdf(img_s2_regions, target_level=patch_resolution)
# do pca with three components and then cast to RGB
pca = PCA(n_components=3)
pca_embeddings = pca.fit_transform(embeddings)
# make the embeddings into a dataframe
pca_embeddings = pd.DataFrame(pca_embeddings, index=embeddings.index)
# convert to RGB
pca_embeddings = (
(pca_embeddings - pca_embeddings.min()) / (pca_embeddings.max() - pca_embeddings.min()) * 255
).astype(int)
# make the rgb array into a string
pca_embeddings["rgb"] = pca_embeddings.apply(
lambda row: f"rgb({row[0]}, {row[1]}, {row[2]})", axis=1
)
color_dict = dict(enumerate(patch_s2_regions.index.map(pca_embeddings["rgb"].to_dict()).to_list()))
patch_s2_regions.reset_index().reset_index().explore(
column="index",
tooltip="region_id",
tiles="CartoDB positron",
legend=False,
cmap=lambda x: color_dict[x],
style_kwds=dict(color="#444", opacity=0.0, fillOpacity=0.5),
)
Out[10]:
Make this Notebook Trusted to load map: File -> Trust Notebook