Contextual count embedder
from srai.loaders.osm_loaders import OSMPbfLoader
from srai.regionalizers import H3Regionalizer
from srai.joiners import IntersectionJoiner
from srai.embedders import ContextualCountEmbedder
from srai.plotting.folium_wrapper import plot_regions, plot_numeric_data
from srai.neighbourhoods import H3Neighbourhood
Data preparation¶
In order to use ContextualCountEmbedder
we need to prepare some data.
Namely we need: regions_gdf
, features_gdf
, and joint_gdf
.
These are the outputs of Regionalizers, Loaders and Joiners respectively.
from srai.regionalizers import geocode_to_region_gdf
area_gdf = geocode_to_region_gdf("Lisboa, PT")
plot_regions(area_gdf)
Regionalize the area using an H3Regionalizer¶
regionalizer = H3Regionalizer(resolution=9, buffer=True)
regions_gdf = regionalizer.transform(area_gdf)
regions_gdf
geometry | |
---|---|
region_id | |
89393367497ffff | POLYGON ((-9.12372 38.75483, -9.12542 38.75349... |
89393362a07ffff | POLYGON ((-9.17995 38.71059, -9.18165 38.70925... |
893933674afffff | POLYGON ((-9.11089 38.74845, -9.11260 38.74711... |
89393362c2bffff | POLYGON ((-9.18703 38.74504, -9.18873 38.74370... |
89393360593ffff | POLYGON ((-9.21986 38.71447, -9.22157 38.71313... |
... | ... |
89393375a57ffff | POLYGON ((-9.14569 38.75638, -9.14739 38.75504... |
89393375b13ffff | POLYGON ((-9.11534 38.76847, -9.11704 38.76713... |
893933629bbffff | POLYGON ((-9.14602 38.74435, -9.14772 38.74301... |
89393362bb7ffff | POLYGON ((-9.14853 38.72591, -9.15023 38.72457... |
8939336755bffff | POLYGON ((-9.10209 38.74125, -9.10380 38.73991... |
830 rows × 1 columns
Download some objects from OpenStreetMap¶
You can use both OsmTagsFilter
and GroupedOsmTagsFilter
filters. In this example, a predefined GroupedOsmTagsFilter
filter BASE_OSM_GROUPS_FILTER
is used.
from srai.loaders.osm_loaders.filters import BASE_OSM_GROUPS_FILTER
loader = OSMPbfLoader()
features_gdf = loader.load(area_gdf, tags=BASE_OSM_GROUPS_FILTER)
features_gdf
[Lisbon, Portugal] Downloading pbf file #1 (Elements): 100%|██████████| 1297575/1297575 [00:02<00:00, 509917.79it/s] ccfc4ec912ac803c97b939feba28e0a57de61e0e543e36e51c010f9f0a167e37.osm.pbf: 100%|██████████| 7.30M/7.30M [00:00<00:00, 10.4MiB/s] [Lisbon, Portugal] Counting pbf features: 772876it [00:03, 251762.17it/s] [Lisbon, Portugal] Parsing pbf file #1: 100%|██████████| 772876/772876 [00:14<00:00, 51775.07it/s] Grouping features: 100%|██████████| 23868/23868 [00:07<00:00, 3153.03it/s]
geometry | aerialway | airports | sustenance | education | transportation | finances | healthcare | culture_art_entertainment | other | buildings | emergency | historic | leisure | shops | sport | tourism | greenery | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
feature_id | ||||||||||||||||||
node/21433772 | POINT (-9.19059 38.72880) | NaN | NaN | NaN | NaN | public_transport=stop_position | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
node/21433776 | POINT (-9.19376 38.72666) | NaN | NaN | NaN | NaN | public_transport=stop_position | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
node/25414208 | POINT (-9.16568 38.74047) | NaN | NaN | NaN | NaN | public_transport=stop_position | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
node/25414256 | POINT (-9.10320 38.74623) | NaN | NaN | NaN | NaN | public_transport=stop_position | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
node/25414265 | POINT (-9.10243 38.74785) | NaN | NaN | NaN | NaN | public_transport=stop_position | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
relation/16068976 | MULTIPOLYGON (((-9.17274 38.71184, -9.17224 38... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | landuse=meadow |
relation/16068971 | MULTIPOLYGON (((-9.17183 38.71125, -9.17176 38... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | landuse=grass |
relation/16071109 | MULTIPOLYGON (((-9.17010 38.71514, -9.17009 38... | NaN | NaN | NaN | NaN | amenity=parking | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
relation/8131598 | MULTIPOLYGON (((-9.14676 38.74328, -9.14662 38... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | leisure=garden | NaN | NaN | NaN | NaN |
relation/16158578 | MULTIPOLYGON (((-9.15193 38.72702, -9.15123 38... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | landuse=grass |
23868 rows × 18 columns
Join the objects with the regions they belong to¶
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, features_gdf)
joint_gdf
region_id | feature_id |
---|---|
89393367497ffff | node/7826714718 |
way/148898258 | |
89393367483ffff | way/148898258 |
89393367497ffff | node/6751599820 |
way/401124568 | |
... | ... |
8939336755bffff | node/3109770683 |
node/10583618730 | |
node/7569587785 | |
node/2486499700 | |
node/6468985985 |
26702 rows × 0 columns
Embed using features existing in data¶
ContextualCountEmbedder
extends capabilities of basic CountEmbedder
by incorporating the neighbourhood of embedded region. In this example we will use the H3Neighbourhood
.
h3n = H3Neighbourhood()
Squashed vector version (default)¶
Embedder will return vector of the same length as CountEmbedder
, but will sum averaged values from the neighbourhoods diminished by the neighbour distance squared.
cce = ContextualCountEmbedder(
neighbourhood=h3n, neighbourhood_distance=10, concatenate_vectors=False
)
embeddings = cce.transform(regions_gdf, features_gdf, joint_gdf)
embeddings
Generating embeddings: 100%|██████████| 830/830 [00:04<00:00, 201.24it/s]
aerialway | airports | sustenance | education | transportation | finances | healthcare | culture_art_entertainment | other | buildings | emergency | historic | leisure | shops | sport | tourism | greenery | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||
89393367497ffff | 0.000965 | 0.017389 | 0.497551 | 2.462877 | 15.005424 | 0.108012 | 0.218776 | 1.185060 | 1.111254 | 0.202870 | 0.001306 | 0.046683 | 5.118919 | 0.711557 | 1.865506 | 0.437898 | 10.927808 |
89393362a07ffff | 0.000000 | 0.000000 | 1.869293 | 0.418422 | 13.230423 | 0.229170 | 0.432691 | 1.082004 | 0.465417 | 0.535593 | 0.000000 | 0.499773 | 9.436378 | 1.780298 | 1.485251 | 2.883786 | 6.468164 |
893933674afffff | 0.002505 | 0.004717 | 0.446957 | 2.421084 | 14.597619 | 0.070648 | 0.303332 | 0.039901 | 1.126671 | 1.364482 | 0.002931 | 0.072529 | 4.952377 | 0.590369 | 0.665959 | 0.371455 | 9.027821 |
89393362c2bffff | 0.009398 | 0.000000 | 1.386980 | 0.631465 | 5.830317 | 0.286869 | 0.288885 | 0.061974 | 0.308920 | 0.894384 | 0.000623 | 1.321903 | 6.893187 | 2.992993 | 1.398104 | 5.402610 | 1.430750 |
89393360593ffff | 0.000000 | 0.000000 | 0.421329 | 0.334847 | 2.008824 | 0.085256 | 0.130598 | 0.188825 | 0.098657 | 0.332108 | 0.000000 | 0.139421 | 4.434659 | 0.270520 | 0.880034 | 2.918944 | 2.736094 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
89393375a57ffff | 0.001516 | 0.031819 | 11.346867 | 2.667361 | 41.629460 | 3.634535 | 1.893502 | 1.084906 | 1.457439 | 1.624066 | 0.003657 | 1.301203 | 7.552573 | 8.090817 | 0.938979 | 0.852490 | 2.250354 |
89393375b13ffff | 0.004608 | 0.033494 | 1.366909 | 1.346210 | 16.974248 | 0.168203 | 2.279197 | 0.031094 | 0.417852 | 0.316410 | 0.006749 | 0.097403 | 4.887185 | 2.850432 | 1.868945 | 1.328426 | 23.347434 |
893933629bbffff | 0.001940 | 0.005279 | 19.991188 | 1.736646 | 44.364427 | 2.685290 | 2.890765 | 1.249996 | 1.354267 | 5.197155 | 0.001064 | 1.443580 | 5.854441 | 13.585506 | 3.360394 | 4.718295 | 2.768193 |
89393362bb7ffff | 0.001475 | 0.000000 | 27.390038 | 5.606997 | 35.802529 | 11.672613 | 6.297369 | 0.541473 | 5.711992 | 39.285628 | 0.000000 | 11.397079 | 4.823785 | 31.570164 | 1.622318 | 25.943659 | 27.815267 |
8939336755bffff | 0.002682 | 0.000295 | 1.700278 | 0.159898 | 8.197234 | 0.052176 | 0.136261 | 1.099558 | 0.239986 | 4.706807 | 0.002019 | 1.192580 | 3.427525 | 1.491377 | 1.262322 | 0.513310 | 2.084276 |
830 rows × 17 columns
Concatenated vector version¶
Embedder will return vector of length n * distance
where n
is number of features from the CountEmbedder
and distance
is number of neighbourhoods analysed.
Each feature will be postfixed with _n
string, where n
is the current distance. Values are averaged from all neighbours.
wide_cce = ContextualCountEmbedder(
neighbourhood=h3n, neighbourhood_distance=10, concatenate_vectors=True
)
wide_embeddings = wide_cce.transform(regions_gdf, features_gdf, joint_gdf)
wide_embeddings
Generating embeddings: 100%|██████████| 830/830 [00:03<00:00, 210.51it/s]
aerialway_0 | airports_0 | sustenance_0 | education_0 | transportation_0 | finances_0 | healthcare_0 | culture_art_entertainment_0 | other_0 | buildings_0 | ... | culture_art_entertainment_10 | other_10 | buildings_10 | emergency_10 | historic_10 | leisure_10 | shops_10 | sport_10 | tourism_10 | greenery_10 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
region_id | |||||||||||||||||||||
89393367497ffff | 0.0 | 0.0 | 0.0 | 2.0 | 10.0 | 0.0 | 0.0 | 1.0 | 1.0 | 0.0 | ... | 0.200000 | 0.750000 | 1.600000 | 0.025000 | 0.725000 | 2.600000 | 4.000000 | 1.000000 | 3.425000 | 4.375000 |
89393362a07ffff | 0.0 | 0.0 | 0.0 | 0.0 | 8.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | ... | 0.540541 | 1.135135 | 2.756757 | 0.000000 | 2.162162 | 6.270270 | 7.594595 | 1.135135 | 5.648649 | 5.891892 |
893933674afffff | 0.0 | 0.0 | 0.0 | 2.0 | 9.0 | 0.0 | 0.0 | 0.0 | 1.0 | 1.0 | ... | 0.323529 | 0.911765 | 2.294118 | 0.058824 | 0.470588 | 3.411765 | 6.617647 | 0.441176 | 2.294118 | 3.029412 |
89393362c2bffff | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.219512 | 0.682927 | 1.390244 | 0.000000 | 1.341463 | 3.390244 | 4.365854 | 0.756098 | 2.756098 | 4.000000 |
89393360593ffff | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.157895 | 0.157895 | 0.473684 | 0.000000 | 0.631579 | 2.842105 | 0.368421 | 0.789474 | 0.631579 | 2.894737 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
89393375a57ffff | 0.0 | 0.0 | 9.0 | 2.0 | 33.0 | 3.0 | 1.0 | 1.0 | 1.0 | 1.0 | ... | 0.050000 | 0.633333 | 2.433333 | 0.000000 | 0.666667 | 4.033333 | 4.266667 | 1.466667 | 1.833333 | 4.166667 |
89393375b13ffff | 0.0 | 0.0 | 1.0 | 1.0 | 9.0 | 0.0 | 2.0 | 0.0 | 0.0 | 0.0 | ... | 0.147059 | 0.235294 | 0.352941 | 0.000000 | 0.382353 | 3.294118 | 2.058824 | 0.970588 | 1.176471 | 3.441176 |
893933629bbffff | 0.0 | 0.0 | 16.0 | 1.0 | 36.0 | 2.0 | 2.0 | 1.0 | 1.0 | 4.0 | ... | 0.266667 | 1.050000 | 0.850000 | 0.016667 | 1.816667 | 2.933333 | 5.550000 | 0.616667 | 2.950000 | 3.433333 |
89393362bb7ffff | 0.0 | 0.0 | 19.0 | 5.0 | 26.0 | 10.0 | 5.0 | 0.0 | 5.0 | 36.0 | ... | 0.195122 | 0.634146 | 0.463415 | 0.000000 | 0.829268 | 3.585366 | 3.121951 | 1.000000 | 1.536585 | 5.585366 |
8939336755bffff | 0.0 | 0.0 | 1.0 | 0.0 | 4.0 | 0.0 | 0.0 | 1.0 | 0.0 | 4.0 | ... | 0.321429 | 1.000000 | 1.500000 | 0.000000 | 0.571429 | 4.642857 | 7.000000 | 1.250000 | 2.535714 | 7.071429 |
830 rows × 187 columns
Plotting example features¶
plot_numeric_data(regions_gdf, "leisure", embeddings)
plot_numeric_data(regions_gdf, "transportation", embeddings)