未验证 提交 0f3670e1 编写于 作者: V Ville Brofeldt 提交者: Ville Brofeldt

feat: support non-numeric columns in pivot table (#10389)

* fix: support non-numeric columns in pivot table

* bump package and add unit tests

* mypy
上级 bde453be
......@@ -29,7 +29,18 @@ import uuid
from collections import defaultdict, OrderedDict
from datetime import datetime, timedelta
from itertools import product
from typing import Any, cast, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
from typing import (
Any,
Callable,
cast,
Dict,
List,
Optional,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
import dataclasses
import geohash
......@@ -736,6 +747,7 @@ class PivotTableViz(BaseViz):
verbose_name = _("Pivot Table")
credits = 'a <a href="https://github.com/airbnb/superset">Superset</a> original'
is_timeseries = False
enforce_numerical_metrics = False
def query_obj(self) -> QueryObjectDict:
d = super().query_obj()
......@@ -766,6 +778,18 @@ class PivotTableViz(BaseViz):
raise QueryObjectValidationError(_("Group By' and 'Columns' can't overlap"))
return d
@staticmethod
def get_aggfunc(
metric: str, df: pd.DataFrame, form_data: Dict[str, Any]
) -> Union[str, Callable[[Any], Any]]:
aggfunc = form_data.get("pandas_aggfunc") or "sum"
if pd.api.types.is_numeric_dtype(df[metric]):
# Ensure that Pandas's sum function mimics that of SQL.
if aggfunc == "sum":
return lambda x: x.sum(min_count=1)
# only min and max work properly for non-numerics
return aggfunc if aggfunc in ("min", "max") else "max"
def get_data(self, df: pd.DataFrame) -> VizData:
if df.empty:
return None
......@@ -773,22 +797,21 @@ class PivotTableViz(BaseViz):
if self.form_data.get("granularity") == "all" and DTTM_ALIAS in df:
del df[DTTM_ALIAS]
aggfunc = self.form_data.get("pandas_aggfunc") or "sum"
# Ensure that Pandas's sum function mimics that of SQL.
if aggfunc == "sum":
aggfunc = lambda x: x.sum(min_count=1)
metrics = [utils.get_metric_name(m) for m in self.form_data["metrics"]]
aggfuncs: Dict[str, Union[str, Callable[[Any], Any]]] = {}
for metric in metrics:
aggfuncs[metric] = self.get_aggfunc(metric, df, self.form_data)
groupby = self.form_data.get("groupby")
columns = self.form_data.get("columns")
if self.form_data.get("transpose_pivot"):
groupby, columns = columns, groupby
metrics = [utils.get_metric_name(m) for m in self.form_data["metrics"]]
df = df.pivot_table(
index=groupby,
columns=columns,
values=metrics,
aggfunc=aggfunc,
aggfunc=aggfuncs,
margins=self.form_data.get("pivot_margins"),
)
......
......@@ -1284,3 +1284,41 @@ class TestBigNumberViz(SupersetTestCase):
)
data = viz.BigNumberViz(datasource, {"metrics": ["y"]}).get_data(df)
assert np.isnan(data[2]["y"])
class TestPivotTableViz(SupersetTestCase):
df = pd.DataFrame(
data={
"intcol": [1, 2, 3, None],
"floatcol": [0.1, 0.2, 0.3, None],
"strcol": ["a", "b", "c", None],
}
)
def test_get_aggfunc_numeric(self):
# is a sum function
func = viz.PivotTableViz.get_aggfunc("intcol", self.df, {})
assert hasattr(func, "__call__")
assert func(self.df["intcol"]) == 6
assert (
viz.PivotTableViz.get_aggfunc("intcol", self.df, {"pandas_aggfunc": "min"})
== "min"
)
assert (
viz.PivotTableViz.get_aggfunc(
"floatcol", self.df, {"pandas_aggfunc": "max"}
)
== "max"
)
def test_get_aggfunc_non_numeric(self):
assert viz.PivotTableViz.get_aggfunc("strcol", self.df, {}) == "max"
assert (
viz.PivotTableViz.get_aggfunc("strcol", self.df, {"pandas_aggfunc": "sum"})
== "max"
)
assert (
viz.PivotTableViz.get_aggfunc("strcol", self.df, {"pandas_aggfunc": "min"})
== "min"
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册