Count embedder
from shapely import geometry
import geopandas as gpd
from srai.constants import WGS84_CRS, REGIONS_INDEX
from srai.loaders.osm_loaders import OSMOnlineLoader
from srai.regionalizers import H3Regionalizer
from srai.joiners import IntersectionJoiner
from srai.embedders import CountEmbedder
from srai.plotting.folium_wrapper import plot_regions, plot_numeric_data
Data preparation¶
In order to use CountEmbedder
we need to prepare some data.
Namely we need: regions_gdf
, features_gdf
, and joint_gdf
.
These are the outputs of Regionalizers, Loaders and Joiners respectively.
Define the bounding box polygon¶
bbox_polygon = geometry.Polygon(
[
[17.0198822, 51.1191217],
[17.017436, 51.105004],
[17.0485067, 51.1027944],
[17.0511246, 51.1175054],
[17.0198822, 51.1191217],
]
)
bbox_gdf = gpd.GeoDataFrame(geometry=[bbox_polygon], crs=WGS84_CRS)
bbox_gdf
geometry | |
---|---|
0 | POLYGON ((17.01988 51.11912, 17.01744 51.10500... |
Regionalize the area using an H3Regionalizer¶
regionalizer = H3Regionalizer(resolution=8, buffer=True)
regions_gdf = regionalizer.transform(bbox_gdf)
folium_map = bbox_gdf.explore(tiles="CartoDB positron")
plot_regions(regions_gdf, map=folium_map)
Download some objects from OpenStreetMap¶
You can use both OsmTagsFilter
and GroupedOsmTagFilter
filters. In this example, a simple OsmTagsFilter
filter is used.
loader = OSMOnlineLoader()
tags = {
"leisure": ["playground", "adult_gaming_centre"],
"amenity": "pub",
}
features_gdf = loader.load(bbox_gdf, tags=tags)
features_gdf
Downloading amenity: pub : 100%|██████████| 3/3 [00:00<00:00, 6.66it/s]
geometry | leisure | amenity | |
---|---|---|---|
feature_id | |||
node/300461010 | POINT (17.03086 51.11136) | None | pub |
node/300461141 | POINT (17.03012 51.10761) | None | pub |
node/320922580 | POINT (17.02893 51.11182) | None | pub |
node/551466140 | POINT (17.02916 51.10994) | None | pub |
node/1093885062 | POINT (17.02947 51.11013) | None | pub |
... | ... | ... | ... |
way/1064113400 | POLYGON ((17.02815 51.11514, 17.02821 51.11520... | playground | None |
way/1074595740 | POLYGON ((17.02759 51.10969, 17.02761 51.10968... | playground | None |
way/1075488712 | POLYGON ((17.03706 51.11154, 17.03710 51.11153... | playground | None |
way/1116839835 | POLYGON ((17.04233 51.11126, 17.04247 51.11126... | playground | None |
way/1202309655 | POLYGON ((17.02259 51.11876, 17.02259 51.11874... | playground | None |
104 rows × 3 columns
folium_map = plot_regions(regions_gdf, tiles_style="CartoDB positron", colormap=["lightgray"])
features_gdf.explore(m=folium_map)
Join the objects with the regions they belong to¶
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, features_gdf, return_geom=True)
joint_gdf
geometry | ||
---|---|---|
region_id | feature_id | |
881e204089fffff | node/300461010 | POINT (17.03086 51.11136) |
node/300461141 | POINT (17.03012 51.10761) | |
node/320922580 | POINT (17.02893 51.11182) | |
node/551466140 | POINT (17.02916 51.10994) | |
node/1093885062 | POINT (17.02947 51.11013) | |
... | ... | ... |
881e20408dfffff | way/1053871552 | POLYGON ((17.01859 51.10537, 17.01854 51.10536... |
way/1053871553 | POLYGON ((17.01814 51.10537, 17.01817 51.10540... | |
way/1053871565 | POLYGON ((17.01830 51.10599, 17.01828 51.10602... | |
way/1053871566 | POLYGON ((17.01781 51.10578, 17.01779 51.10580... | |
way/1057696622 | POLYGON ((17.01832 51.10705, 17.01833 51.10704... |
104 rows × 1 columns
from plotly.express import colors
folium_map = plot_regions(regions_gdf, tiles_style="CartoDB positron", colormap=["rgba(0,0,0,0)"])
joint_gdf.reset_index().explore(m=folium_map, column=REGIONS_INDEX, cmap=colors.qualitative.Bold)
Embed using features existing in data¶
Count Embedder can count features on a higher level (tag key) or separately for each value (tag key and value). Both examples are shown below.
wide_embedder = CountEmbedder(count_subcategories=True)
wide_embedding = wide_embedder.transform(regions_gdf, features_gdf, joint_gdf)
wide_embedding
leisure_adult_gaming_centre | leisure_playground | amenity_pub | |
---|---|---|---|
region_id | |||
881e2040d5fffff | 0 | 10 | 10 |
881e20409dfffff | 0 | 4 | 0 |
881e20408bfffff | 0 | 11 | 3 |
881e204081fffff | 0 | 1 | 1 |
881e20408dfffff | 0 | 5 | 0 |
881e2040c3fffff | 0 | 9 | 1 |
881e2040d7fffff | 0 | 1 | 0 |
881e204089fffff | 1 | 7 | 35 |
881e2040d1fffff | 0 | 2 | 1 |
881e2040ddfffff | 0 | 0 | 0 |
881e2040c7fffff | 0 | 1 | 1 |
dense_embedder = CountEmbedder(count_subcategories=False)
dense_embedding = dense_embedder.transform(regions_gdf, features_gdf, joint_gdf)
dense_embedding
leisure | amenity | |
---|---|---|
region_id | ||
881e2040d5fffff | 10 | 10 |
881e20409dfffff | 4 | 0 |
881e20408bfffff | 11 | 3 |
881e204081fffff | 1 | 1 |
881e20408dfffff | 5 | 0 |
881e2040c3fffff | 9 | 1 |
881e2040d7fffff | 1 | 0 |
881e204089fffff | 8 | 35 |
881e2040d1fffff | 2 | 1 |
881e2040ddfffff | 0 | 0 |
881e2040c7fffff | 1 | 1 |
Embed with specifying expected output features¶
embedder = CountEmbedder(
expected_output_features=[
"amenity_parking",
"leisure_park",
"leisure_playground",
"amenity_pub",
]
)
embedding_expected_features = embedder.transform(regions_gdf, features_gdf, joint_gdf)
embedding_expected_features
amenity_parking | leisure_park | leisure_playground | amenity_pub | |
---|---|---|---|---|
region_id | ||||
881e2040d5fffff | 0 | 0 | 10 | 10 |
881e20409dfffff | 0 | 0 | 4 | 0 |
881e20408bfffff | 0 | 0 | 11 | 3 |
881e204081fffff | 0 | 0 | 1 | 1 |
881e20408dfffff | 0 | 0 | 5 | 0 |
881e2040c3fffff | 0 | 0 | 9 | 1 |
881e2040d7fffff | 0 | 0 | 1 | 0 |
881e204089fffff | 0 | 0 | 7 | 35 |
881e2040d1fffff | 0 | 0 | 2 | 1 |
881e2040ddfffff | 0 | 0 | 0 | 0 |
881e2040c7fffff | 0 | 0 | 1 | 1 |
plot_numeric_data(regions_gdf, "leisure_playground", embedding_expected_features)
The resulting embedding contains only the columns specified in expected_output_features
.
The ones that were not present in the data (leisure_park
, amenity_parking
) are added and filled with zeros.
The features that are both expected and present in the data are counted as usual.
The ones that are present in the data but are not expected (leisure_adult_gaming_centre
) are discarded.