Contextual count embedder
from srai.embedders import ContextualCountEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders.osm_loaders import OSMPbfLoader
from srai.neighbourhoods import H3Neighbourhood
from srai.plotting.folium_wrapper import plot_numeric_data, plot_regions
from srai.regionalizers import H3Regionalizer
Data preparation¶
In order to use ContextualCountEmbedder
we need to prepare some data.
Namely we need: regions_gdf
, features_gdf
, and joint_gdf
.
These are the outputs of Regionalizers, Loaders and Joiners respectively.
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Lisboa, PT")
plot_regions(area_gdf)
Regionalize the area using an H3Regionalizer¶
regionalizer = H3Regionalizer(resolution=9, buffer=True)
regions_gdf = regionalizer.transform(area_gdf)
regions_gdf
geometry | |
---|---|
region_id | |
893933629afffff | POLYGON ((-9.13906 38.74277, -9.14076 38.74143... |
8939336050bffff | POLYGON ((-9.21322 38.70086, -9.21492 38.69952... |
89393362977ffff | POLYGON ((-9.14233 38.73314, -9.14403 38.7318,... |
89393375b93ffff | POLYGON ((-9.12012 38.77648, -9.12182 38.77514... |
89393375a03ffff | POLYGON ((-9.14058 38.7604, -9.14228 38.75906,... |
... | ... |
893933628bbffff | POLYGON ((-9.17419 38.73866, -9.17589 38.73732... |
8939336232bffff | POLYGON ((-9.2228 38.71687, -9.2245 38.71552, ... |
89393362dcbffff | POLYGON ((-9.18083 38.75227, -9.18253 38.75093... |
893933759bbffff | POLYGON ((-9.10259 38.79495, -9.10429 38.79361... |
893933674cfffff | POLYGON ((-9.12808 38.74199, -9.12978 38.74065... |
830 rows × 1 columns
Download some objects from OpenStreetMap¶
You can use both OsmTagsFilter
and GroupedOsmTagsFilter
filters. In this example, a predefined GroupedOsmTagsFilter
filter BASE_OSM_GROUPS_FILTER
is used.
from srai.loaders.osm_loaders.filters import BASE_OSM_GROUPS_FILTER
loader = OSMPbfLoader()
features_gdf = loader.load(area_gdf, tags=BASE_OSM_GROUPS_FILTER)
features_gdf
/opt/hostedtoolcache/Python/3.10.17/x64/lib/python3.10/site-packages/pyogrio/geopandas.py:662: UserWarning: 'crs' was not provided. The output dataset will not have projection information defined and may not be usable in other systems. write(
/opt/hostedtoolcache/Python/3.10.17/x64/lib/python3.10/site-packages/geopandas/array.py:1638: UserWarning: CRS not set for some of the concatenation inputs. Setting output's CRS as WGS 84 (the single non-null crs provided). return GeometryArray(data, crs=_get_common_crs(to_concat))
Finished operation in 0:00:16
geometry | aerialway | airports | buildings | culture_art_entertainment | education | emergency | finances | greenery | healthcare | historic | leisure | other | shops | sport | sustenance | tourism | transportation | water | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
feature_id | |||||||||||||||||||
node/21433772 | POINT (-9.19059 38.7288) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
node/21433776 | POINT (-9.19376 38.72666) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
node/25414208 | POINT (-9.16663 38.74018) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
node/25414256 | POINT (-9.10286 38.74711) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
node/25414265 | POINT (-9.10273 38.74707) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
way/1361840806 | POLYGON ((-9.12207 38.71376, -9.12202 38.71374... | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=parking | None |
way/1361840807 | POLYGON ((-9.10015 38.75626, -9.10018 38.75599... | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=parking | None |
way/1361840808 | POLYGON ((-9.10019 38.75587, -9.10025 38.75523... | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=parking | None |
way/1361844388 | POLYGON ((-9.17727 38.77199, -9.1772 38.77202,... | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=charging_station | None |
way/1361910272 | POLYGON ((-9.15012 38.71775, -9.15011 38.71776... | None | None | None | None | None | None | None | None | None | None | None | None | shop=ticket | None | None | None | None | None |
32399 rows × 19 columns
Join the objects with the regions they belong to¶
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, features_gdf)
joint_gdf
region_id | feature_id |
---|---|
893933628dbffff | node/21433772 |
89393362e37ffff | node/21433776 |
893933628b7ffff | node/25414208 |
893933675c3ffff | node/25414256 |
node/25414265 | |
... | ... |
89393367677ffff | way/1361840806 |
8939336759bffff | way/1361840807 |
way/1361840808 | |
89393375337ffff | way/1361844388 |
89393362babffff | way/1361910272 |
35912 rows × 0 columns
Embed using features existing in data¶
ContextualCountEmbedder
extends capabilities of basic CountEmbedder
by incorporating the neighbourhood of embedded region. In this example we will use the H3Neighbourhood
.
h3n = H3Neighbourhood()
Squashed vector version (default)¶
Embedder will return vector of the same length as CountEmbedder
, but will sum averaged values from the neighbourhoods diminished by the neighbour distance squared.
cce = ContextualCountEmbedder(
neighbourhood=h3n, neighbourhood_distance=10, concatenate_vectors=False
)
embeddings = cce.transform(regions_gdf, features_gdf, joint_gdf)
embeddings
aerialway | airports | buildings | culture_art_entertainment | education | emergency | finances | greenery | healthcare | historic | leisure | other | shops | sport | sustenance | tourism | transportation | water | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | ||||||||||||||||||
893933629afffff | 0.000968 | 0.004215 | 9.210701 | 1.332202 | 4.483444 | 0.027215 | 3.900896 | 16.760416 | 5.307989 | 2.651618 | 10.304817 | 2.524370 | 50.310050 | 0.621947 | 18.785604 | 4.655189 | 41.325174 | 0.214329 |
8939336050bffff | 0.000000 | 0.000000 | 3.352052 | 0.198770 | 4.722795 | 0.004007 | 0.118546 | 12.011503 | 0.302143 | 1.608151 | 31.019329 | 1.436394 | 1.670880 | 0.869973 | 1.832754 | 1.399747 | 10.318437 | 0.354708 |
89393362977ffff | 0.001085 | 0.000551 | 9.422613 | 1.321392 | 4.373407 | 0.144466 | 10.803394 | 4.729329 | 6.569501 | 2.919233 | 7.926855 | 3.048693 | 67.893493 | 0.612283 | 31.450490 | 11.834250 | 50.679015 | 1.196972 |
89393375b93ffff | 0.002691 | 0.103449 | 0.466508 | 0.069929 | 1.412970 | 0.031768 | 0.124175 | 14.779854 | 0.375991 | 0.412953 | 17.666340 | 0.160566 | 3.765288 | 3.581164 | 1.841701 | 1.389265 | 17.048949 | 0.049666 |
89393375a03ffff | 0.000828 | 0.120331 | 3.238466 | 0.129990 | 0.950880 | 0.021773 | 0.500047 | 2.512212 | 4.676677 | 0.170123 | 3.554763 | 1.398667 | 9.579542 | 1.117970 | 7.355305 | 6.650903 | 30.322923 | 0.109591 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
893933628bbffff | 0.028935 | 0.000000 | 1.048404 | 0.105819 | 1.459426 | 0.025381 | 0.237978 | 7.223528 | 0.540667 | 1.498192 | 6.262551 | 0.341938 | 2.053848 | 0.604239 | 1.386508 | 7.919693 | 6.898022 | 0.306944 |
8939336232bffff | 0.000000 | 0.000000 | 0.295683 | 0.094611 | 0.230648 | 0.001764 | 0.045277 | 1.759844 | 0.108550 | 0.085326 | 3.501056 | 0.087700 | 0.412910 | 0.645712 | 0.395166 | 0.728049 | 1.487938 | 0.121718 |
89393362dcbffff | 0.021343 | 0.000207 | 6.079677 | 0.059702 | 1.907662 | 0.012725 | 0.380122 | 7.609776 | 5.558000 | 0.302973 | 6.628561 | 1.224536 | 10.668523 | 2.155046 | 6.157355 | 2.981010 | 29.290849 | 0.178429 |
893933759bbffff | 0.004639 | 0.015479 | 0.229054 | 0.011722 | 0.289314 | 0.014038 | 0.043173 | 1.678683 | 0.089461 | 0.028874 | 1.781890 | 0.061830 | 0.542207 | 0.755674 | 0.339131 | 0.185497 | 3.106649 | 0.406582 |
893933674cfffff | 0.000000 | 0.003084 | 0.874130 | 0.182189 | 1.841272 | 0.008524 | 0.491945 | 12.207426 | 0.840007 | 0.263875 | 5.389719 | 1.790228 | 8.952044 | 1.056324 | 4.145434 | 2.424540 | 32.571635 | 0.188061 |
830 rows × 18 columns
Concatenated vector version¶
Embedder will return vector of length n * distance
where n
is number of features from the CountEmbedder
and distance
is number of neighbourhoods analysed.
Each feature will be postfixed with _n
string, where n
is the current distance. Values are averaged from all neighbours.
wide_cce = ContextualCountEmbedder(
neighbourhood=h3n, neighbourhood_distance=10, concatenate_vectors=True
)
wide_embeddings = wide_cce.transform(regions_gdf, features_gdf, joint_gdf)
wide_embeddings
aerialway_0 | airports_0 | buildings_0 | culture_art_entertainment_0 | education_0 | emergency_0 | finances_0 | greenery_0 | healthcare_0 | historic_0 | ... | healthcare_10 | historic_10 | leisure_10 | other_10 | shops_10 | sport_10 | sustenance_10 | tourism_10 | transportation_10 | water_10 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
893933629afffff | 0.0 | 0.0 | 7.0 | 1.0 | 3.0 | 0.0 | 3.0 | 14.0 | 4.0 | 2.0 | ... | 1.105263 | 1.912281 | 3.333333 | 1.228070 | 12.315789 | 0.947368 | 11.157895 | 5.438596 | 11.140351 | 0.456140 |
8939336050bffff | 0.0 | 0.0 | 2.0 | 0.0 | 4.0 | 0.0 | 0.0 | 9.0 | 0.0 | 1.0 | ... | 0.312500 | 0.250000 | 3.625000 | 0.437500 | 1.562500 | 2.000000 | 2.500000 | 2.000000 | 8.625000 | 0.562500 |
89393362977ffff | 0.0 | 0.0 | 6.0 | 1.0 | 3.0 | 0.0 | 9.0 | 1.0 | 5.0 | 2.0 | ... | 0.622222 | 0.688889 | 3.866667 | 0.355556 | 2.266667 | 1.222222 | 1.822222 | 1.755556 | 7.333333 | 0.444444 |
89393375b93ffff | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 9.0 | 0.0 | 0.0 | ... | 1.090909 | 0.303030 | 5.000000 | 0.363636 | 4.030303 | 2.636364 | 2.545455 | 0.636364 | 13.393939 | 0.424242 |
89393375a03ffff | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 4.0 | 0.0 | ... | 0.727273 | 0.581818 | 4.018182 | 0.727273 | 6.218182 | 1.254545 | 3.763636 | 3.581818 | 18.381818 | 0.745455 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
893933628bbffff | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 4.0 | 0.0 | 1.0 | ... | 1.603774 | 2.207547 | 5.603774 | 1.226415 | 12.433962 | 1.622642 | 10.094340 | 3.150943 | 18.509434 | 0.528302 |
8939336232bffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.150000 | 1.100000 | 3.450000 | 0.250000 | 0.300000 | 1.100000 | 0.600000 | 0.700000 | 3.900000 | 0.450000 |
89393362dcbffff | 0.0 | 0.0 | 5.0 | 0.0 | 1.0 | 0.0 | 0.0 | 5.0 | 5.0 | 0.0 | ... | 1.450000 | 0.425000 | 3.025000 | 0.525000 | 9.075000 | 1.075000 | 5.625000 | 3.275000 | 12.975000 | 0.325000 |
893933759bbffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.764706 | 0.117647 | 2.470588 | 0.529412 | 3.823529 | 0.705882 | 2.470588 | 1.470588 | 11.941176 | 0.882353 |
893933674cfffff | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 9.0 | 0.0 | 0.0 | ... | 1.627907 | 2.651163 | 5.093023 | 1.255814 | 10.860465 | 1.930233 | 13.511628 | 4.651163 | 20.139535 | 0.465116 |
830 rows × 198 columns
Plotting example features¶
plot_numeric_data(regions_gdf, "leisure", embeddings)
plot_numeric_data(regions_gdf, "transportation", embeddings)
Other types of aggregations¶
By default, the ContextualCountEmbedder
averages the counts from neighbours. This aggregation_function
can be changed to one of: median
, sum
, min
, max
.
It's best to combine it with the wide format (concatenate_vectors=True
).
sum_cce = ContextualCountEmbedder(
neighbourhood=h3n,
neighbourhood_distance=10,
concatenate_vectors=True,
aggregation_function="sum",
)
sum_embeddings = sum_cce.transform(regions_gdf, features_gdf, joint_gdf)
sum_embeddings
aerialway_0 | airports_0 | buildings_0 | culture_art_entertainment_0 | education_0 | emergency_0 | finances_0 | greenery_0 | healthcare_0 | historic_0 | ... | healthcare_10 | historic_10 | leisure_10 | other_10 | shops_10 | sport_10 | sustenance_10 | tourism_10 | transportation_10 | water_10 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
893933629afffff | 0.0 | 0.0 | 7.0 | 1.0 | 3.0 | 0.0 | 3.0 | 14.0 | 4.0 | 2.0 | ... | 63.0 | 109.0 | 190.0 | 70.0 | 702.0 | 54.0 | 636.0 | 310.0 | 635.0 | 26.0 |
8939336050bffff | 0.0 | 0.0 | 2.0 | 0.0 | 4.0 | 0.0 | 0.0 | 9.0 | 0.0 | 1.0 | ... | 5.0 | 4.0 | 58.0 | 7.0 | 25.0 | 32.0 | 40.0 | 32.0 | 138.0 | 9.0 |
89393362977ffff | 0.0 | 0.0 | 6.0 | 1.0 | 3.0 | 0.0 | 9.0 | 1.0 | 5.0 | 2.0 | ... | 28.0 | 31.0 | 174.0 | 16.0 | 102.0 | 55.0 | 82.0 | 79.0 | 330.0 | 20.0 |
89393375b93ffff | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 9.0 | 0.0 | 0.0 | ... | 36.0 | 10.0 | 165.0 | 12.0 | 133.0 | 87.0 | 84.0 | 21.0 | 442.0 | 14.0 |
89393375a03ffff | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 4.0 | 0.0 | ... | 40.0 | 32.0 | 221.0 | 40.0 | 342.0 | 69.0 | 207.0 | 197.0 | 1011.0 | 41.0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
893933628bbffff | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 4.0 | 0.0 | 1.0 | ... | 85.0 | 117.0 | 297.0 | 65.0 | 659.0 | 86.0 | 535.0 | 167.0 | 981.0 | 28.0 |
8939336232bffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 3.0 | 22.0 | 69.0 | 5.0 | 6.0 | 22.0 | 12.0 | 14.0 | 78.0 | 9.0 |
89393362dcbffff | 0.0 | 0.0 | 5.0 | 0.0 | 1.0 | 0.0 | 0.0 | 5.0 | 5.0 | 0.0 | ... | 58.0 | 17.0 | 121.0 | 21.0 | 363.0 | 43.0 | 225.0 | 131.0 | 519.0 | 13.0 |
893933759bbffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 13.0 | 2.0 | 42.0 | 9.0 | 65.0 | 12.0 | 42.0 | 25.0 | 203.0 | 15.0 |
893933674cfffff | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 9.0 | 0.0 | 0.0 | ... | 70.0 | 114.0 | 219.0 | 54.0 | 467.0 | 83.0 | 581.0 | 200.0 | 866.0 | 20.0 |
830 rows × 198 columns
plot_numeric_data(regions_gdf, "tourism_8", sum_embeddings)