Contextual count embedder
from srai.embedders import ContextualCountEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders.osm_loaders import OSMPbfLoader
from srai.neighbourhoods import H3Neighbourhood
from srai.plotting.folium_wrapper import plot_numeric_data, plot_regions
from srai.regionalizers import H3Regionalizer
Data preparation¶
In order to use ContextualCountEmbedder
we need to prepare some data.
Namely we need: regions_gdf
, features_gdf
, and joint_gdf
.
These are the outputs of Regionalizers, Loaders and Joiners respectively.
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Lisboa, PT")
plot_regions(area_gdf)
Regionalize the area using an H3Regionalizer¶
regionalizer = H3Regionalizer(resolution=9, buffer=True)
regions_gdf = regionalizer.transform(area_gdf)
regions_gdf
geometry | |
---|---|
region_id | |
893933629b3ffff | POLYGON ((-9.14493 38.74756, -9.14664 38.74622... |
89393375b23ffff | POLYGON ((-9.10142 38.7653, -9.10312 38.76396,... |
89393362a63ffff | POLYGON ((-9.17919 38.70177, -9.18089 38.70043... |
8939337580fffff | POLYGON ((-9.13001 38.78046, -9.13172 38.77912... |
893933676dbffff | POLYGON ((-9.14266 38.72112, -9.14436 38.71978... |
... | ... |
89393362abbffff | POLYGON ((-9.1818 38.7162, -9.1835 38.71485, -... |
89393375bcfffff | POLYGON ((-9.12045 38.76445, -9.12216 38.76311... |
89393362957ffff | POLYGON ((-9.15331 38.73391, -9.15501 38.73257... |
89393360517ffff | POLYGON ((-9.21105 38.70728, -9.21275 38.70594... |
89393375c5bffff | POLYGON ((-9.16374 38.7916, -9.16545 38.79026,... |
830 rows × 1 columns
Download some objects from OpenStreetMap¶
You can use both OsmTagsFilter
and GroupedOsmTagsFilter
filters. In this example, a predefined GroupedOsmTagsFilter
filter BASE_OSM_GROUPS_FILTER
is used.
from srai.loaders.osm_loaders.filters import BASE_OSM_GROUPS_FILTER
loader = OSMPbfLoader()
features_gdf = loader.load(area_gdf, tags=BASE_OSM_GROUPS_FILTER)
features_gdf
0%| | 0.00/52.0M [00:00<?, ?B/s]
0%| | 8.19k/52.0M [00:00<13:30, 64.1kB/s]
0%| | 38.9k/52.0M [00:00<05:14, 165kB/s]
0%| | 93.2k/52.0M [00:00<03:05, 280kB/s]
0%|▏ | 208k/52.0M [00:00<01:40, 515kB/s]
1%|▎ | 433k/52.0M [00:00<00:54, 950kB/s]
2%|▋ | 884k/52.0M [00:00<00:28, 1.80MB/s]
3%|█▎ | 1.78M/52.0M [00:00<00:14, 3.44MB/s]
7%|██▌ | 3.57M/52.0M [00:01<00:07, 6.71MB/s]
14%|█████ | 7.15M/52.0M [00:01<00:03, 13.1MB/s]
20%|███████▍ | 10.4M/52.0M [00:01<00:02, 16.5MB/s]
27%|█████████▊ | 13.9M/52.0M [00:01<00:01, 19.6MB/s]
33%|████████████ | 17.0M/52.0M [00:01<00:01, 20.9MB/s]
40%|██████████████▋ | 20.7M/52.0M [00:01<00:01, 23.1MB/s]
47%|█████████████████▎ | 24.3M/52.0M [00:01<00:01, 24.4MB/s]
54%|███████████████████▉ | 28.0M/52.0M [00:01<00:00, 25.7MB/s]
61%|██████████████████████▌ | 31.6M/52.0M [00:02<00:00, 26.2MB/s]
68%|█████████████████████████ | 35.2M/52.0M [00:02<00:00, 26.6MB/s]
74%|███████████████████████████▎ | 38.5M/52.0M [00:02<00:00, 28.0MB/s]
79%|█████████████████████████████▍ | 41.3M/52.0M [00:02<00:00, 26.5MB/s]
85%|███████████████████████████████▎ | 44.0M/52.0M [00:02<00:00, 24.7MB/s]
89%|█████████████████████████████████ | 46.5M/52.0M [00:02<00:00, 23.2MB/s]
96%|███████████████████████████████████▌ | 49.9M/52.0M [00:02<00:00, 24.3MB/s]
0%| | 0.00/52.0M [00:00<?, ?B/s]
100%|█████████████████████████████████████| 52.0M/52.0M [00:00<00:00, 71.2GB/s]
Finished operation in 0:00:17
geometry | aerialway | airports | buildings | culture_art_entertainment | education | emergency | finances | greenery | healthcare | historic | leisure | other | shops | sport | sustenance | tourism | transportation | water | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
feature_id | |||||||||||||||||||
node/277831550 | POINT (-9.16476 38.73881) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
node/277831931 | POINT (-9.15994 38.73804) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
node/277834117 | POINT (-9.15942 38.75257) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
node/277834790 | POINT (-9.12065 38.73273) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=parking | None |
node/277835172 | POINT (-9.1964 38.75332) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
way/1015889842 | POLYGON ((-9.09644 38.77941, -9.09636 38.77929... | None | None | None | None | None | None | None | landuse=grass | None | None | None | None | None | None | None | None | None | None |
way/1015903406 | POLYGON ((-9.09851 38.78496, -9.09817 38.78508... | None | None | building=commercial | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None |
way/1015903426 | POLYGON ((-9.09413 38.78247, -9.09409 38.7824,... | None | None | None | None | None | None | None | None | None | None | leisure=garden | None | None | None | None | None | None | None |
way/1015903433 | POLYGON ((-9.0939 38.78182, -9.09389 38.78177,... | None | None | None | None | None | None | None | None | None | None | leisure=garden | None | None | None | None | None | None | None |
way/1015903434 | POLYGON ((-9.09402 38.78211, -9.09403 38.78206... | None | None | None | None | None | None | None | None | None | None | leisure=garden | None | None | None | None | None | None | None |
33643 rows × 19 columns
Join the objects with the regions they belong to¶
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, features_gdf)
joint_gdf
region_id | feature_id |
---|---|
893933628a7ffff | node/277831550 |
893933629c3ffff | node/277831931 |
89393362d23ffff | node/277834117 |
8939336747bffff | node/277834790 |
89393362c8fffff | node/277835172 |
... | ... |
89393375967ffff | way/1015889842 |
8939337593bffff | way/1015903406 |
8939337592bffff | way/1015903426 |
way/1015903433 | |
way/1015903434 |
37279 rows × 0 columns
Embed using features existing in data¶
ContextualCountEmbedder
extends capabilities of basic CountEmbedder
by incorporating the neighbourhood of embedded region. In this example we will use the H3Neighbourhood
.
h3n = H3Neighbourhood()
Squashed vector version (default)¶
Embedder will return vector of the same length as CountEmbedder
, but will sum averaged values from the neighbourhoods diminished by the neighbour distance squared.
cce = ContextualCountEmbedder(
neighbourhood=h3n, neighbourhood_distance=10, concatenate_vectors=False
)
embeddings = cce.transform(regions_gdf, features_gdf, joint_gdf)
embeddings
aerialway | airports | buildings | culture_art_entertainment | education | emergency | finances | greenery | healthcare | historic | leisure | other | shops | sport | sustenance | tourism | transportation | water | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | ||||||||||||||||||
893933629b3ffff | 0.001940 | 0.008534 | 5.216979 | 0.369925 | 2.973309 | 0.099288 | 0.984853 | 13.323340 | 3.215023 | 2.383881 | 6.419669 | 0.510170 | 13.435151 | 1.862689 | 11.617866 | 2.870294 | 53.508983 | 1.144703 |
89393375b23ffff | 0.038121 | 0.010088 | 4.151945 | 0.087953 | 0.380836 | 0.067076 | 0.851376 | 43.057781 | 0.699073 | 0.178552 | 7.695632 | 0.290695 | 4.586812 | 1.511707 | 2.805844 | 3.301124 | 35.180132 | 0.783543 |
89393362a63ffff | 0.000000 | 0.000000 | 25.189043 | 0.149533 | 0.701158 | 0.047761 | 0.347083 | 3.631977 | 0.487919 | 0.884661 | 7.297855 | 0.549296 | 3.395439 | 0.489243 | 18.057240 | 7.554073 | 10.609302 | 1.668671 |
8939337580fffff | 0.000403 | 2.492358 | 0.483473 | 0.021564 | 0.231298 | 0.120762 | 0.097968 | 8.583693 | 0.094866 | 0.051908 | 8.480700 | 0.154466 | 0.699687 | 0.259171 | 0.489687 | 0.168565 | 10.457186 | 0.035804 |
893933676dbffff | 0.000708 | 0.000000 | 9.158606 | 1.895920 | 2.151092 | 0.225504 | 1.346171 | 8.426704 | 2.308130 | 15.634765 | 10.100089 | 3.264981 | 27.688539 | 3.471079 | 39.657571 | 19.521223 | 31.790198 | 3.757902 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
89393362abbffff | 0.001196 | 0.000000 | 0.408747 | 2.131416 | 0.309634 | 0.015078 | 0.101555 | 11.533457 | 0.209190 | 0.312631 | 5.735830 | 0.184977 | 1.343324 | 0.938530 | 1.104291 | 1.720988 | 7.794270 | 0.314789 |
89393375bcfffff | 0.002639 | 0.038735 | 2.664668 | 0.043544 | 1.515498 | 0.023815 | 2.172332 | 28.153222 | 2.327316 | 0.056888 | 4.786629 | 3.347528 | 11.918271 | 0.541493 | 1.903307 | 0.385058 | 14.192203 | 0.084379 |
89393362957ffff | 0.002999 | 0.000427 | 15.099440 | 2.443911 | 0.997706 | 0.099230 | 5.931812 | 8.060651 | 6.872466 | 2.125382 | 6.144911 | 2.671272 | 37.950649 | 1.624463 | 25.455413 | 19.002965 | 34.008526 | 3.772235 |
89393360517ffff | 0.000000 | 0.000000 | 7.642777 | 1.107966 | 1.716152 | 0.004740 | 1.143473 | 7.584112 | 1.373226 | 0.297312 | 24.798375 | 0.397808 | 5.822099 | 1.203289 | 3.215527 | 2.912565 | 15.609141 | 0.352171 |
89393375c5bffff | 0.000000 | 0.023091 | 0.139393 | 0.016246 | 0.251120 | 0.004224 | 0.056136 | 4.298787 | 0.111478 | 0.113048 | 1.635843 | 0.191396 | 0.577604 | 0.655485 | 0.291406 | 0.136175 | 3.552983 | 0.083029 |
830 rows × 18 columns
Concatenated vector version¶
Embedder will return vector of length n * distance
where n
is number of features from the CountEmbedder
and distance
is number of neighbourhoods analysed.
Each feature will be postfixed with _n
string, where n
is the current distance. Values are averaged from all neighbours.
wide_cce = ContextualCountEmbedder(
neighbourhood=h3n, neighbourhood_distance=10, concatenate_vectors=True
)
wide_embeddings = wide_cce.transform(regions_gdf, features_gdf, joint_gdf)
wide_embeddings
aerialway_0 | airports_0 | buildings_0 | culture_art_entertainment_0 | education_0 | emergency_0 | finances_0 | greenery_0 | healthcare_0 | historic_0 | ... | healthcare_10 | historic_10 | leisure_10 | other_10 | shops_10 | sport_10 | sustenance_10 | tourism_10 | transportation_10 | water_10 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
893933629b3ffff | 0.0 | 0.0 | 3.0 | 0.0 | 2.0 | 0.0 | 0.0 | 10.0 | 2.0 | 2.0 | ... | 1.033333 | 1.366667 | 3.600000 | 1.033333 | 6.200000 | 1.066667 | 5.950000 | 2.916667 | 13.100000 | 0.316667 |
89393375b23ffff | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | 0.0 | 0.0 | 39.0 | 0.0 | 0.0 | ... | 0.407407 | 0.111111 | 3.222222 | 0.444444 | 5.629630 | 1.592593 | 2.925926 | 0.814815 | 9.814815 | 0.148148 |
89393362a63ffff | 0.0 | 0.0 | 24.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | ... | 1.250000 | 3.093750 | 6.875000 | 1.375000 | 11.250000 | 1.500000 | 12.500000 | 4.843750 | 14.187500 | 0.468750 |
8939337580fffff | 0.0 | 2.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 5.0 | 0.0 | 0.0 | ... | 0.804878 | 0.390244 | 4.073171 | 0.292683 | 5.707317 | 1.097561 | 3.951220 | 1.853659 | 14.853659 | 0.975610 |
893933676dbffff | 0.0 | 0.0 | 6.0 | 1.0 | 1.0 | 0.0 | 0.0 | 3.0 | 1.0 | 12.0 | ... | 0.771429 | 0.571429 | 3.685714 | 0.485714 | 1.971429 | 1.600000 | 2.000000 | 2.457143 | 11.600000 | 0.857143 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
89393362abbffff | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | 0.0 | 9.0 | 0.0 | 0.0 | ... | 1.324324 | 2.945946 | 4.297297 | 0.837838 | 9.810811 | 0.594595 | 13.729730 | 7.054054 | 15.567568 | 1.000000 |
89393375bcfffff | 0.0 | 0.0 | 2.0 | 0.0 | 1.0 | 0.0 | 2.0 | 21.0 | 2.0 | 0.0 | ... | 1.111111 | 0.611111 | 4.166667 | 0.750000 | 9.444444 | 1.305556 | 6.055556 | 2.027778 | 22.027778 | 0.527778 |
89393362957ffff | 0.0 | 0.0 | 11.0 | 2.0 | 0.0 | 0.0 | 5.0 | 4.0 | 5.0 | 1.0 | ... | 0.396552 | 0.413793 | 3.448276 | 0.310345 | 2.344828 | 1.379310 | 1.931034 | 1.844828 | 11.000000 | 0.465517 |
89393360517ffff | 0.0 | 0.0 | 6.0 | 1.0 | 1.0 | 0.0 | 1.0 | 4.0 | 1.0 | 0.0 | ... | 0.277778 | 0.500000 | 3.111111 | 0.555556 | 1.444444 | 1.166667 | 2.222222 | 1.055556 | 9.166667 | 0.500000 |
89393375c5bffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | ... | 0.739130 | 0.130435 | 3.695652 | 0.391304 | 4.217391 | 0.782609 | 2.608696 | 0.826087 | 9.565217 | 0.217391 |
830 rows × 198 columns
Plotting example features¶
plot_numeric_data(regions_gdf, "leisure", embeddings)
plot_numeric_data(regions_gdf, "transportation", embeddings)
Other types of aggregations¶
By default, the ContextualCountEmbedder
averages the counts from neighbours. This aggregation_function
can be changed to one of: median
, sum
, min
, max
.
It's best to combine it with the wide format (concatenate_vectors=True
).
sum_cce = ContextualCountEmbedder(
neighbourhood=h3n,
neighbourhood_distance=10,
concatenate_vectors=True,
aggregation_function="sum",
)
sum_embeddings = sum_cce.transform(regions_gdf, features_gdf, joint_gdf)
sum_embeddings
aerialway_0 | airports_0 | buildings_0 | culture_art_entertainment_0 | education_0 | emergency_0 | finances_0 | greenery_0 | healthcare_0 | historic_0 | ... | healthcare_10 | historic_10 | leisure_10 | other_10 | shops_10 | sport_10 | sustenance_10 | tourism_10 | transportation_10 | water_10 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
893933629b3ffff | 0.0 | 0.0 | 3.0 | 0.0 | 2.0 | 0.0 | 0.0 | 10.0 | 2.0 | 2.0 | ... | 62.0 | 82.0 | 216.0 | 62.0 | 372.0 | 64.0 | 357.0 | 175.0 | 786.0 | 19.0 |
89393375b23ffff | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | 0.0 | 0.0 | 39.0 | 0.0 | 0.0 | ... | 11.0 | 3.0 | 87.0 | 12.0 | 152.0 | 43.0 | 79.0 | 22.0 | 265.0 | 4.0 |
89393362a63ffff | 0.0 | 0.0 | 24.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | ... | 40.0 | 99.0 | 220.0 | 44.0 | 360.0 | 48.0 | 400.0 | 155.0 | 454.0 | 15.0 |
8939337580fffff | 0.0 | 2.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 5.0 | 0.0 | 0.0 | ... | 33.0 | 16.0 | 167.0 | 12.0 | 234.0 | 45.0 | 162.0 | 76.0 | 609.0 | 40.0 |
893933676dbffff | 0.0 | 0.0 | 6.0 | 1.0 | 1.0 | 0.0 | 0.0 | 3.0 | 1.0 | 12.0 | ... | 27.0 | 20.0 | 129.0 | 17.0 | 69.0 | 56.0 | 70.0 | 86.0 | 406.0 | 30.0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
89393362abbffff | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | 0.0 | 9.0 | 0.0 | 0.0 | ... | 49.0 | 109.0 | 159.0 | 31.0 | 363.0 | 22.0 | 508.0 | 261.0 | 576.0 | 37.0 |
89393375bcfffff | 0.0 | 0.0 | 2.0 | 0.0 | 1.0 | 0.0 | 2.0 | 21.0 | 2.0 | 0.0 | ... | 40.0 | 22.0 | 150.0 | 27.0 | 340.0 | 47.0 | 218.0 | 73.0 | 793.0 | 19.0 |
89393362957ffff | 0.0 | 0.0 | 11.0 | 2.0 | 0.0 | 0.0 | 5.0 | 4.0 | 5.0 | 1.0 | ... | 23.0 | 24.0 | 200.0 | 18.0 | 136.0 | 80.0 | 112.0 | 107.0 | 638.0 | 27.0 |
89393360517ffff | 0.0 | 0.0 | 6.0 | 1.0 | 1.0 | 0.0 | 1.0 | 4.0 | 1.0 | 0.0 | ... | 5.0 | 9.0 | 56.0 | 10.0 | 26.0 | 21.0 | 40.0 | 19.0 | 165.0 | 9.0 |
89393375c5bffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | ... | 17.0 | 3.0 | 85.0 | 9.0 | 97.0 | 18.0 | 60.0 | 19.0 | 220.0 | 5.0 |
830 rows × 198 columns
plot_numeric_data(regions_gdf, "tourism_8", sum_embeddings)