Contextual count embedder
from srai.embedders import ContextualCountEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders.osm_loaders import OSMPbfLoader
from srai.neighbourhoods import H3Neighbourhood
from srai.plotting.folium_wrapper import plot_numeric_data, plot_regions
from srai.regionalizers import H3Regionalizer
Data preparation¶
In order to use ContextualCountEmbedder we need to prepare some data.
Namely we need: regions_gdf, features_gdf, and joint_gdf.
These are the outputs of Regionalizers, Loaders and Joiners respectively.
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Lisboa, PT")
plot_regions(area_gdf)
Regionalize the area using an H3Regionalizer¶
regionalizer = H3Regionalizer(resolution=9, buffer=True)
regions_gdf = regionalizer.transform(area_gdf)
regions_gdf
| geometry | |
|---|---|
| region_id | |
| 89393367097ffff | POLYGON ((-9.10318 38.73804, -9.10489 38.73671... |
| 89393375a6bffff | POLYGON ((-9.13688 38.74919, -9.13859 38.74785... |
| 89393362eabffff | POLYGON ((-9.19942 38.73057, -9.20113 38.72923... |
| 89393375b23ffff | POLYGON ((-9.10142 38.7653, -9.10312 38.76396,... |
| 89393362b0fffff | POLYGON ((-9.15288 38.71307, -9.15458 38.71173... |
| ... | ... |
| 89393375863ffff | POLYGON ((-9.12817 38.77486, -9.12987 38.77352... |
| 89393375ad3ffff | POLYGON ((-9.1545 38.76357, -9.1562 38.76223, ... |
| 89393375ec3ffff | POLYGON ((-9.17212 38.77795, -9.17383 38.77661... |
| 89393362a23ffff | POLYGON ((-9.17299 38.70901, -9.1747 38.70766,... |
| 89393375933ffff | POLYGON ((-9.09672 38.79015, -9.09842 38.78881... |
830 rows × 1 columns
Download some objects from OpenStreetMap¶
You can use both OsmTagsFilter and GroupedOsmTagsFilter filters. In this example, a predefined GroupedOsmTagsFilter filter BASE_OSM_GROUPS_FILTER is used.
from srai.loaders.osm_loaders.filters import BASE_OSM_GROUPS_FILTER
loader = OSMPbfLoader()
features_gdf = loader.load(area_gdf, tags=BASE_OSM_GROUPS_FILTER)
features_gdf
0%| | 0.00/52.2M [00:00<?, ?B/s]
0%| | 8.19k/52.2M [00:00<13:35, 63.9kB/s]
0%| | 38.9k/52.2M [00:00<05:15, 165kB/s]
0%| | 93.2k/52.2M [00:00<03:06, 279kB/s]
0%|▏ | 208k/52.2M [00:00<01:40, 515kB/s]
1%|▎ | 437k/52.2M [00:00<00:53, 960kB/s]
2%|▋ | 892k/52.2M [00:00<00:28, 1.81MB/s]
3%|█▎ | 1.80M/52.2M [00:00<00:14, 3.49MB/s]
7%|██▌ | 3.62M/52.2M [00:01<00:07, 6.78MB/s]
14%|█████▏ | 7.23M/52.2M [00:01<00:03, 13.3MB/s]
21%|███████▉ | 11.1M/52.2M [00:01<00:02, 17.9MB/s]
27%|█████████▉ | 14.0M/52.2M [00:01<00:01, 19.2MB/s]
34%|████████████▌ | 17.6M/52.2M [00:01<00:01, 21.7MB/s]
41%|███████████████ | 21.2M/52.2M [00:01<00:01, 23.4MB/s]
47%|█████████████████▌ | 24.7M/52.2M [00:01<00:01, 24.4MB/s]
53%|███████████████████▌ | 27.5M/52.2M [00:01<00:01, 23.4MB/s]
58%|█████████████████████▌ | 30.5M/52.2M [00:02<00:00, 23.0MB/s]
65%|████████████████████████▏ | 34.1M/52.2M [00:02<00:00, 24.3MB/s]
72%|██████████████████████████▍ | 37.4M/52.2M [00:02<00:00, 24.5MB/s]
78%|█████████████████████████████ | 40.9M/52.2M [00:02<00:00, 25.3MB/s]
84%|███████████████████████████████ | 43.7M/52.2M [00:02<00:00, 22.2MB/s]
91%|█████████████████████████████████▌ | 47.3M/52.2M [00:02<00:00, 23.6MB/s]
98%|████████████████████████████████████▏| 51.1M/52.2M [00:02<00:00, 24.2MB/s]
0%| | 0.00/52.2M [00:00<?, ?B/s]
100%|█████████████████████████████████████| 52.2M/52.2M [00:00<00:00, 58.9GB/s]
Finished operation in 0:00:17
| geometry | aerialway | airports | buildings | culture_art_entertainment | education | emergency | finances | greenery | healthcare | historic | leisure | other | shops | sport | sustenance | tourism | transportation | water | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| feature_id | |||||||||||||||||||
| node/21433772 | POINT (-9.19059 38.7288) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
| node/21433776 | POINT (-9.19376 38.72666) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
| node/25414208 | POINT (-9.16663 38.74018) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
| node/25414256 | POINT (-9.10286 38.74711) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
| node/25414265 | POINT (-9.10273 38.74707) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| way/1148218026 | POLYGON ((-9.20274 38.70175, -9.20279 38.70163... | None | None | None | None | None | None | None | None | None | None | leisure=pitch | None | None | sport=soccer | None | None | None | None |
| way/1148218053 | POLYGON ((-9.20364 38.70032, -9.20364 38.70026... | None | None | None | None | None | None | None | None | None | None | leisure=swimming_pool | None | None | None | None | None | None | None |
| way/1148218054 | POLYGON ((-9.20366 38.70015, -9.20361 38.70015... | None | None | None | None | None | None | None | None | None | None | leisure=swimming_pool | None | None | None | None | None | None | None |
| way/1148218070 | POLYGON ((-9.20478 38.7011, -9.20453 38.70101,... | None | None | None | None | None | None | None | landuse=grass | None | None | None | None | None | None | None | None | None | None |
| way/1148218071 | POLYGON ((-9.20443 38.70098, -9.20443 38.70092... | None | None | None | None | None | None | None | landuse=grass | None | None | None | None | None | None | None | None | None | None |
33676 rows × 19 columns
Join the objects with the regions they belong to¶
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, features_gdf)
joint_gdf
| region_id | feature_id |
|---|---|
| 893933628dbffff | node/21433772 |
| 89393362e37ffff | node/21433776 |
| 893933628b7ffff | node/25414208 |
| 893933675c3ffff | node/25414256 |
| node/25414265 | |
| ... | ... |
| 89393360523ffff | way/1148218026 |
| 8939336053bffff | way/1148218053 |
| way/1148218054 | |
| way/1148218070 | |
| way/1148218071 |
37320 rows × 0 columns
Embed using features existing in data¶
ContextualCountEmbedder extends capabilities of basic CountEmbedder by incorporating the neighbourhood of embedded region. In this example we will use the H3Neighbourhood.
h3n = H3Neighbourhood()
Squashed vector version (default)¶
Embedder will return vector of the same length as CountEmbedder, but will sum averaged values from the neighbourhoods diminished by the neighbour distance squared.
cce = ContextualCountEmbedder(
neighbourhood=h3n, neighbourhood_distance=10, concatenate_vectors=False
)
embeddings = cce.transform(regions_gdf, features_gdf, joint_gdf)
embeddings
| aerialway | airports | buildings | culture_art_entertainment | education | emergency | finances | greenery | healthcare | historic | leisure | other | shops | sport | sustenance | tourism | transportation | water | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| region_id | ||||||||||||||||||
| 89393367097ffff | 0.001709 | 0.000000 | 6.968348 | 0.262736 | 0.279909 | 0.002109 | 0.081015 | 2.858649 | 0.258521 | 0.290721 | 1.326690 | 1.367278 | 1.768903 | 0.435475 | 3.426741 | 0.647477 | 16.987121 | 1.265094 |
| 89393375a6bffff | 0.000950 | 0.011091 | 1.214162 | 1.151128 | 2.958092 | 0.018515 | 0.616967 | 6.307910 | 3.017699 | 1.230854 | 14.885221 | 1.382215 | 10.141025 | 1.916552 | 5.335819 | 2.199736 | 54.538192 | 0.094769 |
| 89393362eabffff | 0.001193 | 0.000000 | 0.160942 | 0.024593 | 0.138521 | 0.005053 | 0.032375 | 0.682701 | 0.092859 | 0.098159 | 5.030859 | 0.188682 | 0.497836 | 3.212923 | 0.316621 | 4.119140 | 2.154397 | 0.070713 |
| 89393375b23ffff | 0.038121 | 0.010088 | 5.151945 | 0.087953 | 0.380836 | 0.067076 | 1.059710 | 14.851022 | 0.699073 | 0.178552 | 7.695632 | 0.290695 | 4.813970 | 1.511707 | 2.991029 | 3.301124 | 28.248875 | 0.783543 |
| 89393362b0fffff | 0.000000 | 0.000000 | 7.664333 | 0.598483 | 3.033973 | 0.053098 | 4.391127 | 5.592660 | 3.034139 | 9.014358 | 4.597525 | 5.386641 | 49.390022 | 0.385905 | 57.768863 | 20.126258 | 32.446827 | 1.659618 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 89393375863ffff | 0.001415 | 1.313799 | 9.434916 | 0.025522 | 0.227488 | 0.336316 | 0.300306 | 6.693917 | 0.168136 | 0.104101 | 5.196271 | 1.158497 | 1.116657 | 0.488135 | 1.112214 | 1.350520 | 16.965022 | 0.042516 |
| 89393375ad3ffff | 0.001516 | 0.042673 | 1.097285 | 0.080688 | 2.974660 | 0.022558 | 0.432942 | 5.012365 | 1.818031 | 0.444930 | 10.800818 | 0.161256 | 4.439855 | 3.940326 | 4.755387 | 2.211952 | 22.036049 | 0.129344 |
| 89393375ec3ffff | 0.000590 | 0.011136 | 0.368287 | 0.022892 | 0.487223 | 0.010108 | 0.124784 | 10.185254 | 0.232645 | 0.631143 | 3.654183 | 0.185091 | 1.057462 | 0.796301 | 0.698169 | 0.486984 | 4.738405 | 1.302625 |
| 89393362a23ffff | 0.000000 | 0.000000 | 7.836574 | 0.150023 | 2.702786 | 0.011089 | 1.356533 | 13.681095 | 3.725450 | 3.964703 | 7.905942 | 1.653704 | 25.223663 | 0.449763 | 20.191243 | 6.558954 | 50.821009 | 3.465506 |
| 89393375933ffff | 0.012689 | 0.012803 | 0.685608 | 0.015342 | 0.391365 | 0.037386 | 0.108833 | 4.746487 | 0.176063 | 0.039674 | 4.553705 | 0.193940 | 1.229323 | 1.347861 | 0.748765 | 1.480506 | 8.600802 | 7.508649 |
830 rows × 18 columns
Concatenated vector version¶
Embedder will return vector of length n * distance where n is number of features from the CountEmbedder and distance is number of neighbourhoods analysed.
Each feature will be postfixed with _n string, where n is the current distance. Values are averaged from all neighbours.
wide_cce = ContextualCountEmbedder(
neighbourhood=h3n, neighbourhood_distance=10, concatenate_vectors=True
)
wide_embeddings = wide_cce.transform(regions_gdf, features_gdf, joint_gdf)
wide_embeddings
| aerialway_0 | airports_0 | buildings_0 | culture_art_entertainment_0 | education_0 | emergency_0 | finances_0 | greenery_0 | healthcare_0 | historic_0 | ... | healthcare_10 | historic_10 | leisure_10 | other_10 | shops_10 | sport_10 | sustenance_10 | tourism_10 | transportation_10 | water_10 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| region_id | |||||||||||||||||||||
| 89393367097ffff | 0.0 | 0.0 | 5.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | ... | 1.357143 | 1.285714 | 5.250000 | 1.750000 | 15.750000 | 1.428571 | 10.035714 | 4.035714 | 25.285714 | 0.750000 |
| 89393375a6bffff | 0.0 | 0.0 | 0.0 | 1.0 | 2.0 | 0.0 | 0.0 | 3.0 | 2.0 | 1.0 | ... | 1.350000 | 1.583333 | 4.266667 | 0.983333 | 7.416667 | 1.200000 | 6.116667 | 2.983333 | 16.000000 | 0.433333 |
| 89393362eabffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 1.463415 | 1.609756 | 6.585366 | 1.097561 | 9.926829 | 1.512195 | 7.243902 | 3.170732 | 14.487805 | 1.000000 |
| 89393375b23ffff | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | 0.0 | 0.0 | 12.0 | 0.0 | 0.0 | ... | 0.407407 | 0.111111 | 3.222222 | 0.444444 | 5.666667 | 1.592593 | 2.925926 | 0.814815 | 9.851852 | 0.148148 |
| 89393362b0fffff | 0.0 | 0.0 | 5.0 | 0.0 | 2.0 | 0.0 | 3.0 | 3.0 | 2.0 | 6.0 | ... | 1.515152 | 0.606061 | 2.878788 | 0.969697 | 11.151515 | 0.666667 | 6.606061 | 3.151515 | 19.303030 | 0.363636 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 89393375863ffff | 0.0 | 1.0 | 9.0 | 0.0 | 0.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | ... | 0.772727 | 0.522727 | 3.636364 | 0.386364 | 6.068182 | 1.318182 | 3.522727 | 2.181818 | 13.840909 | 1.204545 |
| 89393375ad3ffff | 0.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | 2.0 | 1.0 | 0.0 | ... | 1.060000 | 0.600000 | 4.940000 | 0.820000 | 8.400000 | 1.000000 | 4.920000 | 2.420000 | 16.920000 | 0.200000 |
| 89393375ec3ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 7.0 | 0.0 | 0.0 | ... | 1.357143 | 0.357143 | 2.071429 | 0.500000 | 10.750000 | 0.714286 | 5.214286 | 2.500000 | 13.821429 | 0.321429 |
| 89393362a23ffff | 0.0 | 0.0 | 6.0 | 0.0 | 2.0 | 0.0 | 1.0 | 10.0 | 3.0 | 3.0 | ... | 1.324324 | 1.810811 | 5.162162 | 1.054054 | 6.378378 | 0.756757 | 8.135135 | 5.756757 | 15.594595 | 0.432432 |
| 89393375933ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | ... | 0.764706 | 0.000000 | 2.117647 | 0.647059 | 5.411765 | 0.705882 | 2.117647 | 0.764706 | 8.705882 | 0.294118 |
830 rows × 198 columns
Plotting example features¶
plot_numeric_data(regions_gdf, "leisure", embeddings)
plot_numeric_data(regions_gdf, "transportation", embeddings)
Other types of aggregations¶
By default, the ContextualCountEmbedder averages the counts from neighbours. This aggregation_function can be changed to one of: median, sum, min, max.
It's best to combine it with the wide format (concatenate_vectors=True).
sum_cce = ContextualCountEmbedder(
neighbourhood=h3n,
neighbourhood_distance=10,
concatenate_vectors=True,
aggregation_function="sum",
)
sum_embeddings = sum_cce.transform(regions_gdf, features_gdf, joint_gdf)
sum_embeddings
| aerialway_0 | airports_0 | buildings_0 | culture_art_entertainment_0 | education_0 | emergency_0 | finances_0 | greenery_0 | healthcare_0 | historic_0 | ... | healthcare_10 | historic_10 | leisure_10 | other_10 | shops_10 | sport_10 | sustenance_10 | tourism_10 | transportation_10 | water_10 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| region_id | |||||||||||||||||||||
| 89393367097ffff | 0.0 | 0.0 | 5.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | ... | 38.0 | 36.0 | 147.0 | 49.0 | 441.0 | 40.0 | 281.0 | 113.0 | 708.0 | 21.0 |
| 89393375a6bffff | 0.0 | 0.0 | 0.0 | 1.0 | 2.0 | 0.0 | 0.0 | 3.0 | 2.0 | 1.0 | ... | 81.0 | 95.0 | 256.0 | 59.0 | 445.0 | 72.0 | 367.0 | 179.0 | 960.0 | 26.0 |
| 89393362eabffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 60.0 | 66.0 | 270.0 | 45.0 | 407.0 | 62.0 | 297.0 | 130.0 | 594.0 | 41.0 |
| 89393375b23ffff | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | 0.0 | 0.0 | 12.0 | 0.0 | 0.0 | ... | 11.0 | 3.0 | 87.0 | 12.0 | 153.0 | 43.0 | 79.0 | 22.0 | 266.0 | 4.0 |
| 89393362b0fffff | 0.0 | 0.0 | 5.0 | 0.0 | 2.0 | 0.0 | 3.0 | 3.0 | 2.0 | 6.0 | ... | 50.0 | 20.0 | 95.0 | 32.0 | 368.0 | 22.0 | 218.0 | 104.0 | 637.0 | 12.0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 89393375863ffff | 0.0 | 1.0 | 9.0 | 0.0 | 0.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | ... | 34.0 | 23.0 | 160.0 | 17.0 | 267.0 | 58.0 | 155.0 | 96.0 | 609.0 | 53.0 |
| 89393375ad3ffff | 0.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | 2.0 | 1.0 | 0.0 | ... | 53.0 | 30.0 | 247.0 | 41.0 | 420.0 | 50.0 | 246.0 | 121.0 | 846.0 | 10.0 |
| 89393375ec3ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 7.0 | 0.0 | 0.0 | ... | 38.0 | 10.0 | 58.0 | 14.0 | 301.0 | 20.0 | 146.0 | 70.0 | 387.0 | 9.0 |
| 89393362a23ffff | 0.0 | 0.0 | 6.0 | 0.0 | 2.0 | 0.0 | 1.0 | 10.0 | 3.0 | 3.0 | ... | 49.0 | 67.0 | 191.0 | 39.0 | 236.0 | 28.0 | 301.0 | 213.0 | 577.0 | 16.0 |
| 89393375933ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | ... | 13.0 | 0.0 | 36.0 | 11.0 | 92.0 | 12.0 | 36.0 | 13.0 | 148.0 | 5.0 |
830 rows × 198 columns
plot_numeric_data(regions_gdf, "tourism_8", sum_embeddings)