Contextual count embedder
from srai.embedders import ContextualCountEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders.osm_loaders import OSMPbfLoader
from srai.neighbourhoods import H3Neighbourhood
from srai.plotting.folium_wrapper import plot_numeric_data, plot_regions
from srai.regionalizers import H3Regionalizer
Data preparation¶
In order to use ContextualCountEmbedder
we need to prepare some data.
Namely we need: regions_gdf
, features_gdf
, and joint_gdf
.
These are the outputs of Regionalizers, Loaders and Joiners respectively.
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Lisboa, PT")
plot_regions(area_gdf)
Regionalize the area using an H3Regionalizer¶
regionalizer = H3Regionalizer(resolution=9, buffer=True)
regions_gdf = regionalizer.transform(area_gdf)
regions_gdf
geometry | |
---|---|
region_id | |
893933629dbffff | POLYGON ((-9.1632 38.73789, -9.16491 38.73655,... |
893933753a3ffff | POLYGON ((-9.18312 38.77872, -9.18482 38.77738... |
89393367447ffff | POLYGON ((-9.12221 38.7372, -9.12391 38.73586,... |
89393362b4fffff | POLYGON ((-9.15908 38.70584, -9.16078 38.7045,... |
89393375bdbffff | POLYGON ((-9.12741 38.76604, -9.12912 38.7647,... |
... | ... |
89393362a27ffff | POLYGON ((-9.16897 38.70982, -9.17067 38.70848... |
89393362c63ffff | POLYGON ((-9.19214 38.74102, -9.19384 38.73968... |
893933670dbffff | POLYGON ((-9.1145 38.7268, -9.1162 38.72546, -... |
89393362977ffff | POLYGON ((-9.14233 38.73314, -9.14403 38.7318,... |
893933628a7ffff | POLYGON ((-9.16614 38.74029, -9.16784 38.73895... |
830 rows × 1 columns
Download some objects from OpenStreetMap¶
You can use both OsmTagsFilter
and GroupedOsmTagsFilter
filters. In this example, a predefined GroupedOsmTagsFilter
filter BASE_OSM_GROUPS_FILTER
is used.
from srai.loaders.osm_loaders.filters import BASE_OSM_GROUPS_FILTER
loader = OSMPbfLoader()
features_gdf = loader.load(area_gdf, tags=BASE_OSM_GROUPS_FILTER)
features_gdf
Finished operation in 0:00:16
geometry | aerialway | airports | buildings | culture_art_entertainment | education | emergency | finances | greenery | healthcare | historic | leisure | other | shops | sport | sustenance | tourism | transportation | water | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
feature_id | |||||||||||||||||||
node/1026402361 | POINT (-9.20574 38.74378) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
node/1026408767 | POINT (-9.19757 38.74976) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
node/1087174563 | POINT (-9.20027 38.69665) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | tourism=artwork | None | None |
node/1090129267 | POINT (-9.13516 38.73433) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
node/1104467812 | POINT (-9.20891 38.75276) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
way/1361840806 | POLYGON ((-9.12207 38.71376, -9.12202 38.71374... | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=parking | None |
way/1361840807 | POLYGON ((-9.10015 38.75626, -9.10018 38.75599... | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=parking | None |
way/1361840808 | POLYGON ((-9.10019 38.75587, -9.10025 38.75523... | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=parking | None |
way/1361844388 | POLYGON ((-9.17727 38.77199, -9.1772 38.77202,... | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=charging_station | None |
way/1361910272 | POLYGON ((-9.15012 38.71775, -9.15011 38.71776... | None | None | None | None | None | None | None | None | None | None | None | None | shop=ticket | None | None | None | None | None |
31934 rows × 19 columns
Join the objects with the regions they belong to¶
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, features_gdf)
joint_gdf
region_id | feature_id |
---|---|
89393362c53ffff | node/1026402361 |
89393362c13ffff | node/1026408767 |
89393360193ffff | node/1087174563 |
89393362923ffff | node/1090129267 |
8939337526fffff | node/1104467812 |
... | ... |
89393367677ffff | way/1361840806 |
8939336759bffff | way/1361840807 |
way/1361840808 | |
89393375337ffff | way/1361844388 |
89393362babffff | way/1361910272 |
35410 rows × 0 columns
Embed using features existing in data¶
ContextualCountEmbedder
extends capabilities of basic CountEmbedder
by incorporating the neighbourhood of embedded region. In this example we will use the H3Neighbourhood
.
h3n = H3Neighbourhood()
Squashed vector version (default)¶
Embedder will return vector of the same length as CountEmbedder
, but will sum averaged values from the neighbourhoods diminished by the neighbour distance squared.
cce = ContextualCountEmbedder(
neighbourhood=h3n, neighbourhood_distance=10, concatenate_vectors=False
)
embeddings = cce.transform(regions_gdf, features_gdf, joint_gdf)
embeddings
aerialway | airports | buildings | culture_art_entertainment | education | emergency | finances | greenery | healthcare | historic | leisure | other | shops | sport | sustenance | tourism | transportation | water | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | ||||||||||||||||||
893933629dbffff | 0.013750 | 0.000138 | 6.713473 | 1.240218 | 2.191891 | 0.029518 | 1.874939 | 40.896288 | 1.036923 | 1.489784 | 10.306784 | 1.554523 | 4.201561 | 1.582757 | 6.677550 | 7.769784 | 33.018778 | 1.281439 |
893933753a3ffff | 0.000000 | 0.000826 | 0.293625 | 0.030676 | 0.245908 | 0.005601 | 0.075546 | 1.878151 | 0.140223 | 0.175842 | 1.103479 | 0.089404 | 0.702728 | 0.351986 | 0.442084 | 0.261290 | 2.643735 | 0.217655 |
89393367447ffff | 0.000000 | 0.000689 | 1.734393 | 1.167020 | 1.493297 | 0.005614 | 0.366134 | 5.831349 | 0.561800 | 0.283602 | 7.010606 | 0.691809 | 6.312255 | 1.676644 | 2.867984 | 0.869743 | 22.008776 | 0.238744 |
89393362b4fffff | 0.000000 | 0.000000 | 11.183750 | 0.660906 | 1.661866 | 0.014191 | 0.630589 | 1.544520 | 0.510877 | 1.515652 | 4.823492 | 0.914754 | 5.390523 | 0.247458 | 11.766384 | 3.526176 | 10.150023 | 2.725521 |
89393375bdbffff | 0.001628 | 0.170667 | 0.310464 | 0.072480 | 0.387735 | 0.048992 | 0.131007 | 13.867602 | 0.228329 | 0.089832 | 4.501894 | 0.187014 | 1.286211 | 1.422231 | 0.633377 | 0.389451 | 8.009958 | 0.176131 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
89393362a27ffff | 0.000000 | 0.000000 | 9.945413 | 0.185046 | 3.681188 | 0.009856 | 0.524008 | 25.532300 | 3.713598 | 5.390983 | 18.470300 | 0.794890 | 21.645819 | 0.443083 | 7.799327 | 6.186430 | 21.767432 | 4.428234 |
89393362c63ffff | 0.004403 | 0.000000 | 0.315241 | 0.031791 | 0.324059 | 0.003140 | 0.108105 | 3.249457 | 0.167538 | 0.284080 | 2.606313 | 0.189521 | 1.288283 | 0.428919 | 0.566665 | 1.286643 | 2.727175 | 0.123863 |
893933670dbffff | 0.000000 | 0.000000 | 0.798359 | 0.244065 | 1.460577 | 0.005518 | 1.216059 | 2.399769 | 2.423647 | 4.573910 | 5.055359 | 2.466280 | 5.932807 | 1.354793 | 7.241812 | 4.882934 | 32.879666 | 2.240181 |
89393362977ffff | 0.001085 | 0.000551 | 9.352661 | 2.316874 | 4.425208 | 0.100760 | 10.720274 | 4.710214 | 6.577959 | 2.913418 | 7.916567 | 3.050283 | 66.991276 | 0.610351 | 30.668436 | 11.808555 | 50.492072 | 1.191267 |
893933628a7ffff | 0.034722 | 0.000138 | 10.356942 | 0.159881 | 2.677771 | 0.052990 | 4.604784 | 19.586611 | 3.869217 | 0.392186 | 5.112295 | 1.404984 | 12.635087 | 0.685524 | 14.375391 | 3.802911 | 52.825518 | 0.472163 |
830 rows × 18 columns
Concatenated vector version¶
Embedder will return vector of length n * distance
where n
is number of features from the CountEmbedder
and distance
is number of neighbourhoods analysed.
Each feature will be postfixed with _n
string, where n
is the current distance. Values are averaged from all neighbours.
wide_cce = ContextualCountEmbedder(
neighbourhood=h3n, neighbourhood_distance=10, concatenate_vectors=True
)
wide_embeddings = wide_cce.transform(regions_gdf, features_gdf, joint_gdf)
wide_embeddings
aerialway_0 | airports_0 | buildings_0 | culture_art_entertainment_0 | education_0 | emergency_0 | finances_0 | greenery_0 | healthcare_0 | historic_0 | ... | healthcare_10 | historic_10 | leisure_10 | other_10 | shops_10 | sport_10 | sustenance_10 | tourism_10 | transportation_10 | water_10 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
893933629dbffff | 0.0 | 0.0 | 5.0 | 1.0 | 1.0 | 0.0 | 1.0 | 35.0 | 0.0 | 1.0 | ... | 1.433333 | 1.883333 | 4.033333 | 1.466667 | 13.100000 | 0.966667 | 12.600000 | 5.166667 | 12.600000 | 0.216667 |
893933753a3ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.900000 | 0.533333 | 5.233333 | 0.766667 | 5.433333 | 2.566667 | 2.733333 | 1.400000 | 12.666667 | 0.300000 |
89393367447ffff | 0.0 | 0.0 | 1.0 | 1.0 | 1.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | ... | 1.833333 | 2.722222 | 3.694444 | 1.361111 | 10.583333 | 1.333333 | 12.277778 | 4.166667 | 18.055556 | 0.388889 |
89393362b4fffff | 0.0 | 0.0 | 9.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 1.322581 | 0.806452 | 5.064516 | 1.129032 | 11.516129 | 1.161290 | 7.967742 | 4.000000 | 19.612903 | 0.354839 |
89393375bdbffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 8.0 | 0.0 | 0.0 | ... | 1.586957 | 0.326087 | 4.847826 | 0.434783 | 9.217391 | 2.021739 | 4.913043 | 1.347826 | 13.956522 | 0.630435 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
89393362a27ffff | 0.0 | 0.0 | 8.0 | 0.0 | 3.0 | 0.0 | 0.0 | 23.0 | 3.0 | 4.0 | ... | 1.162162 | 1.189189 | 4.540541 | 1.351351 | 6.621622 | 1.270270 | 9.216216 | 6.108108 | 15.000000 | 0.918919 |
89393362c63ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | ... | 0.850000 | 0.750000 | 6.325000 | 0.800000 | 5.650000 | 1.775000 | 3.925000 | 2.425000 | 13.375000 | 0.875000 |
893933670dbffff | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 1.0 | 2.0 | 4.0 | ... | 2.344828 | 2.000000 | 4.275862 | 1.241379 | 12.241379 | 2.103448 | 15.000000 | 6.000000 | 19.827586 | 0.310345 |
89393362977ffff | 0.0 | 0.0 | 6.0 | 2.0 | 3.0 | 0.0 | 9.0 | 1.0 | 5.0 | 2.0 | ... | 0.622222 | 0.688889 | 3.844444 | 0.355556 | 2.266667 | 1.222222 | 1.822222 | 1.711111 | 7.355556 | 0.444444 |
893933628a7ffff | 0.0 | 0.0 | 9.0 | 0.0 | 2.0 | 0.0 | 4.0 | 13.0 | 3.0 | 0.0 | ... | 1.233333 | 1.800000 | 4.666667 | 1.216667 | 9.133333 | 0.966667 | 8.900000 | 4.333333 | 12.983333 | 0.516667 |
830 rows × 198 columns
Plotting example features¶
plot_numeric_data(regions_gdf, "leisure", embeddings)
plot_numeric_data(regions_gdf, "transportation", embeddings)
Other types of aggregations¶
By default, the ContextualCountEmbedder
averages the counts from neighbours. This aggregation_function
can be changed to one of: median
, sum
, min
, max
.
It's best to combine it with the wide format (concatenate_vectors=True
).
sum_cce = ContextualCountEmbedder(
neighbourhood=h3n,
neighbourhood_distance=10,
concatenate_vectors=True,
aggregation_function="sum",
)
sum_embeddings = sum_cce.transform(regions_gdf, features_gdf, joint_gdf)
sum_embeddings
aerialway_0 | airports_0 | buildings_0 | culture_art_entertainment_0 | education_0 | emergency_0 | finances_0 | greenery_0 | healthcare_0 | historic_0 | ... | healthcare_10 | historic_10 | leisure_10 | other_10 | shops_10 | sport_10 | sustenance_10 | tourism_10 | transportation_10 | water_10 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
893933629dbffff | 0.0 | 0.0 | 5.0 | 1.0 | 1.0 | 0.0 | 1.0 | 35.0 | 0.0 | 1.0 | ... | 86.0 | 113.0 | 242.0 | 88.0 | 786.0 | 58.0 | 756.0 | 310.0 | 756.0 | 13.0 |
893933753a3ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 27.0 | 16.0 | 157.0 | 23.0 | 163.0 | 77.0 | 82.0 | 42.0 | 380.0 | 9.0 |
89393367447ffff | 0.0 | 0.0 | 1.0 | 1.0 | 1.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | ... | 66.0 | 98.0 | 133.0 | 49.0 | 381.0 | 48.0 | 442.0 | 150.0 | 650.0 | 14.0 |
89393362b4fffff | 0.0 | 0.0 | 9.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 41.0 | 25.0 | 157.0 | 35.0 | 357.0 | 36.0 | 247.0 | 124.0 | 608.0 | 11.0 |
89393375bdbffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 8.0 | 0.0 | 0.0 | ... | 73.0 | 15.0 | 223.0 | 20.0 | 424.0 | 93.0 | 226.0 | 62.0 | 642.0 | 29.0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
89393362a27ffff | 0.0 | 0.0 | 8.0 | 0.0 | 3.0 | 0.0 | 0.0 | 23.0 | 3.0 | 4.0 | ... | 43.0 | 44.0 | 168.0 | 50.0 | 245.0 | 47.0 | 341.0 | 226.0 | 555.0 | 34.0 |
89393362c63ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | ... | 34.0 | 30.0 | 253.0 | 32.0 | 226.0 | 71.0 | 157.0 | 97.0 | 535.0 | 35.0 |
893933670dbffff | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 1.0 | 2.0 | 4.0 | ... | 68.0 | 58.0 | 124.0 | 36.0 | 355.0 | 61.0 | 435.0 | 174.0 | 575.0 | 9.0 |
89393362977ffff | 0.0 | 0.0 | 6.0 | 2.0 | 3.0 | 0.0 | 9.0 | 1.0 | 5.0 | 2.0 | ... | 28.0 | 31.0 | 173.0 | 16.0 | 102.0 | 55.0 | 82.0 | 77.0 | 331.0 | 20.0 |
893933628a7ffff | 0.0 | 0.0 | 9.0 | 0.0 | 2.0 | 0.0 | 4.0 | 13.0 | 3.0 | 0.0 | ... | 74.0 | 108.0 | 280.0 | 73.0 | 548.0 | 58.0 | 534.0 | 260.0 | 779.0 | 31.0 |
830 rows × 198 columns