Contextual count embedder
from srai.embedders import ContextualCountEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders.osm_loaders import OSMPbfLoader
from srai.neighbourhoods import H3Neighbourhood
from srai.plotting.folium_wrapper import plot_numeric_data, plot_regions
from srai.regionalizers import H3Regionalizer
Data preparation¶
In order to use ContextualCountEmbedder
we need to prepare some data.
Namely we need: regions_gdf
, features_gdf
, and joint_gdf
.
These are the outputs of Regionalizers, Loaders and Joiners respectively.
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Lisboa, PT")
plot_regions(area_gdf)
Regionalize the area using an H3Regionalizer¶
regionalizer = H3Regionalizer(resolution=9, buffer=True)
regions_gdf = regionalizer.transform(area_gdf)
regions_gdf
geometry | |
---|---|
region_id | |
89393362a1bffff | POLYGON ((-9.188 38.70896, -9.1897 38.70762, -... |
893933628a3ffff | POLYGON ((-9.17016 38.73948, -9.17187 38.73814... |
893933628dbffff | POLYGON ((-9.19137 38.7322, -9.19308 38.73086,... |
89393375823ffff | POLYGON ((-9.12196 38.78208, -9.12367 38.78074... |
89393375a6bffff | POLYGON ((-9.13688 38.74919, -9.13859 38.74785... |
... | ... |
893933759b7ffff | POLYGON ((-9.09747 38.79896, -9.09918 38.79763... |
89393362893ffff | POLYGON ((-9.18409 38.74264, -9.18579 38.7413,... |
893933664cbffff | POLYGON ((-9.08867 38.79177, -9.09037 38.79043... |
89393375a4bffff | POLYGON ((-9.14787 38.74996, -9.14957 38.74862... |
89393375877ffff | POLYGON ((-9.12708 38.77806, -9.12878 38.77672... |
830 rows × 1 columns
Download some objects from OpenStreetMap¶
You can use both OsmTagsFilter
and GroupedOsmTagsFilter
filters. In this example, a predefined GroupedOsmTagsFilter
filter BASE_OSM_GROUPS_FILTER
is used.
from srai.loaders.osm_loaders.filters import BASE_OSM_GROUPS_FILTER
loader = OSMPbfLoader()
features_gdf = loader.load(area_gdf, tags=BASE_OSM_GROUPS_FILTER)
features_gdf
Finished operation in 0:00:16
geometry | aerialway | airports | buildings | culture_art_entertainment | education | emergency | finances | greenery | healthcare | historic | leisure | other | shops | sport | sustenance | tourism | transportation | water | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
feature_id | |||||||||||||||||||
node/1662570350 | POINT (-9.13652 38.75509) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=parking | None |
node/1662570365 | POINT (-9.13501 38.75859) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=parking | None |
node/1662570400 | POINT (-9.13817 38.75503) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=parking | None |
node/1662570428 | POINT (-9.13233 38.75529) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=parking | None |
node/1662570468 | POINT (-9.13822 38.75534) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=restaurant | None | None | None |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
relation/12990755 | POLYGON ((-9.1692 38.70779, -9.16909 38.70789,... | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=parking | None |
relation/13884079 | POLYGON ((-9.13321 38.74216, -9.13322 38.74213... | None | None | None | None | None | None | None | landuse=grass | None | None | None | None | None | None | None | None | None | None |
relation/8306287 | POLYGON ((-9.14249 38.71874, -9.14251 38.71877... | None | None | None | None | None | None | None | None | None | historic=castle | None | None | None | None | None | None | None | None |
relation/11318814 | POLYGON ((-9.21057 38.697, -9.2103 38.69651, -... | None | None | None | None | None | None | None | None | None | None | leisure=garden | None | None | None | None | None | None | None |
relation/15475183 | POLYGON ((-9.20112 38.69844, -9.20106 38.69825... | None | None | None | None | None | None | None | None | None | historic=castle | None | None | None | None | None | tourism=attraction | None | None |
33475 rows × 19 columns
Join the objects with the regions they belong to¶
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, features_gdf)
joint_gdf
region_id | feature_id |
---|---|
89393375a0fffff | node/1662570350 |
89393375a07ffff | node/1662570365 |
89393375a0fffff | node/1662570400 |
89393375a3bffff | node/1662570428 |
89393375a0fffff | node/1662570468 |
... | ... |
89393362a27ffff | relation/12990755 |
893933674dbffff | relation/13884079 |
893933676dbffff | relation/8306287 |
89393360573ffff | relation/11318814 |
8939336052bffff | relation/15475183 |
37084 rows × 0 columns
Embed using features existing in data¶
ContextualCountEmbedder
extends capabilities of basic CountEmbedder
by incorporating the neighbourhood of embedded region. In this example we will use the H3Neighbourhood
.
h3n = H3Neighbourhood()
Squashed vector version (default)¶
Embedder will return vector of the same length as CountEmbedder
, but will sum averaged values from the neighbourhoods diminished by the neighbour distance squared.
cce = ContextualCountEmbedder(
neighbourhood=h3n, neighbourhood_distance=10, concatenate_vectors=False
)
embeddings = cce.transform(regions_gdf, features_gdf, joint_gdf)
embeddings
aerialway | airports | buildings | culture_art_entertainment | education | emergency | finances | greenery | healthcare | historic | leisure | other | shops | sport | sustenance | tourism | transportation | water | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | ||||||||||||||||||
89393362a1bffff | 0.000000 | 0.000000 | 0.716885 | 0.176462 | 3.819916 | 0.032079 | 0.160717 | 10.420602 | 0.398342 | 0.616710 | 5.976354 | 0.418704 | 4.168007 | 1.566271 | 5.242159 | 1.027607 | 9.238555 | 0.394107 |
893933628a3ffff | 0.034722 | 0.000000 | 7.259548 | 0.136688 | 0.613641 | 0.060180 | 1.461331 | 12.443794 | 2.758658 | 0.471053 | 4.153300 | 0.444289 | 2.912707 | 0.701843 | 4.154008 | 1.963062 | 23.090063 | 0.387077 |
893933628dbffff | 0.002028 | 0.000000 | 0.332965 | 0.027268 | 0.135732 | 0.004393 | 0.048098 | 1.980143 | 0.096074 | 0.157226 | 5.026685 | 0.226721 | 0.642310 | 2.034423 | 0.364559 | 1.730458 | 6.289365 | 0.075055 |
89393375823ffff | 0.002057 | 0.240252 | 0.387444 | 0.057641 | 2.337761 | 0.027099 | 0.099275 | 38.663327 | 0.221804 | 0.194872 | 10.059764 | 0.096011 | 3.305428 | 3.358242 | 1.741261 | 0.279932 | 35.494369 | 0.042457 |
89393375a6bffff | 0.000950 | 0.011091 | 1.202681 | 1.151128 | 2.948880 | 0.018330 | 0.616595 | 6.311057 | 3.017699 | 1.230854 | 14.881460 | 1.382330 | 10.132226 | 1.920024 | 5.294255 | 2.198081 | 54.422227 | 0.094769 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
893933759b7ffff | 0.003893 | 0.005209 | 0.229245 | 0.007306 | 0.286064 | 0.014414 | 0.040057 | 2.198360 | 0.095918 | 0.023439 | 4.056901 | 0.058804 | 0.459404 | 0.921661 | 0.322675 | 0.170430 | 3.691681 | 1.456921 |
89393362893ffff | 0.013750 | 0.000000 | 0.633183 | 0.043639 | 0.560733 | 0.007607 | 0.153460 | 1.548746 | 0.333937 | 1.524210 | 4.109246 | 0.370130 | 2.848712 | 0.442937 | 1.266099 | 3.467760 | 4.452792 | 0.298909 |
893933664cbffff | 0.023845 | 0.003195 | 0.499302 | 0.011176 | 0.194958 | 0.027209 | 0.094964 | 1.308434 | 0.135196 | 0.029546 | 1.904503 | 0.082114 | 0.899838 | 0.844691 | 0.694791 | 0.370233 | 2.341689 | 1.668322 |
89393375a4bffff | 0.002999 | 0.010621 | 12.728852 | 2.195494 | 1.302893 | 0.074019 | 3.502840 | 20.346454 | 3.113338 | 1.440307 | 3.387658 | 1.407813 | 16.101068 | 0.889805 | 10.962789 | 3.856779 | 66.649478 | 0.207773 |
89393375877ffff | 0.001461 | 1.363276 | 5.562409 | 0.056164 | 1.224696 | 0.100400 | 0.132622 | 7.089663 | 0.134208 | 0.081780 | 10.303234 | 0.144550 | 0.886331 | 0.367130 | 0.602830 | 0.309512 | 44.635466 | 0.035165 |
830 rows × 18 columns
Concatenated vector version¶
Embedder will return vector of length n * distance
where n
is number of features from the CountEmbedder
and distance
is number of neighbourhoods analysed.
Each feature will be postfixed with _n
string, where n
is the current distance. Values are averaged from all neighbours.
wide_cce = ContextualCountEmbedder(
neighbourhood=h3n, neighbourhood_distance=10, concatenate_vectors=True
)
wide_embeddings = wide_cce.transform(regions_gdf, features_gdf, joint_gdf)
wide_embeddings
aerialway_0 | airports_0 | buildings_0 | culture_art_entertainment_0 | education_0 | emergency_0 | finances_0 | greenery_0 | healthcare_0 | historic_0 | ... | healthcare_10 | historic_10 | leisure_10 | other_10 | shops_10 | sport_10 | sustenance_10 | tourism_10 | transportation_10 | water_10 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
89393362a1bffff | 0.0 | 0.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | 7.0 | 0.0 | 0.0 | ... | 1.093750 | 2.406250 | 3.531250 | 1.000000 | 6.093750 | 0.593750 | 6.062500 | 3.562500 | 12.562500 | 0.437500 |
893933628a3ffff | 0.0 | 0.0 | 6.0 | 0.0 | 0.0 | 0.0 | 1.0 | 7.0 | 2.0 | 0.0 | ... | 1.366667 | 1.966667 | 4.983333 | 1.216667 | 12.883333 | 1.183333 | 10.866667 | 3.800000 | 17.066667 | 0.383333 |
893933628dbffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | ... | 1.644444 | 2.000000 | 6.266667 | 1.000000 | 7.044444 | 1.600000 | 7.822222 | 3.755556 | 13.844444 | 1.022222 |
89393375823ffff | 0.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | 34.0 | 0.0 | 0.0 | ... | 1.090909 | 0.242424 | 3.969697 | 0.393939 | 7.606061 | 1.272727 | 2.939394 | 1.333333 | 11.606061 | 0.484848 |
89393375a6bffff | 0.0 | 0.0 | 0.0 | 1.0 | 2.0 | 0.0 | 0.0 | 3.0 | 2.0 | 1.0 | ... | 1.350000 | 1.583333 | 4.266667 | 0.983333 | 7.350000 | 1.200000 | 6.083333 | 2.950000 | 15.666667 | 0.433333 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
893933759b7ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | ... | 0.571429 | 0.071429 | 6.285714 | 0.357143 | 10.571429 | 1.214286 | 4.071429 | 1.428571 | 20.285714 | 0.857143 |
89393362893ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | ... | 1.697674 | 1.418605 | 4.023256 | 0.860465 | 7.534884 | 1.116279 | 5.558140 | 4.046512 | 17.976744 | 0.372093 |
893933664cbffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 1.000000 | 0.000000 | 5.928571 | 0.285714 | 3.500000 | 1.214286 | 1.714286 | 0.428571 | 15.071429 | 0.357143 |
89393375a4bffff | 0.0 | 0.0 | 11.0 | 2.0 | 0.0 | 0.0 | 3.0 | 17.0 | 2.0 | 1.0 | ... | 0.683333 | 1.116667 | 4.383333 | 0.700000 | 5.300000 | 1.500000 | 4.200000 | 3.116667 | 13.316667 | 0.450000 |
89393375877ffff | 0.0 | 1.0 | 5.0 | 0.0 | 1.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | ... | 0.976744 | 0.418605 | 3.674419 | 0.348837 | 4.813953 | 1.139535 | 2.813953 | 1.813953 | 13.813953 | 0.767442 |
830 rows × 198 columns
Plotting example features¶
plot_numeric_data(regions_gdf, "leisure", embeddings)
plot_numeric_data(regions_gdf, "transportation", embeddings)
Other types of aggregations¶
By default, the ContextualCountEmbedder
averages the counts from neighbours. This aggregation_function
can be changed to one of: median
, sum
, min
, max
.
It's best to combine it with the wide format (concatenate_vectors=True
).
sum_cce = ContextualCountEmbedder(
neighbourhood=h3n,
neighbourhood_distance=10,
concatenate_vectors=True,
aggregation_function="sum",
)
sum_embeddings = sum_cce.transform(regions_gdf, features_gdf, joint_gdf)
sum_embeddings
aerialway_0 | airports_0 | buildings_0 | culture_art_entertainment_0 | education_0 | emergency_0 | finances_0 | greenery_0 | healthcare_0 | historic_0 | ... | healthcare_10 | historic_10 | leisure_10 | other_10 | shops_10 | sport_10 | sustenance_10 | tourism_10 | transportation_10 | water_10 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
89393362a1bffff | 0.0 | 0.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | 7.0 | 0.0 | 0.0 | ... | 35.0 | 77.0 | 113.0 | 32.0 | 195.0 | 19.0 | 194.0 | 114.0 | 402.0 | 14.0 |
893933628a3ffff | 0.0 | 0.0 | 6.0 | 0.0 | 0.0 | 0.0 | 1.0 | 7.0 | 2.0 | 0.0 | ... | 82.0 | 118.0 | 299.0 | 73.0 | 773.0 | 71.0 | 652.0 | 228.0 | 1024.0 | 23.0 |
893933628dbffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | ... | 74.0 | 90.0 | 282.0 | 45.0 | 317.0 | 72.0 | 352.0 | 169.0 | 623.0 | 46.0 |
89393375823ffff | 0.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | 34.0 | 0.0 | 0.0 | ... | 36.0 | 8.0 | 131.0 | 13.0 | 251.0 | 42.0 | 97.0 | 44.0 | 383.0 | 16.0 |
89393375a6bffff | 0.0 | 0.0 | 0.0 | 1.0 | 2.0 | 0.0 | 0.0 | 3.0 | 2.0 | 1.0 | ... | 81.0 | 95.0 | 256.0 | 59.0 | 441.0 | 72.0 | 365.0 | 177.0 | 940.0 | 26.0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
893933759b7ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | ... | 8.0 | 1.0 | 88.0 | 5.0 | 148.0 | 17.0 | 57.0 | 20.0 | 284.0 | 12.0 |
89393362893ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | ... | 73.0 | 61.0 | 173.0 | 37.0 | 324.0 | 48.0 | 239.0 | 174.0 | 773.0 | 16.0 |
893933664cbffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 14.0 | 0.0 | 83.0 | 4.0 | 49.0 | 17.0 | 24.0 | 6.0 | 211.0 | 5.0 |
89393375a4bffff | 0.0 | 0.0 | 11.0 | 2.0 | 0.0 | 0.0 | 3.0 | 17.0 | 2.0 | 1.0 | ... | 41.0 | 67.0 | 263.0 | 42.0 | 318.0 | 90.0 | 252.0 | 187.0 | 799.0 | 27.0 |
89393375877ffff | 0.0 | 1.0 | 5.0 | 0.0 | 1.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | ... | 42.0 | 18.0 | 158.0 | 15.0 | 207.0 | 49.0 | 121.0 | 78.0 | 594.0 | 33.0 |
830 rows × 198 columns
plot_numeric_data(regions_gdf, "tourism_8", sum_embeddings)