Contextual count embedder
from srai.embedders import ContextualCountEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders.osm_loaders import OSMPbfLoader
from srai.neighbourhoods import H3Neighbourhood
from srai.plotting.folium_wrapper import plot_numeric_data, plot_regions
from srai.regionalizers import H3Regionalizer
Data preparation¶
In order to use ContextualCountEmbedder
we need to prepare some data.
Namely we need: regions_gdf
, features_gdf
, and joint_gdf
.
These are the outputs of Regionalizers, Loaders and Joiners respectively.
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Lisboa, PT")
plot_regions(area_gdf)
Regionalize the area using an H3Regionalizer¶
regionalizer = H3Regionalizer(resolution=9, buffer=True)
regions_gdf = regionalizer.transform(area_gdf)
regions_gdf
geometry | |
---|---|
region_id | |
89393362b5bffff | POLYGON ((-9.16603 38.70742, -9.16774 38.70608... |
893933629cfffff | POLYGON ((-9.15624 38.73631, -9.15795 38.73497... |
89393375b7bffff | POLYGON ((-9.11165 38.75726, -9.11335 38.75592... |
893933629b3ffff | POLYGON ((-9.14493 38.74756, -9.14664 38.74622... |
89393362db7ffff | POLYGON ((-9.16146 38.76515, -9.16317 38.76381... |
... | ... |
893933664cbffff | POLYGON ((-9.08867 38.79177, -9.09037 38.79043... |
893933676d3ffff | POLYGON ((-9.14157 38.72433, -9.14327 38.72298... |
89393360437ffff | POLYGON ((-9.22823 38.70081, -9.22993 38.69947... |
89393362c0fffff | POLYGON ((-9.19399 38.74662, -9.19569 38.74528... |
8939336293bffff | POLYGON ((-9.14124 38.73635, -9.14294 38.73501... |
830 rows × 1 columns
Download some objects from OpenStreetMap¶
You can use both OsmTagsFilter
and GroupedOsmTagsFilter
filters. In this example, a predefined GroupedOsmTagsFilter
filter BASE_OSM_GROUPS_FILTER
is used.
from srai.loaders.osm_loaders.filters import BASE_OSM_GROUPS_FILTER
loader = OSMPbfLoader()
features_gdf = loader.load(area_gdf, tags=BASE_OSM_GROUPS_FILTER)
features_gdf
Finished operation in 0:00:16
geometry | aerialway | airports | buildings | culture_art_entertainment | education | emergency | finances | greenery | healthcare | historic | leisure | other | shops | sport | sustenance | tourism | transportation | water | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
feature_id | |||||||||||||||||||
node/1606370716 | POINT (-9.13287 38.72642) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=cafe | None | None | None |
node/1606908608 | POINT (-9.19198 38.75894) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=restaurant | None | None | None |
node/1606908881 | POINT (-9.18775 38.7605) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=restaurant | None | None | None |
node/1606909014 | POINT (-9.18739 38.76059) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=restaurant | None | None | None |
node/1606921033 | POINT (-9.17963 38.75975) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | highway=bus_stop | None |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
way/1122286620 | POLYGON ((-9.18723 38.77212, -9.18699 38.77187... | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=parking | None |
way/1122286622 | POLYGON ((-9.18684 38.77176, -9.18684 38.77142... | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=parking | None |
way/1122286633 | POLYGON ((-9.18873 38.77076, -9.18853 38.77056... | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=parking | None |
way/1122286653 | POLYGON ((-9.18642 38.77196, -9.18634 38.77189... | None | None | None | None | None | None | None | landuse=grass | None | None | None | None | None | None | None | None | None | None |
way/1122286721 | POLYGON ((-9.19022 38.76877, -9.19029 38.76871... | None | None | None | None | None | None | None | natural=grassland | None | None | None | None | None | None | None | None | None | None |
32206 rows × 19 columns
Join the objects with the regions they belong to¶
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, features_gdf)
joint_gdf
region_id | feature_id |
---|---|
89393367683ffff | node/1606370716 |
8939337534fffff | node/1606908608 |
8939337537bffff | node/1606908881 |
node/1606909014 | |
8939337536fffff | node/1606921033 |
... | ... |
89393375317ffff | way/1122286620 |
way/1122286622 | |
89393375303ffff | way/1122286633 |
89393375317ffff | way/1122286653 |
89393375303ffff | way/1122286721 |
35694 rows × 0 columns
Embed using features existing in data¶
ContextualCountEmbedder
extends capabilities of basic CountEmbedder
by incorporating the neighbourhood of embedded region. In this example we will use the H3Neighbourhood
.
h3n = H3Neighbourhood()
Squashed vector version (default)¶
Embedder will return vector of the same length as CountEmbedder
, but will sum averaged values from the neighbourhoods diminished by the neighbour distance squared.
cce = ContextualCountEmbedder(
neighbourhood=h3n, neighbourhood_distance=10, concatenate_vectors=False
)
embeddings = cce.transform(regions_gdf, features_gdf, joint_gdf)
embeddings
aerialway | airports | buildings | culture_art_entertainment | education | emergency | finances | greenery | healthcare | historic | leisure | other | shops | sport | sustenance | tourism | transportation | water | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | ||||||||||||||||||
89393362b5bffff | 0.000000 | 0.000000 | 5.663186 | 0.453341 | 0.725279 | 0.011207 | 0.472664 | 4.831314 | 0.708613 | 6.582172 | 11.780038 | 4.690245 | 6.739947 | 1.305490 | 9.254783 | 8.385976 | 14.800506 | 0.800748 |
893933629cfffff | 0.005011 | 0.000598 | 15.428869 | 1.535625 | 2.272676 | 1.024988 | 1.878399 | 15.047844 | 6.272882 | 3.865572 | 7.960870 | 2.647368 | 55.462848 | 0.685802 | 29.627573 | 16.474613 | 52.735513 | 0.726965 |
89393375b7bffff | 0.004271 | 0.013067 | 2.976362 | 1.051782 | 1.552804 | 0.010054 | 0.126482 | 6.667774 | 2.308443 | 0.040939 | 4.732799 | 1.282191 | 5.669420 | 1.888389 | 1.652776 | 0.347283 | 17.207074 | 0.072614 |
893933629b3ffff | 0.001940 | 0.008534 | 5.130619 | 0.327924 | 3.031883 | 0.055303 | 0.975397 | 12.496807 | 3.163736 | 2.378963 | 6.365194 | 0.516051 | 13.354228 | 1.849755 | 11.401478 | 2.795933 | 47.826162 | 1.140431 |
89393362db7ffff | 0.002250 | 0.018835 | 2.118245 | 0.070224 | 0.926556 | 0.012719 | 0.491137 | 9.895716 | 1.806445 | 1.370928 | 8.631296 | 0.248820 | 4.813795 | 6.130957 | 1.746922 | 1.662846 | 14.309635 | 0.280951 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
893933664cbffff | 0.023845 | 0.003195 | 0.489175 | 0.016278 | 0.193005 | 0.025256 | 0.093011 | 2.206666 | 0.133242 | 0.029546 | 2.779800 | 0.082114 | 0.918128 | 0.835733 | 0.661078 | 0.345379 | 2.321816 | 1.605822 |
893933676d3ffff | 0.000966 | 0.000000 | 4.665993 | 1.512258 | 2.354209 | 0.063739 | 1.106986 | 7.236907 | 3.195631 | 11.056915 | 6.924263 | 3.132414 | 14.272332 | 2.618369 | 16.608182 | 7.114924 | 47.309663 | 2.393390 |
89393360437ffff | 0.000000 | 0.000000 | 1.681174 | 0.060760 | 0.438621 | 0.002066 | 0.041494 | 2.038276 | 2.283468 | 0.156731 | 4.498050 | 0.147617 | 2.123936 | 0.411544 | 2.975917 | 0.663531 | 18.046442 | 0.340629 |
89393362c0fffff | 0.004036 | 0.000000 | 0.731542 | 0.079646 | 3.888571 | 0.005502 | 0.427937 | 2.463572 | 0.391884 | 0.323954 | 12.113497 | 0.366877 | 5.407537 | 6.611111 | 2.312660 | 2.064079 | 18.510064 | 0.139758 |
8939336293bffff | 0.001016 | 0.001395 | 11.049073 | 0.297790 | 7.572795 | 0.031073 | 2.927453 | 5.285186 | 3.581726 | 2.869260 | 9.416693 | 0.872023 | 37.483913 | 0.661779 | 32.925108 | 3.188785 | 65.123258 | 0.183733 |
830 rows × 18 columns
Concatenated vector version¶
Embedder will return vector of length n * distance
where n
is number of features from the CountEmbedder
and distance
is number of neighbourhoods analysed.
Each feature will be postfixed with _n
string, where n
is the current distance. Values are averaged from all neighbours.
wide_cce = ContextualCountEmbedder(
neighbourhood=h3n, neighbourhood_distance=10, concatenate_vectors=True
)
wide_embeddings = wide_cce.transform(regions_gdf, features_gdf, joint_gdf)
wide_embeddings
aerialway_0 | airports_0 | buildings_0 | culture_art_entertainment_0 | education_0 | emergency_0 | finances_0 | greenery_0 | healthcare_0 | historic_0 | ... | healthcare_10 | historic_10 | leisure_10 | other_10 | shops_10 | sport_10 | sustenance_10 | tourism_10 | transportation_10 | water_10 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
89393362b5bffff | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | 5.0 | ... | 0.942857 | 1.371429 | 5.057143 | 1.428571 | 6.228571 | 1.514286 | 7.942857 | 5.257143 | 15.257143 | 0.971429 |
893933629cfffff | 0.0 | 0.0 | 13.0 | 1.0 | 1.0 | 1.0 | 1.0 | 10.0 | 5.0 | 3.0 | ... | 0.650000 | 1.066667 | 3.633333 | 0.850000 | 5.400000 | 1.600000 | 7.700000 | 3.183333 | 13.816667 | 0.683333 |
89393375b7bffff | 0.0 | 0.0 | 2.0 | 1.0 | 1.0 | 0.0 | 0.0 | 3.0 | 2.0 | 0.0 | ... | 1.147059 | 0.735294 | 2.941176 | 0.970588 | 7.264706 | 0.617647 | 5.176471 | 2.117647 | 16.294118 | 0.205882 |
893933629b3ffff | 0.0 | 0.0 | 3.0 | 0.0 | 2.0 | 0.0 | 0.0 | 10.0 | 2.0 | 2.0 | ... | 1.033333 | 1.333333 | 3.433333 | 0.983333 | 6.183333 | 0.983333 | 5.933333 | 2.850000 | 12.166667 | 0.333333 |
89393362db7ffff | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 7.0 | 1.0 | 1.0 | ... | 0.833333 | 0.520833 | 4.020833 | 0.458333 | 6.937500 | 0.729167 | 3.312500 | 1.666667 | 13.083333 | 0.187500 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
893933664cbffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | ... | 1.000000 | 0.000000 | 5.857143 | 0.285714 | 3.428571 | 1.071429 | 1.714286 | 0.428571 | 15.000000 | 0.357143 |
893933676d3ffff | 0.0 | 0.0 | 2.0 | 1.0 | 1.0 | 0.0 | 0.0 | 4.0 | 2.0 | 9.0 | ... | 0.805556 | 0.527778 | 3.527778 | 0.416667 | 2.750000 | 1.361111 | 3.027778 | 2.861111 | 11.944444 | 0.888889 |
89393360437ffff | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | ... | 0.166667 | 1.083333 | 2.916667 | 0.333333 | 0.500000 | 0.666667 | 1.083333 | 1.500000 | 6.250000 | 0.750000 |
89393362c0fffff | 0.0 | 0.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | ... | 0.971429 | 0.514286 | 4.428571 | 0.600000 | 5.000000 | 1.914286 | 3.342857 | 1.314286 | 14.000000 | 0.400000 |
8939336293bffff | 0.0 | 0.0 | 8.0 | 0.0 | 6.0 | 0.0 | 1.0 | 2.0 | 2.0 | 2.0 | ... | 0.653846 | 0.403846 | 3.346154 | 0.634615 | 3.500000 | 0.788462 | 2.846154 | 1.519231 | 10.096154 | 0.480769 |
830 rows × 198 columns
Plotting example features¶
plot_numeric_data(regions_gdf, "leisure", embeddings)
plot_numeric_data(regions_gdf, "transportation", embeddings)
Other types of aggregations¶
By default, the ContextualCountEmbedder
averages the counts from neighbours. This aggregation_function
can be changed to one of: median
, sum
, min
, max
.
It's best to combine it with the wide format (concatenate_vectors=True
).
sum_cce = ContextualCountEmbedder(
neighbourhood=h3n,
neighbourhood_distance=10,
concatenate_vectors=True,
aggregation_function="sum",
)
sum_embeddings = sum_cce.transform(regions_gdf, features_gdf, joint_gdf)
sum_embeddings
aerialway_0 | airports_0 | buildings_0 | culture_art_entertainment_0 | education_0 | emergency_0 | finances_0 | greenery_0 | healthcare_0 | historic_0 | ... | healthcare_10 | historic_10 | leisure_10 | other_10 | shops_10 | sport_10 | sustenance_10 | tourism_10 | transportation_10 | water_10 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
89393362b5bffff | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | 5.0 | ... | 33.0 | 48.0 | 177.0 | 50.0 | 218.0 | 53.0 | 278.0 | 184.0 | 534.0 | 34.0 |
893933629cfffff | 0.0 | 0.0 | 13.0 | 1.0 | 1.0 | 1.0 | 1.0 | 10.0 | 5.0 | 3.0 | ... | 39.0 | 64.0 | 218.0 | 51.0 | 324.0 | 96.0 | 462.0 | 191.0 | 829.0 | 41.0 |
89393375b7bffff | 0.0 | 0.0 | 2.0 | 1.0 | 1.0 | 0.0 | 0.0 | 3.0 | 2.0 | 0.0 | ... | 39.0 | 25.0 | 100.0 | 33.0 | 247.0 | 21.0 | 176.0 | 72.0 | 554.0 | 7.0 |
893933629b3ffff | 0.0 | 0.0 | 3.0 | 0.0 | 2.0 | 0.0 | 0.0 | 10.0 | 2.0 | 2.0 | ... | 62.0 | 80.0 | 206.0 | 59.0 | 371.0 | 59.0 | 356.0 | 171.0 | 730.0 | 20.0 |
89393362db7ffff | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 7.0 | 1.0 | 1.0 | ... | 40.0 | 25.0 | 193.0 | 22.0 | 333.0 | 35.0 | 159.0 | 80.0 | 628.0 | 9.0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
893933664cbffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | ... | 14.0 | 0.0 | 82.0 | 4.0 | 48.0 | 15.0 | 24.0 | 6.0 | 210.0 | 5.0 |
893933676d3ffff | 0.0 | 0.0 | 2.0 | 1.0 | 1.0 | 0.0 | 0.0 | 4.0 | 2.0 | 9.0 | ... | 29.0 | 19.0 | 127.0 | 15.0 | 99.0 | 49.0 | 109.0 | 103.0 | 430.0 | 32.0 |
89393360437ffff | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | ... | 2.0 | 13.0 | 35.0 | 4.0 | 6.0 | 8.0 | 13.0 | 18.0 | 75.0 | 9.0 |
89393362c0fffff | 0.0 | 0.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | ... | 34.0 | 18.0 | 155.0 | 21.0 | 175.0 | 67.0 | 117.0 | 46.0 | 490.0 | 14.0 |
8939336293bffff | 0.0 | 0.0 | 8.0 | 0.0 | 6.0 | 0.0 | 1.0 | 2.0 | 2.0 | 2.0 | ... | 34.0 | 21.0 | 174.0 | 33.0 | 182.0 | 41.0 | 148.0 | 79.0 | 525.0 | 25.0 |
830 rows × 198 columns