Contextual count embedder
from srai.embedders import ContextualCountEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders.osm_loaders import OSMPbfLoader
from srai.neighbourhoods import H3Neighbourhood
from srai.plotting.folium_wrapper import plot_numeric_data, plot_regions
from srai.regionalizers import H3Regionalizer
Data preparation¶
In order to use ContextualCountEmbedder we need to prepare some data.
Namely we need: regions_gdf, features_gdf, and joint_gdf.
These are the outputs of Regionalizers, Loaders and Joiners respectively.
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Lisboa, PT")
plot_regions(area_gdf)
Regionalize the area using an H3Regionalizer¶
regionalizer = H3Regionalizer(resolution=9, buffer=True)
regions_gdf = regionalizer.transform(area_gdf)
regions_gdf
| geometry | |
|---|---|
| region_id | |
| 89393362d87ffff | POLYGON ((-9.16951 38.76353, -9.17122 38.76219... |
| 89393375907ffff | POLYGON ((-9.10074 38.78934, -9.10245 38.788, ... |
| 89393375817ffff | POLYGON ((-9.13186 38.78607, -9.13357 38.78473... |
| 89393375a53ffff | POLYGON ((-9.14971 38.75557, -9.15142 38.75423... |
| 89393362eabffff | POLYGON ((-9.19942 38.73057, -9.20113 38.72923... |
| ... | ... |
| 893933605b3ffff | POLYGON ((-9.20888 38.71371, -9.21058 38.71236... |
| 893933628a3ffff | POLYGON ((-9.17016 38.73948, -9.17187 38.73814... |
| 89393362d73ffff | POLYGON ((-9.1669 38.74911, -9.16861 38.74777,... |
| 893933674b7ffff | POLYGON ((-9.11274 38.75405, -9.11444 38.75271... |
| 893933759a7ffff | POLYGON ((-9.09454 38.79657, -9.09624 38.79523... |
830 rows × 1 columns
Download some objects from OpenStreetMap¶
You can use both OsmTagsFilter and GroupedOsmTagsFilter filters. In this example, a predefined GroupedOsmTagsFilter filter BASE_OSM_GROUPS_FILTER is used.
from srai.loaders.osm_loaders.filters import BASE_OSM_GROUPS_FILTER
loader = OSMPbfLoader()
features_gdf = loader.load(area_gdf, tags=BASE_OSM_GROUPS_FILTER)
features_gdf
0%| | 0.00/52.2M [00:00<?, ?B/s]
0%| | 14.3k/52.2M [00:00<06:05, 143kB/s]
0%| | 38.9k/52.2M [00:00<04:19, 201kB/s]
0%| | 93.2k/52.2M [00:00<02:28, 351kB/s]
0%|▏ | 208k/52.2M [00:00<01:19, 654kB/s]
1%|▎ | 437k/52.2M [00:00<00:42, 1.22MB/s]
2%|▋ | 892k/52.2M [00:00<00:22, 2.31MB/s]
3%|█▎ | 1.80M/52.2M [00:00<00:11, 4.44MB/s]
7%|██▌ | 3.60M/52.2M [00:00<00:05, 8.62MB/s]
14%|█████ | 7.21M/52.2M [00:00<00:02, 16.9MB/s]
20%|███████▎ | 10.3M/52.2M [00:01<00:02, 20.4MB/s]
24%|█████████ | 12.8M/52.2M [00:01<00:01, 21.4MB/s]
30%|███████████▏ | 15.9M/52.2M [00:01<00:01, 23.8MB/s]
36%|█████████████▍ | 18.9M/52.2M [00:01<00:01, 25.2MB/s]
41%|███████████████▎ | 21.7M/52.2M [00:01<00:01, 25.6MB/s]
47%|█████████████████▍ | 24.6M/52.2M [00:01<00:01, 26.4MB/s]
52%|███████████████████▎ | 27.2M/52.2M [00:01<00:00, 26.1MB/s]
57%|█████████████████████▎ | 30.0M/52.2M [00:01<00:00, 26.5MB/s]
63%|███████████████████████▍ | 33.1M/52.2M [00:01<00:00, 27.6MB/s]
70%|█████████████████████████▊ | 36.4M/52.2M [00:01<00:00, 28.8MB/s]
76%|███████████████████████████▉ | 39.4M/52.2M [00:02<00:00, 29.1MB/s]
82%|██████████████████████████████▏ | 42.5M/52.2M [00:02<00:00, 29.5MB/s]
89%|████████████████████████████████▉ | 46.4M/52.2M [00:02<00:00, 28.1MB/s]
95%|███████████████████████████████████▏ | 49.6M/52.2M [00:02<00:00, 29.1MB/s]
0%| | 0.00/52.2M [00:00<?, ?B/s]
100%|█████████████████████████████████████| 52.2M/52.2M [00:00<00:00, 81.1GB/s]
Finished operation in 0:00:16
| geometry | aerialway | airports | buildings | culture_art_entertainment | education | emergency | finances | greenery | healthcare | historic | leisure | other | shops | sport | sustenance | tourism | transportation | water | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| feature_id | |||||||||||||||||||
| node/1662570350 | POINT (-9.13652 38.75509) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=parking | None |
| node/1662570365 | POINT (-9.13501 38.75859) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=parking | None |
| node/1662570400 | POINT (-9.13817 38.75503) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=parking | None |
| node/1662570428 | POINT (-9.13233 38.75529) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=parking | None |
| node/1662570468 | POINT (-9.13822 38.75534) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=restaurant | None | None | None |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| relation/12990755 | POLYGON ((-9.1692 38.70779, -9.16909 38.70789,... | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=parking | None |
| relation/13884079 | POLYGON ((-9.13321 38.74216, -9.13322 38.74213... | None | None | None | None | None | None | None | landuse=grass | None | None | None | None | None | None | None | None | None | None |
| relation/8306287 | POLYGON ((-9.14249 38.71874, -9.14251 38.71877... | None | None | None | None | None | None | None | None | None | historic=castle | None | None | None | None | None | None | None | None |
| relation/11318814 | POLYGON ((-9.21057 38.697, -9.2103 38.69651, -... | None | None | None | None | None | None | None | None | None | None | leisure=garden | None | None | None | None | None | None | None |
| relation/15475183 | POLYGON ((-9.20112 38.69844, -9.20106 38.69825... | None | None | None | None | None | None | None | None | None | historic=castle | None | None | None | None | None | tourism=attraction | None | None |
33676 rows × 19 columns
Join the objects with the regions they belong to¶
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, features_gdf)
joint_gdf
| region_id | feature_id |
|---|---|
| 89393375a0fffff | node/1662570350 |
| 89393375a07ffff | node/1662570365 |
| 89393375a0fffff | node/1662570400 |
| 89393375a3bffff | node/1662570428 |
| 89393375a0fffff | node/1662570468 |
| ... | ... |
| 89393362a27ffff | relation/12990755 |
| 893933674dbffff | relation/13884079 |
| 893933676dbffff | relation/8306287 |
| 89393360573ffff | relation/11318814 |
| 8939336052bffff | relation/15475183 |
37320 rows × 0 columns
Embed using features existing in data¶
ContextualCountEmbedder extends capabilities of basic CountEmbedder by incorporating the neighbourhood of embedded region. In this example we will use the H3Neighbourhood.
h3n = H3Neighbourhood()
Squashed vector version (default)¶
Embedder will return vector of the same length as CountEmbedder, but will sum averaged values from the neighbourhoods diminished by the neighbour distance squared.
cce = ContextualCountEmbedder(
neighbourhood=h3n, neighbourhood_distance=10, concatenate_vectors=False
)
embeddings = cce.transform(regions_gdf, features_gdf, joint_gdf)
embeddings
| aerialway | airports | buildings | culture_art_entertainment | education | emergency | finances | greenery | healthcare | historic | leisure | other | shops | sport | sustenance | tourism | transportation | water | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| region_id | ||||||||||||||||||
| 89393362d87ffff | 0.003553 | 0.008071 | 4.896608 | 0.034474 | 3.052658 | 0.017737 | 2.409247 | 4.622333 | 1.542947 | 0.207747 | 7.603585 | 1.274656 | 29.458276 | 5.040806 | 13.548876 | 1.588538 | 30.338425 | 0.412871 |
| 89393375907ffff | 0.010166 | 0.019579 | 0.499663 | 0.015315 | 0.373406 | 0.031400 | 0.095192 | 4.896985 | 0.184408 | 0.051209 | 1.872668 | 0.218977 | 1.960812 | 0.720309 | 0.581495 | 0.444326 | 10.759007 | 0.623011 |
| 89393375817ffff | 0.000000 | 2.544783 | 0.214371 | 0.014293 | 0.136029 | 0.025118 | 0.037967 | 5.859158 | 0.063102 | 0.038612 | 0.867616 | 0.150299 | 0.384585 | 0.197346 | 0.238165 | 0.116607 | 4.907818 | 0.032654 |
| 89393375a53ffff | 0.002250 | 0.021911 | 2.315896 | 0.210674 | 2.408794 | 0.070223 | 0.528682 | 6.609459 | 4.862885 | 0.605286 | 8.865136 | 2.343159 | 7.831306 | 1.982779 | 6.635632 | 1.822087 | 24.737613 | 0.234596 |
| 89393362eabffff | 0.001193 | 0.000000 | 0.160942 | 0.024593 | 0.138521 | 0.005053 | 0.032375 | 0.682701 | 0.092859 | 0.098159 | 5.030859 | 0.188682 | 0.497836 | 3.212923 | 0.316621 | 4.119140 | 2.154397 | 0.070713 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 893933605b3ffff | 0.000000 | 0.000000 | 1.804645 | 0.155954 | 0.708589 | 0.006456 | 0.142892 | 10.418686 | 0.423401 | 0.203517 | 13.158872 | 0.280481 | 8.504021 | 2.627477 | 4.083113 | 0.960359 | 17.122806 | 0.573305 |
| 893933628a3ffff | 0.034722 | 0.000000 | 7.265372 | 0.136688 | 0.604754 | 0.060365 | 1.461136 | 12.444069 | 2.758843 | 0.471496 | 4.154793 | 0.444337 | 2.912976 | 0.702100 | 4.158398 | 1.990955 | 23.101428 | 0.387400 |
| 89393362d73ffff | 0.111111 | 0.001687 | 2.043277 | 0.223947 | 2.590042 | 0.022025 | 0.512628 | 14.147280 | 0.851430 | 0.654473 | 10.912792 | 1.195939 | 4.527630 | 4.453599 | 11.170588 | 3.984539 | 24.785419 | 0.633617 |
| 893933674b7ffff | 0.003180 | 0.010031 | 1.739157 | 0.066453 | 2.506673 | 0.008225 | 0.118749 | 8.113353 | 1.311530 | 0.043330 | 7.719367 | 0.302407 | 1.495735 | 3.725419 | 0.618298 | 0.281479 | 26.120669 | 0.067945 |
| 893933759a7ffff | 0.006564 | 0.005209 | 0.439975 | 0.008578 | 1.268216 | 0.025225 | 0.069378 | 3.156037 | 0.138799 | 0.027171 | 7.931241 | 0.079161 | 0.791880 | 4.708692 | 0.527697 | 0.252690 | 9.372814 | 0.576183 |
830 rows × 18 columns
Concatenated vector version¶
Embedder will return vector of length n * distance where n is number of features from the CountEmbedder and distance is number of neighbourhoods analysed.
Each feature will be postfixed with _n string, where n is the current distance. Values are averaged from all neighbours.
wide_cce = ContextualCountEmbedder(
neighbourhood=h3n, neighbourhood_distance=10, concatenate_vectors=True
)
wide_embeddings = wide_cce.transform(regions_gdf, features_gdf, joint_gdf)
wide_embeddings
| aerialway_0 | airports_0 | buildings_0 | culture_art_entertainment_0 | education_0 | emergency_0 | finances_0 | greenery_0 | healthcare_0 | historic_0 | ... | healthcare_10 | historic_10 | leisure_10 | other_10 | shops_10 | sport_10 | sustenance_10 | tourism_10 | transportation_10 | water_10 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| region_id | |||||||||||||||||||||
| 89393362d87ffff | 0.0 | 0.0 | 4.0 | 0.0 | 2.0 | 0.0 | 2.0 | 1.0 | 1.0 | 0.0 | ... | 1.095238 | 0.571429 | 4.523810 | 0.571429 | 6.166667 | 1.023810 | 3.857143 | 1.642857 | 13.785714 | 0.119048 |
| 89393375907ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | ... | 0.777778 | 0.055556 | 1.222222 | 0.444444 | 3.722222 | 0.388889 | 2.277778 | 0.500000 | 9.611111 | 0.277778 |
| 89393375817ffff | 0.0 | 2.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | ... | 1.305556 | 0.250000 | 5.666667 | 0.472222 | 5.833333 | 2.361111 | 3.527778 | 0.833333 | 15.250000 | 0.500000 |
| 89393375a53ffff | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 3.0 | 4.0 | 0.0 | ... | 1.050000 | 1.133333 | 4.283333 | 0.650000 | 5.666667 | 1.383333 | 3.600000 | 2.266667 | 13.600000 | 0.266667 |
| 89393362eabffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 1.463415 | 1.609756 | 6.585366 | 1.097561 | 9.926829 | 1.512195 | 7.243902 | 3.170732 | 14.487805 | 1.000000 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 893933605b3ffff | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 6.0 | 0.0 | 0.0 | ... | 0.400000 | 1.050000 | 3.350000 | 0.650000 | 3.800000 | 0.550000 | 3.300000 | 1.350000 | 9.700000 | 0.400000 |
| 893933628a3ffff | 0.0 | 0.0 | 6.0 | 0.0 | 0.0 | 0.0 | 1.0 | 7.0 | 2.0 | 0.0 | ... | 1.366667 | 1.966667 | 4.966667 | 1.200000 | 12.850000 | 1.183333 | 10.983333 | 3.800000 | 17.166667 | 0.400000 |
| 89393362d73ffff | 0.0 | 0.0 | 1.0 | 0.0 | 2.0 | 0.0 | 0.0 | 7.0 | 0.0 | 0.0 | ... | 1.050000 | 0.850000 | 3.166667 | 0.700000 | 9.450000 | 1.166667 | 5.900000 | 2.333333 | 14.700000 | 0.283333 |
| 893933674b7ffff | 0.0 | 0.0 | 1.0 | 0.0 | 2.0 | 0.0 | 0.0 | 5.0 | 1.0 | 0.0 | ... | 1.270270 | 0.243243 | 2.567568 | 0.513514 | 7.432432 | 0.405405 | 5.756757 | 1.810811 | 14.945946 | 0.243243 |
| 893933759a7ffff | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | ... | 1.000000 | 0.071429 | 6.071429 | 0.285714 | 2.000000 | 1.214286 | 2.214286 | 1.214286 | 15.785714 | 1.000000 |
830 rows × 198 columns
Plotting example features¶
plot_numeric_data(regions_gdf, "leisure", embeddings)
plot_numeric_data(regions_gdf, "transportation", embeddings)
Other types of aggregations¶
By default, the ContextualCountEmbedder averages the counts from neighbours. This aggregation_function can be changed to one of: median, sum, min, max.
It's best to combine it with the wide format (concatenate_vectors=True).
sum_cce = ContextualCountEmbedder(
neighbourhood=h3n,
neighbourhood_distance=10,
concatenate_vectors=True,
aggregation_function="sum",
)
sum_embeddings = sum_cce.transform(regions_gdf, features_gdf, joint_gdf)
sum_embeddings
| aerialway_0 | airports_0 | buildings_0 | culture_art_entertainment_0 | education_0 | emergency_0 | finances_0 | greenery_0 | healthcare_0 | historic_0 | ... | healthcare_10 | historic_10 | leisure_10 | other_10 | shops_10 | sport_10 | sustenance_10 | tourism_10 | transportation_10 | water_10 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| region_id | |||||||||||||||||||||
| 89393362d87ffff | 0.0 | 0.0 | 4.0 | 0.0 | 2.0 | 0.0 | 2.0 | 1.0 | 1.0 | 0.0 | ... | 46.0 | 24.0 | 190.0 | 24.0 | 259.0 | 43.0 | 162.0 | 69.0 | 579.0 | 5.0 |
| 89393375907ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | ... | 14.0 | 1.0 | 22.0 | 8.0 | 67.0 | 7.0 | 41.0 | 9.0 | 173.0 | 5.0 |
| 89393375817ffff | 0.0 | 2.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | ... | 47.0 | 9.0 | 204.0 | 17.0 | 210.0 | 85.0 | 127.0 | 30.0 | 549.0 | 18.0 |
| 89393375a53ffff | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 3.0 | 4.0 | 0.0 | ... | 63.0 | 68.0 | 257.0 | 39.0 | 340.0 | 83.0 | 216.0 | 136.0 | 816.0 | 16.0 |
| 89393362eabffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 60.0 | 66.0 | 270.0 | 45.0 | 407.0 | 62.0 | 297.0 | 130.0 | 594.0 | 41.0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 893933605b3ffff | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 6.0 | 0.0 | 0.0 | ... | 8.0 | 21.0 | 67.0 | 13.0 | 76.0 | 11.0 | 66.0 | 27.0 | 194.0 | 8.0 |
| 893933628a3ffff | 0.0 | 0.0 | 6.0 | 0.0 | 0.0 | 0.0 | 1.0 | 7.0 | 2.0 | 0.0 | ... | 82.0 | 118.0 | 298.0 | 72.0 | 771.0 | 71.0 | 659.0 | 228.0 | 1030.0 | 24.0 |
| 89393362d73ffff | 0.0 | 0.0 | 1.0 | 0.0 | 2.0 | 0.0 | 0.0 | 7.0 | 0.0 | 0.0 | ... | 63.0 | 51.0 | 190.0 | 42.0 | 567.0 | 70.0 | 354.0 | 140.0 | 882.0 | 17.0 |
| 893933674b7ffff | 0.0 | 0.0 | 1.0 | 0.0 | 2.0 | 0.0 | 0.0 | 5.0 | 1.0 | 0.0 | ... | 47.0 | 9.0 | 95.0 | 19.0 | 275.0 | 15.0 | 213.0 | 67.0 | 553.0 | 9.0 |
| 893933759a7ffff | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | ... | 14.0 | 1.0 | 85.0 | 4.0 | 28.0 | 17.0 | 31.0 | 17.0 | 221.0 | 14.0 |
830 rows × 198 columns
plot_numeric_data(regions_gdf, "tourism_8", sum_embeddings)