Spatial splitting with stratification¶
SRAI library contains a dedicated functions for splitting the points dataset into train / test (and optionally validation) splits by separating the points spatially while also keeping them stratified based on a given target.
The function only works for points dataset and uses H3
indexing system to cluster points together and separate H3 cells into different splits.
When working with most machine learning datasets, splitting into training and testing sets is straightforward: pick a random subset for testing, and (optionally) use stratification to keep the distribution of a target variable balanced between the two. This works fine when the data points are independent.
Geospatial data plays by different rules. Nearby locations often share similar characteristics - a phenomenon called spatial autocorrelation. If we split data randomly, our training and test sets might end up covering the same areas, meaning the model is “tested” on locations that are practically identical to ones it has already seen. This can make performance look much better than it really is and we can't test its capability to generalize the reasoning based on spatial features.
That’s why for geo-related tasks, we need spatial splitting: making sure the training and test sets are separated in space so that evaluation reflects real-world conditions. Sometimes we also want to stratify these spatial splits by a numerical value to ensure both sets still have similar value distributions. Standard train_test_split
functions can’t combine these two needs, so we provide a dedicated function for spatially aware splitting with optional stratification.
This notebook will show how different modes of splitting work based on buildings dataset from Overture Maps Foundation.
How does it work?¶
To separate the input dataset into multiple outputs, H3 indexing system is used to split groups of points together.
First, the algorithm transform the points into H3 cells with a given resolution and calculates statistics per H3 cell (number of points per bucket / category).
Next, all H3 cells are shuffled (with optional random_state
to ensure reproducibility) and iterated one by one.
For each split (test, validation, test) and each bucket per split, a current number of points is saved. While iterating each H3 cell with a group of points inside it, a potential new number of points is calculated with a difference to the expected ratio. Current H3 cell is assigned to the split where the difference to the expected ratio is the lowest.
After iterating all H3 cells, the original dataset of points is split based on the list of assigned H3 cells.
The report of splitting is printed with differences between expected and actual ratios.
import contextily as cx
import geopandas as gpd
import matplotlib.pyplot as plt
import overturemaestro as om
import pyarrow.compute as pc
import seaborn as sns
from matplotlib.axes import Axes
from matplotlib.patches import Patch
from sklearn.model_selection import train_test_split
from srai.h3 import h3_to_geoseries, shapely_geometry_to_h3
from srai.spatial_split import spatial_split_points, train_test_spatial_split
Let's start with downloading example data. Here we will use Vancouver buildings from Overture Maps dataset.
We only want buildings with both height
and subtype
columns filled.
Height will be used in the numerical split example and subtype in the categorical split example.
Because the splitting only works on points, we will assign a centroid to each building as an additional column. Centroids will be calculated in the corresponding projected Coordinate Reference System.
VANCOUVER_BOUNDING_BOX = (-123.148670, 49.255555, -123.076572, 49.296907)
VANCOUVER_PROJECTED_CRS = 26910 # NAD83 / UTM zone 10N
H3_RESOLUTION = 9
buildings = om.convert_bounding_box_to_geodataframe(
theme="buildings",
type="building",
bbox=VANCOUVER_BOUNDING_BOX,
release="2025-07-23.0",
pyarrow_filter=pc.field("subtype").is_valid() & pc.field("height").is_valid(),
columns_to_download=["subtype", "height"],
)
buildings["centroid"] = buildings.to_crs(VANCOUVER_PROJECTED_CRS).centroid.to_crs(4326)
buildings
Finished operation in 0:00:13
geometry | subtype | height | centroid | |
---|---|---|---|---|
id | ||||
194b888e-182e-4c49-89a5-9e44f18b4682 | POLYGON ((-123.1479 49.25701, -123.14821 49.25... | residential | 6.712387 | POINT (-123.14808 49.25695) |
e95cbac8-8f90-49c7-a636-6f0e0e193fe0 | POLYGON ((-123.14761 49.25752, -123.14816 49.2... | residential | 7.460882 | POINT (-123.14788 49.25745) |
4288241a-912e-4835-923c-240edd246863 | POLYGON ((-123.1471 49.25689, -123.14705 49.25... | residential | 6.850000 | POINT (-123.14719 49.25693) |
d4c6de14-a75f-4304-9a18-95447eec11d1 | POLYGON ((-123.14763 49.257, -123.14781 49.257... | residential | 8.560000 | POINT (-123.14772 49.25694) |
def45e2f-2423-46ea-91e8-035e60fa1e12 | POLYGON ((-123.14654 49.25667, -123.14665 49.2... | residential | 3.350000 | POINT (-123.14659 49.25665) |
... | ... | ... | ... | ... |
104af1b3-8da9-4d72-8d4f-e40f4adefd9c | POLYGON ((-123.0835 49.25951, -123.08363 49.25... | residential | 4.922461 | POINT (-123.08357 49.25944) |
c2a8a5cc-19ed-41de-93cf-e55b63264aa9 | POLYGON ((-123.0831 49.2594, -123.0831 49.2595... | commercial | 5.506591 | POINT (-123.08316 49.25949) |
4007a55a-8a43-4d1a-a55d-94a7a26c0f78 | POLYGON ((-123.0817 49.25721, -123.08169 49.25... | residential | 7.308841 | POINT (-123.08179 49.25725) |
30705051-74aa-436e-af3a-e1956bab30a1 | POLYGON ((-123.08187 49.2573, -123.08171 49.25... | residential | 6.070602 | POINT (-123.08179 49.25733) |
0318b079-1404-44ca-97c4-28f443080728 | POLYGON ((-123.08004 49.25661, -123.08007 49.2... | residential | 6.237479 | POINT (-123.08001 49.25655) |
3272 rows × 4 columns
First, let's see how the random split without spatial context looks like.
# train_test_split function from scikit-learn
random_train_gdf, random_test_gdf = train_test_split(buildings, test_size=0.2, random_state=42)
ax = random_train_gdf.plot(color="#1E88E5", figsize=(15, 12), label="train", legend=True)
random_test_gdf.plot(color="#FFC107", ax=ax, label="test", legend=True)
buildings.exterior.plot(ax=ax, color="black", linewidth=0.3, alpha=0.5)
ax.set_title("Vancouver buildings data - random split")
ax.legend(
handles=[Patch(facecolor="#1E88E5"), Patch(facecolor="#FFC107")],
labels=["Train", "Test"],
)
cx.add_basemap(
ax,
source=cx.providers.CartoDB.VoyagerNoLabels,
crs=4326,
zoom=15,
alpha=0.8,
)
ax.set_axis_off()
plt.show()
As shown, the buildings are split at random, resulting in both sets covering the same geographic area.
With this approach, you can’t properly evaluate the model’s ability to generalize based on spatial patterns.
Without target column - default¶
Target column isn't required for spatial splitting.
By default, the algorithm calculates a density of points per H3 cell and uses it for the for stratification. This way both splits have both dense and sparse regions in them.
train_default_gdf, test_default_gdf = train_test_spatial_split(
input_gdf=buildings,
parent_h3_resolution=H3_RESOLUTION,
geometry_column="centroid",
target_column=None,
test_size=0.2,
random_state=42,
)
train_default_gdf
Summary of the split: Train: 167 H3 cells (2651 points) Test: 43 H3 cells (621 points) Expected ratios: {'train': 0.8, 'validation': 0, 'test': 0.2} Actual ratios: {'train': 0.81, 'test': 0.19} Actual ratios difference: {'train': -0.01, 'test': 0.01} bucket train_ratio test_ratio train_ratio_difference \ 0 0 0.80000 0.20000 0.00000 1 1 0.79861 0.20139 0.00139 2 2 0.78400 0.21600 0.01600 3 3 0.80060 0.19940 -0.00060 4 4 0.80870 0.19130 -0.00870 5 5 0.79845 0.20155 0.00155 6 6 0.82357 0.17643 -0.02357 test_ratio_difference train_points test_points 0 0.00000 40 10 1 -0.00139 115 29 2 -0.01600 98 27 3 0.00060 269 67 4 0.00870 372 88 5 -0.00155 618 156 6 0.02357 1139 244
geometry | subtype | height | centroid | |
---|---|---|---|---|
id | ||||
194b888e-182e-4c49-89a5-9e44f18b4682 | POLYGON ((-123.1479 49.25701, -123.14821 49.25... | residential | 6.712387 | POINT (-123.14808 49.25695) |
e95cbac8-8f90-49c7-a636-6f0e0e193fe0 | POLYGON ((-123.14761 49.25752, -123.14816 49.2... | residential | 7.460882 | POINT (-123.14788 49.25745) |
4288241a-912e-4835-923c-240edd246863 | POLYGON ((-123.1471 49.25689, -123.14705 49.25... | residential | 6.850000 | POINT (-123.14719 49.25693) |
d4c6de14-a75f-4304-9a18-95447eec11d1 | POLYGON ((-123.14763 49.257, -123.14781 49.257... | residential | 8.560000 | POINT (-123.14772 49.25694) |
def45e2f-2423-46ea-91e8-035e60fa1e12 | POLYGON ((-123.14654 49.25667, -123.14665 49.2... | residential | 3.350000 | POINT (-123.14659 49.25665) |
... | ... | ... | ... | ... |
71998d2c-ad3e-438a-b705-3877b62396d6 | POLYGON ((-123.08488 49.25953, -123.08501 49.2... | residential | 5.109287 | POINT (-123.08495 49.25946) |
8b803974-00e9-491b-a420-0e881db8ca38 | POLYGON ((-123.08502 49.25953, -123.08514 49.2... | residential | 4.885464 | POINT (-123.08508 49.25946) |
4007a55a-8a43-4d1a-a55d-94a7a26c0f78 | POLYGON ((-123.0817 49.25721, -123.08169 49.25... | residential | 7.308841 | POINT (-123.08179 49.25725) |
30705051-74aa-436e-af3a-e1956bab30a1 | POLYGON ((-123.08187 49.2573, -123.08171 49.25... | residential | 6.070602 | POINT (-123.08179 49.25733) |
0318b079-1404-44ca-97c4-28f443080728 | POLYGON ((-123.08004 49.25661, -123.08007 49.2... | residential | 6.237479 | POINT (-123.08001 49.25655) |
2651 rows × 4 columns
covering_h3_cells = h3_to_geoseries(shapely_geometry_to_h3(buildings["centroid"], H3_RESOLUTION))
ax = train_default_gdf.plot(color="#1E88E5", figsize=(15, 12), zorder=2)
test_default_gdf.plot(color="#FFC107", ax=ax, zorder=2)
buildings.exterior.plot(ax=ax, color="black", linewidth=0.3, alpha=0.5, zorder=2)
covering_h3_cells.boundary.plot(color="black", ax=ax, linewidth=0.3, alpha=0.5, zorder=1)
ax.set_title("Vancouver buildings data - count split")
ax.legend(
handles=[Patch(facecolor="#1E88E5"), Patch(facecolor="#FFC107")],
labels=["Train", "Test"],
)
cx.add_basemap(
ax,
source=cx.providers.CartoDB.VoyagerNoLabels,
crs=4326,
zoom=15,
alpha=0.8,
)
ax.set_axis_off()
plt.show()
With numerical target column¶
If a target column is provided, it will be automatically treated as a numerical column, split into buckets (default: 7
) and stratified based on those buckets. The value distibution will be roughly the same in both splits.
train_height_gdf, test_height_gdf = train_test_spatial_split(
input_gdf=buildings,
parent_h3_resolution=9,
geometry_column="centroid",
target_column="height",
n_bins=7,
test_size=0.2,
random_state=42,
)
Summary of the split: Train: 139 H3 cells (2617 points) Test: 71 H3 cells (655 points) Expected ratios: {'train': 0.8, 'validation': 0, 'test': 0.2} Actual ratios: {'train': 0.8, 'test': 0.2} Actual ratios difference: {'train': 0.0, 'test': 0.0} bucket train_ratio test_ratio train_ratio_difference \ 0 0 0.79487 0.20513 0.00513 1 1 0.80300 0.19700 -0.00300 2 2 0.79872 0.20128 0.00128 3 3 0.80556 0.19444 -0.00556 4 4 0.79657 0.20343 0.00343 5 5 0.80086 0.19914 -0.00086 6 6 0.79915 0.20085 0.00085 test_ratio_difference train_points test_points 0 -0.00513 372 96 1 0.00300 375 92 2 -0.00128 373 94 3 0.00556 377 91 4 -0.00343 372 95 5 0.00086 374 93 6 -0.00085 374 94
ax = sns.kdeplot(
data=train_height_gdf,
x="height",
fill=True,
label="train",
log_scale=True,
)
sns.kdeplot(
data=test_height_gdf,
x="height",
fill=True,
label="test",
ax=ax,
)
ax.legend()
ax.set_xlim(left=1)
plt.show()
train_covering_h3_cells = h3_to_geoseries(
shapely_geometry_to_h3(train_height_gdf["centroid"], H3_RESOLUTION)
)
test_covering_h3_cells = h3_to_geoseries(
shapely_geometry_to_h3(test_height_gdf["centroid"], H3_RESOLUTION)
)
with plt.rc_context({"hatch.linewidth": 0.3}):
ax = buildings.plot(
gpd.pd.qcut(buildings["height"], 7),
figsize=(15, 12),
cmap="Spectral_r",
legend=True,
legend_kwds=dict(title="Height category (m)"),
zorder=2,
)
buildings.exterior.plot(ax=ax, color="black", linewidth=0.3, zorder=2, alpha=0.5)
train_covering_h3_cells.boundary.plot(color="black", ax=ax, linewidth=0.3, zorder=1)
test_covering_h3_cells.plot(
ax=ax,
linewidth=0.3,
color=(0, 0, 0, 0),
edgecolor="black",
hatch="//",
zorder=1,
)
ax2 = ax.twinx()
ax2.get_yaxis().set_visible(False)
ax2.legend(
handles=[
Patch(edgecolor="black", facecolor=(0, 0, 0, 0)),
Patch(edgecolor="black", facecolor=(0, 0, 0, 0), hatch="///"),
],
labels=["Train", "Test"],
loc=2,
)
ax.set_title("Vancouver buildings data - numerical split")
cx.add_basemap(
ax,
source=cx.providers.CartoDB.VoyagerNoLabels,
crs=4326,
zoom=15,
alpha=0.5,
)
ax.set_axis_off()
plt.show()
With categorical target column¶
Stratification can be also done based on the extisting categorical column, without using buckets.
In that case, the categorical
parameter must be set to True
.
buildings["subtype"].value_counts()
subtype residential 2662 commercial 401 industrial 52 education 34 religious 32 civic 23 medical 19 outbuilding 17 transportation 15 entertainment 12 service 5 Name: count, dtype: int64
train_categorical_gdf, test_categorical_gdf = train_test_spatial_split(
input_gdf=buildings,
parent_h3_resolution=9,
geometry_column="centroid",
target_column="subtype",
categorical=True,
test_size=0.2,
random_state=42,
)
Summary of the split: Train: 155 H3 cells (2603 points) Test: 55 H3 cells (669 points) Expected ratios: {'train': 0.8, 'validation': 0, 'test': 0.2} Actual ratios: {'train': 0.796, 'test': 0.204} Actual ratios difference: {'train': 0.004, 'test': -0.004} bucket train_ratio test_ratio train_ratio_difference \ 0 civic 0.82609 0.17391 -0.02609 1 commercial 0.80050 0.19950 -0.00050 2 education 0.79412 0.20588 0.00588 3 entertainment 0.83333 0.16667 -0.03333 4 industrial 0.80769 0.19231 -0.00769 5 medical 0.78947 0.21053 0.01053 6 outbuilding 0.82353 0.17647 -0.02353 7 religious 0.81250 0.18750 -0.01250 8 residential 0.79376 0.20624 0.00624 9 service 0.80000 0.20000 0.00000 10 transportation 0.80000 0.20000 0.00000 test_ratio_difference train_points test_points 0 0.02609 19 4 1 0.00050 321 80 2 -0.00588 27 7 3 0.03333 10 2 4 0.00769 42 10 5 -0.01053 15 4 6 0.02353 14 3 7 0.01250 26 6 8 -0.00624 2113 549 9 0.00000 4 1 10 0.00000 12 3
train_categories_stats = train_categorical_gdf["subtype"].value_counts().reset_index()
train_categories_stats["count"] /= train_categories_stats["count"].sum()
train_categories_stats["split"] = "train"
test_categories_stats = test_categorical_gdf["subtype"].value_counts().reset_index()
test_categories_stats["count"] /= test_categories_stats["count"].sum()
test_categories_stats["split"] = "test"
sns.barplot(
data=gpd.pd.concat([train_categories_stats, test_categories_stats]),
x="count",
y="subtype",
hue="split",
)
plt.show()
train_covering_h3_cells = h3_to_geoseries(
shapely_geometry_to_h3(train_categorical_gdf["centroid"], H3_RESOLUTION)
)
test_covering_h3_cells = h3_to_geoseries(
shapely_geometry_to_h3(test_categorical_gdf["centroid"], H3_RESOLUTION)
)
with plt.rc_context({"hatch.linewidth": 0.3}):
ax = buildings.plot(
"subtype",
categories=buildings["subtype"].value_counts().index,
figsize=(15, 12),
cmap="Set3",
legend=True,
legend_kwds=dict(title="Building subtype"),
zorder=2,
)
buildings.exterior.plot(ax=ax, color="black", linewidth=0.3, zorder=2, alpha=0.5)
train_covering_h3_cells.boundary.plot(color="black", ax=ax, linewidth=0.3, zorder=1)
test_covering_h3_cells.plot(
ax=ax,
linewidth=0.3,
color=(0, 0, 0, 0),
edgecolor="black",
hatch="//",
zorder=1,
)
ax2 = ax.twinx()
ax2.get_yaxis().set_visible(False)
ax2.legend(
handles=[
Patch(edgecolor="black", facecolor=(0, 0, 0, 0)),
Patch(edgecolor="black", facecolor=(0, 0, 0, 0), hatch="///"),
],
labels=["Train", "Test"],
loc=2,
)
ax.set_title("Vancouver buildings data - categorical split")
cx.add_basemap(
ax,
source=cx.providers.CartoDB.VoyagerNoLabels,
crs=4326,
zoom=15,
alpha=0.5,
)
ax.set_axis_off()
plt.show()
Splitting into three datasets at once¶
By using another function, spatial_split_points
, user can split the dataset into three groups at once (train, validation, test).
Usually users want to split data into train and test sets, and run the splitting again to get the validation set, but SRAI
exposes a function to split directly into 3 splits. This function returns a dictionary with splitted data.
splits = spatial_split_points(
input_gdf=buildings,
parent_h3_resolution=H3_RESOLUTION,
geometry_column="centroid",
target_column=None,
# Size can also be passed as an expected number of points, not only a fraction
test_size=1000,
validation_size=500,
random_state=42,
)
Summary of the split: Train: 109 H3 cells (1773 points) Validation: 34 H3 cells (519 points) Test: 67 H3 cells (980 points) Expected ratios: {'train': 0.541564792176039, 'validation': 0.1528117359413203, 'test': 0.3056234718826406} Actual ratios: {'train': 0.542, 'validation': 0.159, 'test': 0.3} Actual ratios difference: {'train': -0.0, 'validation': -0.006, 'test': 0.006} bucket train_ratio validation_ratio test_ratio train_ratio_difference \ 0 0 0.54000 0.16000 0.30000 0.00156 1 1 0.52778 0.15972 0.31250 0.01378 2 2 0.52800 0.16000 0.31200 0.01356 3 3 0.54167 0.14881 0.30952 -0.00011 4 4 0.53043 0.15870 0.31087 0.01113 5 5 0.53747 0.16279 0.29974 0.00409 6 6 0.55098 0.15835 0.29067 -0.00942 validation_ratio_difference test_ratio_difference train_points \ 0 -0.00719 0.00562 27 1 -0.00691 -0.00688 76 2 -0.00719 -0.00638 66 3 0.00400 -0.00390 182 4 -0.00589 -0.00525 244 5 -0.00998 0.00588 416 6 -0.00554 0.01495 762 validation_points test_points 0 8 15 1 23 45 2 20 39 3 50 104 4 73 143 5 126 232 6 219 402
print(splits.keys())
dict_keys(['train', 'validation', 'test'])
ax = splits["train"].plot(color="#1E88E5", figsize=(15, 12), zorder=2)
splits["test"].plot(color="#FFC107", ax=ax, zorder=2)
splits["validation"].plot(color="#D81B60", ax=ax, zorder=2)
buildings.exterior.plot(ax=ax, color="black", linewidth=0.3, alpha=0.5, zorder=2)
covering_h3_cells.boundary.plot(color="black", ax=ax, linewidth=0.3, alpha=0.5, zorder=1)
ax.set_title("Vancouver buildings data - count split into three sets")
ax.legend(
handles=[
Patch(facecolor="#1E88E5"),
Patch(facecolor="#FFC107"),
Patch(facecolor="#D81B60"),
],
labels=["Train", "Test", "Validation"],
)
cx.add_basemap(
ax,
source=cx.providers.CartoDB.VoyagerNoLabels,
crs=4326,
zoom=15,
alpha=0.8,
)
ax.set_axis_off()
plt.show()
Parse split report manually¶
You can use the return_split_stats
to get the splitting report as a pandas DataFrame and manually validate splitting ratios.
You can also use the verbose
parameter to disable the output.
splits, split_report = spatial_split_points(
input_gdf=buildings,
parent_h3_resolution=H3_RESOLUTION,
geometry_column="centroid",
target_column=None,
# Can also be passed as an expected number of points, not only a fraction
test_size=1000,
validation_size=500,
random_state=42,
return_split_stats=True,
verbose=False,
)
split_report
bucket | train_ratio | validation_ratio | test_ratio | train_ratio_difference | validation_ratio_difference | test_ratio_difference | train_points | validation_points | test_points | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 0 | 0.54000 | 0.16000 | 0.30000 | 0.00156 | -0.00719 | 0.00562 | 27 | 8 | 15 |
1 | 1 | 0.52778 | 0.15972 | 0.31250 | 0.01378 | -0.00691 | -0.00688 | 76 | 23 | 45 |
2 | 2 | 0.52800 | 0.16000 | 0.31200 | 0.01356 | -0.00719 | -0.00638 | 66 | 20 | 39 |
3 | 3 | 0.54167 | 0.14881 | 0.30952 | -0.00011 | 0.00400 | -0.00390 | 182 | 50 | 104 |
4 | 4 | 0.53043 | 0.15870 | 0.31087 | 0.01113 | -0.00589 | -0.00525 | 244 | 73 | 143 |
5 | 5 | 0.53747 | 0.16279 | 0.29974 | 0.00409 | -0.00998 | 0.00588 | 416 | 126 | 232 |
6 | 6 | 0.55098 | 0.15835 | 0.29067 | -0.00942 | -0.00554 | 0.01495 | 762 | 219 | 402 |
split_report[["train_ratio", "validation_ratio", "test_ratio"]].mean()
train_ratio 0.536619 validation_ratio 0.158339 test_ratio 0.305043 dtype: float64
split_report[
[
"train_ratio_difference",
"validation_ratio_difference",
"test_ratio_difference",
]
].mean()
train_ratio_difference 0.004941 validation_ratio_difference -0.005529 test_ratio_difference 0.000577 dtype: float64
split_report[["train_points", "validation_points", "test_points"]].sum()
train_points 1773 validation_points 519 test_points 980 dtype: int64
Different H3 resolutions¶
You can perform splitting at different H3 resolutions, and the choice of resolution will affect the results.
- Higher resolutions (smaller hexagons) produce a split ratio closer to your target, but the regions are physically closer together, which reduces true spatial separation.
- Lower resolutions (larger hexagons) improve spatial separation but may cause the actual split ratio to deviate more from the target.
Selecting proper H3 resolution
As a rule of thumb, choose the lowest resolution that still keeps the split ratio difference within an acceptable range for your use case.
def split_per_resolution(resolution: int, ax: Axes, h3_edge_alpha: float) -> None:
"""Split the data using given resolution."""
test_ratio = 0.4
_train_gdf, _test_gdf = train_test_spatial_split(
input_gdf=buildings,
parent_h3_resolution=resolution,
geometry_column="centroid",
target_column="height",
test_size=test_ratio,
random_state=42,
verbose=False,
)
buildings.exterior.plot(ax=ax, color="black", linewidth=0.3, alpha=0.5, zorder=2)
actual_test_ratio = len(_test_gdf) / len(buildings)
test_ratio_diff = test_ratio - actual_test_ratio
_train_covering_h3_cells = h3_to_geoseries(
shapely_geometry_to_h3(_train_gdf["centroid"], resolution)
)
_test_covering_h3_cells = h3_to_geoseries(
shapely_geometry_to_h3(_test_gdf["centroid"], resolution)
)
_train_gdf.plot(color="#1E88E5", ax=ax, zorder=2)
_test_gdf.plot(color="#FFC107", ax=ax, zorder=2)
_train_covering_h3_cells.boundary.plot(
color="black", ax=ax, linewidth=0.3, alpha=h3_edge_alpha, zorder=1
)
_test_covering_h3_cells.boundary.plot(
color="black", ax=ax, linewidth=0.3, alpha=h3_edge_alpha, zorder=1
)
ax.legend(
handles=[Patch(facecolor="#1E88E5"), Patch(facecolor="#FFC107")],
labels=["Train", "Test"],
)
ax.set_title(
f"Vancouver buildings data - numerical split (H3 resolution: {resolution})\n"
f"Expected test ratio: {test_ratio:.2f}, "
f"Actual test ratio: {actual_test_ratio:.2f}, "
f"Diff: {test_ratio_diff:.3f}"
)
cx.add_basemap(
ax,
source=cx.providers.CartoDB.VoyagerNoLabels,
crs=4326,
zoom=15,
alpha=0.8,
)
ax.set_axis_off()
with plt.rc_context({"hatch.linewidth": 0.3}):
fig, axes = plt.subplots(2, 2, figsize=(20, 18), sharex=True, sharey=True)
pairs = [
(7, axes[0][0], 1.0),
(8, axes[0][1], 0.9),
(9, axes[1][0], 0.8),
(10, axes[1][1], 0.7),
]
for h3_res, ax, h3_edge_alpha in pairs:
split_per_resolution(h3_res, ax, h3_edge_alpha)
buildings_bounds = buildings.total_bounds
for ax in axes.flatten():
ax.set_xlim(buildings_bounds[0] - 0.001, buildings_bounds[2] + 0.001)
ax.set_ylim(buildings_bounds[1] - 0.001, buildings_bounds[3] + 0.001)
plt.tight_layout()
plt.show()
What to do with timeseries data?¶
When working with geospatial datasets that include a time component — for example, store locations with monthly performance data over the past year — it’s important to consider how the split is performed.
If you split purely at the row level, the same store might appear in both training and test sets for different months. This creates data leakage: the model could learn store-specific patterns from the training set and then see almost the same data in the test set, inflating performance metrics.
A better approach is to split at the entity level. For stores, that means assigning each store to a single split (train or test) and including all its historical monthly records in that split. This ensures that the model is evaluated on entirely unseen stores, which is especially important when the goal is to build a whitespot model for identifying promising new locations.
Utilizing temporal component
If your dataset is big enough (data from multiple years), you can combine spatial splitting with temporal splitting to test how the model generalizes to both unseen stores and future time periods.
Example below will show you how to utilize monthly transaction data to split the locations for the whitespot analysis.
First, let's select only the commercial buildings from the dataset.
stores = buildings[buildings["subtype"] == "commercial"].copy()
stores
geometry | subtype | height | centroid | |
---|---|---|---|---|
id | ||||
4249f3e3-360b-42e2-8a0a-9f324b8f4c7d | POLYGON ((-123.14605 49.26473, -123.14588 49.2... | commercial | 6.687434 | POINT (-123.14596 49.26489) |
1ff56f02-1afb-4c78-bc10-f77b6767e93f | POLYGON ((-123.14556 49.265, -123.14556 49.264... | commercial | 6.753078 | POINT (-123.14541 49.26487) |
23e4f62b-22a8-4062-aa03-5bc1e4cd129c | POLYGON ((-123.1446 49.26499, -123.14501 49.26... | commercial | 11.917554 | POINT (-123.14483 49.26485) |
2b57bdcf-4997-48ee-a0ac-23cc9d78760c | POLYGON ((-123.14531 49.26425, -123.14525 49.2... | commercial | 10.611635 | POINT (-123.1451 49.26438) |
1683a7a9-f295-45e2-9276-fbd9f550f3aa | POLYGON ((-123.14627 49.26357, -123.14632 49.2... | commercial | 8.197035 | POINT (-123.1462 49.26345) |
... | ... | ... | ... | ... |
4f18f7b7-db10-4053-9472-c84231a18a56 | POLYGON ((-123.09372 49.26236, -123.09414 49.2... | commercial | 7.688835 | POINT (-123.09365 49.26227) |
c5de42ae-d2bd-4ca7-ad94-17389aa7a3f4 | POLYGON ((-123.08882 49.26278, -123.08898 49.2... | commercial | 5.626834 | POINT (-123.0889 49.2627) |
d7dd54ee-9b04-4e35-82ae-0704e40d0631 | POLYGON ((-123.08856 49.2629, -123.08856 49.26... | commercial | 7.098448 | POINT (-123.08849 49.26276) |
caae285e-45ca-4c1e-8adb-38de4e11d28d | POLYGON ((-123.07726 49.26484, -123.07727 49.2... | commercial | 8.902528 | POINT (-123.07709 49.2646) |
c2a8a5cc-19ed-41de-93cf-e55b63264aa9 | POLYGON ((-123.0831 49.2594, -123.0831 49.2595... | commercial | 5.506591 | POINT (-123.08316 49.25949) |
401 rows × 4 columns
ax = stores.plot(figsize=(15, 15))
cx.add_basemap(
ax,
source=cx.providers.CartoDB.VoyagerNoLabels,
crs=4326,
zoom=15,
alpha=0.8,
)
ax.set_axis_off()
ax.set_title("Vancouver - commercial buildings")
plt.show()
Now, we can generate the dummy monthly sales data for the last year.
import numpy as np
import pandas as pd
def generate_monthly_sales(store_ids: pd.Index, seed=None):
"""
Generate dummy monthly sales data for the past 12 months.
Args:
store_ids (pd.Index): IDs of locations.
seed (int, optional): Random seed for reproducibility.
Returns:
pd.DataFrame: Columns = ['location_id', 'month', 'sales']
"""
rng = np.random.default_rng(seed=seed)
# Generate month labels (last 12 months, newest last)
months = pd.date_range(end=pd.Timestamp.today(), periods=12, freq="M")
month_labels = months.strftime("%Y-%m").tolist()
# Seasonal multiplier with peak at December (sinusoidal pattern)
phases = 2 * np.pi * (months.month - 12) / 12.0
seasonal_factor = 1.0 + 0.2 * np.cos(phases)
data = []
for loc_id in store_ids:
# Start with a base sales value for this location
base_sales = rng.integers(8000, 20000)
# Create gradual monthly changes using a small random walk
gradual_changes = np.cumsum(rng.normal(loc=0, scale=300, size=12))
# Combine base + changes + seasonality
sales = (base_sales + gradual_changes) * seasonal_factor
# Ensure sales are positive
sales = np.clip(sales, 0, None)
# Append to dataset
for month_str, value in zip(month_labels, sales):
data.append((loc_id, month_str, round(value, 2)))
df = pd.DataFrame(data, columns=["id", "month", "sales"]).set_index("id")
return df
df_sales = generate_monthly_sales(store_ids=stores.index, seed=42)
df_sales
/tmp/ipykernel_5576/226688246.py:19: FutureWarning: 'M' is deprecated and will be removed in a future version, please use 'ME' instead. months = pd.date_range(end=pd.Timestamp.today(), periods=12, freq="M")
month | sales | |
---|---|---|
id | ||
4249f3e3-360b-42e2-8a0a-9f324b8f4c7d | 2024-09 | 8759.00 |
4249f3e3-360b-42e2-8a0a-9f324b8f4c7d | 2024-10 | 9882.55 |
4249f3e3-360b-42e2-8a0a-9f324b8f4c7d | 2024-11 | 10871.28 |
4249f3e3-360b-42e2-8a0a-9f324b8f4c7d | 2024-12 | 10417.20 |
4249f3e3-360b-42e2-8a0a-9f324b8f4c7d | 2025-01 | 9726.28 |
... | ... | ... |
c2a8a5cc-19ed-41de-93cf-e55b63264aa9 | 2025-04 | 13797.42 |
c2a8a5cc-19ed-41de-93cf-e55b63264aa9 | 2025-05 | 12472.93 |
c2a8a5cc-19ed-41de-93cf-e55b63264aa9 | 2025-06 | 12141.19 |
c2a8a5cc-19ed-41de-93cf-e55b63264aa9 | 2025-07 | 12741.54 |
c2a8a5cc-19ed-41de-93cf-e55b63264aa9 | 2025-08 | 13526.53 |
4812 rows × 2 columns
fig, ax = plt.subplots(figsize=(12, 6))
sns.lineplot(df_sales, x="month", y="sales", hue="id", legend=False, alpha=0.4, ax=ax)
ax.set_title("Monthly sales data per store")
ax.set_ylabel("Sales")
ax.set_xlabel("Month")
plt.tight_layout()
plt.show()
Now that we have stores locations and a dataframe with monthly sales data per location, we will calculate the average number of sales per month and use this information to stratify the spatial split.
mean_monthly_sales = df_sales.groupby("id")["sales"].mean() # You can also use median
sns.histplot(mean_monthly_sales, kde=True)
plt.show()
Now we have to assign the mean values to the original dataframe.
stores["mean_monthly_sales"] = mean_monthly_sales
stores
geometry | subtype | height | centroid | mean_monthly_sales | |
---|---|---|---|---|---|
id | |||||
4249f3e3-360b-42e2-8a0a-9f324b8f4c7d | POLYGON ((-123.14605 49.26473, -123.14588 49.2... | commercial | 6.687434 | POINT (-123.14596 49.26489) | 8523.560000 |
1ff56f02-1afb-4c78-bc10-f77b6767e93f | POLYGON ((-123.14556 49.265, -123.14556 49.264... | commercial | 6.753078 | POINT (-123.14541 49.26487) | 17558.229167 |
23e4f62b-22a8-4062-aa03-5bc1e4cd129c | POLYGON ((-123.1446 49.26499, -123.14501 49.26... | commercial | 11.917554 | POINT (-123.14483 49.26485) | 17858.095833 |
2b57bdcf-4997-48ee-a0ac-23cc9d78760c | POLYGON ((-123.14531 49.26425, -123.14525 49.2... | commercial | 10.611635 | POINT (-123.1451 49.26438) | 10637.268333 |
1683a7a9-f295-45e2-9276-fbd9f550f3aa | POLYGON ((-123.14627 49.26357, -123.14632 49.2... | commercial | 8.197035 | POINT (-123.1462 49.26345) | 17073.358333 |
... | ... | ... | ... | ... | ... |
4f18f7b7-db10-4053-9472-c84231a18a56 | POLYGON ((-123.09372 49.26236, -123.09414 49.2... | commercial | 7.688835 | POINT (-123.09365 49.26227) | 12521.415833 |
c5de42ae-d2bd-4ca7-ad94-17389aa7a3f4 | POLYGON ((-123.08882 49.26278, -123.08898 49.2... | commercial | 5.626834 | POINT (-123.0889 49.2627) | 10955.856667 |
d7dd54ee-9b04-4e35-82ae-0704e40d0631 | POLYGON ((-123.08856 49.2629, -123.08856 49.26... | commercial | 7.098448 | POINT (-123.08849 49.26276) | 8408.697500 |
caae285e-45ca-4c1e-8adb-38de4e11d28d | POLYGON ((-123.07726 49.26484, -123.07727 49.2... | commercial | 8.902528 | POINT (-123.07709 49.2646) | 14436.159167 |
c2a8a5cc-19ed-41de-93cf-e55b63264aa9 | POLYGON ((-123.0831 49.2594, -123.0831 49.2595... | commercial | 5.506591 | POINT (-123.08316 49.25949) | 15689.902500 |
401 rows × 5 columns
Let's do the spit based on the mean monthly sales. We will reduce the number of bins to decrease the actual ratio difference.
train_sales_gdf, test_sales_gdf = train_test_spatial_split(
input_gdf=stores,
parent_h3_resolution=8,
geometry_column="centroid",
target_column="mean_monthly_sales",
n_bins=5,
test_size=0.2,
random_state=42,
)
Summary of the split: Train: 26 H3 cells (318 points) Test: 10 H3 cells (83 points) Expected ratios: {'train': 0.8, 'validation': 0, 'test': 0.2} Actual ratios: {'train': 0.793, 'test': 0.207} Actual ratios difference: {'train': 0.007, 'test': -0.007} bucket train_ratio test_ratio train_ratio_difference \ 0 0 0.80247 0.19753 -0.00247 1 1 0.77500 0.22500 0.02500 2 2 0.80000 0.20000 0.00000 3 3 0.78750 0.21250 0.01250 4 4 0.80000 0.20000 0.00000 test_ratio_difference train_points test_points 0 0.00247 65 16 1 -0.02500 62 18 2 0.00000 64 16 3 -0.01250 63 17 4 0.00000 64 16
Here is the distribution between two sets.
ax = sns.kdeplot(
data=train_sales_gdf,
x="mean_monthly_sales",
fill=True,
label="train",
)
sns.kdeplot(
data=test_sales_gdf,
x="mean_monthly_sales",
fill=True,
label="test",
ax=ax,
)
ax.legend()
ax.set_title("Mean monthly sales distribution per split")
plt.show()
train_covering_h3_cells = h3_to_geoseries(shapely_geometry_to_h3(train_sales_gdf["centroid"], 8))
test_covering_h3_cells = h3_to_geoseries(shapely_geometry_to_h3(test_sales_gdf["centroid"], 8))
with plt.rc_context({"hatch.linewidth": 0.3}):
ax = stores.plot(
# gpd.pd.qcut(buildings["height"], 7),
"mean_monthly_sales",
figsize=(15, 10),
# cmap="Spectral_r",
cmap="RdYlBu_r",
legend=True,
legend_kwds=dict(
shrink=0.9,
orientation="horizontal",
pad=0.01,
label="Mean monthly sales",
aspect=60,
fraction=0.03,
),
zorder=2,
)
stores.exterior.plot(ax=ax, color="black", linewidth=0.3, zorder=2)
train_covering_h3_cells.boundary.plot(color="black", ax=ax, linewidth=0.3, zorder=1)
test_covering_h3_cells.plot(
ax=ax,
linewidth=0.3,
color=(0, 0, 0, 0),
edgecolor="black",
hatch="/",
zorder=1,
)
ax2 = ax.twinx()
ax2.get_yaxis().set_visible(False)
ax2.legend(
handles=[
Patch(edgecolor="black", facecolor=(0, 0, 0, 0)),
Patch(edgecolor="black", facecolor=(0, 0, 0, 0), hatch="///"),
],
labels=["Train", "Test"],
loc=2,
)
ax.set_title("Vancouver - commercial buildings - sales numerical split")
cx.add_basemap(
ax,
source=cx.providers.CartoDB.VoyagerNoLabels,
crs=4326,
zoom=15,
alpha=0.8,
)
ax.set_axis_off()
stores_bounds = stores.total_bounds
ax.set_xlim(stores_bounds[0] - 0.01, stores_bounds[2] + 0.01)
ax.set_ylim(stores_bounds[1] - 0.01, stores_bounds[3] + 0.01)
plt.show()
Now we can select transaction data based on location IDs.
train_store_sales = df_sales.loc[train_sales_gdf.index]
test_store_sales = df_sales.loc[test_sales_gdf.index]
print(len(train_store_sales), len(test_store_sales))
train_store_sales
3816 996
month | sales | |
---|---|---|
id | ||
6c18201f-b48f-44df-a121-e1d789a723e9 | 2024-09 | 14202.30 |
6c18201f-b48f-44df-a121-e1d789a723e9 | 2024-10 | 15069.89 |
6c18201f-b48f-44df-a121-e1d789a723e9 | 2024-11 | 15901.63 |
6c18201f-b48f-44df-a121-e1d789a723e9 | 2024-12 | 16245.45 |
6c18201f-b48f-44df-a121-e1d789a723e9 | 2025-01 | 16504.94 |
... | ... | ... |
c2a8a5cc-19ed-41de-93cf-e55b63264aa9 | 2025-04 | 13797.42 |
c2a8a5cc-19ed-41de-93cf-e55b63264aa9 | 2025-05 | 12472.93 |
c2a8a5cc-19ed-41de-93cf-e55b63264aa9 | 2025-06 | 12141.19 |
c2a8a5cc-19ed-41de-93cf-e55b63264aa9 | 2025-07 | 12741.54 |
c2a8a5cc-19ed-41de-93cf-e55b63264aa9 | 2025-08 | 13526.53 |
3816 rows × 2 columns
fig, ax = plt.subplots(figsize=(12, 6))
sns.lineplot(train_store_sales, x="month", y="sales", legend=True, ax=ax, label="train")
sns.lineplot(test_store_sales, x="month", y="sales", legend=True, ax=ax, label="test")
ax.set_title("Monthly sales data per store")
ax.set_ylabel("Sales")
ax.set_xlabel("Month")
plt.tight_layout()
plt.show()
ax = sns.kdeplot(
data=train_store_sales,
x="sales",
fill=True,
label="train",
)
sns.kdeplot(
data=test_store_sales,
x="sales",
fill=True,
label="test",
ax=ax,
)
ax.legend()
ax.set_title("Sales distribution per split")
plt.show()