Contextual count embedder
from srai.embedders import ContextualCountEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders.osm_loaders import OSMPbfLoader
from srai.neighbourhoods import H3Neighbourhood
from srai.plotting.folium_wrapper import plot_numeric_data, plot_regions
from srai.regionalizers import H3Regionalizer
Data preparation¶
In order to use ContextualCountEmbedder we need to prepare some data.
Namely we need: regions_gdf, features_gdf, and joint_gdf.
These are the outputs of Regionalizers, Loaders and Joiners respectively.
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Lisboa, PT")
plot_regions(area_gdf)
Regionalize the area using an H3Regionalizer¶
regionalizer = H3Regionalizer(resolution=9, buffer=True)
regions_gdf = regionalizer.transform(area_gdf)
regions_gdf
| geometry | |
|---|---|
| region_id | |
| 89393362e1bffff | POLYGON ((-9.20856 38.72573, -9.21026 38.72439... |
| 893933629dbffff | POLYGON ((-9.1632 38.73789, -9.16491 38.73655,... |
| 89393375877ffff | POLYGON ((-9.12708 38.77806, -9.12878 38.77672... |
| 8939337526fffff | POLYGON ((-9.20977 38.75539, -9.21147 38.75405... |
| 89393375e8bffff | POLYGON ((-9.16701 38.78197, -9.16871 38.78063... |
| ... | ... |
| 89393367437ffff | POLYGON ((-9.10796 38.74605, -9.10966 38.74471... |
| 8939336054bffff | POLYGON ((-9.21941 38.69363, -9.22112 38.69228... |
| 893933628bbffff | POLYGON ((-9.17419 38.73866, -9.17589 38.73732... |
| 8939337534fffff | POLYGON ((-9.19258 38.76186, -9.19428 38.76052... |
| 89393375c57ffff | POLYGON ((-9.15863 38.79562, -9.16033 38.79428... |
830 rows × 1 columns
Download some objects from OpenStreetMap¶
You can use both OsmTagsFilter and GroupedOsmTagsFilter filters. In this example, a predefined GroupedOsmTagsFilter filter BASE_OSM_GROUPS_FILTER is used.
from srai.loaders.osm_loaders.filters import BASE_OSM_GROUPS_FILTER
loader = OSMPbfLoader()
features_gdf = loader.load(area_gdf, tags=BASE_OSM_GROUPS_FILTER)
features_gdf
Finished operation in 0:00:15
| geometry | aerialway | airports | buildings | culture_art_entertainment | education | emergency | finances | greenery | healthcare | historic | leisure | other | shops | sport | sustenance | tourism | transportation | water | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| feature_id | |||||||||||||||||||
| way/975228400 | LINESTRING (-9.3237 38.67323, -9.32366 38.6733... | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | natural=coastline |
| way/971986160 | POLYGON ((-9.22989 38.69505, -9.23037 38.69537... | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | natural=beach |
| node/8984389091 | POINT (-9.2294 38.6961) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | tourism=artwork | None | None |
| way/1431592332 | POLYGON ((-9.22573 38.69592, -9.22576 38.69588... | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=parking | None |
| way/1431592333 | POLYGON ((-9.22581 38.69583, -9.22584 38.69579... | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=parking | None |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| node/6555312687 | POINT (-9.13278 38.70649) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
| node/6354971911 | POINT (-9.13343 38.70629) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=ferry_terminal | None |
| way/935203997 | POLYGON ((-9.13163 38.70793, -9.13167 38.70791... | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=bar | None | None | None |
| way/935203993 | POLYGON ((-9.13097 38.70821, -9.131 38.70819, ... | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=bar | None | None | None |
| way/1226716239 | POLYGON ((-9.13234 38.70734, -9.13221 38.70716... | None | None | None | None | None | None | None | None | None | None | leisure=marina | None | None | None | None | None | None | None |
33567 rows × 19 columns
Join the objects with the regions they belong to¶
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, features_gdf)
joint_gdf
| region_id | feature_id |
|---|---|
| 89393360097ffff | way/975228400 |
| 8939336054bffff | way/975228400 |
| 8939336042fffff | way/975228400 |
| 8939336055bffff | way/975228400 |
| 89393360423ffff | way/975228400 |
| ... | ... |
| 89393360db3ffff | node/6354971911 |
| way/935203997 | |
| 8939336764fffff | way/935203993 |
| 89393360db3ffff | way/1226716239 |
| 8939336764fffff | way/1226716239 |
37202 rows × 0 columns
Embed using features existing in data¶
ContextualCountEmbedder extends capabilities of basic CountEmbedder by incorporating the neighbourhood of embedded region. In this example we will use the H3Neighbourhood.
h3n = H3Neighbourhood()
Squashed vector version (default)¶
Embedder will return vector of the same length as CountEmbedder, but will sum averaged values from the neighbourhoods diminished by the neighbour distance squared.
cce = ContextualCountEmbedder(
neighbourhood=h3n, neighbourhood_distance=10, concatenate_vectors=False
)
embeddings = cce.transform(regions_gdf, features_gdf, joint_gdf)
embeddings
| aerialway | airports | buildings | culture_art_entertainment | education | emergency | finances | greenery | healthcare | historic | leisure | other | shops | sport | sustenance | tourism | transportation | water | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| region_id | ||||||||||||||||||
| 89393362e1bffff | 0.000000 | 0.000000 | 1.142153 | 0.026911 | 0.155364 | 0.004613 | 0.022340 | 0.784345 | 0.081835 | 0.077059 | 2.660051 | 0.137571 | 0.418138 | 0.678845 | 0.351375 | 3.784888 | 3.401824 | 1.080450 |
| 893933629dbffff | 0.013750 | 0.000138 | 6.744283 | 1.259504 | 1.946478 | 0.041948 | 1.961811 | 41.035235 | 1.035692 | 1.501397 | 14.349762 | 1.552464 | 4.255394 | 1.670401 | 6.645188 | 7.964320 | 33.721066 | 1.296090 |
| 89393375877ffff | 0.001461 | 1.363276 | 5.562752 | 0.056164 | 1.224696 | 0.141724 | 0.132622 | 7.092292 | 0.134551 | 0.081780 | 10.303041 | 0.144550 | 0.892832 | 0.367130 | 0.605388 | 0.313189 | 44.626591 | 0.042110 |
| 89393360553ffff | 0.000000 | 0.000000 | 1.894545 | 0.147459 | 3.551903 | 0.003090 | 0.124213 | 5.096767 | 2.345634 | 0.263859 | 22.997243 | 0.323470 | 6.563625 | 2.512733 | 7.513019 | 1.050174 | 12.515799 | 0.511238 |
| 89393360c97ffff | 0.000000 | 0.000000 | 12.034435 | 3.199125 | 0.481008 | 0.010944 | 0.360787 | 2.046259 | 0.484064 | 1.175769 | 3.711816 | 0.707314 | 4.377218 | 0.389654 | 4.325226 | 1.959565 | 17.563975 | 2.863723 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 89393375813ffff | 0.000000 | 1.512845 | 0.220690 | 0.013268 | 0.148713 | 0.022485 | 0.034724 | 8.382283 | 0.068070 | 0.047117 | 0.802296 | 0.123413 | 0.389784 | 0.228757 | 0.230918 | 0.163358 | 2.513358 | 0.044094 |
| 89393362e4fffff | 0.000000 | 0.000000 | 2.386850 | 0.081764 | 0.476771 | 0.006441 | 0.054610 | 6.578803 | 1.238248 | 0.180590 | 12.266482 | 0.145571 | 3.003110 | 7.021042 | 0.781236 | 2.032564 | 6.903290 | 0.269339 |
| 893933676b7ffff | 0.000000 | 0.000000 | 0.746027 | 0.215421 | 0.476896 | 0.009978 | 0.303919 | 2.755259 | 0.551066 | 0.485270 | 1.665136 | 1.498212 | 4.638268 | 0.414667 | 3.388278 | 1.656918 | 13.946540 | 0.262393 |
| 89393375a67ffff | 0.000910 | 0.018551 | 0.493258 | 0.082255 | 6.739237 | 0.010060 | 0.221766 | 7.751106 | 1.380306 | 0.102567 | 4.763932 | 0.293021 | 2.603673 | 0.964559 | 1.659336 | 1.974941 | 22.776253 | 0.089686 |
| 893933628bbffff | 0.028935 | 0.000000 | 1.062638 | 0.106076 | 1.436313 | 0.025752 | 0.249317 | 7.248780 | 0.552109 | 1.499987 | 6.265162 | 0.340900 | 2.061953 | 0.608225 | 1.386669 | 8.028384 | 7.043493 | 0.298673 |
830 rows × 18 columns
Concatenated vector version¶
Embedder will return vector of length n * distance where n is number of features from the CountEmbedder and distance is number of neighbourhoods analysed.
Each feature will be postfixed with _n string, where n is the current distance. Values are averaged from all neighbours.
wide_cce = ContextualCountEmbedder(
neighbourhood=h3n, neighbourhood_distance=10, concatenate_vectors=True
)
wide_embeddings = wide_cce.transform(regions_gdf, features_gdf, joint_gdf)
wide_embeddings
| aerialway_0 | airports_0 | buildings_0 | culture_art_entertainment_0 | education_0 | emergency_0 | finances_0 | greenery_0 | healthcare_0 | historic_0 | ... | healthcare_10 | historic_10 | leisure_10 | other_10 | shops_10 | sport_10 | sustenance_10 | tourism_10 | transportation_10 | water_10 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| region_id | |||||||||||||||||||||
| 89393375877ffff | 0.0 | 1.0 | 5.0 | 0.0 | 1.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | ... | 0.976744 | 0.418605 | 3.651163 | 0.348837 | 4.790698 | 1.139535 | 2.813953 | 1.813953 | 13.860465 | 0.767442 |
| 893933628d3ffff | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | ... | 1.790698 | 1.581395 | 6.790698 | 1.093023 | 8.930233 | 2.744186 | 6.418605 | 3.534884 | 17.139535 | 0.813953 |
| 893933674a7ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | ... | 1.545455 | 0.515152 | 3.303030 | 0.545455 | 10.818182 | 0.424242 | 7.484848 | 2.181818 | 20.575758 | 0.181818 |
| 89393375b17ffff | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 16.0 | 4.0 | 0.0 | ... | 0.966667 | 0.200000 | 4.066667 | 0.566667 | 4.733333 | 2.066667 | 2.500000 | 0.566667 | 14.233333 | 0.233333 |
| 89393362873ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | ... | 1.375000 | 1.729167 | 5.437500 | 1.000000 | 14.375000 | 1.958333 | 12.229167 | 4.041667 | 21.291667 | 0.666667 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 893933674c3ffff | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 14.0 | 1.0 | 1.0 | ... | 1.500000 | 3.130435 | 5.173913 | 1.369565 | 14.369565 | 2.282609 | 14.239130 | 5.717391 | 19.630435 | 0.391304 |
| 89393375c73ffff | 0.0 | 0.0 | 2.0 | 0.0 | 3.0 | 0.0 | 0.0 | 4.0 | 0.0 | 0.0 | ... | 0.500000 | 0.272727 | 3.909091 | 0.136364 | 3.000000 | 0.909091 | 1.954545 | 0.454545 | 11.863636 | 0.363636 |
| 893933758cfffff | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | ... | 0.636364 | 0.181818 | 3.727273 | 0.454545 | 2.575758 | 2.181818 | 1.878788 | 1.030303 | 16.424242 | 0.333333 |
| 89393362b07ffff | 0.0 | 0.0 | 2.0 | 1.0 | 2.0 | 0.0 | 2.0 | 3.0 | 2.0 | 4.0 | ... | 0.942857 | 0.571429 | 3.600000 | 0.571429 | 6.285714 | 0.914286 | 3.685714 | 2.200000 | 18.914286 | 0.571429 |
| 89393367437ffff | 0.0 | 0.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | 8.0 | 0.0 | 0.0 | ... | 1.875000 | 0.656250 | 5.000000 | 1.125000 | 15.812500 | 1.062500 | 10.593750 | 3.531250 | 27.218750 | 0.375000 |
830 rows × 198 columns
Plotting example features¶
plot_numeric_data(regions_gdf, "leisure", embeddings)
plot_numeric_data(regions_gdf, "transportation", embeddings)
Other types of aggregations¶
By default, the ContextualCountEmbedder averages the counts from neighbours. This aggregation_function can be changed to one of: median, sum, min, max.
It's best to combine it with the wide format (concatenate_vectors=True).
sum_cce = ContextualCountEmbedder(
neighbourhood=h3n,
neighbourhood_distance=10,
concatenate_vectors=True,
aggregation_function="sum",
)
sum_embeddings = sum_cce.transform(regions_gdf, features_gdf, joint_gdf)
sum_embeddings
| aerialway_0 | airports_0 | buildings_0 | culture_art_entertainment_0 | education_0 | emergency_0 | finances_0 | greenery_0 | healthcare_0 | historic_0 | ... | healthcare_10 | historic_10 | leisure_10 | other_10 | shops_10 | sport_10 | sustenance_10 | tourism_10 | transportation_10 | water_10 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| region_id | |||||||||||||||||||||
| 89393375877ffff | 0.0 | 1.0 | 5.0 | 0.0 | 1.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | ... | 42.0 | 18.0 | 157.0 | 15.0 | 206.0 | 49.0 | 121.0 | 78.0 | 596.0 | 33.0 |
| 89393375e8bffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 4.0 | 0.0 | 0.0 | ... | 30.0 | 21.0 | 140.0 | 16.0 | 241.0 | 73.0 | 132.0 | 51.0 | 387.0 | 6.0 |
| 89393362b1bffff | 0.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | 1.0 | 2.0 | 4.0 | 3.0 | ... | 39.0 | 37.0 | 164.0 | 38.0 | 317.0 | 41.0 | 200.0 | 90.0 | 781.0 | 22.0 |
| 893933628d3ffff | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | ... | 77.0 | 68.0 | 292.0 | 47.0 | 384.0 | 118.0 | 276.0 | 152.0 | 737.0 | 35.0 |
| 893933674a7ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | ... | 51.0 | 17.0 | 109.0 | 18.0 | 357.0 | 14.0 | 247.0 | 72.0 | 679.0 | 6.0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 89393375923ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | ... | 11.0 | 0.0 | 44.0 | 11.0 | 83.0 | 20.0 | 26.0 | 16.0 | 125.0 | 3.0 |
| 8939337533bffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 15.0 | 0.0 | 0.0 | ... | 23.0 | 17.0 | 89.0 | 14.0 | 34.0 | 24.0 | 45.0 | 40.0 | 353.0 | 6.0 |
| 89393362e97ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | ... | 13.0 | 27.0 | 209.0 | 10.0 | 61.0 | 33.0 | 51.0 | 60.0 | 316.0 | 24.0 |
| 893933758cfffff | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | ... | 21.0 | 6.0 | 123.0 | 15.0 | 85.0 | 72.0 | 62.0 | 34.0 | 542.0 | 11.0 |
| 8939336054bffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 14.0 | 0.0 | 2.0 | ... | 8.0 | 4.0 | 38.0 | 4.0 | 51.0 | 4.0 | 53.0 | 30.0 | 69.0 | 8.0 |
830 rows × 198 columns
plot_numeric_data(regions_gdf, "tourism_8", sum_embeddings)