Contextual count embedder
from srai.loaders.osm_loaders import OSMPbfLoader
from srai.regionalizers import H3Regionalizer
from srai.joiners import IntersectionJoiner
from srai.embedders import ContextualCountEmbedder
from srai.plotting.folium_wrapper import plot_regions, plot_numeric_data
from srai.neighbourhoods import H3Neighbourhood
Data preparation¶
In order to use ContextualCountEmbedder
we need to prepare some data.
Namely we need: regions_gdf
, features_gdf
, and joint_gdf
.
These are the outputs of Regionalizers, Loaders and Joiners respectively.
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Lisboa, PT")
plot_regions(area_gdf)
Regionalize the area using an H3Regionalizer¶
regionalizer = H3Regionalizer(resolution=9, buffer=True)
regions_gdf = regionalizer.transform(area_gdf)
regions_gdf
geometry | |
---|---|
region_id | |
89393362d53ffff | POLYGON ((-9.17789 38.74988, -9.17959 38.74854... |
89393362a23ffff | POLYGON ((-9.17299 38.70901, -9.17470 38.70766... |
89393362b77ffff | POLYGON ((-9.14994 38.71068, -9.15165 38.70933... |
89393375e5bffff | POLYGON ((-9.17136 38.76913, -9.17307 38.76779... |
89393375a7bffff | POLYGON ((-9.13982 38.75158, -9.14152 38.75024... |
... | ... |
89393362b63ffff | POLYGON ((-9.15103 38.70747, -9.15273 38.70612... |
89393362d3bffff | POLYGON ((-9.16179 38.75313, -9.16349 38.75179... |
893933676a3ffff | POLYGON ((-9.12145 38.72838, -9.12316 38.72704... |
89393362b33ffff | POLYGON ((-9.14777 38.71709, -9.14947 38.71575... |
89393367473ffff | POLYGON ((-9.11819 38.73801, -9.11989 38.73667... |
830 rows × 1 columns
Download some objects from OpenStreetMap¶
You can use both OsmTagsFilter
and GroupedOsmTagsFilter
filters. In this example, a predefined GroupedOsmTagsFilter
filter BASE_OSM_GROUPS_FILTER
is used.
from srai.loaders.osm_loaders.filters import BASE_OSM_GROUPS_FILTER
loader = OSMPbfLoader()
features_gdf = loader.load(area_gdf, tags=BASE_OSM_GROUPS_FILTER)
features_gdf
/opt/hostedtoolcache/Python/3.10.13/x64/lib/python3.10/site-packages/srai/loaders/osm_loaders/pbf_file_downloader.py:154: UserWarning: Error occured (Expecting value: line 1 column 1 (char 0)). Auto-switching to 'geofabrik' download source. warnings.warn( Finding matching extracts: 100%|██████████| 1/1 [00:00<00:00, 114.72it/s] Filtering extracts: 100%|██████████| 1/1 [00:00<00:00, 462.08it/s] /opt/hostedtoolcache/Python/3.10.13/x64/lib/python3.10/site-packages/srai/loaders/osm_loaders/openstreetmap_extracts.py:344: FutureWarning: `unary_union` returned None due to all-None GeoSeries. In future, `unary_union` will return 'GEOMETRYCOLLECTION EMPTY' instead. ].unary_union portugal.osm.pbf: 100%|██████████| 308M/308M [00:18<00:00, 17.2MiB/s] Clipping PBF files: 100%|██████████| 1/1 [00:12<00:00, 12.02s/it] [Lisbon, Portugal] Counting pbf features: 757947it [00:01, 394928.58it/s] [Lisbon, Portugal] Parsing pbf file #1: 100%|██████████| 757947/757947 [00:11<00:00, 65921.81it/s] Grouping features: 100%|██████████| 18/18 [00:00<00:00, 42.45it/s]
geometry | aerialway | airports | sustenance | education | transportation | finances | healthcare | culture_art_entertainment | other | buildings | emergency | historic | leisure | shops | sport | tourism | greenery | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
feature_id | ||||||||||||||||||
node/21433772 | POINT (-9.19059 38.72880) | NaN | NaN | NaN | NaN | public_transport=stop_position | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
node/21433776 | POINT (-9.19376 38.72666) | NaN | NaN | NaN | NaN | public_transport=stop_position | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
node/25414208 | POINT (-9.16663 38.74018) | NaN | NaN | NaN | NaN | public_transport=stop_position | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
node/25414256 | POINT (-9.10286 38.74711) | NaN | NaN | NaN | NaN | public_transport=stop_position | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
node/25414265 | POINT (-9.10273 38.74707) | NaN | NaN | NaN | NaN | public_transport=stop_position | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
way/1233527783 | MULTIPOLYGON (((-9.17772 38.70171, -9.17727 38... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | leisure=garden | NaN | NaN | NaN | NaN |
way/1233527784 | MULTIPOLYGON (((-9.17816 38.70149, -9.17813 38... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | leisure=garden | NaN | NaN | NaN | NaN |
way/1233529819 | MULTIPOLYGON (((-9.17752 38.70182, -9.17707 38... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | leisure=garden | NaN | NaN | NaN | NaN |
way/1233529822 | MULTIPOLYGON (((-9.17768 38.70092, -9.17763 38... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | leisure=garden | NaN | NaN | NaN | NaN |
way/1233529823 | MULTIPOLYGON (((-9.17685 38.70133, -9.17680 38... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | leisure=garden | NaN | NaN | NaN | NaN |
24525 rows × 18 columns
Join the objects with the regions they belong to¶
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, features_gdf)
joint_gdf
region_id | feature_id |
---|---|
89393362d53ffff | way/846013172 |
way/846013173 | |
89393362c27ffff | way/846013173 |
89393362d53ffff | way/767722585 |
way/468486162 | |
... | ... |
89393367473ffff | node/7874358949 |
node/7874358950 | |
node/7874358938 | |
node/7874358937 | |
node/2484075271 |
27426 rows × 0 columns
Embed using features existing in data¶
ContextualCountEmbedder
extends capabilities of basic CountEmbedder
by incorporating the neighbourhood of embedded region. In this example we will use the H3Neighbourhood
.
h3n = H3Neighbourhood()
Squashed vector version (default)¶
Embedder will return vector of the same length as CountEmbedder
, but will sum averaged values from the neighbourhoods diminished by the neighbour distance squared.
cce = ContextualCountEmbedder(
neighbourhood=h3n, neighbourhood_distance=10, concatenate_vectors=False
)
embeddings = cce.transform(regions_gdf, features_gdf, joint_gdf)
embeddings
aerialway | airports | sustenance | education | transportation | finances | healthcare | culture_art_entertainment | other | buildings | emergency | historic | leisure | shops | sport | tourism | greenery | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||
89393362d53ffff | 0.072917 | 0.000192 | 7.049973 | 0.798421 | 9.700568 | 0.376339 | 7.723772 | 0.051618 | 1.298173 | 1.803744 | 0.001134 | 0.295827 | 7.027701 | 6.225010 | 0.726625 | 1.390425 | 14.189274 |
89393362a23ffff | 0.000000 | 0.000000 | 14.885541 | 1.344104 | 48.244612 | 4.316979 | 2.496115 | 0.124103 | 1.583231 | 6.068717 | 0.000000 | 3.941206 | 7.706329 | 22.917968 | 0.350639 | 6.295888 | 13.523924 |
89393362b77ffff | 0.000000 | 0.000000 | 41.794498 | 0.817481 | 17.149590 | 3.704593 | 1.975281 | 0.607149 | 2.388330 | 4.971299 | 0.000000 | 9.116613 | 5.255371 | 10.385850 | 0.186026 | 17.611171 | 2.707732 |
89393375e5bffff | 0.002171 | 0.009580 | 5.129058 | 1.525253 | 18.706515 | 2.268491 | 0.483091 | 0.019396 | 0.216797 | 3.665843 | 0.042625 | 0.330743 | 8.577515 | 4.669144 | 2.615211 | 0.286662 | 2.110102 |
89393375a7bffff | 0.001329 | 0.016833 | 5.873040 | 1.608596 | 30.109064 | 0.748181 | 1.900893 | 0.150605 | 0.327157 | 0.600536 | 0.000926 | 0.234317 | 29.058290 | 3.469994 | 27.098608 | 3.983954 | 5.125023 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
89393362b63ffff | 0.000000 | 0.000000 | 15.335653 | 0.440664 | 25.063135 | 1.439152 | 0.747858 | 1.509269 | 0.880665 | 2.780649 | 0.000000 | 2.121481 | 1.780097 | 8.344467 | 0.170434 | 5.663130 | 3.509038 |
89393362d3bffff | 0.011944 | 0.005152 | 3.672109 | 0.817639 | 34.353666 | 1.412718 | 2.627522 | 0.090118 | 0.178404 | 0.631495 | 0.009826 | 3.367098 | 36.785965 | 1.227560 | 34.190321 | 8.005572 | 18.420294 |
893933676a3ffff | 0.000000 | 0.000000 | 4.830824 | 0.446900 | 21.033386 | 0.318262 | 1.504342 | 0.109247 | 1.468130 | 0.536393 | 0.000000 | 1.476001 | 1.199374 | 7.817281 | 0.406053 | 1.002460 | 1.510879 |
89393362b33ffff | 0.000000 | 0.000000 | 48.121657 | 2.737643 | 17.053382 | 5.264418 | 3.062970 | 1.886075 | 7.390598 | 7.330694 | 0.000000 | 21.306956 | 7.298391 | 18.230879 | 1.272581 | 17.073928 | 5.119436 |
89393367473ffff | 0.000000 | 0.000708 | 1.241623 | 0.365919 | 11.587839 | 0.234045 | 0.225897 | 0.135530 | 0.232671 | 0.357372 | 0.000303 | 0.255232 | 2.827974 | 1.440492 | 0.682045 | 0.479371 | 7.522781 |
830 rows × 17 columns
Concatenated vector version¶
Embedder will return vector of length n * distance
where n
is number of features from the CountEmbedder
and distance
is number of neighbourhoods analysed.
Each feature will be postfixed with _n
string, where n
is the current distance. Values are averaged from all neighbours.
wide_cce = ContextualCountEmbedder(
neighbourhood=h3n, neighbourhood_distance=10, concatenate_vectors=True
)
wide_embeddings = wide_cce.transform(regions_gdf, features_gdf, joint_gdf)
wide_embeddings
aerialway_0 | airports_0 | sustenance_0 | education_0 | transportation_0 | finances_0 | healthcare_0 | culture_art_entertainment_0 | other_0 | buildings_0 | ... | culture_art_entertainment_10 | other_10 | buildings_10 | emergency_10 | historic_10 | leisure_10 | shops_10 | sport_10 | tourism_10 | greenery_10 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
89393362d53ffff | 0.0 | 0.0 | 5.0 | 0.0 | 4.0 | 0.0 | 7.0 | 0.0 | 1.0 | 1.0 | ... | 0.302326 | 0.558140 | 3.023256 | 0.0 | 0.930233 | 4.441860 | 4.883721 | 2.023256 | 3.093023 | 4.255814 |
89393362a23ffff | 0.0 | 0.0 | 12.0 | 1.0 | 41.0 | 4.0 | 2.0 | 0.0 | 1.0 | 5.0 | ... | 0.648649 | 0.972973 | 1.918919 | 0.0 | 1.729730 | 4.810811 | 3.810811 | 0.621622 | 5.000000 | 4.945946 |
89393362b77ffff | 0.0 | 0.0 | 25.0 | 0.0 | 7.0 | 2.0 | 1.0 | 0.0 | 1.0 | 3.0 | ... | 0.290323 | 0.580645 | 0.645161 | 0.0 | 0.935484 | 2.903226 | 4.064516 | 0.612903 | 2.290323 | 4.774194 |
89393375e5bffff | 0.0 | 0.0 | 4.0 | 1.0 | 13.0 | 2.0 | 0.0 | 0.0 | 0.0 | 3.0 | ... | 0.131579 | 0.500000 | 0.631579 | 0.0 | 0.236842 | 4.026316 | 2.105263 | 1.815789 | 1.894737 | 3.289474 |
89393375a7bffff | 0.0 | 0.0 | 3.0 | 1.0 | 19.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | ... | 0.250000 | 0.833333 | 1.783333 | 0.0 | 1.383333 | 3.066667 | 4.550000 | 0.766667 | 1.833333 | 4.916667 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
89393362b63ffff | 0.0 | 0.0 | 2.0 | 0.0 | 16.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | ... | 0.200000 | 0.633333 | 1.433333 | 0.0 | 1.266667 | 4.133333 | 6.933333 | 1.133333 | 1.900000 | 3.400000 |
89393362d3bffff | 0.0 | 0.0 | 2.0 | 0.0 | 25.0 | 1.0 | 2.0 | 0.0 | 0.0 | 0.0 | ... | 0.101695 | 0.610169 | 1.406780 | 0.0 | 0.745763 | 2.983051 | 2.830508 | 0.813559 | 2.152542 | 5.491525 |
893933676a3ffff | 0.0 | 0.0 | 2.0 | 0.0 | 15.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | ... | 0.218750 | 0.593750 | 2.031250 | 0.0 | 1.187500 | 4.656250 | 5.468750 | 1.406250 | 2.937500 | 4.156250 |
89393362b33ffff | 0.0 | 0.0 | 30.0 | 2.0 | 8.0 | 4.0 | 2.0 | 1.0 | 6.0 | 5.0 | ... | 0.171429 | 0.542857 | 1.257143 | 0.0 | 0.457143 | 3.314286 | 2.742857 | 0.857143 | 1.628571 | 3.628571 |
89393367473ffff | 0.0 | 0.0 | 0.0 | 0.0 | 6.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.485714 | 1.000000 | 5.257143 | 0.0 | 1.857143 | 2.771429 | 10.085714 | 0.542857 | 6.228571 | 8.857143 |
830 rows × 187 columns
Plotting example features¶
plot_numeric_data(regions_gdf, "leisure", embeddings)
plot_numeric_data(regions_gdf, "transportation", embeddings)