未验证 提交 26ede6e0 编写于 作者: Z zhulei 提交者: GitHub

Add median api. (#28310)

* Add median api.

* Add median api.

* Add median api.

* Add median api.

* Add median api.
上级 8cd1c102
......@@ -248,6 +248,7 @@ from .tensor.stat import std #DEFINE_ALIAS
from .tensor.stat import var #DEFINE_ALIAS
# from .fluid.data import data
from .tensor.stat import numel #DEFINE_ALIAS
from .tensor.stat import median #DEFINE_ALIAS
from .device import get_cudnn_version
from .device import set_device
from .device import get_device
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
from paddle.static import Program, program_guard
DELTA = 1e-6
class TestMedian(unittest.TestCase):
def check_numpy_res(self, np1, np2):
self.assertEqual(np1.shape, np2.shape)
mismatch = np.sum((np1 - np2) * (np1 - np2))
self.assertAlmostEqual(mismatch, 0, DELTA)
def static_single_test_median(self, lis_test):
paddle.enable_static()
x, axis, keepdims = lis_test
res_np = np.median(x, axis=axis, keepdims=keepdims)
if not isinstance(res_np, np.ndarray):
res_np = np.array([res_np])
main_program = Program()
startup_program = Program()
exe = paddle.static.Executor()
with program_guard(main_program, startup_program):
x_in = paddle.fluid.data(shape=x.shape, dtype=x.dtype, name='x')
y = paddle.median(x_in, axis, keepdims)
[res_pd] = exe.run(feed={'x': x}, fetch_list=[y])
self.check_numpy_res(res_pd, res_np)
paddle.disable_static()
def dygraph_single_test_median(self, lis_test):
x, axis, keepdims = lis_test
res_np = np.median(x, axis=axis, keepdims=keepdims)
if not isinstance(res_np, np.ndarray):
res_np = np.array([res_np])
res_pd = paddle.median(paddle.to_tensor(x), axis, keepdims)
self.check_numpy_res(res_pd.numpy(), res_np)
def test_median_static(self):
h = 3
w = 4
l = 2
x = np.arange(h * w * l).reshape([h, w, l])
lis_tests = [[x, axis, keepdims]
for axis in [-1, 0, 1, 2, None]
for keepdims in [False, True]]
for lis_test in lis_tests:
self.static_single_test_median(lis_test)
def test_median_dygraph(self):
paddle.disable_static()
h = 3
w = 4
l = 2
x = np.arange(h * w * l).reshape([h, w, l])
lis_tests = [[x, axis, keepdims]
for axis in [-1, 0, 1, 2, None]
for keepdims in [False, True]]
for lis_test in lis_tests:
self.dygraph_single_test_median(lis_test)
def test_median_exception(self):
paddle.disable_static()
x = [1, 2, 3, 4]
self.assertRaises(TypeError, paddle.median, x)
x = paddle.arange(12).reshape([3, 4])
self.assertRaises(ValueError, paddle.median, x, 1.0)
self.assertRaises(ValueError, paddle.median, x, 2)
if __name__ == '__main__':
unittest.main()
......@@ -190,6 +190,7 @@ from .stat import mean #DEFINE_ALIAS
from .stat import std #DEFINE_ALIAS
from .stat import var #DEFINE_ALIAS
from .stat import numel #DEFINE_ALIAS
from .stat import median #DEFINE_ALIAS
# from .tensor import Tensor #DEFINE_ALIAS
# from .tensor import LoDTensor #DEFINE_ALIAS
# from .tensor import LoDTensorArray #DEFINE_ALIAS
......
......@@ -14,7 +14,7 @@
# TODO: define statistical functions of a tensor
__all__ = ['mean', 'std', 'var', 'numel']
__all__ = ['mean', 'std', 'var', 'numel', 'median']
import numpy as np
from ..fluid.framework import Variable
......@@ -258,3 +258,89 @@ def numel(x, name=None):
dtype=core.VarDesc.VarType.INT64)
helper.append_op(type='size', inputs={'Input': x}, outputs={'Out': out})
return out
def median(x, axis=None, keepdim=False, name=None):
"""
Compute the median along the specified axis.
Args:
x (Tensor): The input Tensor, it's data type can be bool, float16, float32, float64, int32, int64.
axis (int, optional): The axis along which to perform median calculations ``axis`` should be int.
``axis`` should be in range [-D, D), where D is the dimensions of ``x`` .
If ``axis`` is less than 0, it works the same way as :math:`axis + D`.
If ``axis`` is None, median is calculated over all elements of ``x``. Default is None.
keepdim (bool, optional): Whether to reserve the reduced dimension(s)
in the output Tensor. If ``keepdim`` is True, the dimensions of
the output Tensor is the same as ``x`` except in the reduced
dimensions(it is of size 1 in this case). Otherwise, the shape of
the output Tensor is squeezed in ``axis`` . Default is False.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor, results of median along ``axis`` of ``x``. If data type of ``x`` is float64, data type of results will be float64, otherwise data type will be float32.
Examples:
.. code-block:: python
import paddle
x = paddle.arange(12).reshape([3, 4])
# x is [[0 , 1 , 2 , 3 ],
# [4 , 5 , 6 , 7 ],
# [8 , 9 , 10, 11]]
y1 = paddle.median(x)
# y1 is [5.5]
y2 = paddle.median(x, axis=0)
# y2 is [4., 5., 6., 7.]
y3 = paddle.median(x, axis=1)
# y3 is [1.5, 5.5, 9.5]
y4 = paddle.median(x, axis=0, keepdim=True)
# y4 is [[4., 5., 6., 7.]]
"""
if not isinstance(x, Variable):
raise TypeError("In median, the input x should be a Tensor.")
is_flatten = axis is None
dims = len(x.shape)
if is_flatten:
x = paddle.flatten(x)
axis = 0
else:
if not isinstance(axis, int) or not (axis < dims and axis >= -dims):
raise ValueError(
"In median, axis should be none or an integer in range [-rank(x), rank(x))."
)
if axis < 0:
axis += dims
sz = x.shape[axis]
kth = sz >> 1
tensor_topk, idx = paddle.topk(x, kth + 1, axis=axis, largest=False)
dtype = 'float64' if x.dtype == core.VarDesc.VarType.FP64 else 'float32'
if sz & 1 == 0:
out_tensor = paddle.slice(
tensor_topk, axes=[axis], starts=[kth - 1],
ends=[kth]) + paddle.slice(
tensor_topk, axes=[axis], starts=[kth], ends=[kth + 1])
out_tensor = paddle.cast(out_tensor, dtype=dtype) / 2
else:
out_tensor = paddle.cast(
paddle.slice(
tensor_topk, axes=[axis], starts=[kth], ends=[kth + 1]),
dtype=dtype)
if not keepdim or is_flatten:
if not is_flatten:
newshape = x.shape[:axis] + x.shape[axis + 1:]
elif not keepdim:
newshape = [1]
else:
newshape = [1] * dims
else:
newshape = out_tensor.shape
out_tensor = out_tensor.reshape(newshape, name=name)
return out_tensor
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册