diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 468aa460486275f78d240bdee40b9d73f07dbcda..dd0da03e4fd2816bcd28ea76cc4a0712451c3e39 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -321,6 +321,9 @@ def median(x, axis=None, keepdim=False, name=None): paddle.slice( tensor_topk, axes=[axis], starts=[kth], ends=[kth + 1]), dtype=dtype) + out_tensor = out_tensor + paddle.sum( + paddle.cast( + paddle.isnan(x), dtype=dtype) * x, axis=axis, keepdim=True) if not keepdim or is_flatten: if not is_flatten: newshape = x.shape[:axis] + x.shape[axis + 1:]