Contextual count embedder
from srai.embedders import ContextualCountEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders.osm_loaders import OSMPbfLoader
from srai.neighbourhoods import H3Neighbourhood
from srai.plotting.folium_wrapper import plot_numeric_data, plot_regions
from srai.regionalizers import H3Regionalizer
Data preparation¶
In order to use ContextualCountEmbedder
we need to prepare some data.
Namely we need: regions_gdf
, features_gdf
, and joint_gdf
.
These are the outputs of Regionalizers, Loaders and Joiners respectively.
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Lisboa, PT")
plot_regions(area_gdf)
Regionalize the area using an H3Regionalizer¶
regionalizer = H3Regionalizer(resolution=9, buffer=True)
regions_gdf = regionalizer.transform(area_gdf)
regions_gdf
geometry | |
---|---|
region_id | |
89393362a0bffff | POLYGON ((-9.18506 38.70656, -9.18676 38.70522... |
89393362e27ffff | POLYGON ((-9.18952 38.72659, -9.19123 38.72525... |
89393362b6fffff | POLYGON ((-9.1481 38.70507, -9.1498 38.70373, ... |
89393375343ffff | POLYGON ((-9.19552 38.76426, -9.19722 38.76291... |
8939337581bffff | POLYGON ((-9.13698 38.78205, -9.13868 38.78071... |
... | ... |
89393375b53ffff | POLYGON ((-9.12154 38.76125, -9.12325 38.75991... |
89393375e6bffff | POLYGON ((-9.15744 38.76597, -9.15914 38.76463... |
89393362e23ffff | POLYGON ((-9.19355 38.72578, -9.19525 38.72444... |
8939336019bffff | POLYGON ((-9.20039 38.69449, -9.20209 38.69315... |
89393375c73ffff | POLYGON ((-9.15166 38.79403, -9.15337 38.79269... |
830 rows × 1 columns
Download some objects from OpenStreetMap¶
You can use both OsmTagsFilter
and GroupedOsmTagsFilter
filters. In this example, a predefined GroupedOsmTagsFilter
filter BASE_OSM_GROUPS_FILTER
is used.
from srai.loaders.osm_loaders.filters import BASE_OSM_GROUPS_FILTER
loader = OSMPbfLoader()
features_gdf = loader.load(area_gdf, tags=BASE_OSM_GROUPS_FILTER)
features_gdf
Finished operation in 0:00:16
geometry | aerialway | airports | buildings | culture_art_entertainment | education | emergency | finances | greenery | healthcare | historic | leisure | other | shops | sport | sustenance | tourism | transportation | water | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
feature_id | |||||||||||||||||||
node/1598143514 | POINT (-9.20599 38.75113) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=parking | None |
node/1598576790 | POINT (-9.12053 38.77965) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=taxi | None |
node/1598576841 | POINT (-9.12009 38.78093) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | highway=bus_stop | None |
node/1599340087 | POINT (-9.1312 38.70964) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=restaurant | None | None | None |
node/1599340093 | POINT (-9.13028 38.71645) | None | None | None | None | None | None | None | None | None | None | None | None | shop=car_repair | None | None | None | None | None |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
way/1386986885 | POLYGON ((-9.21711 38.69309, -9.217 38.69307, ... | None | None | None | None | None | None | None | landuse=grass | None | None | None | None | None | None | None | None | None | None |
way/1386986886 | POLYGON ((-9.21498 38.69209, -9.21495 38.69205... | None | None | None | None | None | None | None | landuse=grass | None | None | None | None | None | None | None | None | None | None |
way/1386986887 | POLYGON ((-9.21658 38.69309, -9.2169 38.69311,... | None | None | None | None | None | None | None | landuse=grass | None | None | None | None | None | None | None | None | None | None |
way/1387098203 | POLYGON ((-9.09173 38.77477, -9.0917 38.77478,... | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | tourism=attraction | None | None |
way/1387117391 | POLYGON ((-9.19772 38.70271, -9.19775 38.70275... | None | None | None | None | None | None | None | None | None | None | None | amenity=place_of_worship | None | None | None | None | None | None |
33482 rows × 19 columns
Join the objects with the regions they belong to¶
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, features_gdf)
joint_gdf
region_id | feature_id |
---|---|
89393362cd7ffff | node/1598143514 |
89393375823ffff | node/1598576790 |
node/1598576841 | |
8939336764fffff | node/1599340087 |
8939336760bffff | node/1599340093 |
... | ... |
8939336054fffff | way/1386986886 |
way/1386986887 | |
893933666d7ffff | way/1387098203 |
8939336668bffff | way/1387098203 |
89393360527ffff | way/1387117391 |
37091 rows × 0 columns
Embed using features existing in data¶
ContextualCountEmbedder
extends capabilities of basic CountEmbedder
by incorporating the neighbourhood of embedded region. In this example we will use the H3Neighbourhood
.
h3n = H3Neighbourhood()
Squashed vector version (default)¶
Embedder will return vector of the same length as CountEmbedder
, but will sum averaged values from the neighbourhoods diminished by the neighbour distance squared.
cce = ContextualCountEmbedder(
neighbourhood=h3n, neighbourhood_distance=10, concatenate_vectors=False
)
embeddings = cce.transform(regions_gdf, features_gdf, joint_gdf)
embeddings
aerialway | airports | buildings | culture_art_entertainment | education | emergency | finances | greenery | healthcare | historic | leisure | other | shops | sport | sustenance | tourism | transportation | water | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | ||||||||||||||||||
89393362a0bffff | 0.000000 | 0.000000 | 6.022455 | 2.117986 | 3.024200 | 0.020750 | 2.142967 | 8.150722 | 4.420302 | 1.774605 | 7.197043 | 2.459471 | 31.322485 | 0.693493 | 15.907499 | 4.244004 | 20.017984 | 0.563893 |
89393362e27ffff | 0.001380 | 0.000000 | 0.196098 | 0.054762 | 0.146516 | 0.007496 | 0.044921 | 0.895289 | 0.092007 | 0.148622 | 4.057407 | 0.156758 | 0.630380 | 0.777530 | 0.421215 | 3.243138 | 3.300262 | 0.136742 |
89393362b6fffff | 0.000000 | 0.000000 | 2.504797 | 0.733041 | 0.636045 | 0.040088 | 1.282660 | 2.883642 | 0.763544 | 1.745204 | 1.379711 | 0.839346 | 8.590226 | 0.251331 | 14.677360 | 4.345886 | 13.727322 | 2.936274 |
89393375343ffff | 0.002690 | 0.000000 | 0.493201 | 0.252668 | 0.594209 | 0.011865 | 0.194433 | 15.806517 | 0.394453 | 0.425309 | 2.973838 | 0.342506 | 2.273863 | 0.590402 | 1.136522 | 0.527953 | 26.350712 | 0.065443 |
8939337581bffff | 0.000000 | 1.473146 | 0.272900 | 0.015302 | 0.170824 | 0.028214 | 0.050242 | 5.750420 | 0.084451 | 0.049497 | 0.803847 | 0.128850 | 0.470687 | 0.233202 | 0.294279 | 0.221757 | 2.516925 | 0.043820 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
89393375b53ffff | 0.002243 | 0.031630 | 1.568799 | 0.061863 | 0.700789 | 0.013569 | 0.185825 | 11.734186 | 1.355814 | 0.044838 | 2.064274 | 0.287119 | 3.252057 | 0.796593 | 1.063449 | 0.355370 | 8.703967 | 0.086051 |
89393375e6bffff | 0.001567 | 0.032820 | 5.053148 | 0.076424 | 1.778348 | 0.013618 | 3.527627 | 6.838790 | 3.843607 | 1.382875 | 10.712983 | 0.238387 | 9.714370 | 4.996098 | 5.573520 | 2.259766 | 27.203224 | 0.176199 |
89393362e23ffff | 0.001082 | 0.000000 | 0.163621 | 0.083763 | 0.160587 | 0.006812 | 0.036133 | 0.949848 | 0.079346 | 0.138334 | 7.069998 | 0.179100 | 0.542527 | 1.865509 | 0.409899 | 2.431085 | 8.365804 | 1.132930 |
8939336019bffff | 0.000000 | 0.000000 | 0.461556 | 0.099048 | 0.494264 | 0.005253 | 0.299842 | 4.358777 | 0.210466 | 0.533621 | 2.960745 | 0.356368 | 1.695530 | 0.474562 | 2.267539 | 2.115613 | 5.011760 | 2.161245 |
89393375c73ffff | 0.000000 | 0.058717 | 2.170058 | 0.050440 | 3.219512 | 0.005823 | 0.032879 | 5.672647 | 0.111694 | 0.078702 | 7.555094 | 1.239135 | 16.776454 | 1.667468 | 6.393979 | 0.185299 | 21.335172 | 0.065097 |
830 rows × 18 columns
Concatenated vector version¶
Embedder will return vector of length n * distance
where n
is number of features from the CountEmbedder
and distance
is number of neighbourhoods analysed.
Each feature will be postfixed with _n
string, where n
is the current distance. Values are averaged from all neighbours.
wide_cce = ContextualCountEmbedder(
neighbourhood=h3n, neighbourhood_distance=10, concatenate_vectors=True
)
wide_embeddings = wide_cce.transform(regions_gdf, features_gdf, joint_gdf)
wide_embeddings
aerialway_0 | airports_0 | buildings_0 | culture_art_entertainment_0 | education_0 | emergency_0 | finances_0 | greenery_0 | healthcare_0 | historic_0 | ... | healthcare_10 | historic_10 | leisure_10 | other_10 | shops_10 | sport_10 | sustenance_10 | tourism_10 | transportation_10 | water_10 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
89393362a0bffff | 0.0 | 0.0 | 5.0 | 2.0 | 2.0 | 0.0 | 2.0 | 5.0 | 4.0 | 1.0 | ... | 1.090909 | 2.606061 | 3.545455 | 1.060606 | 8.818182 | 0.909091 | 13.090909 | 4.181818 | 13.393939 | 0.484848 |
89393362e27ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 1.404255 | 0.893617 | 5.446809 | 0.872340 | 11.085106 | 1.212766 | 6.744681 | 3.787234 | 13.489362 | 0.851064 |
89393362b6fffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 1.750000 | 1.285714 | 4.928571 | 1.071429 | 15.035714 | 1.107143 | 8.928571 | 2.964286 | 19.892857 | 0.464286 |
89393375343ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 12.0 | 0.0 | 0.0 | ... | 1.111111 | 0.481481 | 4.555556 | 0.259259 | 2.444444 | 2.333333 | 2.000000 | 2.296296 | 13.407407 | 0.407407 |
8939337581bffff | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | ... | 0.583333 | 0.305556 | 4.388889 | 0.388889 | 2.111111 | 1.944444 | 1.583333 | 1.666667 | 15.222222 | 0.500000 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
89393375b53ffff | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 6.0 | 1.0 | 0.0 | ... | 1.621622 | 0.432432 | 4.189189 | 0.540541 | 8.756757 | 1.729730 | 6.162162 | 2.567568 | 19.594595 | 0.459459 |
89393375e6bffff | 0.0 | 0.0 | 4.0 | 0.0 | 1.0 | 0.0 | 3.0 | 4.0 | 3.0 | 1.0 | ... | 0.977778 | 0.555556 | 4.733333 | 0.777778 | 6.555556 | 0.955556 | 3.577778 | 2.222222 | 16.044444 | 0.133333 |
89393362e23ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 1.111111 | 1.511111 | 5.133333 | 0.822222 | 10.244444 | 0.888889 | 5.533333 | 3.866667 | 12.688889 | 1.088889 |
8939336019bffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.368421 | 0.736842 | 3.473684 | 0.473684 | 1.157895 | 0.684211 | 0.631579 | 2.052632 | 5.894737 | 0.631579 |
89393375c73ffff | 0.0 | 0.0 | 2.0 | 0.0 | 3.0 | 0.0 | 0.0 | 4.0 | 0.0 | 0.0 | ... | 0.500000 | 0.227273 | 3.909091 | 0.136364 | 3.000000 | 0.909091 | 1.954545 | 0.454545 | 11.909091 | 0.363636 |
830 rows × 198 columns
Plotting example features¶
plot_numeric_data(regions_gdf, "leisure", embeddings)
plot_numeric_data(regions_gdf, "transportation", embeddings)
Other types of aggregations¶
By default, the ContextualCountEmbedder
averages the counts from neighbours. This aggregation_function
can be changed to one of: median
, sum
, min
, max
.
It's best to combine it with the wide format (concatenate_vectors=True
).
sum_cce = ContextualCountEmbedder(
neighbourhood=h3n,
neighbourhood_distance=10,
concatenate_vectors=True,
aggregation_function="sum",
)
sum_embeddings = sum_cce.transform(regions_gdf, features_gdf, joint_gdf)
sum_embeddings
aerialway_0 | airports_0 | buildings_0 | culture_art_entertainment_0 | education_0 | emergency_0 | finances_0 | greenery_0 | healthcare_0 | historic_0 | ... | healthcare_10 | historic_10 | leisure_10 | other_10 | shops_10 | sport_10 | sustenance_10 | tourism_10 | transportation_10 | water_10 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
89393362a0bffff | 0.0 | 0.0 | 5.0 | 2.0 | 2.0 | 0.0 | 2.0 | 5.0 | 4.0 | 1.0 | ... | 36.0 | 86.0 | 117.0 | 35.0 | 291.0 | 30.0 | 432.0 | 138.0 | 442.0 | 16.0 |
89393362e27ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 66.0 | 42.0 | 256.0 | 41.0 | 521.0 | 57.0 | 317.0 | 178.0 | 634.0 | 40.0 |
89393362b6fffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 49.0 | 36.0 | 138.0 | 30.0 | 421.0 | 31.0 | 250.0 | 83.0 | 557.0 | 13.0 |
89393375343ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 12.0 | 0.0 | 0.0 | ... | 30.0 | 13.0 | 123.0 | 7.0 | 66.0 | 63.0 | 54.0 | 62.0 | 362.0 | 11.0 |
8939337581bffff | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | ... | 21.0 | 11.0 | 158.0 | 14.0 | 76.0 | 70.0 | 57.0 | 60.0 | 548.0 | 18.0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
89393375b53ffff | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 6.0 | 1.0 | 0.0 | ... | 60.0 | 16.0 | 155.0 | 20.0 | 324.0 | 64.0 | 228.0 | 95.0 | 725.0 | 17.0 |
89393375e6bffff | 0.0 | 0.0 | 4.0 | 0.0 | 1.0 | 0.0 | 3.0 | 4.0 | 3.0 | 1.0 | ... | 44.0 | 25.0 | 213.0 | 35.0 | 295.0 | 43.0 | 161.0 | 100.0 | 722.0 | 6.0 |
89393362e23ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 50.0 | 68.0 | 231.0 | 37.0 | 461.0 | 40.0 | 249.0 | 174.0 | 571.0 | 49.0 |
8939336019bffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 7.0 | 14.0 | 66.0 | 9.0 | 22.0 | 13.0 | 12.0 | 39.0 | 112.0 | 12.0 |
89393375c73ffff | 0.0 | 0.0 | 2.0 | 0.0 | 3.0 | 0.0 | 0.0 | 4.0 | 0.0 | 0.0 | ... | 11.0 | 5.0 | 86.0 | 3.0 | 66.0 | 20.0 | 43.0 | 10.0 | 262.0 | 8.0 |
830 rows × 198 columns
plot_numeric_data(regions_gdf, "tourism_8", sum_embeddings)