From 0c33c47ee752befb54b6a16f6608cb3c411506d9 Mon Sep 17 00:00:00 2001 From: ronnywang <524019753@qq.com> Date: Tue, 8 Mar 2022 10:21:48 +0800 Subject: [PATCH] fix paddle.median torch diff (#40118) --- python/paddle/tensor/stat.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 468aa46048..dd0da03e4f 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -321,6 +321,9 @@ def median(x, axis=None, keepdim=False, name=None): paddle.slice( tensor_topk, axes=[axis], starts=[kth], ends=[kth + 1]), dtype=dtype) + out_tensor = out_tensor + paddle.sum( + paddle.cast( + paddle.isnan(x), dtype=dtype) * x, axis=axis, keepdim=True) if not keepdim or is_flatten: if not is_flatten: newshape = x.shape[:axis] + x.shape[axis + 1:] -- GitLab