Count embedder
import geopandas as gpd
from shapely import geometry
from srai.constants import REGIONS_INDEX, WGS84_CRS
from srai.embedders import CountEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders.osm_loaders import OSMOnlineLoader
from srai.plotting.folium_wrapper import plot_numeric_data, plot_regions
from srai.regionalizers import H3Regionalizer
Data preparation¶
In order to use CountEmbedder
we need to prepare some data.
Namely we need: regions_gdf
, features_gdf
, and joint_gdf
.
These are the outputs of Regionalizers, Loaders and Joiners respectively.
Define the bounding box polygon¶
bbox_polygon = geometry.Polygon(
[
[17.0198822, 51.1191217],
[17.017436, 51.105004],
[17.0485067, 51.1027944],
[17.0511246, 51.1175054],
[17.0198822, 51.1191217],
]
)
bbox_gdf = gpd.GeoDataFrame(geometry=[bbox_polygon], crs=WGS84_CRS)
bbox_gdf
geometry | |
---|---|
0 | POLYGON ((17.01988 51.11912, 17.01744 51.10500... |
Regionalize the area using an H3Regionalizer¶
regionalizer = H3Regionalizer(resolution=8, buffer=True)
regions_gdf = regionalizer.transform(bbox_gdf)
folium_map = bbox_gdf.explore(tiles="CartoDB positron")
plot_regions(regions_gdf, map=folium_map)
Download some objects from OpenStreetMap¶
You can use both OsmTagsFilter
and GroupedOsmTagFilter
filters. In this example, a simple OsmTagsFilter
filter is used.
loader = OSMOnlineLoader()
tags = {
"leisure": ["playground", "adult_gaming_centre"],
"amenity": "pub",
}
features_gdf = loader.load(bbox_gdf, tags=tags)
features_gdf
0%| | 0/3 [00:00<?, ?it/s]
Downloading leisure: playground : 0%| | 0/3 [00:00<?, ?it/s]
Downloading leisure: playground : 33%|███▎ | 1/3 [00:00<00:00, 7.47it/s]
Downloading leisure: adult_gaming_centre: 33%|███▎ | 1/3 [00:00<00:00, 7.47it/s]
Downloading leisure: adult_gaming_centre: 67%|██████▋ | 2/3 [00:00<00:00, 7.97it/s]
Downloading amenity: pub : 67%|██████▋ | 2/3 [00:00<00:00, 7.97it/s]
Downloading amenity: pub : 100%|██████████| 3/3 [00:00<00:00, 8.05it/s]
Downloading amenity: pub : 100%|██████████| 3/3 [00:00<00:00, 7.96it/s]
geometry | leisure | amenity | |
---|---|---|---|
feature_id | |||
node/300461010 | POINT (17.03086 51.11136) | None | pub |
node/300461141 | POINT (17.03012 51.10761) | None | pub |
node/320922580 | POINT (17.02893 51.11182) | None | pub |
node/551466140 | POINT (17.02916 51.10994) | None | pub |
node/1093885062 | POINT (17.02947 51.11013) | None | pub |
... | ... | ... | ... |
way/1074595740 | POLYGON ((17.02759 51.10969, 17.02761 51.10968... | playground | None |
way/1075488712 | POLYGON ((17.03706 51.11154, 17.03710 51.11153... | playground | None |
way/1116839835 | POLYGON ((17.04233 51.11126, 17.04247 51.11126... | playground | None |
way/1202309655 | POLYGON ((17.02259 51.11876, 17.02259 51.11874... | playground | None |
way/1286669934 | POLYGON ((17.03622 51.11049, 17.03668 51.11041... | playground | None |
106 rows × 3 columns
folium_map = plot_regions(regions_gdf, tiles_style="CartoDB positron", colormap=["lightgray"])
features_gdf.explore(m=folium_map)
Join the objects with the regions they belong to¶
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, features_gdf, return_geom=True)
joint_gdf
geometry | ||
---|---|---|
region_id | feature_id | |
881e204089fffff | node/300461010 | POINT (17.03086 51.11136) |
node/300461141 | POINT (17.03012 51.10761) | |
node/320922580 | POINT (17.02893 51.11182) | |
node/551466140 | POINT (17.02916 51.10994) | |
node/1093885062 | POINT (17.02947 51.11013) | |
... | ... | |
way/1074595740 | POLYGON ((17.02761 51.10971, 17.02764 51.10973... | |
881e2040d5fffff | way/1075488712 | POLYGON ((17.03710 51.11153, 17.03712 51.11153... |
way/1116839835 | POLYGON ((17.04233 51.11140, 17.04247 51.11140... | |
881e2040c3fffff | way/1202309655 | POLYGON ((17.02265 51.11876, 17.02265 51.11874... |
881e2040d5fffff | way/1286669934 | POLYGON ((17.03668 51.11041, 17.03664 51.11033... |
106 rows × 1 columns
from plotly.express import colors
folium_map = plot_regions(regions_gdf, tiles_style="CartoDB positron", colormap=["rgba(0,0,0,0)"])
joint_gdf.reset_index().explore(m=folium_map, column=REGIONS_INDEX, cmap=colors.qualitative.Bold)
Embed using features existing in data¶
Count Embedder can count features on a higher level (tag key) or separately for each value (tag key and value). Both examples are shown below.
wide_embedder = CountEmbedder(count_subcategories=True)
wide_embedding = wide_embedder.transform(regions_gdf, features_gdf, joint_gdf)
wide_embedding
leisure_adult_gaming_centre | leisure_playground | amenity_pub | |
---|---|---|---|
region_id | |||
881e204081fffff | 0 | 1 | 1 |
881e20408bfffff | 0 | 11 | 3 |
881e2040d1fffff | 0 | 2 | 1 |
881e2040c7fffff | 0 | 1 | 1 |
881e204089fffff | 1 | 7 | 36 |
881e2040c3fffff | 0 | 9 | 1 |
881e20408dfffff | 0 | 5 | 0 |
881e2040d5fffff | 0 | 11 | 10 |
881e20409dfffff | 0 | 4 | 0 |
881e2040d7fffff | 0 | 1 | 0 |
881e2040ddfffff | 0 | 0 | 0 |
dense_embedder = CountEmbedder(count_subcategories=False)
dense_embedding = dense_embedder.transform(regions_gdf, features_gdf, joint_gdf)
dense_embedding
leisure | amenity | |
---|---|---|
region_id | ||
881e204081fffff | 1 | 1 |
881e20408bfffff | 11 | 3 |
881e2040d1fffff | 2 | 1 |
881e2040c7fffff | 1 | 1 |
881e204089fffff | 8 | 36 |
881e2040c3fffff | 9 | 1 |
881e20408dfffff | 5 | 0 |
881e2040d5fffff | 11 | 10 |
881e20409dfffff | 4 | 0 |
881e2040d7fffff | 1 | 0 |
881e2040ddfffff | 0 | 0 |
Embed with specifying expected output features¶
embedder = CountEmbedder(
expected_output_features=[
"amenity_parking",
"leisure_park",
"leisure_playground",
"amenity_pub",
]
)
embedding_expected_features = embedder.transform(regions_gdf, features_gdf, joint_gdf)
embedding_expected_features
amenity_parking | leisure_park | leisure_playground | amenity_pub | |
---|---|---|---|---|
region_id | ||||
881e204081fffff | 0 | 0 | 1 | 1 |
881e20408bfffff | 0 | 0 | 11 | 3 |
881e2040d1fffff | 0 | 0 | 2 | 1 |
881e2040c7fffff | 0 | 0 | 1 | 1 |
881e204089fffff | 0 | 0 | 7 | 36 |
881e2040c3fffff | 0 | 0 | 9 | 1 |
881e20408dfffff | 0 | 0 | 5 | 0 |
881e2040d5fffff | 0 | 0 | 11 | 10 |
881e20409dfffff | 0 | 0 | 4 | 0 |
881e2040d7fffff | 0 | 0 | 1 | 0 |
881e2040ddfffff | 0 | 0 | 0 | 0 |
plot_numeric_data(regions_gdf, "leisure_playground", embedding_expected_features)
The resulting embedding contains only the columns specified in expected_output_features
.
The ones that were not present in the data (leisure_park
, amenity_parking
) are added and filled with zeros.
The features that are both expected and present in the data are counted as usual.
The ones that are present in the data but are not expected (leisure_adult_gaming_centre
) are discarded.