未验证 提交 0c33c47e 编写于 作者: R ronnywang 提交者: GitHub

fix paddle.median torch diff (#40118)

上级 fe1cc8bd
......@@ -321,6 +321,9 @@ def median(x, axis=None, keepdim=False, name=None):
paddle.slice(
tensor_topk, axes=[axis], starts=[kth], ends=[kth + 1]),
dtype=dtype)
out_tensor = out_tensor + paddle.sum(
paddle.cast(
paddle.isnan(x), dtype=dtype) * x, axis=axis, keepdim=True)
if not keepdim or is_flatten:
if not is_flatten:
newshape = x.shape[:axis] + x.shape[axis + 1:]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册