Contextual count embedder
from srai.embedders import ContextualCountEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders.osm_loaders import OSMPbfLoader
from srai.neighbourhoods import H3Neighbourhood
from srai.plotting.folium_wrapper import plot_numeric_data, plot_regions
from srai.regionalizers import H3Regionalizer
Data preparation¶
In order to use ContextualCountEmbedder
we need to prepare some data.
Namely we need: regions_gdf
, features_gdf
, and joint_gdf
.
These are the outputs of Regionalizers, Loaders and Joiners respectively.
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Lisboa, PT")
plot_regions(area_gdf)
Regionalize the area using an H3Regionalizer¶
regionalizer = H3Regionalizer(resolution=9, buffer=True)
regions_gdf = regionalizer.transform(area_gdf)
regions_gdf
geometry | |
---|---|
region_id | |
89393362b97ffff | POLYGON ((-9.15951 38.72668, -9.16121 38.72534... |
89393375bd7ffff | POLYGON ((-9.1223 38.77006, -9.124 38.76872, -... |
89393362acfffff | POLYGON ((-9.19202 38.70815, -9.19372 38.70681... |
893933675d3ffff | POLYGON ((-9.10578 38.75247, -9.10748 38.75113... |
89393375857ffff | POLYGON ((-9.13807 38.77884, -9.13977 38.7775,... |
... | ... |
89393362a53ffff | POLYGON ((-9.19311 38.70494, -9.19481 38.70359... |
89393362857ffff | POLYGON ((-9.18148 38.72822, -9.18318 38.72688... |
89393375803ffff | POLYGON ((-9.13295 38.78286, -9.13466 38.78152... |
89393362813ffff | POLYGON ((-9.1793 38.73464, -9.181 38.7333, -9... |
89393375a7bffff | POLYGON ((-9.13982 38.75158, -9.14152 38.75024... |
830 rows × 1 columns
Download some objects from OpenStreetMap¶
You can use both OsmTagsFilter
and GroupedOsmTagsFilter
filters. In this example, a predefined GroupedOsmTagsFilter
filter BASE_OSM_GROUPS_FILTER
is used.
from srai.loaders.osm_loaders.filters import BASE_OSM_GROUPS_FILTER
loader = OSMPbfLoader()
features_gdf = loader.load(area_gdf, tags=BASE_OSM_GROUPS_FILTER)
features_gdf
Finished operation in 0:00:16
geometry | aerialway | airports | buildings | culture_art_entertainment | education | emergency | finances | greenery | healthcare | historic | leisure | other | shops | sport | sustenance | tourism | transportation | water | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
feature_id | |||||||||||||||||||
node/21433772 | POINT (-9.19059 38.7288) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
node/21433776 | POINT (-9.19376 38.72666) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
node/25414208 | POINT (-9.16663 38.74018) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
node/25414256 | POINT (-9.10286 38.74711) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
node/25414265 | POINT (-9.10273 38.74707) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
relation/11318814 | POLYGON ((-9.21057 38.697, -9.2103 38.69651, -... | None | None | None | None | None | None | None | None | None | None | leisure=garden | None | None | None | None | None | None | None |
relation/15475183 | POLYGON ((-9.20112 38.69844, -9.20106 38.69825... | None | None | None | None | None | None | None | None | None | historic=castle | None | None | None | None | None | tourism=attraction | None | None |
relation/18216967 | POLYGON ((-9.13568 38.71231, -9.13578 38.71233... | None | None | None | None | None | None | None | None | amenity=clinic | historic=castle | None | None | None | None | None | None | None | None |
relation/18338969 | POLYGON ((-9.16811 38.73178, -9.16811 38.73182... | None | None | None | None | None | None | None | landuse=grass | None | None | None | None | None | None | None | None | None | None |
relation/18400952 | POLYGON ((-9.16191 38.70211, -9.16191 38.70242... | None | None | None | None | None | None | None | None | None | None | leisure=marina | None | None | None | None | None | None | natural=water |
31953 rows × 19 columns
Join the objects with the regions they belong to¶
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, features_gdf)
joint_gdf
region_id | feature_id |
---|---|
893933628dbffff | node/21433772 |
89393362e37ffff | node/21433776 |
893933628b7ffff | node/25414208 |
893933675c3ffff | node/25414256 |
node/25414265 | |
... | ... |
89393362a67ffff | relation/18400952 |
89393360c93ffff | relation/18400952 |
89393360c87ffff | relation/18400952 |
89393360c97ffff | relation/18400952 |
89393362b4bffff | relation/18400952 |
35430 rows × 0 columns
Embed using features existing in data¶
ContextualCountEmbedder
extends capabilities of basic CountEmbedder
by incorporating the neighbourhood of embedded region. In this example we will use the H3Neighbourhood
.
h3n = H3Neighbourhood()
Squashed vector version (default)¶
Embedder will return vector of the same length as CountEmbedder
, but will sum averaged values from the neighbourhoods diminished by the neighbour distance squared.
cce = ContextualCountEmbedder(
neighbourhood=h3n, neighbourhood_distance=10, concatenate_vectors=False
)
embeddings = cce.transform(regions_gdf, features_gdf, joint_gdf)
embeddings
aerialway | airports | buildings | culture_art_entertainment | education | emergency | finances | greenery | healthcare | historic | leisure | other | shops | sport | sustenance | tourism | transportation | water | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | ||||||||||||||||||
89393362b97ffff | 0.001641 | 0.000000 | 6.066429 | 0.254202 | 5.002115 | 0.077370 | 4.052422 | 2.722241 | 2.595638 | 2.494950 | 5.237621 | 0.920639 | 16.132251 | 1.902889 | 8.624746 | 5.115959 | 30.867446 | 0.243514 |
89393375bd7ffff | 0.002571 | 0.085539 | 0.427797 | 0.035389 | 1.326850 | 0.051411 | 0.173113 | 40.488552 | 0.315275 | 0.106968 | 6.364254 | 0.344346 | 1.716207 | 0.850299 | 0.720412 | 0.402315 | 10.502652 | 0.053958 |
89393362acfffff | 0.000000 | 0.000000 | 0.445893 | 0.172381 | 2.133525 | 0.022321 | 0.093988 | 24.149137 | 0.400552 | 0.486860 | 9.233918 | 1.598880 | 8.679023 | 0.800650 | 8.291982 | 1.957979 | 15.073100 | 0.359857 |
893933675d3ffff | 0.005254 | 0.004818 | 5.855983 | 0.052238 | 2.479438 | 0.010598 | 0.077559 | 5.305094 | 0.254478 | 0.051907 | 3.927930 | 1.209476 | 2.321852 | 1.753842 | 1.629721 | 1.449391 | 18.312358 | 0.122483 |
89393375857ffff | 0.000000 | 1.434967 | 0.164779 | 0.017893 | 0.191654 | 0.043936 | 0.063896 | 7.652478 | 0.103082 | 0.047342 | 0.840011 | 0.084684 | 0.527854 | 0.259583 | 0.368842 | 0.194046 | 2.373477 | 0.036881 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
89393362a53ffff | 0.000000 | 0.000000 | 0.578942 | 0.142874 | 2.148370 | 0.010417 | 0.141095 | 15.801858 | 1.494815 | 1.675593 | 10.629129 | 5.552768 | 38.600999 | 5.888087 | 14.419194 | 4.020281 | 38.676370 | 0.518865 |
89393362857ffff | 0.002445 | 0.000000 | 0.278415 | 0.045853 | 0.182825 | 0.004699 | 0.080235 | 1.034536 | 0.145839 | 0.208923 | 2.470832 | 0.172443 | 0.986676 | 0.351414 | 0.639774 | 3.831899 | 2.075804 | 0.087542 |
89393375803ffff | 0.000000 | 2.544239 | 0.188682 | 0.016963 | 0.151727 | 0.031715 | 0.050254 | 9.868554 | 0.076694 | 0.043882 | 1.131544 | 1.100015 | 0.450723 | 0.203909 | 0.277555 | 0.137604 | 7.063271 | 0.029055 |
89393362813ffff | 0.006852 | 0.000000 | 0.386279 | 0.042635 | 0.283326 | 0.005832 | 0.105647 | 2.431563 | 0.203350 | 0.326391 | 4.002112 | 0.190471 | 1.201029 | 0.489670 | 2.724601 | 10.332907 | 3.651948 | 1.115169 |
89393375a7bffff | 0.001329 | 0.016833 | 1.246564 | 0.221167 | 2.083606 | 0.019763 | 0.746359 | 5.400405 | 2.158624 | 0.232991 | 29.241764 | 0.412389 | 8.099991 | 27.175623 | 7.998454 | 4.385029 | 32.962403 | 0.124755 |
830 rows × 18 columns
Concatenated vector version¶
Embedder will return vector of length n * distance
where n
is number of features from the CountEmbedder
and distance
is number of neighbourhoods analysed.
Each feature will be postfixed with _n
string, where n
is the current distance. Values are averaged from all neighbours.
wide_cce = ContextualCountEmbedder(
neighbourhood=h3n, neighbourhood_distance=10, concatenate_vectors=True
)
wide_embeddings = wide_cce.transform(regions_gdf, features_gdf, joint_gdf)
wide_embeddings
aerialway_0 | airports_0 | buildings_0 | culture_art_entertainment_0 | education_0 | emergency_0 | finances_0 | greenery_0 | healthcare_0 | historic_0 | ... | healthcare_10 | historic_10 | leisure_10 | other_10 | shops_10 | sport_10 | sustenance_10 | tourism_10 | transportation_10 | water_10 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
89393362b97ffff | 0.0 | 0.0 | 2.0 | 0.0 | 4.0 | 0.0 | 3.0 | 0.0 | 1.0 | 1.0 | ... | 1.106383 | 0.595745 | 5.106383 | 0.893617 | 5.680851 | 2.617021 | 4.914894 | 2.319149 | 14.255319 | 0.425532 |
89393375bd7ffff | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 33.0 | 0.0 | 0.0 | ... | 0.810811 | 0.729730 | 4.621622 | 0.729730 | 6.405405 | 2.000000 | 3.621622 | 2.189189 | 16.351351 | 0.810811 |
89393362acfffff | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 20.0 | 0.0 | 0.0 | ... | 1.240000 | 1.880000 | 3.320000 | 0.600000 | 8.120000 | 0.520000 | 8.880000 | 3.080000 | 7.960000 | 0.440000 |
893933675d3ffff | 0.0 | 0.0 | 5.0 | 0.0 | 2.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | ... | 1.633333 | 0.900000 | 4.533333 | 0.866667 | 13.466667 | 1.033333 | 7.566667 | 1.866667 | 20.500000 | 0.133333 |
89393375857ffff | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 5.0 | 0.0 | 0.0 | ... | 0.675676 | 0.189189 | 4.081081 | 0.459459 | 3.432432 | 1.540541 | 2.405405 | 0.972973 | 14.486486 | 0.324324 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
89393362a53ffff | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 12.0 | 1.0 | 1.0 | ... | 0.571429 | 1.047619 | 3.380952 | 0.809524 | 6.476190 | 1.047619 | 7.190476 | 2.952381 | 8.333333 | 0.476190 |
89393362857ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 1.313725 | 1.392157 | 4.607843 | 0.980392 | 10.431373 | 2.078431 | 7.392157 | 3.235294 | 14.392157 | 0.588235 |
89393375803ffff | 0.0 | 2.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 7.0 | 0.0 | 0.0 | ... | 0.894737 | 0.605263 | 5.105263 | 0.473684 | 3.868421 | 2.394737 | 2.447368 | 1.447368 | 17.315789 | 0.394737 |
89393362813ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | ... | 1.461538 | 2.442308 | 5.326923 | 1.384615 | 10.000000 | 1.038462 | 10.730769 | 4.038462 | 17.730769 | 0.596154 |
89393375a7bffff | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 3.0 | 1.0 | 0.0 | ... | 1.083333 | 1.366667 | 3.733333 | 0.866667 | 6.316667 | 0.983333 | 5.483333 | 2.766667 | 15.316667 | 0.316667 |
830 rows × 198 columns
Plotting example features¶
plot_numeric_data(regions_gdf, "leisure", embeddings)
plot_numeric_data(regions_gdf, "transportation", embeddings)
Other types of aggregations¶
By default, the ContextualCountEmbedder
averages the counts from neighbours. This aggregation_function
can be changed to one of: median
, sum
, min
, max
.
It's best to combine it with the wide format (concatenate_vectors=True
).
sum_cce = ContextualCountEmbedder(
neighbourhood=h3n,
neighbourhood_distance=10,
concatenate_vectors=True,
aggregation_function="sum",
)
sum_embeddings = sum_cce.transform(regions_gdf, features_gdf, joint_gdf)
sum_embeddings
aerialway_0 | airports_0 | buildings_0 | culture_art_entertainment_0 | education_0 | emergency_0 | finances_0 | greenery_0 | healthcare_0 | historic_0 | ... | healthcare_10 | historic_10 | leisure_10 | other_10 | shops_10 | sport_10 | sustenance_10 | tourism_10 | transportation_10 | water_10 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
89393362b97ffff | 0.0 | 0.0 | 2.0 | 0.0 | 4.0 | 0.0 | 3.0 | 0.0 | 1.0 | 1.0 | ... | 52.0 | 28.0 | 240.0 | 42.0 | 267.0 | 123.0 | 231.0 | 109.0 | 670.0 | 20.0 |
89393375bd7ffff | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 33.0 | 0.0 | 0.0 | ... | 30.0 | 27.0 | 171.0 | 27.0 | 237.0 | 74.0 | 134.0 | 81.0 | 605.0 | 30.0 |
89393362acfffff | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 20.0 | 0.0 | 0.0 | ... | 31.0 | 47.0 | 83.0 | 15.0 | 203.0 | 13.0 | 222.0 | 77.0 | 199.0 | 11.0 |
893933675d3ffff | 0.0 | 0.0 | 5.0 | 0.0 | 2.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | ... | 49.0 | 27.0 | 136.0 | 26.0 | 404.0 | 31.0 | 227.0 | 56.0 | 615.0 | 4.0 |
89393375857ffff | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 5.0 | 0.0 | 0.0 | ... | 25.0 | 7.0 | 151.0 | 17.0 | 127.0 | 57.0 | 89.0 | 36.0 | 536.0 | 12.0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
89393362a53ffff | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 12.0 | 1.0 | 1.0 | ... | 12.0 | 22.0 | 71.0 | 17.0 | 136.0 | 22.0 | 151.0 | 62.0 | 175.0 | 10.0 |
89393362857ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 67.0 | 71.0 | 235.0 | 50.0 | 532.0 | 106.0 | 377.0 | 165.0 | 734.0 | 30.0 |
89393375803ffff | 0.0 | 2.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 7.0 | 0.0 | 0.0 | ... | 34.0 | 23.0 | 194.0 | 18.0 | 147.0 | 91.0 | 93.0 | 55.0 | 658.0 | 15.0 |
89393362813ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | ... | 76.0 | 127.0 | 277.0 | 72.0 | 520.0 | 54.0 | 558.0 | 210.0 | 922.0 | 31.0 |
89393375a7bffff | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 3.0 | 1.0 | 0.0 | ... | 65.0 | 82.0 | 224.0 | 52.0 | 379.0 | 59.0 | 329.0 | 166.0 | 919.0 | 19.0 |
830 rows × 198 columns