Contextual count embedder
from srai.embedders import ContextualCountEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders.osm_loaders import OSMPbfLoader
from srai.neighbourhoods import H3Neighbourhood
from srai.plotting.folium_wrapper import plot_numeric_data, plot_regions
from srai.regionalizers import H3Regionalizer
Data preparation¶
In order to use ContextualCountEmbedder
we need to prepare some data.
Namely we need: regions_gdf
, features_gdf
, and joint_gdf
.
These are the outputs of Regionalizers, Loaders and Joiners respectively.
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Lisboa, PT")
plot_regions(area_gdf)
Regionalize the area using an H3Regionalizer¶
regionalizer = H3Regionalizer(resolution=9, buffer=True)
regions_gdf = regionalizer.transform(area_gdf)
regions_gdf
geometry | |
---|---|
region_id | |
8939337581bffff | POLYGON ((-9.13698 38.78205, -9.13868 38.78071... |
89393375e03ffff | POLYGON ((-9.16113 38.77718, -9.16284 38.77584... |
89393360183ffff | POLYGON ((-9.19636 38.6953, -9.19807 38.69396,... |
89393375a87ffff | POLYGON ((-9.14133 38.76921, -9.14304 38.76787... |
89393375ab7ffff | POLYGON ((-9.13328 38.77084, -9.13499 38.7695,... |
... | ... |
89393362a23ffff | POLYGON ((-9.17299 38.70901, -9.1747 38.70766,... |
8939337596bffff | POLYGON ((-9.10108 38.77732, -9.10279 38.77598... |
89393375a17ffff | POLYGON ((-9.13949 38.76361, -9.14119 38.76227... |
89393362d77ffff | POLYGON ((-9.16288 38.74992, -9.16458 38.74858... |
8939337594fffff | POLYGON ((-9.10804 38.77891, -9.10975 38.77757... |
830 rows × 1 columns
Download some objects from OpenStreetMap¶
You can use both OsmTagsFilter
and GroupedOsmTagsFilter
filters. In this example, a predefined GroupedOsmTagsFilter
filter BASE_OSM_GROUPS_FILTER
is used.
from srai.loaders.osm_loaders.filters import BASE_OSM_GROUPS_FILTER
loader = OSMPbfLoader()
features_gdf = loader.load(area_gdf, tags=BASE_OSM_GROUPS_FILTER)
features_gdf
/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/pyogrio/geopandas.py:662: UserWarning: 'crs' was not provided. The output dataset will not have projection information defined and may not be usable in other systems. write(
/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/geopandas/array.py:1638: UserWarning: CRS not set for some of the concatenation inputs. Setting output's CRS as WGS 84 (the single non-null crs provided). return GeometryArray(data, crs=_get_common_crs(to_concat))
Finished operation in 0:00:17
geometry | aerialway | airports | buildings | culture_art_entertainment | education | emergency | finances | greenery | healthcare | historic | leisure | other | shops | sport | sustenance | tourism | transportation | water | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
feature_id | |||||||||||||||||||
node/1491303660 | POINT (-9.14435 38.70854) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | tourism=hostel | None | None |
node/1558080346 | POINT (-9.19934 38.70374) | None | None | None | None | None | None | None | None | amenity=pharmacy | None | None | None | None | None | None | None | None | None |
node/1585677447 | POINT (-9.19387 38.70566) | None | None | None | None | None | None | None | None | None | None | None | None | shop=kiosk | None | None | None | None | None |
node/1585677479 | POINT (-9.19582 38.70509) | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=cafe | None | None | None |
node/1585677486 | POINT (-9.19443 38.70547) | None | None | None | None | None | None | None | None | amenity=clinic | None | None | None | None | None | None | None | None | None |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
way/749249324 | POLYGON ((-9.22085 38.69514, -9.22069 38.69507... | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=restaurant | None | None | None |
way/749834057 | POLYGON ((-9.15022 38.75482, -9.15026 38.7548,... | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | None | amenity=parking | None |
way/749983967 | POLYGON ((-9.1005 38.74179, -9.10044 38.74184,... | None | None | None | None | None | None | None | landuse=grass | None | None | None | None | None | None | None | None | None | None |
way/749983968 | POLYGON ((-9.09966 38.74289, -9.09945 38.7436,... | None | None | None | None | None | None | None | landuse=grass | None | None | None | None | None | None | None | None | None | None |
way/753144796 | POLYGON ((-9.14522 38.77637, -9.14568 38.77526... | None | None | None | None | None | None | None | None | None | None | leisure=sports_centre | None | None | None | None | None | None | None |
31762 rows × 19 columns
Join the objects with the regions they belong to¶
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, features_gdf)
joint_gdf
region_id | feature_id |
---|---|
8939337581bffff | way/8053332 |
way/1017692509 | |
way/336112741 | |
way/8053313 | |
89393375e03ffff | node/11287438897 |
... | ... |
89393362d77ffff | way/71571983 |
node/5457844759 | |
8939337594fffff | way/1302618848 |
way/574867509 | |
way/819189800 |
35201 rows × 0 columns
Embed using features existing in data¶
ContextualCountEmbedder
extends capabilities of basic CountEmbedder
by incorporating the neighbourhood of embedded region. In this example we will use the H3Neighbourhood
.
h3n = H3Neighbourhood()
Squashed vector version (default)¶
Embedder will return vector of the same length as CountEmbedder
, but will sum averaged values from the neighbourhoods diminished by the neighbour distance squared.
cce = ContextualCountEmbedder(
neighbourhood=h3n, neighbourhood_distance=10, concatenate_vectors=False
)
embeddings = cce.transform(regions_gdf, features_gdf, joint_gdf)
embeddings
Generating embeddings for neighbours: 0%| | 0/10 [00:00<?, ?it/s]
Generating embeddings for neighbours: 10%|█ | 1/10 [00:00<00:02, 3.31it/s]
Generating embeddings for neighbours: 20%|██ | 2/10 [00:00<00:02, 3.24it/s]
Generating embeddings for neighbours: 30%|███ | 3/10 [00:00<00:02, 3.24it/s]
Generating embeddings for neighbours: 40%|████ | 4/10 [00:01<00:01, 3.21it/s]
Generating embeddings for neighbours: 50%|█████ | 5/10 [00:01<00:01, 3.17it/s]
Generating embeddings for neighbours: 60%|██████ | 6/10 [00:01<00:01, 3.17it/s]
Generating embeddings for neighbours: 70%|███████ | 7/10 [00:02<00:00, 3.16it/s]
Generating embeddings for neighbours: 80%|████████ | 8/10 [00:02<00:00, 3.10it/s]
Generating embeddings for neighbours: 90%|█████████ | 9/10 [00:02<00:00, 3.10it/s]
Generating embeddings for neighbours: 100%|██████████| 10/10 [00:03<00:00, 3.08it/s]
Generating embeddings for neighbours: 100%|██████████| 10/10 [00:03<00:00, 3.14it/s]
aerialway | airports | buildings | culture_art_entertainment | education | emergency | finances | greenery | healthcare | historic | leisure | other | shops | sport | sustenance | tourism | transportation | water | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | ||||||||||||||||||
8939337581bffff | 0.000000 | 1.473146 | 0.160375 | 0.016119 | 0.171692 | 0.027985 | 0.052390 | 5.735442 | 0.081760 | 0.050183 | 0.777338 | 0.123850 | 0.467361 | 0.217815 | 0.280668 | 0.215247 | 2.453377 | 0.041174 |
89393375e03ffff | 0.001435 | 0.030241 | 2.510878 | 0.081341 | 3.693383 | 1.005127 | 0.425964 | 24.361112 | 1.785552 | 0.453471 | 5.087195 | 0.216154 | 19.222494 | 0.577140 | 8.548444 | 0.933652 | 40.878750 | 0.306236 |
89393360183ffff | 0.000000 | 0.000000 | 0.465605 | 0.074146 | 0.605504 | 0.005601 | 0.326114 | 3.641193 | 0.309361 | 0.643598 | 3.514902 | 0.276305 | 1.837293 | 0.434031 | 2.215174 | 2.767140 | 4.813069 | 1.991171 |
89393375a87ffff | 0.000000 | 1.390355 | 0.345082 | 0.042415 | 0.241024 | 0.026329 | 0.141298 | 7.662708 | 0.186119 | 0.060771 | 0.865226 | 0.094489 | 1.164539 | 0.305983 | 0.769630 | 0.278473 | 2.876438 | 0.042715 |
89393375ab7ffff | 0.000359 | 1.397100 | 0.290631 | 0.027281 | 0.180623 | 0.262207 | 0.251055 | 9.127187 | 0.166464 | 0.092108 | 1.011516 | 0.127033 | 0.993106 | 0.306369 | 1.021618 | 0.276196 | 5.086090 | 0.039360 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
89393362a23ffff | 0.000000 | 0.000000 | 7.696887 | 0.149256 | 2.615428 | 0.008694 | 4.328706 | 13.583682 | 3.718496 | 3.952870 | 7.848077 | 1.643482 | 24.964655 | 0.391164 | 19.929864 | 6.430694 | 50.260973 | 3.454038 |
8939337596bffff | 0.035980 | 0.018373 | 1.820851 | 0.095708 | 1.523136 | 0.122336 | 0.373613 | 2.747567 | 0.427851 | 0.055042 | 6.402048 | 0.237872 | 6.479561 | 2.639019 | 2.331755 | 1.010371 | 16.311080 | 0.490610 |
89393375a17ffff | 0.000318 | 1.224772 | 5.623927 | 2.039577 | 0.438353 | 0.017783 | 0.251006 | 3.535609 | 0.515442 | 0.098433 | 1.818002 | 0.192040 | 2.569314 | 0.840038 | 2.528746 | 0.756099 | 7.549133 | 0.057774 |
89393362d77ffff | 0.028935 | 0.003153 | 2.829857 | 0.094466 | 1.736450 | 0.022118 | 2.419017 | 11.130072 | 5.709749 | 2.365105 | 6.096784 | 0.226076 | 2.035800 | 2.572874 | 4.235222 | 4.359241 | 14.602468 | 0.351905 |
8939337594fffff | 0.009398 | 0.029759 | 0.402041 | 1.095480 | 1.385251 | 0.023677 | 0.144846 | 4.017949 | 0.263237 | 0.181543 | 2.475304 | 0.124859 | 1.465004 | 0.744785 | 0.801487 | 0.304296 | 9.033977 | 0.136048 |
830 rows × 18 columns
Concatenated vector version¶
Embedder will return vector of length n * distance
where n
is number of features from the CountEmbedder
and distance
is number of neighbourhoods analysed.
Each feature will be postfixed with _n
string, where n
is the current distance. Values are averaged from all neighbours.
wide_cce = ContextualCountEmbedder(
neighbourhood=h3n, neighbourhood_distance=10, concatenate_vectors=True
)
wide_embeddings = wide_cce.transform(regions_gdf, features_gdf, joint_gdf)
wide_embeddings
Generating embeddings for neighbours: 0%| | 0/10 [00:00<?, ?it/s]
Generating embeddings for neighbours: 10%|█ | 1/10 [00:00<00:02, 3.10it/s]
Generating embeddings for neighbours: 20%|██ | 2/10 [00:00<00:02, 3.13it/s]
Generating embeddings for neighbours: 30%|███ | 3/10 [00:00<00:02, 3.14it/s]
Generating embeddings for neighbours: 40%|████ | 4/10 [00:01<00:01, 3.16it/s]
Generating embeddings for neighbours: 50%|█████ | 5/10 [00:01<00:01, 3.14it/s]
Generating embeddings for neighbours: 60%|██████ | 6/10 [00:01<00:01, 3.12it/s]
Generating embeddings for neighbours: 70%|███████ | 7/10 [00:02<00:00, 3.11it/s]
Generating embeddings for neighbours: 80%|████████ | 8/10 [00:02<00:00, 3.07it/s]
Generating embeddings for neighbours: 90%|█████████ | 9/10 [00:02<00:00, 3.05it/s]
Generating embeddings for neighbours: 100%|██████████| 10/10 [00:03<00:00, 3.03it/s]
Generating embeddings for neighbours: 100%|██████████| 10/10 [00:03<00:00, 3.08it/s]
aerialway_0 | airports_0 | buildings_0 | culture_art_entertainment_0 | education_0 | emergency_0 | finances_0 | greenery_0 | healthcare_0 | historic_0 | ... | healthcare_10 | historic_10 | leisure_10 | other_10 | shops_10 | sport_10 | sustenance_10 | tourism_10 | transportation_10 | water_10 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
8939337581bffff | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | ... | 0.555556 | 0.333333 | 4.166667 | 0.388889 | 2.166667 | 1.861111 | 1.583333 | 1.416667 | 13.944444 | 0.500000 |
89393375e03ffff | 0.0 | 0.0 | 2.0 | 0.0 | 3.0 | 1.0 | 0.0 | 21.0 | 1.0 | 0.0 | ... | 1.033333 | 0.500000 | 4.233333 | 0.400000 | 3.733333 | 1.733333 | 2.266667 | 2.433333 | 14.266667 | 1.000000 |
89393360183ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.150000 | 0.800000 | 3.600000 | 0.250000 | 0.700000 | 0.450000 | 1.250000 | 1.600000 | 3.550000 | 0.300000 |
89393375a87ffff | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 5.0 | 0.0 | 0.0 | ... | 0.975610 | 0.365854 | 4.341463 | 0.560976 | 5.829268 | 1.390244 | 3.731707 | 1.341463 | 13.585366 | 0.292683 |
89393375ab7ffff | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 6.0 | 0.0 | 0.0 | ... | 1.108696 | 0.326087 | 4.913043 | 0.413043 | 6.543478 | 2.195652 | 4.326087 | 1.782609 | 14.847826 | 0.826087 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
89393362a23ffff | 0.0 | 0.0 | 6.0 | 0.0 | 2.0 | 0.0 | 4.0 | 10.0 | 3.0 | 3.0 | ... | 1.297297 | 1.729730 | 5.081081 | 1.000000 | 6.297297 | 0.702703 | 7.918919 | 5.405405 | 13.810811 | 0.405405 |
8939337596bffff | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.285714 | 0.000000 | 2.809524 | 0.095238 | 0.904762 | 0.857143 | 0.523810 | 0.523810 | 5.380952 | 0.142857 |
89393375a17ffff | 0.0 | 1.0 | 5.0 | 2.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | ... | 1.423077 | 0.673077 | 3.192308 | 0.519231 | 9.884615 | 1.115385 | 5.519231 | 2.250000 | 16.576923 | 0.326923 |
89393362d77ffff | 0.0 | 0.0 | 2.0 | 0.0 | 1.0 | 0.0 | 2.0 | 6.0 | 5.0 | 2.0 | ... | 1.033333 | 0.916667 | 3.316667 | 0.716667 | 7.350000 | 0.866667 | 4.133333 | 2.266667 | 12.050000 | 0.333333 |
8939337594fffff | 0.0 | 0.0 | 0.0 | 1.0 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | ... | 0.217391 | 0.043478 | 3.217391 | 0.304348 | 0.217391 | 1.347826 | 0.347826 | 0.391304 | 7.739130 | 0.130435 |
830 rows × 198 columns
Plotting example features¶
plot_numeric_data(regions_gdf, "leisure", embeddings)
plot_numeric_data(regions_gdf, "transportation", embeddings)