Contextual count embedder
from srai.embedders import ContextualCountEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders.osm_loaders import OSMPbfLoader
from srai.neighbourhoods import H3Neighbourhood
from srai.plotting.folium_wrapper import plot_numeric_data, plot_regions
from srai.regionalizers import H3Regionalizer
Data preparation¶
In order to use ContextualCountEmbedder
we need to prepare some data.
Namely we need: regions_gdf
, features_gdf
, and joint_gdf
.
These are the outputs of Regionalizers, Loaders and Joiners respectively.
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Lisboa, PT")
plot_regions(area_gdf)
Regionalize the area using an H3Regionalizer¶
regionalizer = H3Regionalizer(resolution=9, buffer=True)
regions_gdf = regionalizer.transform(area_gdf)
regions_gdf
geometry | |
---|---|
region_id | |
89393362b8bffff | POLYGON ((-9.16168 38.72026, -9.16339 38.71892... |
89393367083ffff | POLYGON ((-9.10427 38.73484, -9.10598 38.7335,... |
89393362bcbffff | POLYGON ((-9.16788 38.71303, -9.16959 38.71169... |
89393362b97ffff | POLYGON ((-9.15951 38.72668, -9.16121 38.72534... |
89393375eafffff | POLYGON ((-9.15199 38.78201, -9.1537 38.78067,... |
... | ... |
89393362d0fffff | POLYGON ((-9.16581 38.75232, -9.16752 38.75098... |
89393362a97ffff | POLYGON ((-9.18767 38.72099, -9.18938 38.71965... |
89393375bcfffff | POLYGON ((-9.12045 38.76445, -9.12216 38.76311... |
8939337595bffff | POLYGON ((-9.115 38.7805, -9.11671 38.77916, -... |
89393375303ffff | POLYGON ((-9.18932 38.77149, -9.19102 38.77015... |
830 rows × 1 columns
Download some objects from OpenStreetMap¶
You can use both OsmTagsFilter
and GroupedOsmTagsFilter
filters. In this example, a predefined GroupedOsmTagsFilter
filter BASE_OSM_GROUPS_FILTER
is used.
from srai.loaders.osm_loaders.filters import BASE_OSM_GROUPS_FILTER
loader = OSMPbfLoader()
features_gdf = loader.load(area_gdf, tags=BASE_OSM_GROUPS_FILTER)
features_gdf
/opt/hostedtoolcache/Python/3.10.17/x64/lib/python3.10/site-packages/pyogrio/geopandas.py:662: UserWarning: 'crs' was not provided. The output dataset will not have projection information defined and may not be usable in other systems. write(
/opt/hostedtoolcache/Python/3.10.17/x64/lib/python3.10/site-packages/geopandas/array.py:1638: UserWarning: CRS not set for some of the concatenation inputs. Setting output's CRS as WGS 84 (the single non-null crs provided). return GeometryArray(data, crs=_get_common_crs(to_concat))
Finished operation in 0:00:16
geometry | aerialway | airports | buildings | culture_art_entertainment | education | emergency | finances | greenery | healthcare | historic | leisure | other | shops | sport | sustenance | tourism | transportation | water | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
feature_id | |||||||||||||||||||
node/1026408767 | POINT (-9.19757 38.74976) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
node/1087174563 | POINT (-9.20027 38.69665) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | tourism=artwork | None | None |
node/1090129267 | POINT (-9.13516 38.73433) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
node/1104467812 | POINT (-9.20891 38.75276) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
node/1104467908 | POINT (-9.20858 38.75338) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
way/1122286633 | POLYGON ((-9.18873 38.77076, -9.18853 38.77056... | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=parking | None |
way/1122286653 | POLYGON ((-9.18642 38.77196, -9.18634 38.77189... | None | None | None | None | None | None | None | landuse=grass | None | None | None | None | None | None | None | None | None | None |
way/1122286721 | POLYGON ((-9.19022 38.76877, -9.19029 38.76871... | None | None | None | None | None | None | None | natural=grassland | None | None | None | None | None | None | None | None | None | None |
way/1123088659 | POLYGON ((-9.16851 38.70583, -9.16844 38.70584... | None | None | None | None | None | None | None | None | None | historic=castle | None | None | None | None | None | None | None | None |
way/1123137162 | POLYGON ((-9.20778 38.69885, -9.20729 38.69898... | None | None | None | None | None | None | None | None | None | None | leisure=pitch | None | None | None | None | None | None | None |
32399 rows × 19 columns
Join the objects with the regions they belong to¶
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, features_gdf)
joint_gdf
region_id | feature_id |
---|---|
89393362c13ffff | node/1026408767 |
89393360193ffff | node/1087174563 |
89393362923ffff | node/1090129267 |
8939337526fffff | node/1104467812 |
node/1104467908 | |
... | ... |
89393375303ffff | way/1122286633 |
89393375317ffff | way/1122286653 |
89393375303ffff | way/1122286721 |
89393362a2fffff | way/1123088659 |
8939336050fffff | way/1123137162 |
35912 rows × 0 columns
Embed using features existing in data¶
ContextualCountEmbedder
extends capabilities of basic CountEmbedder
by incorporating the neighbourhood of embedded region. In this example we will use the H3Neighbourhood
.
h3n = H3Neighbourhood()
Squashed vector version (default)¶
Embedder will return vector of the same length as CountEmbedder
, but will sum averaged values from the neighbourhoods diminished by the neighbour distance squared.
cce = ContextualCountEmbedder(
neighbourhood=h3n, neighbourhood_distance=10, concatenate_vectors=False
)
embeddings = cce.transform(regions_gdf, features_gdf, joint_gdf)
embeddings
aerialway | airports | buildings | culture_art_entertainment | education | emergency | finances | greenery | healthcare | historic | leisure | other | shops | sport | sustenance | tourism | transportation | water | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | ||||||||||||||||||
89393362b8bffff | 0.001134 | 0.000000 | 3.651246 | 1.468443 | 6.203856 | 0.033018 | 0.972666 | 2.354758 | 2.375852 | 2.852459 | 10.037147 | 2.142248 | 19.012022 | 1.460579 | 15.094744 | 4.677305 | 22.419767 | 1.368091 |
89393367083ffff | 0.001064 | 0.000000 | 7.332154 | 0.110823 | 0.269581 | 0.001039 | 0.144910 | 2.575411 | 1.191108 | 0.358733 | 1.951287 | 1.479027 | 4.594223 | 0.327114 | 4.458051 | 0.547507 | 11.754366 | 1.244452 |
89393362bcbffff | 0.000000 | 0.000000 | 1.906946 | 0.198202 | 2.932747 | 0.012166 | 0.568701 | 7.300475 | 1.907012 | 2.338524 | 12.042169 | 1.056559 | 8.763961 | 1.535442 | 7.542049 | 4.240285 | 24.977098 | 0.444601 |
89393362b97ffff | 0.001641 | 0.000000 | 6.092667 | 0.258359 | 4.994278 | 0.082316 | 4.053902 | 2.732793 | 2.505698 | 2.540775 | 5.253876 | 0.964375 | 16.204945 | 1.907603 | 8.770687 | 5.132719 | 33.096069 | 0.247110 |
89393375eafffff | 0.000000 | 0.072681 | 0.304640 | 0.064294 | 1.537870 | 0.010795 | 0.078250 | 7.112021 | 1.261148 | 0.102022 | 6.232368 | 0.096647 | 1.025924 | 2.373134 | 2.608877 | 0.367234 | 17.521706 | 3.232470 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
89393362d0fffff | 0.028935 | 0.003182 | 0.771874 | 0.117181 | 1.789568 | 0.022724 | 1.390828 | 19.016603 | 1.758456 | 0.482746 | 17.349547 | 0.231543 | 5.089790 | 9.618187 | 6.199226 | 1.391343 | 27.706603 | 0.323878 |
89393362a97ffff | 0.001306 | 0.000000 | 0.226389 | 2.067917 | 0.231590 | 0.013051 | 0.057763 | 8.619739 | 0.103793 | 0.151739 | 6.589560 | 0.156698 | 0.709846 | 2.163447 | 2.607648 | 7.106821 | 10.345524 | 1.187639 |
89393375bcfffff | 0.002639 | 0.038735 | 2.661394 | 0.046818 | 1.501687 | 0.023815 | 2.168773 | 28.057468 | 2.324351 | 0.056888 | 4.761977 | 3.342389 | 11.882436 | 0.528960 | 1.890496 | 0.377546 | 13.973610 | 0.076731 |
8939337595bffff | 0.004116 | 0.061044 | 0.434290 | 0.059555 | 2.432357 | 0.019721 | 0.111703 | 5.723602 | 0.328868 | 1.394643 | 7.562593 | 0.119268 | 1.823380 | 1.600465 | 0.884410 | 0.292032 | 20.834515 | 0.067206 |
89393375303ffff | 0.001567 | 0.000000 | 0.523098 | 1.136584 | 1.497083 | 0.007591 | 0.118607 | 18.493812 | 0.304532 | 2.205469 | 7.955324 | 1.179516 | 1.351383 | 1.569012 | 0.737922 | 3.923212 | 18.504649 | 0.122719 |
830 rows × 18 columns
Concatenated vector version¶
Embedder will return vector of length n * distance
where n
is number of features from the CountEmbedder
and distance
is number of neighbourhoods analysed.
Each feature will be postfixed with _n
string, where n
is the current distance. Values are averaged from all neighbours.
wide_cce = ContextualCountEmbedder(
neighbourhood=h3n, neighbourhood_distance=10, concatenate_vectors=True
)
wide_embeddings = wide_cce.transform(regions_gdf, features_gdf, joint_gdf)
wide_embeddings
aerialway_0 | airports_0 | buildings_0 | culture_art_entertainment_0 | education_0 | emergency_0 | finances_0 | greenery_0 | healthcare_0 | historic_0 | ... | healthcare_10 | historic_10 | leisure_10 | other_10 | shops_10 | sport_10 | sustenance_10 | tourism_10 | transportation_10 | water_10 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
89393362b8bffff | 0.0 | 0.0 | 1.0 | 1.0 | 5.0 | 0.0 | 0.0 | 0.0 | 1.0 | 1.0 | ... | 1.279070 | 1.255814 | 3.395349 | 1.255814 | 10.534884 | 1.232558 | 6.581395 | 3.255814 | 19.209302 | 0.581395 |
89393367083ffff | 0.0 | 0.0 | 6.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 1.0 | 0.0 | ... | 1.857143 | 1.928571 | 4.535714 | 1.321429 | 9.821429 | 0.964286 | 9.964286 | 4.464286 | 20.714286 | 1.000000 |
89393362bcbffff | 0.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | 4.0 | 1.0 | 1.0 | ... | 1.205128 | 1.000000 | 4.461538 | 1.179487 | 9.102564 | 1.128205 | 9.692308 | 5.333333 | 18.102564 | 1.076923 |
89393362b97ffff | 0.0 | 0.0 | 2.0 | 0.0 | 4.0 | 0.0 | 3.0 | 0.0 | 1.0 | 1.0 | ... | 1.127660 | 0.617021 | 5.148936 | 0.914894 | 5.702128 | 2.617021 | 4.936170 | 2.404255 | 14.319149 | 0.425532 |
89393375eafffff | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 4.0 | 1.0 | 0.0 | ... | 1.379310 | 0.482759 | 6.241379 | 0.551724 | 4.482759 | 2.413793 | 3.103448 | 1.241379 | 15.862069 | 0.103448 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
89393362d0fffff | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 13.0 | 1.0 | 0.0 | ... | 1.241379 | 0.758621 | 3.413793 | 0.775862 | 8.413793 | 1.379310 | 5.500000 | 2.293103 | 12.706897 | 0.206897 |
89393362a97ffff | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | 0.0 | 7.0 | 0.0 | 0.0 | ... | 1.742857 | 1.771429 | 5.228571 | 0.800000 | 10.085714 | 0.800000 | 10.857143 | 5.771429 | 16.114286 | 0.657143 |
89393375bcfffff | 0.0 | 0.0 | 2.0 | 0.0 | 1.0 | 0.0 | 2.0 | 21.0 | 2.0 | 0.0 | ... | 1.083333 | 0.611111 | 4.166667 | 0.750000 | 9.361111 | 1.305556 | 6.055556 | 1.750000 | 20.694444 | 0.500000 |
8939337595bffff | 0.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | ... | 0.592593 | 0.074074 | 2.592593 | 0.296296 | 1.481481 | 0.777778 | 1.333333 | 1.555556 | 7.703704 | 0.407407 |
89393375303ffff | 0.0 | 0.0 | 0.0 | 1.0 | 1.0 | 0.0 | 0.0 | 12.0 | 0.0 | 2.0 | ... | 0.566667 | 1.033333 | 4.033333 | 0.300000 | 1.966667 | 1.300000 | 1.833333 | 2.966667 | 10.533333 | 1.033333 |
830 rows × 198 columns
Plotting example features¶
plot_numeric_data(regions_gdf, "leisure", embeddings)
plot_numeric_data(regions_gdf, "transportation", embeddings)
Other types of aggregations¶
By default, the ContextualCountEmbedder
averages the counts from neighbours. This aggregation_function
can be changed to one of: median
, sum
, min
, max
.
It's best to combine it with the wide format (concatenate_vectors=True
).
sum_cce = ContextualCountEmbedder(
neighbourhood=h3n,
neighbourhood_distance=10,
concatenate_vectors=True,
aggregation_function="sum",
)
sum_embeddings = sum_cce.transform(regions_gdf, features_gdf, joint_gdf)
sum_embeddings
aerialway_0 | airports_0 | buildings_0 | culture_art_entertainment_0 | education_0 | emergency_0 | finances_0 | greenery_0 | healthcare_0 | historic_0 | ... | healthcare_10 | historic_10 | leisure_10 | other_10 | shops_10 | sport_10 | sustenance_10 | tourism_10 | transportation_10 | water_10 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
89393362b8bffff | 0.0 | 0.0 | 1.0 | 1.0 | 5.0 | 0.0 | 0.0 | 0.0 | 1.0 | 1.0 | ... | 55.0 | 54.0 | 146.0 | 54.0 | 453.0 | 53.0 | 283.0 | 140.0 | 826.0 | 25.0 |
89393367083ffff | 0.0 | 0.0 | 6.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 1.0 | 0.0 | ... | 52.0 | 54.0 | 127.0 | 37.0 | 275.0 | 27.0 | 279.0 | 125.0 | 580.0 | 28.0 |
89393362bcbffff | 0.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | 4.0 | 1.0 | 1.0 | ... | 47.0 | 39.0 | 174.0 | 46.0 | 355.0 | 44.0 | 378.0 | 208.0 | 706.0 | 42.0 |
89393362b97ffff | 0.0 | 0.0 | 2.0 | 0.0 | 4.0 | 0.0 | 3.0 | 0.0 | 1.0 | 1.0 | ... | 53.0 | 29.0 | 242.0 | 43.0 | 268.0 | 123.0 | 232.0 | 113.0 | 673.0 | 20.0 |
89393375eafffff | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 4.0 | 1.0 | 0.0 | ... | 40.0 | 14.0 | 181.0 | 16.0 | 130.0 | 70.0 | 90.0 | 36.0 | 460.0 | 3.0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
89393362d0fffff | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 13.0 | 1.0 | 0.0 | ... | 72.0 | 44.0 | 198.0 | 45.0 | 488.0 | 80.0 | 319.0 | 133.0 | 737.0 | 12.0 |
89393362a97ffff | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | 0.0 | 7.0 | 0.0 | 0.0 | ... | 61.0 | 62.0 | 183.0 | 28.0 | 353.0 | 28.0 | 380.0 | 202.0 | 564.0 | 23.0 |
89393375bcfffff | 0.0 | 0.0 | 2.0 | 0.0 | 1.0 | 0.0 | 2.0 | 21.0 | 2.0 | 0.0 | ... | 39.0 | 22.0 | 150.0 | 27.0 | 337.0 | 47.0 | 218.0 | 63.0 | 745.0 | 18.0 |
8939337595bffff | 0.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | ... | 16.0 | 2.0 | 70.0 | 8.0 | 40.0 | 21.0 | 36.0 | 42.0 | 208.0 | 11.0 |
89393375303ffff | 0.0 | 0.0 | 0.0 | 1.0 | 1.0 | 0.0 | 0.0 | 12.0 | 0.0 | 2.0 | ... | 17.0 | 31.0 | 121.0 | 9.0 | 59.0 | 39.0 | 55.0 | 89.0 | 316.0 | 31.0 |
830 rows × 198 columns
plot_numeric_data(regions_gdf, "tourism_8", sum_embeddings)