From 68f3f7221e3d39891a26800cfc983859c8870a23 Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Tue, 3 Mar 2026 00:54:12 +0100 Subject: [PATCH 1/5] Fix datashader failing with single-category Categorical (#483) When a pd.Categorical column has only 1 category, datashader's ds.by() still produces a 3D DataArray requiring a color_key. The >1 guard was causing color_key to be None and groups to fall back to count-based aggregation, both leading to errors or incorrect output. Changes: - Change color_key guard from >1 to >0 for both shapes and points - Change groups guard from >1 to >=1 for both shapes and points - Use _hex_no_alpha in points color_key construction for consistency Co-Authored-By: Claude Opus 4.6 --- src/spatialdata_plot/pl/render.py | 14 +++++--------- tests/pl/test_render_points.py | 27 +++++++++++++++++++++++++++ tests/pl/test_render_shapes.py | 19 +++++++++++++++++++ 3 files changed, 51 insertions(+), 9 deletions(-) diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 014c3cc5..56821c75 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -301,7 +301,7 @@ def _render_shapes( # Render shapes with datashader color_by_categorical = col_for_color is not None and color_source_vector is not None aggregate_with_reduction = None - if col_for_color is not None and (render_params.groups is None or len(render_params.groups) > 1): + if col_for_color is not None and (render_params.groups is None or len(render_params.groups) >= 1): if color_by_categorical: agg = cvs.polygons( transformed_element, @@ -359,7 +359,7 @@ def _render_shapes( color_key = ( [_hex_no_alpha(x) for x in color_vector.categories.values] if (type(color_vector) is pd.core.arrays.categorical.Categorical) - and (len(color_vector.categories.values) > 1) + and (len(color_vector.categories.values) > 0) else None ) @@ -815,7 +815,7 @@ def _render_points( if color_by_categorical and transformed_element[col_for_color].values.dtype == object: transformed_element[col_for_color] = transformed_element[col_for_color].astype("category") aggregate_with_reduction = None - if col_for_color is not None and (render_params.groups is None or len(render_params.groups) > 1): + if col_for_color is not None and (render_params.groups is None or len(render_params.groups) >= 1): if color_by_categorical: agg = cvs.points(transformed_element, "x", "y", agg=ds.by(col_for_color, ds.count())) else: @@ -853,15 +853,11 @@ def _render_points( agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5) color_key: list[str] | None = ( - list(color_vector.categories.values) + [_hex_no_alpha(x) for x in color_vector.categories.values] if (type(color_vector) is pd.core.arrays.categorical.Categorical) - and (len(color_vector.categories.values) > 1) + and (len(color_vector.categories.values) > 0) else None ) - - # remove alpha from color if it's hex - if color_key is not None and all(len(x) == 9 for x in color_key) and color_key[0][0] == "#": - color_key = [x[:-2] for x in color_key] if isinstance(color_vector[0], str) and ( color_vector is not None and all(len(x) == 9 for x in color_vector) and color_vector[0][0] == "#" ): diff --git a/tests/pl/test_render_points.py b/tests/pl/test_render_points.py index 34b63e94..cad05dda 100644 --- a/tests/pl/test_render_points.py +++ b/tests/pl/test_render_points.py @@ -572,3 +572,30 @@ def test_datashader_colors_points_from_table_obs(sdata_blobs: SpatialData): method="datashader", size=5, ).pl.show() + + +def test_plot_datashader_single_category_points(sdata_blobs: SpatialData): + """Datashader should handle a Categorical column with only 1 category (#483).""" + n_obs = len(sdata_blobs["blobs_points"]) + obs = pd.DataFrame( + { + "instance_id": np.arange(n_obs), + "region": pd.Categorical(["blobs_points"] * n_obs), + "foo": pd.Categorical(["only_cat"] * n_obs), + } + ) + table = TableModel.parse( + adata=AnnData(get_standard_RNG().normal(size=(n_obs, 3)), obs=obs), + region="blobs_points", + region_key="region", + instance_key="instance_id", + ) + sdata_blobs["single_cat_table"] = table + + sdata_blobs.pl.render_points( + "blobs_points", + color="foo", + table_name="single_cat_table", + method="datashader", + size=5, + ).pl.show() diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index fce236d6..5b5240ac 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -973,3 +973,22 @@ def test_plot_can_handle_mixed_numeric_and_color_data(sdata_blobs: SpatialData): # Mixed numeric / non-numeric values should raise a TypeError with pytest.raises(TypeError, match="contains both numeric and non-numeric values"): sdata_blobs.pl.render_shapes(element="blobs_circles", color="mixed_data", na_color="gray").pl.show() + + +def test_plot_datashader_single_category(sdata_blobs: SpatialData): + """Datashader should handle a Categorical column with only 1 category (#483).""" + n_obs = len(sdata_blobs["blobs_polygons"]) + adata = AnnData(get_standard_RNG().normal(size=(n_obs, 10))) + adata.obs = pd.DataFrame(get_standard_RNG().normal(size=(n_obs, 3)), columns=["a", "b", "c"]) + adata.obs["category"] = pd.Categorical(["only_cat"] * n_obs) + adata.obs["instance_id"] = list(range(n_obs)) + adata.obs["region"] = "blobs_polygons" + table = TableModel.parse( + adata=adata, + region_key="region", + instance_key="instance_id", + region="blobs_polygons", + ) + sdata_blobs["table"] = table + + sdata_blobs.pl.render_shapes(element="blobs_polygons", color="category", method="datashader").pl.show() From f2e2cffbd87191074f9ef5fa4932396dcbfb8efc Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Tue, 3 Mar 2026 01:01:52 +0100 Subject: [PATCH 2/5] Guard _hex_no_alpha calls for non-hex color names color_vector.categories.values can contain named colors like "lightgreen", not just hex strings. Only apply _hex_no_alpha when the value starts with "#". Co-Authored-By: Claude Opus 4.6 --- src/spatialdata_plot/pl/render.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 56821c75..d335ee1b 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -357,7 +357,10 @@ def _render_shapes( agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5) color_key = ( - [_hex_no_alpha(x) for x in color_vector.categories.values] + [ + _hex_no_alpha(x) if isinstance(x, str) and x.startswith("#") else x + for x in color_vector.categories.values + ] if (type(color_vector) is pd.core.arrays.categorical.Categorical) and (len(color_vector.categories.values) > 0) else None @@ -853,7 +856,10 @@ def _render_points( agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5) color_key: list[str] | None = ( - [_hex_no_alpha(x) for x in color_vector.categories.values] + [ + _hex_no_alpha(x) if isinstance(x, str) and x.startswith("#") else x + for x in color_vector.categories.values + ] if (type(color_vector) is pd.core.arrays.categorical.Categorical) and (len(color_vector.categories.values) > 0) else None From 752fa770e3c6582ab2bc07e6f99bbdbd77a16b1b Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Tue, 3 Mar 2026 01:10:23 +0100 Subject: [PATCH 3/5] Add issue link and detail to regression tests, remove non-reproducing test The test_datashader_single_group_named_color test didn't reproduce a failure on main (it only caught an intermediate regression within this PR), so remove it. Add the issue URL and detailed docstrings to the two tests that do reproduce the #483 bug on main. Co-Authored-By: Claude Opus 4.6 --- tests/pl/test_render_points.py | 9 ++++++++- tests/pl/test_render_shapes.py | 9 ++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/tests/pl/test_render_points.py b/tests/pl/test_render_points.py index cad05dda..e0574137 100644 --- a/tests/pl/test_render_points.py +++ b/tests/pl/test_render_points.py @@ -575,7 +575,14 @@ def test_datashader_colors_points_from_table_obs(sdata_blobs: SpatialData): def test_plot_datashader_single_category_points(sdata_blobs: SpatialData): - """Datashader should handle a Categorical column with only 1 category (#483).""" + """Datashader with a single-category Categorical must not raise. + + Regression test for https://github.com/scverse/spatialdata-plot/issues/483. + Before the fix, color_key was None when there was only 1 category, but ds.by() + still produced a 3D DataArray, causing datashader to raise: + ValueError: Color key must be provided, with at least as many colors as + there are categorical fields + """ n_obs = len(sdata_blobs["blobs_points"]) obs = pd.DataFrame( { diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index 5b5240ac..8065ca9e 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -976,7 +976,14 @@ def test_plot_can_handle_mixed_numeric_and_color_data(sdata_blobs: SpatialData): def test_plot_datashader_single_category(sdata_blobs: SpatialData): - """Datashader should handle a Categorical column with only 1 category (#483).""" + """Datashader with a single-category Categorical must not raise. + + Regression test for https://github.com/scverse/spatialdata-plot/issues/483. + Before the fix, color_key was None when there was only 1 category, but ds.by() + still produced a 3D DataArray, causing datashader to raise: + ValueError: Color key must be provided, with at least as many colors as + there are categorical fields + """ n_obs = len(sdata_blobs["blobs_polygons"]) adata = AnnData(get_standard_RNG().normal(size=(n_obs, 10))) adata.obs = pd.DataFrame(get_standard_RNG().normal(size=(n_obs, 3)), columns=["a", "b", "c"]) From fe7421cacc43860d000d7ebc8d790ce4aa0521ec Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Tue, 3 Mar 2026 01:13:24 +0100 Subject: [PATCH 4/5] Simplify groups guard to remove dead else branch The condition `groups is None or len(groups) >= 1` was always true since groups is either None or a non-empty list. Simplify to just `col_for_color is not None`, which makes the intent clearer: use categorical/reduction aggregation when there's a color column, otherwise fall back to plain count aggregation. Also note: render_labels does not use datashader, so despite the #483 issue title mentioning labels, this bug only affects shapes and points. Co-Authored-By: Claude Opus 4.6 --- src/spatialdata_plot/pl/render.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index d335ee1b..8ff65505 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -301,7 +301,7 @@ def _render_shapes( # Render shapes with datashader color_by_categorical = col_for_color is not None and color_source_vector is not None aggregate_with_reduction = None - if col_for_color is not None and (render_params.groups is None or len(render_params.groups) >= 1): + if col_for_color is not None: if color_by_categorical: agg = cvs.polygons( transformed_element, @@ -818,7 +818,7 @@ def _render_points( if color_by_categorical and transformed_element[col_for_color].values.dtype == object: transformed_element[col_for_color] = transformed_element[col_for_color].astype("category") aggregate_with_reduction = None - if col_for_color is not None and (render_params.groups is None or len(render_params.groups) >= 1): + if col_for_color is not None: if color_by_categorical: agg = cvs.points(transformed_element, "x", "y", agg=ds.by(col_for_color, ds.count())) else: From 3539fc7b0eb806db6cca7b9a45a686602471fe4a Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Tue, 3 Mar 2026 01:25:39 +0100 Subject: [PATCH 5/5] Extract _build_color_key_from_categorical helper to deduplicate The color_key construction logic was duplicated between _render_shapes and _render_points. Extract into a shared helper that also caches the categories lookup and keeps the hex-vs-named-color guard in one place. Co-Authored-By: Claude Opus 4.6 --- src/spatialdata_plot/pl/render.py | 35 +++++++++++++++---------------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 8ff65505..b89d65a8 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -62,6 +62,21 @@ _Normalize = Normalize | abc.Sequence[Normalize] +def _build_color_key_from_categorical(color_vector: object) -> list[str] | None: + """Build a datashader ``color_key`` list from a categorical color vector. + + Returns ``None`` when *color_vector* is not a :class:`pd.Categorical` or + has no categories. Hex colours are stripped of their alpha channel; + named colours (e.g. ``"red"``) are passed through unchanged. + """ + if type(color_vector) is not pd.core.arrays.categorical.Categorical: + return None + cat_values = color_vector.categories.values + if len(cat_values) == 0: + return None + return [_hex_no_alpha(x) if isinstance(x, str) and x.startswith("#") else x for x in cat_values] + + def _split_colorbar_params(params: dict[str, object] | None) -> tuple[dict[str, object], dict[str, object], str | None]: """Split colorbar params into layout hints, Matplotlib kwargs, and label override.""" layout: dict[str, object] = {} @@ -356,15 +371,7 @@ def _render_shapes( agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=2) agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5) - color_key = ( - [ - _hex_no_alpha(x) if isinstance(x, str) and x.startswith("#") else x - for x in color_vector.categories.values - ] - if (type(color_vector) is pd.core.arrays.categorical.Categorical) - and (len(color_vector.categories.values) > 0) - else None - ) + color_key = _build_color_key_from_categorical(color_vector) if color_by_categorical or col_for_color is None: ds_cmap = None @@ -855,15 +862,7 @@ def _render_points( agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=2) agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5) - color_key: list[str] | None = ( - [ - _hex_no_alpha(x) if isinstance(x, str) and x.startswith("#") else x - for x in color_vector.categories.values - ] - if (type(color_vector) is pd.core.arrays.categorical.Categorical) - and (len(color_vector.categories.values) > 0) - else None - ) + color_key = _build_color_key_from_categorical(color_vector) if isinstance(color_vector[0], str) and ( color_vector is not None and all(len(x) == 9 for x in color_vector) and color_vector[0][0] == "#" ):