Count embedder
import geopandas as gpd
from shapely import geometry
from srai.constants import REGIONS_INDEX, WGS84_CRS
from srai.embedders import CountEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders.osm_loaders import OSMOnlineLoader
from srai.plotting.folium_wrapper import plot_numeric_data, plot_regions
from srai.regionalizers import H3Regionalizer
Data preparation¶
In order to use CountEmbedder we need to prepare some data.
Namely we need: regions_gdf, features_gdf, and joint_gdf.
These are the outputs of Regionalizers, Loaders and Joiners respectively.
Define the bounding box polygon¶
bbox_polygon = geometry.Polygon(
[
[17.0198822, 51.1191217],
[17.017436, 51.105004],
[17.0485067, 51.1027944],
[17.0511246, 51.1175054],
[17.0198822, 51.1191217],
]
)
bbox_gdf = gpd.GeoDataFrame(geometry=[bbox_polygon], crs=WGS84_CRS)
bbox_gdf
| geometry | |
|---|---|
| 0 | POLYGON ((17.01988 51.11912, 17.01744 51.105, ... |
Regionalize the area using an H3Regionalizer¶
regionalizer = H3Regionalizer(resolution=8, buffer=True)
regions_gdf = regionalizer.transform(bbox_gdf)
folium_map = bbox_gdf.explore(tiles="CartoDB positron")
plot_regions(regions_gdf, map=folium_map)
Download some objects from OpenStreetMap¶
You can use both OsmTagsFilter and GroupedOsmTagFilter filters. In this example, a simple OsmTagsFilter filter is used.
loader = OSMOnlineLoader()
tags = {
"leisure": ["playground", "adult_gaming_centre"],
"amenity": "pub",
}
features_gdf = loader.load(bbox_gdf, tags=tags)
features_gdf
| geometry | leisure | amenity | |
|---|---|---|---|
| feature_id | |||
| node/300461010 | POINT (17.03086 51.11136) | None | pub |
| node/300461141 | POINT (17.03012 51.10761) | None | pub |
| node/320922580 | POINT (17.02893 51.11182) | None | pub |
| node/551466140 | POINT (17.02916 51.10994) | None | pub |
| node/1093885062 | POINT (17.02947 51.11013) | None | pub |
| ... | ... | ... | ... |
| way/1075064799 | POLYGON ((17.03563 51.11154, 17.03576 51.11151... | playground | None |
| way/1075488712 | POLYGON ((17.03706 51.11154, 17.0371 51.11153,... | playground | None |
| way/1116839835 | POLYGON ((17.04233 51.11126, 17.04247 51.11126... | playground | None |
| way/1202309655 | POLYGON ((17.02257 51.11876, 17.02257 51.11873... | playground | None |
| way/1286669934 | POLYGON ((17.03622 51.11049, 17.03668 51.1104,... | playground | None |
107 rows × 3 columns
folium_map = plot_regions(regions_gdf, tiles_style="CartoDB positron", colormap=["lightgray"])
features_gdf.explore(m=folium_map)
Join the objects with the regions they belong to¶
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, features_gdf, return_geom=True)
joint_gdf
| geometry | ||
|---|---|---|
| region_id | feature_id | |
| 881e204089fffff | node/300461010 | POINT (17.03086 51.11136) |
| node/300461141 | POINT (17.03012 51.10761) | |
| node/320922580 | POINT (17.02893 51.11182) | |
| node/551466140 | POINT (17.02916 51.10994) | |
| node/1093885062 | POINT (17.02947 51.11013) | |
| ... | ... | ... |
| 881e2040d5fffff | way/1075064799 | POLYGON ((17.03576 51.11151, 17.03571 51.11139... |
| way/1075488712 | POLYGON ((17.0371 51.11153, 17.03712 51.11153,... | |
| way/1116839835 | POLYGON ((17.04233 51.1114, 17.04247 51.1114, ... | |
| 881e2040c3fffff | way/1202309655 | POLYGON ((17.02263 51.11876, 17.02263 51.11873... |
| 881e2040d5fffff | way/1286669934 | POLYGON ((17.03668 51.1104, 17.03664 51.11033,... |
107 rows × 1 columns
from plotly.express import colors
folium_map = plot_regions(regions_gdf, tiles_style="CartoDB positron", colormap=["rgba(0,0,0,0)"])
joint_gdf.reset_index().explore(m=folium_map, column=REGIONS_INDEX, cmap=colors.qualitative.Bold)
Embed using features existing in data¶
Count Embedder can count features on a higher level (tag key) or separately for each value (tag key and value). Both examples are shown below.
wide_embedder = CountEmbedder(count_subcategories=True)
wide_embedding = wide_embedder.transform(regions_gdf, features_gdf, joint_gdf)
wide_embedding
| leisure_adult_gaming_centre | leisure_playground | amenity_pub | |
|---|---|---|---|
| region_id | |||
| 881e2040d5fffff | 0 | 12 | 10 |
| 881e2040d7fffff | 0 | 1 | 0 |
| 881e2040d1fffff | 0 | 2 | 1 |
| 881e20409dfffff | 0 | 4 | 0 |
| 881e2040ddfffff | 0 | 0 | 0 |
| 881e2040c3fffff | 0 | 11 | 1 |
| 881e20408dfffff | 0 | 5 | 0 |
| 881e204081fffff | 0 | 1 | 1 |
| 881e2040c7fffff | 0 | 1 | 2 |
| 881e204089fffff | 1 | 7 | 33 |
| 881e20408bfffff | 0 | 11 | 3 |
dense_embedder = CountEmbedder(count_subcategories=False)
dense_embedding = dense_embedder.transform(regions_gdf, features_gdf, joint_gdf)
dense_embedding
| leisure | amenity | |
|---|---|---|
| region_id | ||
| 881e2040d7fffff | 1 | 0 |
| 881e2040d5fffff | 12 | 10 |
| 881e2040d1fffff | 2 | 1 |
| 881e2040c3fffff | 11 | 1 |
| 881e20409dfffff | 4 | 0 |
| 881e2040ddfffff | 0 | 0 |
| 881e2040c7fffff | 1 | 2 |
| 881e20408dfffff | 5 | 0 |
| 881e204081fffff | 1 | 1 |
| 881e20408bfffff | 11 | 3 |
| 881e204089fffff | 8 | 33 |
Embed with specifying expected output features¶
embedder = CountEmbedder(
expected_output_features=[
"amenity_parking",
"leisure_park",
"leisure_playground",
"amenity_pub",
]
)
embedding_expected_features = embedder.transform(regions_gdf, features_gdf, joint_gdf)
embedding_expected_features
| amenity_parking | leisure_park | leisure_playground | amenity_pub | |
|---|---|---|---|---|
| region_id | ||||
| 881e2040d1fffff | 0 | 0 | 2 | 1 |
| 881e2040d5fffff | 0 | 0 | 12 | 10 |
| 881e2040d7fffff | 0 | 0 | 1 | 0 |
| 881e20409dfffff | 0 | 0 | 4 | 0 |
| 881e2040ddfffff | 0 | 0 | 0 | 0 |
| 881e2040c3fffff | 0 | 0 | 11 | 1 |
| 881e2040c7fffff | 0 | 0 | 1 | 2 |
| 881e20408dfffff | 0 | 0 | 5 | 0 |
| 881e204081fffff | 0 | 0 | 1 | 1 |
| 881e204089fffff | 0 | 0 | 7 | 33 |
| 881e20408bfffff | 0 | 0 | 11 | 3 |
plot_numeric_data(regions_gdf, "leisure_playground", embedding_expected_features)
The resulting embedding contains only the columns specified in expected_output_features.
The ones that were not present in the data (leisure_park, amenity_parking) are added and filled with zeros.
The features that are both expected and present in the data are counted as usual.
The ones that are present in the data but are not expected (leisure_adult_gaming_centre) are discarded.