Skip to content

grid_classify

cellseg_gsontools.grid.grid_classify(grid, objs, metric_func, predicate, new_col_names, parallel=True, num_processes=-1, pbar=False, **kwargs)

Classify the grid based on objs inside the grid cells.

Parameters:

Name Type Description Default
grid GeoDataFrame

The grid of rectangles to classify.

required
objs GeoDataFrame

The objects to use for classification.

required
metric_func Callable

The metric/heuristic function to use for classification.

required
predicate str

The predicate to use for the spatial join. Allowed values are "intersects" and "within".

required
new_col_names Union[Tuple[str, ...], str]

The name of the new column(s) in the grid gdf.

required
parallel bool

Whether to use parallel processing.

True
num_processes int

The number of processes to use. If -1, uses all available cores. Ignored if parallel=False.

-1
pbar bool

Whether to show a progress bar. Ignored if parallel=False.

False

Returns:

Type Description
GeoDataFrame

gpd.GeoDataFrame: The grid with the new columns added.

Raises:

Type Description
ValueError

If predicate is not one of "intersects" or "within".

Examples:

Get the number of immune cells in each grid cell at the tumor stroma interface:

>>> import geopandas as gpd
>>> from shapely.geometry import Polygon
>>> from functools import partial
>>> from cellseg_gsontools import gdf_apply, read_gdf
>>> from cellseg_gsontools.grid import grid_classify, grid_overlay
>>> from cellseg_gsontools.context import InterfaceContext
>>> # Define a heuristic function to get the number of immune cells
>>> def get_immune_cell_cnt(gdf: gpd.GeoDataFrame, **kwargs) -> int:
...     try:
...         cnt = gdf.class_name.value_counts()["inflammatory"]
...     except KeyError:
...         cnt = 0
...     return int(cnt)
>>> # Read in the tissue areas and cells
>>> area_gdf = gpd.read_file("path/to/area.geojson")
>>> cell_gdf = gpd.read_file("path/to/cell.geojson")
>>> # Fit a tumor-stroma interface
>>> tumor_stroma_iface = InterfaceContext(
...     area_gdf=area_gdf,
...     cell_gdf=cell_gdf,
...     top_labels="area_cin",
...     bottom_labels="areastroma",
...     buffer_dist=250,
...     graph_type="distband",
...     dist_thresh=75,
...     patch_size=(128, 128),
...     stride=(128, 128),
...     min_area_size=50000,
... )
>>> tumor_stroma_iface.fit(verbose=False)
>>> # Get the grid and the cells at the interface
>>> iface_grid = grid_overlay(
...     tumor_stroma_iface.context2gdf("interface_area"),
...     patch_size=(128, 128),
...     stride=(128, 128),
... )
>>> cells = tumor_stroma_iface.context2gdf("interface_cells")
>>> # Classify the grid
>>> iface_grid = grid_classify(
>>>     grid=iface_grid,
>>>     objs=cells,
>>>     metric_func=get_immune_cnt,
>>>     predicate="intersects",
>>>     new_col_name="immune_cnt",
>>>     parallel=True,
>>>     pbar=True,
>>>     num_processes=-1
>>> )
>>> iface_grid
geometry  immune_cnt
28  POLYGON ((20032.00000 54098.50000, 20160.00000... 15
29  POLYGON ((20160.00000 54098.50000, 20288.00000... 3
Source code in cellseg_gsontools/grid.py
def grid_classify(
    grid: gpd.GeoDataFrame,
    objs: gpd.GeoDataFrame,
    metric_func: Callable,
    predicate: str,
    new_col_names: Union[Tuple[str, ...], str],
    parallel: bool = True,
    num_processes: int = -1,
    pbar: bool = False,
    **kwargs,
) -> gpd.GeoDataFrame:
    """Classify the grid based on objs inside the grid cells.

    Parameters:
        grid (gpd.GeoDataFrame):
            The grid of rectangles to classify.
        objs (gpd.GeoDataFrame):
            The objects to use for classification.
        metric_func (Callable):
            The metric/heuristic function to use for classification.
        predicate (str):
            The predicate to use for the spatial join. Allowed values are "intersects"
            and "within".
        new_col_names (Union[Tuple[str, ...], str]):
            The name of the new column(s) in the grid gdf.
        parallel (bool):
            Whether to use parallel processing.
        num_processes (int):
            The number of processes to use. If -1, uses all available cores.
            Ignored if parallel=False.
        pbar (bool):
            Whether to show a progress bar. Ignored if parallel=False.

    Returns:
        gpd.GeoDataFrame:
            The grid with the new columns added.

    Raises:
        ValueError: If predicate is not one of "intersects" or "within".

    Examples:
        Get the number of immune cells in each grid cell at the tumor stroma interface:
        >>> import geopandas as gpd
        >>> from shapely.geometry import Polygon
        >>> from functools import partial
        >>> from cellseg_gsontools import gdf_apply, read_gdf
        >>> from cellseg_gsontools.grid import grid_classify, grid_overlay
        >>> from cellseg_gsontools.context import InterfaceContext
        >>> # Define a heuristic function to get the number of immune cells
        >>> def get_immune_cell_cnt(gdf: gpd.GeoDataFrame, **kwargs) -> int:
        ...     try:
        ...         cnt = gdf.class_name.value_counts()["inflammatory"]
        ...     except KeyError:
        ...         cnt = 0
        ...     return int(cnt)
        >>> # Read in the tissue areas and cells
        >>> area_gdf = gpd.read_file("path/to/area.geojson")
        >>> cell_gdf = gpd.read_file("path/to/cell.geojson")
        >>> # Fit a tumor-stroma interface
        >>> tumor_stroma_iface = InterfaceContext(
        ...     area_gdf=area_gdf,
        ...     cell_gdf=cell_gdf,
        ...     top_labels="area_cin",
        ...     bottom_labels="areastroma",
        ...     buffer_dist=250,
        ...     graph_type="distband",
        ...     dist_thresh=75,
        ...     patch_size=(128, 128),
        ...     stride=(128, 128),
        ...     min_area_size=50000,
        ... )
        >>> tumor_stroma_iface.fit(verbose=False)
        >>> # Get the grid and the cells at the interface
        >>> iface_grid = grid_overlay(
        ...     tumor_stroma_iface.context2gdf("interface_area"),
        ...     patch_size=(128, 128),
        ...     stride=(128, 128),
        ... )
        >>> cells = tumor_stroma_iface.context2gdf("interface_cells")
        >>> # Classify the grid
        >>> iface_grid = grid_classify(
        >>>     grid=iface_grid,
        >>>     objs=cells,
        >>>     metric_func=get_immune_cnt,
        >>>     predicate="intersects",
        >>>     new_col_name="immune_cnt",
        >>>     parallel=True,
        >>>     pbar=True,
        >>>     num_processes=-1
        >>> )
        >>> iface_grid
        geometry  immune_cnt
        28  POLYGON ((20032.00000 54098.50000, 20160.00000... 15
        29  POLYGON ((20160.00000 54098.50000, 20288.00000... 3
    """
    allowed = ["intersects", "within"]
    if predicate not in allowed:
        raise ValueError(f"predicate must be one of {allowed}. Got {predicate}")

    if isinstance(new_col_names, str):
        new_col_names = [new_col_names]

    func = partial(
        get_rect_metric, objs=objs, predicate=predicate, metric_func=metric_func
    )
    grid.loc[:, list(new_col_names)] = gdf_apply(
        grid,
        func=func,
        parallel=parallel,
        pbar=pbar,
        num_processes=num_processes,
        columns=["geometry"],
    )

    return grid