Contextual count embedder
from srai.embedders import ContextualCountEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders.osm_loaders import OSMPbfLoader
from srai.neighbourhoods import H3Neighbourhood
from srai.plotting.folium_wrapper import plot_numeric_data, plot_regions
from srai.regionalizers import H3Regionalizer
Data preparation¶
In order to use ContextualCountEmbedder
we need to prepare some data.
Namely we need: regions_gdf
, features_gdf
, and joint_gdf
.
These are the outputs of Regionalizers, Loaders and Joiners respectively.
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Lisboa, PT")
plot_regions(area_gdf)
Regionalize the area using an H3Regionalizer¶
regionalizer = H3Regionalizer(resolution=9, buffer=True)
regions_gdf = regionalizer.transform(area_gdf)
regions_gdf
geometry | |
---|---|
region_id | |
89393375353ffff | POLYGON ((-9.19846 38.76665, -9.20016 38.76531... |
89393375c5bffff | POLYGON ((-9.16374 38.7916, -9.16545 38.79026,... |
893933601b7ffff | POLYGON ((-9.1843 38.69775, -9.186 38.69641, -... |
8939336764bffff | POLYGON ((-9.13494 38.71072, -9.13665 38.70937... |
89393375ba3ffff | POLYGON ((-9.1062 38.7733, -9.1079 38.77196, -... |
... | ... |
893933674cbffff | POLYGON ((-9.1321 38.74118, -9.13381 38.73984,... |
89393375a4fffff | POLYGON ((-9.14384 38.75077, -9.14555 38.74943... |
89393375e27ffff | POLYGON ((-9.14612 38.77722, -9.14782 38.77588... |
89393362e43ffff | POLYGON ((-9.21073 38.71931, -9.21243 38.71797... |
893933670c3ffff | POLYGON ((-9.11047 38.72761, -9.11218 38.72627... |
830 rows × 1 columns
Download some objects from OpenStreetMap¶
You can use both OsmTagsFilter
and GroupedOsmTagsFilter
filters. In this example, a predefined GroupedOsmTagsFilter
filter BASE_OSM_GROUPS_FILTER
is used.
from srai.loaders.osm_loaders.filters import BASE_OSM_GROUPS_FILTER
loader = OSMPbfLoader()
features_gdf = loader.load(area_gdf, tags=BASE_OSM_GROUPS_FILTER)
features_gdf
0%| | 0.00/52.0M [00:00<?, ?B/s]
0%| | 14.3k/52.0M [00:00<06:02, 143kB/s]
0%| | 38.9k/52.0M [00:00<04:24, 197kB/s]
0%| | 93.2k/52.0M [00:00<02:29, 347kB/s]
0%|▏ | 208k/52.0M [00:00<01:19, 650kB/s]
1%|▎ | 437k/52.0M [00:00<00:42, 1.22MB/s]
2%|▋ | 892k/52.0M [00:00<00:22, 2.30MB/s]
3%|█▎ | 1.80M/52.0M [00:00<00:11, 4.42MB/s]
7%|██▌ | 3.62M/52.0M [00:00<00:05, 8.64MB/s]
14%|█████▏ | 7.23M/52.0M [00:00<00:02, 16.9MB/s]
20%|███████▍ | 10.4M/52.0M [00:01<00:01, 20.9MB/s]
25%|█████████▍ | 13.2M/52.0M [00:01<00:01, 22.7MB/s]
31%|███████████▍ | 16.1M/52.0M [00:01<00:01, 24.3MB/s]
37%|█████████████▋ | 19.3M/52.0M [00:01<00:01, 25.6MB/s]
43%|███████████████▉ | 22.5M/52.0M [00:01<00:01, 27.0MB/s]
49%|██████████████████▏ | 25.5M/52.0M [00:01<00:00, 27.2MB/s]
55%|████████████████████▍ | 28.7M/52.0M [00:01<00:00, 28.2MB/s]
61%|██████████████████████▍ | 31.5M/52.0M [00:01<00:00, 27.7MB/s]
67%|████████████████████████▋ | 34.7M/52.0M [00:01<00:00, 28.9MB/s]
72%|██████████████████████████▋ | 37.6M/52.0M [00:01<00:00, 28.6MB/s]
79%|█████████████████████████████▍ | 41.3M/52.0M [00:02<00:00, 28.9MB/s]
85%|███████████████████████████████▍ | 44.2M/52.0M [00:02<00:00, 28.5MB/s]
90%|█████████████████████████████████▍ | 47.0M/52.0M [00:02<00:00, 28.1MB/s]
96%|███████████████████████████████████▍ | 49.8M/52.0M [00:02<00:00, 28.0MB/s]
0%| | 0.00/52.0M [00:00<?, ?B/s]
100%|█████████████████████████████████████| 52.0M/52.0M [00:00<00:00, 59.3GB/s]
Finished operation in 0:00:16
geometry | aerialway | airports | buildings | culture_art_entertainment | education | emergency | finances | greenery | healthcare | historic | leisure | other | shops | sport | sustenance | tourism | transportation | water | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
feature_id | |||||||||||||||||||
node/21433772 | POINT (-9.19059 38.7288) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
node/21433776 | POINT (-9.19376 38.72666) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
node/25414208 | POINT (-9.16663 38.74018) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
node/25414256 | POINT (-9.10286 38.74711) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
node/25414265 | POINT (-9.10273 38.74707) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | public_transport=stop_position | None |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
way/1388343170 | POLYGON ((-9.09895 38.75713, -9.0989 38.75718,... | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=charging_station | None |
way/1388655571 | POLYGON ((-9.15816 38.75962, -9.15789 38.75971... | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=charging_station | None |
way/1388744343 | POLYGON ((-9.17958 38.7405, -9.17879 38.74066,... | None | None | None | None | None | None | None | None | None | None | leisure=garden | None | None | None | None | None | None | None |
way/1388817423 | POLYGON ((-9.21143 38.70477, -9.21141 38.70478... | None | None | None | None | None | None | None | None | None | None | leisure=swimming_pool | None | None | None | None | None | None | None |
way/1388932907 | POLYGON ((-9.18675 38.76115, -9.18728 38.76118... | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=parking | None |
33643 rows × 19 columns
Join the objects with the regions they belong to¶
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, features_gdf)
joint_gdf
region_id | feature_id |
---|---|
893933628dbffff | node/21433772 |
89393362e37ffff | node/21433776 |
893933628b7ffff | node/25414208 |
893933675c3ffff | node/25414256 |
node/25414265 | |
... | ... |
89393367593ffff | way/1388343170 |
89393362da7ffff | way/1388655571 |
89393362897ffff | way/1388744343 |
89393360517ffff | way/1388817423 |
8939337537bffff | way/1388932907 |
37279 rows × 0 columns
Embed using features existing in data¶
ContextualCountEmbedder
extends capabilities of basic CountEmbedder
by incorporating the neighbourhood of embedded region. In this example we will use the H3Neighbourhood
.
h3n = H3Neighbourhood()
Squashed vector version (default)¶
Embedder will return vector of the same length as CountEmbedder
, but will sum averaged values from the neighbourhoods diminished by the neighbour distance squared.
cce = ContextualCountEmbedder(
neighbourhood=h3n, neighbourhood_distance=10, concatenate_vectors=False
)
embeddings = cce.transform(regions_gdf, features_gdf, joint_gdf)
embeddings
aerialway | airports | buildings | culture_art_entertainment | education | emergency | finances | greenery | healthcare | historic | leisure | other | shops | sport | sustenance | tourism | transportation | water | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | ||||||||||||||||||
89393375353ffff | 0.002223 | 0.000000 | 0.351832 | 0.151862 | 0.524396 | 0.008693 | 0.121492 | 6.623513 | 0.250484 | 1.212328 | 3.599159 | 1.200150 | 1.691857 | 0.429628 | 0.738429 | 0.588625 | 8.980530 | 0.184155 |
89393375c5bffff | 0.000000 | 0.023091 | 0.139393 | 0.016246 | 0.251120 | 0.004224 | 0.056136 | 4.298787 | 0.111478 | 0.113048 | 1.635843 | 0.191396 | 0.577604 | 0.655485 | 0.291406 | 0.136175 | 3.552983 | 0.083029 |
893933601b7ffff | 0.000000 | 0.000000 | 1.037624 | 0.078534 | 0.833861 | 0.065901 | 0.121951 | 3.703926 | 0.366703 | 0.883133 | 2.757708 | 0.395566 | 1.796840 | 0.349990 | 2.419751 | 1.284807 | 6.338681 | 1.671753 |
8939336764bffff | 0.000000 | 0.000000 | 17.175437 | 1.713251 | 1.726555 | 0.069127 | 4.961731 | 14.376833 | 1.888211 | 9.336623 | 7.608650 | 8.542790 | 42.242624 | 0.315564 | 77.696275 | 36.105004 | 69.036305 | 0.695638 |
89393375ba3ffff | 0.016903 | 0.022867 | 1.874377 | 0.160945 | 3.585424 | 0.071746 | 0.306038 | 7.351886 | 0.346161 | 0.148156 | 5.723265 | 0.267499 | 7.332034 | 2.869054 | 2.374208 | 0.861172 | 29.015417 | 0.460180 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
893933674cbffff | 0.000184 | 0.002810 | 1.481423 | 2.151396 | 2.124653 | 0.012870 | 0.908884 | 9.726462 | 4.148521 | 0.452135 | 4.907228 | 3.668173 | 35.475242 | 0.771647 | 26.127810 | 1.862023 | 77.156327 | 0.227506 |
89393375a4fffff | 0.001940 | 0.014055 | 7.712679 | 1.266049 | 2.021333 | 0.037944 | 1.865514 | 15.451573 | 4.322060 | 0.381092 | 6.603319 | 0.436802 | 17.492262 | 2.914704 | 6.890157 | 7.517803 | 63.547545 | 0.143889 |
89393375e27ffff | 0.000000 | 1.245880 | 3.519266 | 0.027100 | 1.418931 | 0.014412 | 0.067230 | 5.969597 | 1.176406 | 0.060223 | 6.236678 | 0.115966 | 0.716845 | 3.470381 | 0.522107 | 0.382200 | 12.785797 | 0.131707 |
89393362e43ffff | 0.000000 | 0.000000 | 1.316745 | 0.069588 | 1.282352 | 0.004937 | 0.043374 | 2.132622 | 1.160093 | 0.111364 | 5.270236 | 0.139457 | 0.749612 | 2.276102 | 0.634365 | 2.831470 | 7.522308 | 0.171103 |
893933670c3ffff | 0.000000 | 0.000000 | 4.609627 | 1.147257 | 0.370635 | 0.006689 | 0.272832 | 1.344363 | 0.401850 | 3.600557 | 1.171878 | 1.457707 | 2.472827 | 0.318745 | 3.060596 | 3.230610 | 20.617232 | 1.323177 |
830 rows × 18 columns
Concatenated vector version¶
Embedder will return vector of length n * distance
where n
is number of features from the CountEmbedder
and distance
is number of neighbourhoods analysed.
Each feature will be postfixed with _n
string, where n
is the current distance. Values are averaged from all neighbours.
wide_cce = ContextualCountEmbedder(
neighbourhood=h3n, neighbourhood_distance=10, concatenate_vectors=True
)
wide_embeddings = wide_cce.transform(regions_gdf, features_gdf, joint_gdf)
wide_embeddings
aerialway_0 | airports_0 | buildings_0 | culture_art_entertainment_0 | education_0 | emergency_0 | finances_0 | greenery_0 | healthcare_0 | historic_0 | ... | healthcare_10 | historic_10 | leisure_10 | other_10 | shops_10 | sport_10 | sustenance_10 | tourism_10 | transportation_10 | water_10 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
89393375353ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | 1.0 | ... | 0.760000 | 0.680000 | 5.040000 | 0.440000 | 1.840000 | 2.840000 | 1.600000 | 1.720000 | 12.120000 | 0.600000 |
89393375c5bffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | ... | 0.739130 | 0.130435 | 3.695652 | 0.391304 | 4.217391 | 0.782609 | 2.608696 | 0.826087 | 9.565217 | 0.217391 |
893933601b7ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | ... | 0.566667 | 0.966667 | 5.966667 | 0.566667 | 5.000000 | 1.033333 | 6.533333 | 3.466667 | 9.500000 | 0.666667 |
8939336764bffff | 0.0 | 0.0 | 12.0 | 1.0 | 1.0 | 0.0 | 3.0 | 12.0 | 1.0 | 6.0 | ... | 1.642857 | 1.214286 | 4.607143 | 1.071429 | 12.428571 | 1.035714 | 7.571429 | 2.571429 | 24.678571 | 0.500000 |
89393375ba3ffff | 0.0 | 0.0 | 1.0 | 0.0 | 3.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | ... | 0.416667 | 0.291667 | 2.708333 | 0.333333 | 0.625000 | 0.458333 | 0.875000 | 2.166667 | 5.916667 | 0.166667 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
893933674cbffff | 0.0 | 0.0 | 0.0 | 2.0 | 1.0 | 0.0 | 0.0 | 6.0 | 3.0 | 0.0 | ... | 1.066667 | 2.333333 | 5.422222 | 1.088889 | 7.311111 | 2.288889 | 10.977778 | 4.066667 | 16.688889 | 0.422222 |
89393375a4fffff | 0.0 | 0.0 | 6.0 | 1.0 | 1.0 | 0.0 | 1.0 | 12.0 | 3.0 | 0.0 | ... | 1.016667 | 1.550000 | 4.500000 | 0.566667 | 5.400000 | 1.050000 | 4.450000 | 3.183333 | 14.750000 | 0.366667 |
89393375e27ffff | 0.0 | 1.0 | 3.0 | 0.0 | 1.0 | 0.0 | 0.0 | 3.0 | 1.0 | 0.0 | ... | 1.242424 | 0.181818 | 3.818182 | 0.393939 | 4.424242 | 1.121212 | 3.242424 | 0.939394 | 17.212121 | 0.181818 |
89393362e43ffff | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | ... | 0.652174 | 0.565217 | 4.304348 | 0.565217 | 2.913043 | 1.347826 | 2.956522 | 2.652174 | 9.869565 | 0.304348 |
893933670c3ffff | 0.0 | 0.0 | 4.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 3.0 | ... | 1.678571 | 3.392857 | 5.357143 | 1.321429 | 17.357143 | 1.571429 | 19.285714 | 7.142857 | 26.071429 | 0.500000 |
830 rows × 198 columns
Plotting example features¶
plot_numeric_data(regions_gdf, "leisure", embeddings)
plot_numeric_data(regions_gdf, "transportation", embeddings)
Other types of aggregations¶
By default, the ContextualCountEmbedder
averages the counts from neighbours. This aggregation_function
can be changed to one of: median
, sum
, min
, max
.
It's best to combine it with the wide format (concatenate_vectors=True
).
sum_cce = ContextualCountEmbedder(
neighbourhood=h3n,
neighbourhood_distance=10,
concatenate_vectors=True,
aggregation_function="sum",
)
sum_embeddings = sum_cce.transform(regions_gdf, features_gdf, joint_gdf)
sum_embeddings
aerialway_0 | airports_0 | buildings_0 | culture_art_entertainment_0 | education_0 | emergency_0 | finances_0 | greenery_0 | healthcare_0 | historic_0 | ... | healthcare_10 | historic_10 | leisure_10 | other_10 | shops_10 | sport_10 | sustenance_10 | tourism_10 | transportation_10 | water_10 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
89393375353ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | 1.0 | ... | 19.0 | 17.0 | 126.0 | 11.0 | 46.0 | 71.0 | 40.0 | 43.0 | 303.0 | 15.0 |
89393375c5bffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | ... | 17.0 | 3.0 | 85.0 | 9.0 | 97.0 | 18.0 | 60.0 | 19.0 | 220.0 | 5.0 |
893933601b7ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | ... | 17.0 | 29.0 | 179.0 | 17.0 | 150.0 | 31.0 | 196.0 | 104.0 | 285.0 | 20.0 |
8939336764bffff | 0.0 | 0.0 | 12.0 | 1.0 | 1.0 | 0.0 | 3.0 | 12.0 | 1.0 | 6.0 | ... | 46.0 | 34.0 | 129.0 | 30.0 | 348.0 | 29.0 | 212.0 | 72.0 | 691.0 | 14.0 |
89393375ba3ffff | 0.0 | 0.0 | 1.0 | 0.0 | 3.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | ... | 10.0 | 7.0 | 65.0 | 8.0 | 15.0 | 11.0 | 21.0 | 52.0 | 142.0 | 4.0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
893933674cbffff | 0.0 | 0.0 | 0.0 | 2.0 | 1.0 | 0.0 | 0.0 | 6.0 | 3.0 | 0.0 | ... | 48.0 | 105.0 | 244.0 | 49.0 | 329.0 | 103.0 | 494.0 | 183.0 | 751.0 | 19.0 |
89393375a4fffff | 0.0 | 0.0 | 6.0 | 1.0 | 1.0 | 0.0 | 1.0 | 12.0 | 3.0 | 0.0 | ... | 61.0 | 93.0 | 270.0 | 34.0 | 324.0 | 63.0 | 267.0 | 191.0 | 885.0 | 22.0 |
89393375e27ffff | 0.0 | 1.0 | 3.0 | 0.0 | 1.0 | 0.0 | 0.0 | 3.0 | 1.0 | 0.0 | ... | 41.0 | 6.0 | 126.0 | 13.0 | 146.0 | 37.0 | 107.0 | 31.0 | 568.0 | 6.0 |
89393362e43ffff | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | ... | 15.0 | 13.0 | 99.0 | 13.0 | 67.0 | 31.0 | 68.0 | 61.0 | 227.0 | 7.0 |
893933670c3ffff | 0.0 | 0.0 | 4.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 3.0 | ... | 47.0 | 95.0 | 150.0 | 37.0 | 486.0 | 44.0 | 540.0 | 200.0 | 730.0 | 14.0 |
830 rows × 198 columns
plot_numeric_data(regions_gdf, "tourism_8", sum_embeddings)