Spatial splitting with stratification¶
SRAI library contains a dedicated functions for splitting the points dataset into train / test (and optionally validation) splits by separating the points spatially while also keeping them stratified based on a given target.
The function only works for points dataset and uses H3
indexing system to cluster points together and separate H3 cells into different splits.
When working with most machine learning datasets, splitting into training and testing sets is straightforward: pick a random subset for testing, and (optionally) use stratification to keep the distribution of a target variable balanced between the two. This works fine when the data points are independent.
Geospatial data plays by different rules. Nearby locations often share similar characteristics - a phenomenon called spatial autocorrelation. If we split data randomly, our training and test sets might end up covering the same areas, meaning the model is “tested” on locations that are practically identical to ones it has already seen. This can make performance look much better than it really is and we can't test its capability to generalize the reasoning based on spatial features.
That’s why for geo-related tasks, we need spatial splitting: making sure the training and test sets are separated in space so that evaluation reflects real-world conditions. Sometimes we also want to stratify these spatial splits by a numerical value to ensure both sets still have similar value distributions. Standard train_test_split
functions can’t combine these two needs, so we provide a dedicated function for spatially aware splitting with optional stratification.
This notebook will show how different modes of splitting work based on buildings dataset from Overture Maps Foundation.
How does it work?¶
To separate the input dataset into multiple outputs, H3 indexing system is used to split groups of points together.
First, the algorithm transform the points into H3 cells with a given resolution and calculates statistics per H3 cell (number of points per bucket / category).
Next, all H3 cells are shuffled (with optional random_state
to ensure reproducibility) and iterated one by one.
For each split (test, validation, test) and each bucket per split, a current number of points is saved. While iterating each H3 cell with a group of points inside it, a potential new number of points is calculated with a difference to the expected ratio. Current H3 cell is assigned to the split where the difference to the expected ratio is the lowest.
After iterating all H3 cells, the original dataset of points is split based on the list of assigned H3 cells.
The report of splitting is printed with differences between expected and actual ratios.
import contextily as cx
import geopandas as gpd
import matplotlib.pyplot as plt
import overturemaestro as om
import pyarrow.compute as pc
import seaborn as sns
from matplotlib.axes import Axes
from matplotlib.patches import Patch
from sklearn.model_selection import train_test_split
from srai.h3 import h3_to_geoseries, shapely_geometry_to_h3
from srai.spatial_split import spatial_split_points, train_test_spatial_split
Let's start with downloading example data. Here we will use Vancouver buildings from Overture Maps dataset.
We only want buildings with both height
and subtype
columns filled.
Height will be used in the numerical split example and subtype in the categorical split example.
Because the splitting only works on points, we will assign a centroid to each building as an additional column. Centroids will be calculated in the corresponding projected Coordinate Reference System.
VANCOUVER_BOUNDING_BOX = (-123.148670, 49.255555, -123.076572, 49.296907)
VANCOUVER_PROJECTED_CRS = 26910 # NAD83 / UTM zone 10N
H3_RESOLUTION = 9
buildings = om.convert_bounding_box_to_geodataframe(
theme="buildings",
type="building",
bbox=VANCOUVER_BOUNDING_BOX,
pyarrow_filter=pc.field("subtype").is_valid() & pc.field("height").is_valid(),
columns_to_download=["subtype", "height"],
)
buildings["centroid"] = buildings.to_crs(VANCOUVER_PROJECTED_CRS).centroid.to_crs(4326)
buildings
Finished operation in 0:00:10
geometry | subtype | height | centroid | |
---|---|---|---|---|
id | ||||
194b888e-182e-4c49-89a5-9e44f18b4682 | POLYGON ((-123.1479 49.25701, -123.14821 49.25... | residential | 6.712387 | POINT (-123.14808 49.25695) |
e95cbac8-8f90-49c7-a636-6f0e0e193fe0 | POLYGON ((-123.14761 49.25752, -123.14816 49.2... | residential | 7.460882 | POINT (-123.14788 49.25745) |
4288241a-912e-4835-923c-240edd246863 | POLYGON ((-123.1471 49.25689, -123.14705 49.25... | residential | 6.850000 | POINT (-123.14719 49.25693) |
d4c6de14-a75f-4304-9a18-95447eec11d1 | POLYGON ((-123.14763 49.257, -123.14781 49.257... | residential | 8.560000 | POINT (-123.14772 49.25694) |
def45e2f-2423-46ea-91e8-035e60fa1e12 | POLYGON ((-123.14654 49.25667, -123.14665 49.2... | residential | 3.350000 | POINT (-123.14659 49.25665) |
... | ... | ... | ... | ... |
104af1b3-8da9-4d72-8d4f-e40f4adefd9c | POLYGON ((-123.0835 49.25951, -123.08363 49.25... | residential | 4.922461 | POINT (-123.08357 49.25944) |
c2a8a5cc-19ed-41de-93cf-e55b63264aa9 | POLYGON ((-123.0831 49.2594, -123.0831 49.2595... | commercial | 5.506591 | POINT (-123.08316 49.25949) |
30705051-74aa-436e-af3a-e1956bab30a1 | POLYGON ((-123.08187 49.2573, -123.08171 49.25... | residential | 6.070602 | POINT (-123.08179 49.25733) |
4007a55a-8a43-4d1a-a55d-94a7a26c0f78 | POLYGON ((-123.0817 49.25721, -123.08169 49.25... | residential | 7.308841 | POINT (-123.08179 49.25725) |
0318b079-1404-44ca-97c4-28f443080728 | POLYGON ((-123.08004 49.25661, -123.08007 49.2... | residential | 6.237479 | POINT (-123.08001 49.25655) |
3299 rows × 4 columns
First, let's see how the random split without spatial context looks like.
# train_test_split function from scikit-learn
random_train_gdf, random_test_gdf = train_test_split(
buildings, test_size=0.2, random_state=42
)
ax = random_train_gdf.plot(
color="#1E88E5", figsize=(15, 12), label="train", legend=True
)
random_test_gdf.plot(color="#FFC107", ax=ax, label="test", legend=True)
buildings.exterior.plot(ax=ax, color="black", linewidth=0.3, alpha=0.5)
ax.set_title("Vancouver buildings data - random split")
ax.legend(
handles=[Patch(facecolor="#1E88E5"), Patch(facecolor="#FFC107")],
labels=["Train", "Test"],
)
cx.add_basemap(
ax,
source=cx.providers.CartoDB.VoyagerNoLabels,
crs=4326,
zoom=15,
alpha=0.8,
)
ax.set_axis_off()
plt.show()
As shown, the buildings are split at random, resulting in both sets covering the same geographic area.
With this approach, you can’t properly evaluate the model’s ability to generalize based on spatial patterns.
Without target column - default¶
Target column isn't required for spatial splitting.
By default, the algorithm calculates a density of points per H3 cell and uses it for the for stratification. This way both splits have both dense and sparse regions in them.
train_default_gdf, test_default_gdf = train_test_spatial_split(
input_gdf=buildings,
parent_h3_resolution=H3_RESOLUTION,
geometry_column="centroid",
target_column=None,
test_size=0.2,
random_state=42,
)
train_default_gdf
Summary of the split: Train: 164 H3 cells (2623 points) Test: 46 H3 cells (676 points) Expected ratios: {'train': 0.8, 'validation': 0, 'test': 0.2} Actual ratios: {'train': 0.795, 'test': 0.205} Actual ratios difference: {'train': 0.005, 'test': -0.005} bucket train_ratio test_ratio train_ratio_difference \ 0 0 0.80000 0.20000 0.00000 1 1 0.80000 0.20000 0.00000 2 2 0.79293 0.20707 0.00707 3 3 0.80377 0.19623 -0.00377 4 4 0.80309 0.19691 -0.00309 5 5 0.78290 0.21710 0.01710 6 6 0.79654 0.20346 0.00346 test_ratio_difference train_points test_points 0 0.00000 40 10 1 0.00000 116 29 2 -0.00707 157 41 3 0.00377 213 52 4 0.00309 416 102 5 -0.01710 577 160 6 -0.00346 1104 282
geometry | subtype | height | centroid | |
---|---|---|---|---|
id | ||||
194b888e-182e-4c49-89a5-9e44f18b4682 | POLYGON ((-123.1479 49.25701, -123.14821 49.25... | residential | 6.712387 | POINT (-123.14808 49.25695) |
e95cbac8-8f90-49c7-a636-6f0e0e193fe0 | POLYGON ((-123.14761 49.25752, -123.14816 49.2... | residential | 7.460882 | POINT (-123.14788 49.25745) |
4288241a-912e-4835-923c-240edd246863 | POLYGON ((-123.1471 49.25689, -123.14705 49.25... | residential | 6.850000 | POINT (-123.14719 49.25693) |
d4c6de14-a75f-4304-9a18-95447eec11d1 | POLYGON ((-123.14763 49.257, -123.14781 49.257... | residential | 8.560000 | POINT (-123.14772 49.25694) |
def45e2f-2423-46ea-91e8-035e60fa1e12 | POLYGON ((-123.14654 49.25667, -123.14665 49.2... | residential | 3.350000 | POINT (-123.14659 49.25665) |
... | ... | ... | ... | ... |
104af1b3-8da9-4d72-8d4f-e40f4adefd9c | POLYGON ((-123.0835 49.25951, -123.08363 49.25... | residential | 4.922461 | POINT (-123.08357 49.25944) |
c2a8a5cc-19ed-41de-93cf-e55b63264aa9 | POLYGON ((-123.0831 49.2594, -123.0831 49.2595... | commercial | 5.506591 | POINT (-123.08316 49.25949) |
30705051-74aa-436e-af3a-e1956bab30a1 | POLYGON ((-123.08187 49.2573, -123.08171 49.25... | residential | 6.070602 | POINT (-123.08179 49.25733) |
4007a55a-8a43-4d1a-a55d-94a7a26c0f78 | POLYGON ((-123.0817 49.25721, -123.08169 49.25... | residential | 7.308841 | POINT (-123.08179 49.25725) |
0318b079-1404-44ca-97c4-28f443080728 | POLYGON ((-123.08004 49.25661, -123.08007 49.2... | residential | 6.237479 | POINT (-123.08001 49.25655) |
2623 rows × 4 columns
covering_h3_cells = h3_to_geoseries(
shapely_geometry_to_h3(buildings["centroid"], H3_RESOLUTION)
)
ax = train_default_gdf.plot(color="#1E88E5", figsize=(15, 12), zorder=2)
test_default_gdf.plot(color="#FFC107", ax=ax, zorder=2)
buildings.exterior.plot(ax=ax, color="black", linewidth=0.3, alpha=0.5, zorder=2)
covering_h3_cells.boundary.plot(
color="black", ax=ax, linewidth=0.3, alpha=0.5, zorder=1
)
ax.set_title("Vancouver buildings data - count split")
ax.legend(
handles=[Patch(facecolor="#1E88E5"), Patch(facecolor="#FFC107")],
labels=["Train", "Test"],
)
cx.add_basemap(
ax,
source=cx.providers.CartoDB.VoyagerNoLabels,
crs=4326,
zoom=15,
alpha=0.8,
)
ax.set_axis_off()
plt.show()
With numerical target column¶
If a target column is provided, it will be automatically treated as a numerical column, split into buckets (default: 7
) and stratified based on those buckets. The value distibution will be roughly the same in both splits.
train_height_gdf, test_height_gdf = train_test_spatial_split(
input_gdf=buildings,
parent_h3_resolution=9,
geometry_column="centroid",
target_column="height",
n_bins=7,
test_size=0.2,
random_state=42,
)
Summary of the split: Train: 152 H3 cells (2628 points) Test: 58 H3 cells (671 points) Expected ratios: {'train': 0.8, 'validation': 0, 'test': 0.2} Actual ratios: {'train': 0.797, 'test': 0.203} Actual ratios difference: {'train': 0.003, 'test': -0.003} bucket train_ratio test_ratio train_ratio_difference \ 0 0 0.79873 0.20127 0.00127 1 1 0.79830 0.20170 0.00170 2 2 0.80042 0.19958 -0.00042 3 3 0.80042 0.19958 -0.00042 4 4 0.77495 0.22505 0.02505 5 5 0.80255 0.19745 -0.00255 6 6 0.80085 0.19915 -0.00085 test_ratio_difference train_points test_points 0 -0.00127 377 95 1 -0.00170 376 95 2 0.00042 377 94 3 0.00042 377 94 4 -0.02505 365 106 5 0.00255 378 93 6 0.00085 378 94
ax = sns.kdeplot(
data=train_height_gdf,
x="height",
fill=True,
label="train",
log_scale=True,
)
sns.kdeplot(
data=test_height_gdf,
x="height",
fill=True,
label="test",
ax=ax,
)
ax.legend()
ax.set_xlim(left=1)
plt.show()
train_covering_h3_cells = h3_to_geoseries(
shapely_geometry_to_h3(train_height_gdf["centroid"], H3_RESOLUTION)
)
test_covering_h3_cells = h3_to_geoseries(
shapely_geometry_to_h3(test_height_gdf["centroid"], H3_RESOLUTION)
)
with plt.rc_context({"hatch.linewidth": 0.3}):
ax = buildings.plot(
gpd.pd.qcut(buildings["height"], 7),
figsize=(15, 12),
cmap="Spectral_r",
legend=True,
legend_kwds=dict(title="Height category (m)"),
zorder=2,
)
buildings.exterior.plot(ax=ax, color="black", linewidth=0.3, zorder=2, alpha=0.5)
train_covering_h3_cells.boundary.plot(color="black", ax=ax, linewidth=0.3, zorder=1)
test_covering_h3_cells.plot(
ax=ax,
linewidth=0.3,
color=(0, 0, 0, 0),
edgecolor="black",
hatch="//",
zorder=1,
)
ax2 = ax.twinx()
ax2.get_yaxis().set_visible(False)
ax2.legend(
handles=[
Patch(edgecolor="black", facecolor=(0, 0, 0, 0)),
Patch(edgecolor="black", facecolor=(0, 0, 0, 0), hatch="///"),
],
labels=["Train", "Test"],
loc=2,
)
ax.set_title("Vancouver buildings data - numerical split")
cx.add_basemap(
ax,
source=cx.providers.CartoDB.VoyagerNoLabels,
crs=4326,
zoom=15,
alpha=0.5,
)
ax.set_axis_off()
plt.show()
With categorical target column¶
Stratification can be also done based on the extisting categorical column, without using buckets.
In that case, the categorical
parameter must be set to True
.
buildings["subtype"].value_counts()
subtype residential 2686 commercial 405 industrial 51 education 34 religious 32 civic 23 medical 19 outbuilding 17 transportation 15 entertainment 12 service 5 Name: count, dtype: int64
train_categorical_gdf, test_categorical_gdf = train_test_spatial_split(
input_gdf=buildings,
parent_h3_resolution=9,
geometry_column="centroid",
target_column="subtype",
categorical=True,
test_size=0.2,
random_state=42,
)
Summary of the split: Train: 151 H3 cells (2628 points) Test: 59 H3 cells (671 points) Expected ratios: {'train': 0.8, 'validation': 0, 'test': 0.2} Actual ratios: {'train': 0.797, 'test': 0.203} Actual ratios difference: {'train': 0.003, 'test': -0.003} bucket train_ratio test_ratio train_ratio_difference \ 0 civic 0.78261 0.21739 0.01739 1 commercial 0.79506 0.20494 0.00494 2 education 0.79412 0.20588 0.00588 3 entertainment 0.83333 0.16667 -0.03333 4 industrial 0.80392 0.19608 -0.00392 5 medical 0.78947 0.21053 0.01053 6 outbuilding 0.82353 0.17647 -0.02353 7 religious 0.81250 0.18750 -0.01250 8 residential 0.79635 0.20365 0.00365 9 service 0.80000 0.20000 0.00000 10 transportation 0.80000 0.20000 0.00000 test_ratio_difference train_points test_points 0 -0.01739 18 5 1 -0.00494 322 83 2 -0.00588 27 7 3 0.03333 10 2 4 0.00392 41 10 5 -0.01053 15 4 6 0.02353 14 3 7 0.01250 26 6 8 -0.00365 2139 547 9 0.00000 4 1 10 0.00000 12 3
train_categories_stats = train_categorical_gdf["subtype"].value_counts().reset_index()
train_categories_stats["count"] /= train_categories_stats["count"].sum()
train_categories_stats["split"] = "train"
test_categories_stats = test_categorical_gdf["subtype"].value_counts().reset_index()
test_categories_stats["count"] /= test_categories_stats["count"].sum()
test_categories_stats["split"] = "test"
sns.barplot(
data=gpd.pd.concat([train_categories_stats, test_categories_stats]),
x="count",
y="subtype",
hue="split",
)
plt.show()
train_covering_h3_cells = h3_to_geoseries(
shapely_geometry_to_h3(train_categorical_gdf["centroid"], H3_RESOLUTION)
)
test_covering_h3_cells = h3_to_geoseries(
shapely_geometry_to_h3(test_categorical_gdf["centroid"], H3_RESOLUTION)
)
with plt.rc_context({"hatch.linewidth": 0.3}):
ax = buildings.plot(
"subtype",
categories=buildings["subtype"].value_counts().index,
figsize=(15, 12),
cmap="Set3",
legend=True,
legend_kwds=dict(title="Building subtype"),
zorder=2,
)
buildings.exterior.plot(ax=ax, color="black", linewidth=0.3, zorder=2, alpha=0.5)
train_covering_h3_cells.boundary.plot(color="black", ax=ax, linewidth=0.3, zorder=1)
test_covering_h3_cells.plot(
ax=ax,
linewidth=0.3,
color=(0, 0, 0, 0),
edgecolor="black",
hatch="//",
zorder=1,
)
ax2 = ax.twinx()
ax2.get_yaxis().set_visible(False)
ax2.legend(
handles=[
Patch(edgecolor="black", facecolor=(0, 0, 0, 0)),
Patch(edgecolor="black", facecolor=(0, 0, 0, 0), hatch="///"),
],
labels=["Train", "Test"],
loc=2,
)
ax.set_title("Vancouver buildings data - categorical split")
cx.add_basemap(
ax,
source=cx.providers.CartoDB.VoyagerNoLabels,
crs=4326,
zoom=15,
alpha=0.5,
)
ax.set_axis_off()
plt.show()
Splitting into three datasets at once¶
By using another function, spatial_split_points
, user can split the dataset into three groups at once (train, validation, test).
Usually users want to split data into train and test sets, and run the splitting again to get the validation set, but SRAI
exposes a function to split directly into 3 splits. This function returns a dictionary with splitted data.
splits = spatial_split_points(
input_gdf=buildings,
parent_h3_resolution=H3_RESOLUTION,
geometry_column="centroid",
target_column=None,
# Size can also be passed as an expected number of points, not only a fraction
test_size=1000,
validation_size=500,
random_state=42,
)
Summary of the split: Train: 113 H3 cells (1835 points) Validation: 33 H3 cells (489 points) Test: 64 H3 cells (975 points) Expected ratios: {'train': 0.5453167626553501, 'validation': 0.15156107911488328, 'test': 0.30312215822976657} Actual ratios: {'train': 0.556, 'validation': 0.148, 'test': 0.296} Actual ratios difference: {'train': -0.011, 'validation': 0.004, 'test': 0.007} bucket train_ratio validation_ratio test_ratio train_ratio_difference \ 0 0 0.54000 0.16000 0.30000 0.00532 1 1 0.54483 0.15862 0.29655 0.00049 2 2 0.52525 0.16162 0.31313 0.02007 3 3 0.54340 0.16604 0.29057 0.00192 4 4 0.55212 0.13900 0.30888 -0.00680 5 5 0.54953 0.14247 0.30801 -0.00421 6 6 0.56999 0.14791 0.28211 -0.02467 validation_ratio_difference test_ratio_difference train_points \ 0 -0.00844 0.00312 27 1 -0.00706 0.00657 79 2 -0.01006 -0.01001 104 3 -0.01448 0.01255 144 4 0.01256 -0.00576 286 5 0.00909 -0.00489 405 6 0.00365 0.02101 790 validation_points test_points 0 8 15 1 23 43 2 32 62 3 44 77 4 72 160 5 105 227 6 205 391
print(splits.keys())
dict_keys(['train', 'validation', 'test'])
ax = splits["train"].plot(color="#1E88E5", figsize=(15, 12), zorder=2)
splits["test"].plot(color="#FFC107", ax=ax, zorder=2)
splits["validation"].plot(color="#D81B60", ax=ax, zorder=2)
buildings.exterior.plot(ax=ax, color="black", linewidth=0.3, alpha=0.5, zorder=2)
covering_h3_cells.boundary.plot(
color="black", ax=ax, linewidth=0.3, alpha=0.5, zorder=1
)
ax.set_title("Vancouver buildings data - count split into three sets")
ax.legend(
handles=[
Patch(facecolor="#1E88E5"),
Patch(facecolor="#FFC107"),
Patch(facecolor="#D81B60"),
],
labels=["Train", "Test", "Validation"],
)
cx.add_basemap(
ax,
source=cx.providers.CartoDB.VoyagerNoLabels,
crs=4326,
zoom=15,
alpha=0.8,
)
ax.set_axis_off()
plt.show()
Parse split report manually¶
You can use the return_split_stats
to get the splitting report as a pandas DataFrame and manually validate splitting ratios.
You can also use the verbose
parameter to disable the output.
splits, split_report = spatial_split_points(
input_gdf=buildings,
parent_h3_resolution=H3_RESOLUTION,
geometry_column="centroid",
target_column=None,
# Can also be passed as an expected number of points, not only a fraction
test_size=1000,
validation_size=500,
random_state=42,
return_split_stats=True,
verbose=False,
)
split_report
bucket | train_ratio | validation_ratio | test_ratio | train_ratio_difference | validation_ratio_difference | test_ratio_difference | train_points | validation_points | test_points | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 0 | 0.54000 | 0.16000 | 0.30000 | 0.00532 | -0.00844 | 0.00312 | 27 | 8 | 15 |
1 | 1 | 0.54483 | 0.15862 | 0.29655 | 0.00049 | -0.00706 | 0.00657 | 79 | 23 | 43 |
2 | 2 | 0.52525 | 0.16162 | 0.31313 | 0.02007 | -0.01006 | -0.01001 | 104 | 32 | 62 |
3 | 3 | 0.54340 | 0.16604 | 0.29057 | 0.00192 | -0.01448 | 0.01255 | 144 | 44 | 77 |
4 | 4 | 0.55212 | 0.13900 | 0.30888 | -0.00680 | 0.01256 | -0.00576 | 286 | 72 | 160 |
5 | 5 | 0.54953 | 0.14247 | 0.30801 | -0.00421 | 0.00909 | -0.00489 | 405 | 105 | 227 |
6 | 6 | 0.56999 | 0.14791 | 0.28211 | -0.02467 | 0.00365 | 0.02101 | 790 | 205 | 391 |
split_report[["train_ratio", "validation_ratio", "test_ratio"]].mean()
train_ratio 0.546446 validation_ratio 0.153666 test_ratio 0.299893 dtype: float64
split_report[
[
"train_ratio_difference",
"validation_ratio_difference",
"test_ratio_difference",
]
].mean()
train_ratio_difference -0.001126 validation_ratio_difference -0.002106 test_ratio_difference 0.003227 dtype: float64
split_report[["train_points", "validation_points", "test_points"]].sum()
train_points 1835 validation_points 489 test_points 975 dtype: int64
Different H3 resolutions¶
You can perform splitting at different H3 resolutions, and the choice of resolution will affect the results.
- Higher resolutions (smaller hexagons) produce a split ratio closer to your target, but the regions are physically closer together, which reduces true spatial separation.
- Lower resolutions (larger hexagons) improve spatial separation but may cause the actual split ratio to deviate more from the target.
Selecting proper H3 resolution
As a rule of thumb, choose the lowest resolution that still keeps the split ratio difference within an acceptable range for your use case.
def split_per_resolution(resolution: int, ax: Axes, h3_edge_alpha: float) -> None:
"""Split the data using given resolution."""
test_ratio = 0.4
_train_gdf, _test_gdf = train_test_spatial_split(
input_gdf=buildings,
parent_h3_resolution=resolution,
geometry_column="centroid",
target_column="height",
test_size=test_ratio,
random_state=42,
verbose=False,
)
buildings.exterior.plot(ax=ax, color="black", linewidth=0.3, alpha=0.5, zorder=2)
actual_test_ratio = len(_test_gdf) / len(buildings)
test_ratio_diff = test_ratio - actual_test_ratio
_train_covering_h3_cells = h3_to_geoseries(
shapely_geometry_to_h3(_train_gdf["centroid"], resolution)
)
_test_covering_h3_cells = h3_to_geoseries(
shapely_geometry_to_h3(_test_gdf["centroid"], resolution)
)
_train_gdf.plot(color="#1E88E5", ax=ax, zorder=2)
_test_gdf.plot(color="#FFC107", ax=ax, zorder=2)
_train_covering_h3_cells.boundary.plot(
color="black", ax=ax, linewidth=0.3, alpha=h3_edge_alpha, zorder=1
)
_test_covering_h3_cells.boundary.plot(
color="black", ax=ax, linewidth=0.3, alpha=h3_edge_alpha, zorder=1
)
ax.legend(
handles=[Patch(facecolor="#1E88E5"), Patch(facecolor="#FFC107")],
labels=["Train", "Test"],
)
ax.set_title(
f"Vancouver buildings data - numerical split (H3 resolution: {resolution})\n"
f"Expected test ratio: {test_ratio:.2f}, "
f"Actual test ratio: {actual_test_ratio:.2f}, "
f"Diff: {test_ratio_diff:.3f}"
)
cx.add_basemap(
ax,
source=cx.providers.CartoDB.VoyagerNoLabels,
crs=4326,
zoom=15,
alpha=0.8,
)
ax.set_axis_off()
with plt.rc_context({"hatch.linewidth": 0.3}):
fig, axes = plt.subplots(2, 2, figsize=(20, 18), sharex=True, sharey=True)
pairs = [
(7, axes[0][0], 1.0),
(8, axes[0][1], 0.9),
(9, axes[1][0], 0.8),
(10, axes[1][1], 0.7),
]
for h3_res, ax, h3_edge_alpha in pairs:
split_per_resolution(h3_res, ax, h3_edge_alpha)
buildings_bounds = buildings.total_bounds
for ax in axes.flatten():
ax.set_xlim(buildings_bounds[0] - 0.001, buildings_bounds[2] + 0.001)
ax.set_ylim(buildings_bounds[1] - 0.001, buildings_bounds[3] + 0.001)
plt.tight_layout()
plt.show()
What to do with timeseries data?¶
When working with geospatial datasets that include a time component — for example, store locations with monthly performance data over the past year — it’s important to consider how the split is performed.
If you split purely at the row level, the same store might appear in both training and test sets for different months. This creates data leakage: the model could learn store-specific patterns from the training set and then see almost the same data in the test set, inflating performance metrics.
A better approach is to split at the entity level. For stores, that means assigning each store to a single split (train or test) and including all its historical monthly records in that split. This ensures that the model is evaluated on entirely unseen stores, which is especially important when the goal is to build a whitespot model for identifying promising new locations.
Utilizing temporal component
If your dataset is big enough (data from multiple years), you can combine spatial splitting with temporal splitting to test how the model generalizes to both unseen stores and future time periods.
Example below will show you how to utilize monthly transaction data to split the locations for the whitespot analysis.
First, let's select only the commercial buildings from the dataset.
stores = buildings[buildings["subtype"] == "commercial"].copy()
stores
geometry | subtype | height | centroid | |
---|---|---|---|---|
id | ||||
4249f3e3-360b-42e2-8a0a-9f324b8f4c7d | POLYGON ((-123.14605 49.26473, -123.14588 49.2... | commercial | 6.687434 | POINT (-123.14596 49.26489) |
1ff56f02-1afb-4c78-bc10-f77b6767e93f | POLYGON ((-123.14556 49.265, -123.14556 49.264... | commercial | 6.753078 | POINT (-123.14541 49.26487) |
23e4f62b-22a8-4062-aa03-5bc1e4cd129c | POLYGON ((-123.1446 49.26499, -123.14501 49.26... | commercial | 11.917554 | POINT (-123.14483 49.26485) |
2b57bdcf-4997-48ee-a0ac-23cc9d78760c | POLYGON ((-123.14531 49.26425, -123.14525 49.2... | commercial | 10.611635 | POINT (-123.1451 49.26438) |
1683a7a9-f295-45e2-9276-fbd9f550f3aa | POLYGON ((-123.14627 49.26357, -123.14632 49.2... | commercial | 8.197035 | POINT (-123.1462 49.26345) |
... | ... | ... | ... | ... |
4f18f7b7-db10-4053-9472-c84231a18a56 | POLYGON ((-123.09372 49.26236, -123.09414 49.2... | commercial | 7.688835 | POINT (-123.09365 49.26227) |
c5de42ae-d2bd-4ca7-ad94-17389aa7a3f4 | POLYGON ((-123.08882 49.26278, -123.08898 49.2... | commercial | 5.626834 | POINT (-123.0889 49.2627) |
d7dd54ee-9b04-4e35-82ae-0704e40d0631 | POLYGON ((-123.08856 49.2629, -123.08856 49.26... | commercial | 7.098448 | POINT (-123.08849 49.26276) |
caae285e-45ca-4c1e-8adb-38de4e11d28d | POLYGON ((-123.07726 49.26484, -123.07727 49.2... | commercial | 8.902528 | POINT (-123.07709 49.2646) |
c2a8a5cc-19ed-41de-93cf-e55b63264aa9 | POLYGON ((-123.0831 49.2594, -123.0831 49.2595... | commercial | 5.506591 | POINT (-123.08316 49.25949) |
405 rows × 4 columns
ax = stores.plot(figsize=(15, 15))
cx.add_basemap(
ax,
source=cx.providers.CartoDB.VoyagerNoLabels,
crs=4326,
zoom=15,
alpha=0.8,
)
ax.set_axis_off()
ax.set_title("Vancouver - commercial buildings")
plt.show()
Now, we can generate the dummy monthly sales data for the last year.
import numpy as np
import pandas as pd
def generate_monthly_sales(store_ids: pd.Index, seed=None):
"""
Generate dummy monthly sales data for the past 12 months.
Args:
store_ids (pd.Index): IDs of locations.
seed (int, optional): Random seed for reproducibility.
Returns:
pd.DataFrame: Columns = ['location_id', 'month', 'sales']
"""
rng = np.random.default_rng(seed=seed)
# Generate month labels (last 12 months, newest last)
months = pd.date_range(end=pd.Timestamp.today(), periods=12, freq="M")
month_labels = months.strftime("%Y-%m").tolist()
# Seasonal multiplier with peak at December (sinusoidal pattern)
phases = 2 * np.pi * (months.month - 12) / 12.0
seasonal_factor = 1.0 + 0.2 * np.cos(phases)
data = []
for loc_id in store_ids:
# Start with a base sales value for this location
base_sales = rng.integers(8000, 20000)
# Create gradual monthly changes using a small random walk
gradual_changes = np.cumsum(rng.normal(loc=0, scale=300, size=12))
# Combine base + changes + seasonality
sales = (base_sales + gradual_changes) * seasonal_factor
# Ensure sales are positive
sales = np.clip(sales, 0, None)
# Append to dataset
for month_str, value in zip(month_labels, sales):
data.append((loc_id, month_str, round(value, 2)))
df = pd.DataFrame(data, columns=["id", "month", "sales"]).set_index("id")
return df
df_sales = generate_monthly_sales(store_ids=stores.index, seed=42)
df_sales
/tmp/ipykernel_13184/226688246.py:19: FutureWarning: 'M' is deprecated and will be removed in a future version, please use 'ME' instead. months = pd.date_range(end=pd.Timestamp.today(), periods=12, freq="M")
month | sales | |
---|---|---|
id | ||
4249f3e3-360b-42e2-8a0a-9f324b8f4c7d | 2024-10 | 9634.91 |
4249f3e3-360b-42e2-8a0a-9f324b8f4c7d | 2024-11 | 10540.24 |
4249f3e3-360b-42e2-8a0a-9f324b8f4c7d | 2024-12 | 11119.57 |
4249f3e3-360b-42e2-8a0a-9f324b8f4c7d | 2025-01 | 10184.59 |
4249f3e3-360b-42e2-8a0a-9f324b8f4c7d | 2025-02 | 9119.38 |
... | ... | ... |
c2a8a5cc-19ed-41de-93cf-e55b63264aa9 | 2025-05 | 13708.08 |
c2a8a5cc-19ed-41de-93cf-e55b63264aa9 | 2025-06 | 13281.32 |
c2a8a5cc-19ed-41de-93cf-e55b63264aa9 | 2025-07 | 13734.69 |
c2a8a5cc-19ed-41de-93cf-e55b63264aa9 | 2025-08 | 14667.70 |
c2a8a5cc-19ed-41de-93cf-e55b63264aa9 | 2025-09 | 16537.98 |
4860 rows × 2 columns
fig, ax = plt.subplots(figsize=(12, 6))
sns.lineplot(df_sales, x="month", y="sales", hue="id", legend=False, alpha=0.4, ax=ax)
ax.set_title("Monthly sales data per store")
ax.set_ylabel("Sales")
ax.set_xlabel("Month")
plt.tight_layout()
plt.show()
Now that we have stores locations and a dataframe with monthly sales data per location, we will calculate the average number of sales per month and use this information to stratify the spatial split.
mean_monthly_sales = df_sales.groupby("id")["sales"].mean() # You can also use median
sns.histplot(mean_monthly_sales, kde=True)
plt.show()
Now we have to assign the mean values to the original dataframe.
stores["mean_monthly_sales"] = mean_monthly_sales
stores
geometry | subtype | height | centroid | mean_monthly_sales | |
---|---|---|---|---|---|
id | |||||
4249f3e3-360b-42e2-8a0a-9f324b8f4c7d | POLYGON ((-123.14605 49.26473, -123.14588 49.2... | commercial | 6.687434 | POINT (-123.14596 49.26489) | 8536.880833 |
1ff56f02-1afb-4c78-bc10-f77b6767e93f | POLYGON ((-123.14556 49.265, -123.14556 49.264... | commercial | 6.753078 | POINT (-123.14541 49.26487) | 17561.853333 |
23e4f62b-22a8-4062-aa03-5bc1e4cd129c | POLYGON ((-123.1446 49.26499, -123.14501 49.26... | commercial | 11.917554 | POINT (-123.14483 49.26485) | 17848.665833 |
2b57bdcf-4997-48ee-a0ac-23cc9d78760c | POLYGON ((-123.14531 49.26425, -123.14525 49.2... | commercial | 10.611635 | POINT (-123.1451 49.26438) | 10640.552500 |
1683a7a9-f295-45e2-9276-fbd9f550f3aa | POLYGON ((-123.14627 49.26357, -123.14632 49.2... | commercial | 8.197035 | POINT (-123.1462 49.26345) | 17076.764167 |
... | ... | ... | ... | ... | ... |
4f18f7b7-db10-4053-9472-c84231a18a56 | POLYGON ((-123.09372 49.26236, -123.09414 49.2... | commercial | 7.688835 | POINT (-123.09365 49.26227) | 15693.730833 |
c5de42ae-d2bd-4ca7-ad94-17389aa7a3f4 | POLYGON ((-123.08882 49.26278, -123.08898 49.2... | commercial | 5.626834 | POINT (-123.0889 49.2627) | 15149.172500 |
d7dd54ee-9b04-4e35-82ae-0704e40d0631 | POLYGON ((-123.08856 49.2629, -123.08856 49.26... | commercial | 7.098448 | POINT (-123.08849 49.26276) | 8338.610833 |
caae285e-45ca-4c1e-8adb-38de4e11d28d | POLYGON ((-123.07726 49.26484, -123.07727 49.2... | commercial | 8.902528 | POINT (-123.07709 49.2646) | 17367.416667 |
c2a8a5cc-19ed-41de-93cf-e55b63264aa9 | POLYGON ((-123.0831 49.2594, -123.0831 49.2595... | commercial | 5.506591 | POINT (-123.08316 49.25949) | 16675.399167 |
405 rows × 5 columns
Let's do the spit based on the mean monthly sales. We will reduce the number of bins to decrease the actual ratio difference.
train_sales_gdf, test_sales_gdf = train_test_spatial_split(
input_gdf=stores,
parent_h3_resolution=8,
geometry_column="centroid",
target_column="mean_monthly_sales",
n_bins=5,
test_size=0.2,
random_state=42,
)
Summary of the split: Train: 23 H3 cells (323 points) Test: 13 H3 cells (82 points) Expected ratios: {'train': 0.8, 'validation': 0, 'test': 0.2} Actual ratios: {'train': 0.798, 'test': 0.202} Actual ratios difference: {'train': 0.002, 'test': -0.002} bucket train_ratio test_ratio train_ratio_difference \ 0 0 0.80247 0.19753 -0.00247 1 1 0.77778 0.22222 0.02222 2 2 0.77778 0.22222 0.02222 3 3 0.80247 0.19753 -0.00247 4 4 0.82716 0.17284 -0.02716 test_ratio_difference train_points test_points 0 0.00247 65 16 1 -0.02222 63 18 2 -0.02222 63 18 3 0.00247 65 16 4 0.02716 67 14
Here is the distribution between two sets.
ax = sns.kdeplot(
data=train_sales_gdf,
x="mean_monthly_sales",
fill=True,
label="train",
)
sns.kdeplot(
data=test_sales_gdf,
x="mean_monthly_sales",
fill=True,
label="test",
ax=ax,
)
ax.legend()
ax.set_title("Mean monthly sales distribution per split")
plt.show()
train_covering_h3_cells = h3_to_geoseries(
shapely_geometry_to_h3(train_sales_gdf["centroid"], 8)
)
test_covering_h3_cells = h3_to_geoseries(
shapely_geometry_to_h3(test_sales_gdf["centroid"], 8)
)
with plt.rc_context({"hatch.linewidth": 0.3}):
ax = stores.plot(
# gpd.pd.qcut(buildings["height"], 7),
"mean_monthly_sales",
figsize=(15, 10),
# cmap="Spectral_r",
cmap="RdYlBu_r",
legend=True,
legend_kwds=dict(
shrink=0.9,
orientation="horizontal",
pad=0.01,
label="Mean monthly sales",
aspect=60,
fraction=0.03,
),
zorder=2,
)
stores.exterior.plot(ax=ax, color="black", linewidth=0.3, zorder=2)
train_covering_h3_cells.boundary.plot(color="black", ax=ax, linewidth=0.3, zorder=1)
test_covering_h3_cells.plot(
ax=ax,
linewidth=0.3,
color=(0, 0, 0, 0),
edgecolor="black",
hatch="/",
zorder=1,
)
ax2 = ax.twinx()
ax2.get_yaxis().set_visible(False)
ax2.legend(
handles=[
Patch(edgecolor="black", facecolor=(0, 0, 0, 0)),
Patch(edgecolor="black", facecolor=(0, 0, 0, 0), hatch="///"),
],
labels=["Train", "Test"],
loc=2,
)
ax.set_title("Vancouver - commercial buildings - sales numerical split")
cx.add_basemap(
ax,
source=cx.providers.CartoDB.VoyagerNoLabels,
crs=4326,
zoom=15,
alpha=0.8,
)
ax.set_axis_off()
stores_bounds = stores.total_bounds
ax.set_xlim(stores_bounds[0] - 0.01, stores_bounds[2] + 0.01)
ax.set_ylim(stores_bounds[1] - 0.01, stores_bounds[3] + 0.01)
plt.show()
Now we can select transaction data based on location IDs.
train_store_sales = df_sales.loc[train_sales_gdf.index]
test_store_sales = df_sales.loc[test_sales_gdf.index]
print(len(train_store_sales), len(test_store_sales))
train_store_sales
3876 984
month | sales | |
---|---|---|
id | ||
4249f3e3-360b-42e2-8a0a-9f324b8f4c7d | 2024-10 | 9634.91 |
4249f3e3-360b-42e2-8a0a-9f324b8f4c7d | 2024-11 | 10540.24 |
4249f3e3-360b-42e2-8a0a-9f324b8f4c7d | 2024-12 | 11119.57 |
4249f3e3-360b-42e2-8a0a-9f324b8f4c7d | 2025-01 | 10184.59 |
4249f3e3-360b-42e2-8a0a-9f324b8f4c7d | 2025-02 | 9119.38 |
... | ... | ... |
c2a8a5cc-19ed-41de-93cf-e55b63264aa9 | 2025-05 | 13708.08 |
c2a8a5cc-19ed-41de-93cf-e55b63264aa9 | 2025-06 | 13281.32 |
c2a8a5cc-19ed-41de-93cf-e55b63264aa9 | 2025-07 | 13734.69 |
c2a8a5cc-19ed-41de-93cf-e55b63264aa9 | 2025-08 | 14667.70 |
c2a8a5cc-19ed-41de-93cf-e55b63264aa9 | 2025-09 | 16537.98 |
3876 rows × 2 columns
fig, ax = plt.subplots(figsize=(12, 6))
sns.lineplot(train_store_sales, x="month", y="sales", legend=True, ax=ax, label="train")
sns.lineplot(test_store_sales, x="month", y="sales", legend=True, ax=ax, label="test")
ax.set_title("Monthly sales data per store")
ax.set_ylabel("Sales")
ax.set_xlabel("Month")
plt.tight_layout()
plt.show()
ax = sns.kdeplot(
data=train_store_sales,
x="sales",
fill=True,
label="train",
)
sns.kdeplot(
data=test_store_sales,
x="sales",
fill=True,
label="test",
ax=ax,
)
ax.legend()
ax.set_title("Sales distribution per split")
plt.show()