未验证 提交 6e5670b8 编写于 作者: Z zhupengyang 提交者: GitHub

mean: not support int32, int64; add check for axis (#26401)

上级 6e6567f3
......@@ -103,11 +103,7 @@ REGISTER_OP_CPU_KERNEL(reduce_mean,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
float, ops::MeanFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
double, ops::MeanFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
int, ops::MeanFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::MeanFunctor>);
double, ops::MeanFunctor>);
template <typename T>
using CPUReduceMeanGradKernel =
......@@ -115,6 +111,4 @@ using CPUReduceMeanGradKernel =
ops::MeanGradFunctor, true>;
REGISTER_OP_CPU_KERNEL(reduce_mean_grad, CPUReduceMeanGradKernel<float>,
CPUReduceMeanGradKernel<double>,
CPUReduceMeanGradKernel<int>,
CPUReduceMeanGradKernel<int64_t>);
CPUReduceMeanGradKernel<double>);
......@@ -66,6 +66,4 @@ class ReduceMeanKernel : public framework::OpKernel<T> {
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(reduce_mean, ops::ReduceMeanKernel<float>,
ops::ReduceMeanKernel<double>,
ops::ReduceMeanKernel<int>,
ops::ReduceMeanKernel<int64_t>);
ops::ReduceMeanKernel<double>);
......@@ -334,6 +334,12 @@ class ReduceOp : public framework::OperatorWithKernel {
"range [-dimension(X), dimension(X)] "
"which dimesion = %d. But received dim index = %d.",
i, x_rank, dims[i]));
PADDLE_ENFORCE_GE(dims[i], -x_rank,
platform::errors::InvalidArgument(
"The reduce dim index %d should be in the "
"range [-dimension(X), dimension(X)] "
"which dimesion = %d. But received dim index = %d.",
i, x_rank, dims[i]));
if (dims[i] < 0) dims[i] = x_rank + dims[i];
}
sort(dims.begin(), dims.end());
......
......@@ -129,9 +129,14 @@ class TestMeanAPI(unittest.TestCase):
paddle.enable_static()
def test_errors(self):
paddle.disable_static()
x = np.random.uniform(-1, 1, [10, 12]).astype('float32')
x = paddle.to_tensor(x)
self.assertRaises(Exception, paddle.mean, x, -3)
self.assertRaises(Exception, paddle.mean, x, 2)
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data('X', [10, 12], 'int8')
x = paddle.data('X', [10, 12], 'int32')
self.assertRaises(TypeError, paddle.mean, x)
......
......@@ -32,8 +32,7 @@ def mean(x, axis=None, keepdim=False, name=None):
Computes the mean of the input tensor's elements along ``axis``.
Args:
x (Tensor): The input Tensor with data type float32, float64, int32,
int64.
x (Tensor): The input Tensor with data type float32, float64.
axis (int|list|tuple, optional): The axis along which to perform mean
calculations. ``axis`` should be int, list(int) or tuple(int). If
``axis`` is a list/tuple of dimension(s), mean is calculated along
......@@ -97,9 +96,12 @@ def mean(x, axis=None, keepdim=False, name=None):
return core.ops.reduce_mean(x, 'dim', axis, 'keep_dim', keepdim,
'reduce_all', reduce_all)
check_variable_and_dtype(x, 'x/input',
['float32', 'float64', 'int32', 'int64'],
check_variable_and_dtype(x, 'x/input', ['float32', 'float64'],
'mean/reduce_mean')
check_type(axis, 'axis/dim', (int, list, tuple), 'mean/reduce_mean')
if isinstance(axis, (list, tuple)):
for item in axis:
check_type(item, 'elements of axis/dim', (int), 'mean/reduce_mean')
helper = LayerHelper('mean', **locals())
attrs = {'dim': axis, 'keep_dim': keepdim, 'reduce_all': reduce_all}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册