未验证 提交 6e5670b8 编写于 作者: Z zhupengyang 提交者: GitHub

mean: not support int32, int64; add check for axis (#26401)

上级 6e6567f3
...@@ -103,11 +103,7 @@ REGISTER_OP_CPU_KERNEL(reduce_mean, ...@@ -103,11 +103,7 @@ REGISTER_OP_CPU_KERNEL(reduce_mean,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, ops::ReduceKernel<paddle::platform::CPUDeviceContext,
float, ops::MeanFunctor>, float, ops::MeanFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, ops::ReduceKernel<paddle::platform::CPUDeviceContext,
double, ops::MeanFunctor>, double, ops::MeanFunctor>);
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
int, ops::MeanFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::MeanFunctor>);
template <typename T> template <typename T>
using CPUReduceMeanGradKernel = using CPUReduceMeanGradKernel =
...@@ -115,6 +111,4 @@ using CPUReduceMeanGradKernel = ...@@ -115,6 +111,4 @@ using CPUReduceMeanGradKernel =
ops::MeanGradFunctor, true>; ops::MeanGradFunctor, true>;
REGISTER_OP_CPU_KERNEL(reduce_mean_grad, CPUReduceMeanGradKernel<float>, REGISTER_OP_CPU_KERNEL(reduce_mean_grad, CPUReduceMeanGradKernel<float>,
CPUReduceMeanGradKernel<double>, CPUReduceMeanGradKernel<double>);
CPUReduceMeanGradKernel<int>,
CPUReduceMeanGradKernel<int64_t>);
...@@ -66,6 +66,4 @@ class ReduceMeanKernel : public framework::OpKernel<T> { ...@@ -66,6 +66,4 @@ class ReduceMeanKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
REGISTER_OP_CUDA_KERNEL(reduce_mean, ops::ReduceMeanKernel<float>, REGISTER_OP_CUDA_KERNEL(reduce_mean, ops::ReduceMeanKernel<float>,
ops::ReduceMeanKernel<double>, ops::ReduceMeanKernel<double>);
ops::ReduceMeanKernel<int>,
ops::ReduceMeanKernel<int64_t>);
...@@ -334,6 +334,12 @@ class ReduceOp : public framework::OperatorWithKernel { ...@@ -334,6 +334,12 @@ class ReduceOp : public framework::OperatorWithKernel {
"range [-dimension(X), dimension(X)] " "range [-dimension(X), dimension(X)] "
"which dimesion = %d. But received dim index = %d.", "which dimesion = %d. But received dim index = %d.",
i, x_rank, dims[i])); i, x_rank, dims[i]));
PADDLE_ENFORCE_GE(dims[i], -x_rank,
platform::errors::InvalidArgument(
"The reduce dim index %d should be in the "
"range [-dimension(X), dimension(X)] "
"which dimesion = %d. But received dim index = %d.",
i, x_rank, dims[i]));
if (dims[i] < 0) dims[i] = x_rank + dims[i]; if (dims[i] < 0) dims[i] = x_rank + dims[i];
} }
sort(dims.begin(), dims.end()); sort(dims.begin(), dims.end());
......
...@@ -129,9 +129,14 @@ class TestMeanAPI(unittest.TestCase): ...@@ -129,9 +129,14 @@ class TestMeanAPI(unittest.TestCase):
paddle.enable_static() paddle.enable_static()
def test_errors(self): def test_errors(self):
paddle.disable_static()
x = np.random.uniform(-1, 1, [10, 12]).astype('float32')
x = paddle.to_tensor(x)
self.assertRaises(Exception, paddle.mean, x, -3)
self.assertRaises(Exception, paddle.mean, x, 2)
paddle.enable_static() paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data('X', [10, 12], 'int8') x = paddle.data('X', [10, 12], 'int32')
self.assertRaises(TypeError, paddle.mean, x) self.assertRaises(TypeError, paddle.mean, x)
......
...@@ -32,8 +32,7 @@ def mean(x, axis=None, keepdim=False, name=None): ...@@ -32,8 +32,7 @@ def mean(x, axis=None, keepdim=False, name=None):
Computes the mean of the input tensor's elements along ``axis``. Computes the mean of the input tensor's elements along ``axis``.
Args: Args:
x (Tensor): The input Tensor with data type float32, float64, int32, x (Tensor): The input Tensor with data type float32, float64.
int64.
axis (int|list|tuple, optional): The axis along which to perform mean axis (int|list|tuple, optional): The axis along which to perform mean
calculations. ``axis`` should be int, list(int) or tuple(int). If calculations. ``axis`` should be int, list(int) or tuple(int). If
``axis`` is a list/tuple of dimension(s), mean is calculated along ``axis`` is a list/tuple of dimension(s), mean is calculated along
...@@ -97,9 +96,12 @@ def mean(x, axis=None, keepdim=False, name=None): ...@@ -97,9 +96,12 @@ def mean(x, axis=None, keepdim=False, name=None):
return core.ops.reduce_mean(x, 'dim', axis, 'keep_dim', keepdim, return core.ops.reduce_mean(x, 'dim', axis, 'keep_dim', keepdim,
'reduce_all', reduce_all) 'reduce_all', reduce_all)
check_variable_and_dtype(x, 'x/input', check_variable_and_dtype(x, 'x/input', ['float32', 'float64'],
['float32', 'float64', 'int32', 'int64'],
'mean/reduce_mean') 'mean/reduce_mean')
check_type(axis, 'axis/dim', (int, list, tuple), 'mean/reduce_mean')
if isinstance(axis, (list, tuple)):
for item in axis:
check_type(item, 'elements of axis/dim', (int), 'mean/reduce_mean')
helper = LayerHelper('mean', **locals()) helper = LayerHelper('mean', **locals())
attrs = {'dim': axis, 'keep_dim': keepdim, 'reduce_all': reduce_all} attrs = {'dim': axis, 'keep_dim': keepdim, 'reduce_all': reduce_all}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册