未验证 提交 c8fc3379 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] support input 0D Tensor for reduce_sum/reduce_mean (#47219)

上级 81b93ebb
...@@ -2685,7 +2685,7 @@ DDim ReduceInferDim(const MetaTensor& x, ...@@ -2685,7 +2685,7 @@ DDim ReduceInferDim(const MetaTensor& x,
bool full_dim = true; bool full_dim = true;
std::set<int64_t> dims_set(formated_axis.begin(), formated_axis.end()); std::set<int64_t> dims_set(formated_axis.begin(), formated_axis.end());
for (int64_t i = 0; i < x.dims().size(); ++i) { for (int64_t i = 0; i < x_rank; ++i) {
if (dims_set.find(i) == dims_set.end()) { if (dims_set.find(i) == dims_set.end()) {
full_dim = false; full_dim = false;
break; break;
...@@ -2695,7 +2695,7 @@ DDim ReduceInferDim(const MetaTensor& x, ...@@ -2695,7 +2695,7 @@ DDim ReduceInferDim(const MetaTensor& x,
std::vector<int64_t> out_dim_vector; std::vector<int64_t> out_dim_vector;
if (keep_dim) { if (keep_dim) {
for (int64_t i = 0; i < x.dims().size(); ++i) { for (int64_t i = 0; i < x_rank; ++i) {
if (reduce_all || dims_set.find(i) != dims_set.end()) { if (reduce_all || dims_set.find(i) != dims_set.end()) {
out_dim_vector.push_back(1); out_dim_vector.push_back(1);
} else { } else {
...@@ -2703,7 +2703,7 @@ DDim ReduceInferDim(const MetaTensor& x, ...@@ -2703,7 +2703,7 @@ DDim ReduceInferDim(const MetaTensor& x,
} }
} }
} else { } else {
for (int64_t i = 0; i < x.dims().size(); ++i) { for (int64_t i = 0; i < x_rank; ++i) {
if (reduce_all || dims_set.find(i) != dims_set.end()) { if (reduce_all || dims_set.find(i) != dims_set.end()) {
continue; continue;
} else { } else {
...@@ -2711,7 +2711,7 @@ DDim ReduceInferDim(const MetaTensor& x, ...@@ -2711,7 +2711,7 @@ DDim ReduceInferDim(const MetaTensor& x,
} }
} }
if (out_dim_vector.size() == 0) { if (x_rank > 0 && out_dim_vector.size() == 0) {
out_dim_vector.push_back(1); out_dim_vector.push_back(1);
} }
} }
...@@ -3013,6 +3013,7 @@ void SetValueInferMeta(const MetaTensor& x, MetaTensor* out) { ...@@ -3013,6 +3013,7 @@ void SetValueInferMeta(const MetaTensor& x, MetaTensor* out) {
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"The rank of input should be less than 7, but received %d.", "The rank of input should be less than 7, but received %d.",
in_dims.size())); in_dims.size()));
out->set_dims(in_dims);
} }
void ShapeInferMeta(const MetaTensor& input, MetaTensor* out) { void ShapeInferMeta(const MetaTensor& input, MetaTensor* out) {
......
...@@ -44,7 +44,7 @@ struct DimensionsTransform { ...@@ -44,7 +44,7 @@ struct DimensionsTransform {
int64_t in_idx = 0; int64_t in_idx = 0;
if (in_dim.size() < dim_size) { if (in_dim.size() < dim_size) {
DimVector tmp_dim(dim_size, 1); DimVector tmp_dim(dim_size, 1);
do { for (; in_idx < in_dim.size();) {
if (in_dim[in_idx] == out_dims[axis] || in_dim[in_idx] == 1) { if (in_dim[in_idx] == out_dims[axis] || in_dim[in_idx] == 1) {
tmp_dim[axis] = in_dim[in_idx]; tmp_dim[axis] = in_dim[in_idx];
in_idx++; in_idx++;
...@@ -59,11 +59,11 @@ struct DimensionsTransform { ...@@ -59,11 +59,11 @@ struct DimensionsTransform {
out_dims[axis], out_dims[axis],
in_dim[in_idx])); in_dim[in_idx]));
} }
} while (in_idx < in_dim.size()); }
in_dim.resize(dim_size); in_dim.resize(dim_size);
std::copy(tmp_dim.begin(), tmp_dim.end(), in_dim.begin()); std::copy(tmp_dim.begin(), tmp_dim.end(), in_dim.begin());
} else { } else {
do { for (; in_idx < dim_size;) {
if (in_dim[in_idx] == out_dims[in_idx] || in_dim[in_idx] == 1) { if (in_dim[in_idx] == out_dims[in_idx] || in_dim[in_idx] == 1) {
in_idx++; in_idx++;
} else { } else {
...@@ -76,7 +76,7 @@ struct DimensionsTransform { ...@@ -76,7 +76,7 @@ struct DimensionsTransform {
out_dims[in_idx], out_dims[in_idx],
in_dim[in_idx])); in_dim[in_idx]));
} }
} while (in_idx < dim_size); }
} }
std::reverse(in_dim.begin(), in_dim.end()); std::reverse(in_dim.begin(), in_dim.end());
} }
......
...@@ -1063,6 +1063,14 @@ void ReduceKernel(const KPDevice& dev_ctx, ...@@ -1063,6 +1063,14 @@ void ReduceKernel(const KPDevice& dev_ctx,
dev_ctx.Alloc<Ty>(y); dev_ctx.Alloc<Ty>(y);
auto x_dim = phi::vectorize<int>(x.dims()); auto x_dim = phi::vectorize<int>(x.dims());
if (x_dim.size() == 0) {
std::vector<const DenseTensor*> inputs = {&x};
std::vector<DenseTensor*> outputs = {y};
funcs::ElementwiseKernel<Ty>(dev_ctx, inputs, &outputs, transform);
return;
}
auto config = ReduceConfig<Ty>(origin_reduce_dims, x_dim); auto config = ReduceConfig<Ty>(origin_reduce_dims, x_dim);
config.Run(dev_ctx); config.Run(dev_ctx);
int numel = x.numel(); int numel = x.numel();
......
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/reduce_function.h" #include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/gpu/reduce_grad.h"
namespace phi { namespace phi {
...@@ -29,23 +29,34 @@ void ReduceMeanGradKernel(const Context& dev_ctx, ...@@ -29,23 +29,34 @@ void ReduceMeanGradKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* x_grad) { DenseTensor* x_grad) {
// get reduce_dim and reduce_num for reduce_mean_grad
int dim_size = x.dims().size(); int dim_size = x.dims().size();
if (dims.size() == 0) {
reduce_all = true;
}
std::vector<int> reduce_dims = std::vector<int> reduce_dims =
funcs::details::GetReduceDim(dims.GetData(), dim_size, reduce_all); funcs::details::GetReduceDim(dims.GetData(), dim_size, reduce_all);
auto update_dims = vectorize(x.dims());
int reduce_num = 1; int reduce_num = 1;
for (auto i : reduce_dims) { for (auto i : reduce_dims) {
reduce_num *= (x.dims())[i]; reduce_num *= (x.dims())[i];
update_dims[i] = 1;
} }
// make new tensor
DenseTensor new_out_grad(out_grad.dtype());
new_out_grad.ShareDataWith(out_grad);
new_out_grad.Resize(phi::make_ddim(update_dims));
// call BroadcastKernel
dev_ctx.Alloc(x_grad, x.dtype());
std::vector<const DenseTensor*> inputs = {&new_out_grad};
std::vector<DenseTensor*> outputs = {x_grad};
using MPType = typename kps::details::MPTypeTrait<T>::Type; using MPType = typename kps::details::MPTypeTrait<T>::Type;
ReduceGradKernel<T, T, Context, kps::DivideFunctor<T, MPType>>( funcs::BroadcastKernel<phi::ElementwiseType::kUnary, T, T>(
dev_ctx, dev_ctx, inputs, &outputs, 0, kps::DivideFunctor<T, MPType>(reduce_num));
x,
out_grad,
dims.GetData(),
keep_dim,
reduce_all,
x_grad,
kps::DivideFunctor<T, MPType>(reduce_num));
} }
} // namespace phi } // namespace phi
......
...@@ -29,42 +29,32 @@ void ReduceSumGradKernel(const Context& dev_ctx, ...@@ -29,42 +29,32 @@ void ReduceSumGradKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* x_grad) { DenseTensor* x_grad) {
using MPType = typename kps::details::MPTypeTrait<T>::Type; // get reduce_dim for reduce_mean_grad
auto out_dtype = x.dtype(); int dim_size = x.dims().size();
auto* in_x = &x;
auto* d_out = &out_grad;
auto* d_x = x_grad;
// get reduce_dim and reduce_num for reduce_mean_grad
int dim_size = in_x->dims().size();
if (dims.size() == 0) { if (dims.size() == 0) {
reduce_all = true; reduce_all = true;
} }
std::vector<int> reduce_dims = std::vector<int> reduce_dims =
funcs::details::GetReduceDim(dims.GetData(), dim_size, reduce_all); funcs::details::GetReduceDim(dims.GetData(), dim_size, reduce_all);
auto update_dims = vectorize(d_x->dims()); auto update_dims = vectorize(x.dims());
int reduce_num = 1;
for (auto i : reduce_dims) { for (auto i : reduce_dims) {
reduce_num *= (in_x->dims())[i];
update_dims[i] = 1; update_dims[i] = 1;
} }
// make new tensor // make new tensor
DenseTensor new_d_out(d_out->dtype()); DenseTensor new_out_grad(out_grad.dtype());
new_d_out.ShareDataWith(*d_out); new_out_grad.ShareDataWith(out_grad);
new_d_out.Resize(phi::make_ddim(update_dims)); new_out_grad.Resize(phi::make_ddim(update_dims));
dev_ctx.Alloc(d_x, x.dtype()); // call ReduceGrad
auto pt_out_dtype = x.dtype(); dev_ctx.Alloc(x_grad, x.dtype());
auto pt_d_out = new_d_out; using MPType = typename kps::details::MPTypeTrait<T>::Type;
auto pt_d_x = *d_x;
std::vector<const DenseTensor*> inputs = {&pt_d_out};
std::vector<DenseTensor*> outputs = {&pt_d_x};
phi::ReduceGrad<T, kps::IdentityFunctor<T, MPType>>( phi::ReduceGrad<T, kps::IdentityFunctor<T, MPType>>(
dev_ctx, dev_ctx,
&pt_d_out, &new_out_grad,
&pt_d_x, x_grad,
pt_out_dtype, x.dtype(),
kps::IdentityFunctor<T, MPType>()); kps::IdentityFunctor<T, MPType>());
} }
......
...@@ -26,6 +26,9 @@ void MeanKernel(const Context& dev_ctx, ...@@ -26,6 +26,9 @@ void MeanKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
DenseTensor* out) { DenseTensor* out) {
bool reduce_all = false; bool reduce_all = false;
if (dims.size() == 0) {
reduce_all = true;
}
MeanRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out); MeanRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
} }
......
...@@ -5096,9 +5096,6 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None): ...@@ -5096,9 +5096,6 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None):
fluid.layers.reduce_sum(y, dim=[0, 1]) # [16, 20] fluid.layers.reduce_sum(y, dim=[0, 1]) # [16, 20]
""" """
if dim is not None and not isinstance(dim, list):
dim = [dim]
reduce_all, dim = _get_reduce_dim(dim, input) reduce_all, dim = _get_reduce_dim(dim, input)
if in_dygraph_mode(): if in_dygraph_mode():
......
...@@ -58,6 +58,21 @@ class TestMeanOp(OpTest): ...@@ -58,6 +58,21 @@ class TestMeanOp(OpTest):
self.check_grad(['X'], 'Out', check_eager=True) self.check_grad(['X'], 'Out', check_eager=True)
class TestMeanOp_ZeroDim(OpTest):
def setUp(self):
self.op_type = "mean"
self.python_api = paddle.mean
self.dtype = np.float64
self.inputs = {'X': np.random.random([]).astype(self.dtype)}
self.outputs = {'Out': np.mean(self.inputs["X"])}
def test_check_output(self):
self.check_output(check_eager=True)
def test_checkout_grad(self):
self.check_grad(['X'], 'Out', check_eager=True)
class TestMeanOpError(unittest.TestCase): class TestMeanOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
......
...@@ -37,6 +37,21 @@ class TestSumOp(OpTest): ...@@ -37,6 +37,21 @@ class TestSumOp(OpTest):
self.check_grad(['X'], 'Out', check_eager=True) self.check_grad(['X'], 'Out', check_eager=True)
class TestSumOp_ZeroDim(OpTest):
def setUp(self):
self.python_api = paddle.sum
self.op_type = "reduce_sum"
self.inputs = {'X': np.random.random([]).astype("float64")}
self.outputs = {'Out': self.inputs['X'].sum(axis=None)}
self.attrs = {'dim': [], 'reduce_all': True}
def test_check_output(self):
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=True)
class TestSumOp_fp16(OpTest): class TestSumOp_fp16(OpTest):
def setUp(self): def setUp(self):
self.python_api = paddle.sum self.python_api = paddle.sum
......
...@@ -17,6 +17,7 @@ import paddle.fluid as fluid ...@@ -17,6 +17,7 @@ import paddle.fluid as fluid
import numpy as np import numpy as np
import unittest import unittest
unary_api_list = [ unary_api_list = [
paddle.nn.functional.elu, paddle.nn.functional.elu,
paddle.nn.functional.gelu, paddle.nn.functional.gelu,
...@@ -159,5 +160,55 @@ class TestUnaryAPI(unittest.TestCase): ...@@ -159,5 +160,55 @@ class TestUnaryAPI(unittest.TestCase):
paddle.disable_static() paddle.disable_static()
reduce_api_list = [
paddle.sum,
paddle.mean,
paddle.nansum,
paddle.nanmean,
]
class TestReduceAPI(unittest.TestCase):
def test_dygraph(self):
paddle.disable_static()
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
for api in reduce_api_list:
x = paddle.rand([])
x.stop_gradient = False
out = api(x, None)
out.backward()
self.assertEqual(x.shape, [])
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])
paddle.enable_static()
def test_static(self):
paddle.enable_static()
for api in reduce_api_list:
main_prog = fluid.Program()
with fluid.program_guard(main_prog, fluid.Program()):
x = paddle.rand([])
x.stop_gradient = False
out = api(x, None)
fluid.backward.append_backward(out)
# Test compile shape, grad is always [1]
self.assertEqual(x.shape, ())
self.assertEqual(out.shape, ())
exe = fluid.Executor()
result = exe.run(main_prog, fetch_list=[x, out])
# Test runtime shape
self.assertEqual(result[0].shape, ())
self.assertEqual(result[1].shape, ())
paddle.disable_static()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -1265,22 +1265,7 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None): ...@@ -1265,22 +1265,7 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
out8 = paddle.sum(x, axis=0) # [1, 1, 1, 1] out8 = paddle.sum(x, axis=0) # [1, 1, 1, 1]
out9 = paddle.sum(x, axis=1) # [4, 0] out9 = paddle.sum(x, axis=1) # [4, 0]
""" """
if isinstance(axis, Variable): reduce_all, axis = _get_reduce_axis_with_tensor(axis, x)
reduce_all_flag = True if axis.shape[0] == len(x.shape) else False
else:
if axis is not None and not isinstance(axis, (list, tuple)):
axis = [axis]
if not axis:
axis = []
if len(axis) == 0:
reduce_all_flag = True
else:
if len(axis) == len(x.shape):
reduce_all_flag = True
else:
reduce_all_flag = False
dtype_flag = False dtype_flag = False
if dtype is not None: if dtype is not None:
...@@ -1290,11 +1275,6 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None): ...@@ -1290,11 +1275,6 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.sum(x, axis, dtype, keepdim) return _C_ops.sum(x, axis, dtype, keepdim)
if not isinstance(axis, Variable):
axis = axis if axis != None and axis != [] and axis != () else [0]
if utils._contain_var(axis):
axis = utils._convert_to_tensor_list(axis)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
if dtype_flag: if dtype_flag:
return _legacy_C_ops.reduce_sum( return _legacy_C_ops.reduce_sum(
...@@ -1304,7 +1284,7 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None): ...@@ -1304,7 +1284,7 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
'keep_dim', 'keep_dim',
keepdim, keepdim,
'reduce_all', 'reduce_all',
reduce_all_flag, reduce_all,
'in_dtype', 'in_dtype',
x.dtype, x.dtype,
'out_dtype', 'out_dtype',
...@@ -1318,10 +1298,10 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None): ...@@ -1318,10 +1298,10 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
'keep_dim', 'keep_dim',
keepdim, keepdim,
'reduce_all', 'reduce_all',
reduce_all_flag, reduce_all,
) )
attrs = {'dim': axis, 'keep_dim': keepdim, 'reduce_all': reduce_all_flag} attrs = {'dim': axis, 'keep_dim': keepdim, 'reduce_all': reduce_all}
if dtype_flag: if dtype_flag:
attrs.update({'in_dtype': x.dtype, 'out_dtype': dtype}) attrs.update({'in_dtype': x.dtype, 'out_dtype': dtype})
...@@ -2304,13 +2284,13 @@ def inverse(x, name=None): ...@@ -2304,13 +2284,13 @@ def inverse(x, name=None):
return out return out
def _get_reduce_axis(axis): def _get_reduce_axis(axis, x):
""" """
Internal function for max, min, amax and amin. Internal function for max, min, amax and amin.
It computes the attribute reduce_all value based on axis. It computes the attribute reduce_all value based on axis.
""" """
if axis is not None and not isinstance(axis, list): if axis is not None and not isinstance(axis, list):
if isinstance(axis, tuple): if isinstance(axis, (tuple, range)):
axis = list(axis) axis = list(axis)
elif isinstance(axis, int): elif isinstance(axis, int):
axis = [axis] axis = [axis]
...@@ -2320,37 +2300,25 @@ def _get_reduce_axis(axis): ...@@ -2320,37 +2300,25 @@ def _get_reduce_axis(axis):
type(axis) type(axis)
) )
) )
reduce_all = True if axis == None or axis == [] else False if axis is None:
if axis == None:
axis = [] axis = []
if axis == [] or len(axis) == len(x.shape):
reduce_all = True
else:
reduce_all = False
return reduce_all, axis return reduce_all, axis
def _get_reduce_axis_with_tensor(axis): def _get_reduce_axis_with_tensor(axis, x):
if isinstance(axis, Variable): if isinstance(axis, Variable):
return False, axis if axis.shape[0] == len(x.shape):
return _get_reduce_axis(axis) reduce_all = True
def _get_reduce_all_value(axis):
"""
Internal function for max, min, amax and amin.
It computes the attribute reduce_all value based on axis.
"""
if axis is not None and not isinstance(axis, list):
if isinstance(axis, tuple):
axis = list(axis)
elif isinstance(axis, int):
axis = [axis]
else: else:
raise TypeError( reduce_all = False
"The type of axis must be int, list or tuple, but received {}".format( else:
type(axis) reduce_all, axis = _get_reduce_axis(axis, x)
) if utils._contain_var(axis):
) axis = utils._convert_to_tensor_list(axis)
reduce_all = True if axis == None or axis == [] else False
axis = axis if axis != None and axis != [] else [0]
return reduce_all, axis return reduce_all, axis
...@@ -2432,7 +2400,7 @@ def max(x, axis=None, keepdim=False, name=None): ...@@ -2432,7 +2400,7 @@ def max(x, axis=None, keepdim=False, name=None):
#[7., 8.], [[[0., 0.], [0., 0.]], [[0., 0.], [1., 1.]]] #[7., 8.], [[[0., 0.], [0., 0.]], [[0., 0.], [1., 1.]]]
""" """
reduce_all, axis = _get_reduce_axis_with_tensor(axis) reduce_all, axis = _get_reduce_axis_with_tensor(axis, x)
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.max(x, axis, keepdim) return _C_ops.max(x, axis, keepdim)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
...@@ -2534,7 +2502,7 @@ def min(x, axis=None, keepdim=False, name=None): ...@@ -2534,7 +2502,7 @@ def min(x, axis=None, keepdim=False, name=None):
#[1., 2.], [[[1., 1.], [0., 0.]], [[0., 0.], [0., 0.]]] #[1., 2.], [[[1., 1.], [0., 0.]], [[0., 0.], [0., 0.]]]
""" """
reduce_all, axis = _get_reduce_axis_with_tensor(axis) reduce_all, axis = _get_reduce_axis_with_tensor(axis, x)
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.min(x, axis, keepdim) return _C_ops.min(x, axis, keepdim)
...@@ -2650,7 +2618,7 @@ def amax(x, axis=None, keepdim=False, name=None): ...@@ -2650,7 +2618,7 @@ def amax(x, axis=None, keepdim=False, name=None):
#[0.9., 0.9], [[[0., 0.3333], [0.5, 0.3333]], [[0.5, 0.3333], [1., 1.]]] #[0.9., 0.9], [[[0., 0.3333], [0.5, 0.3333]], [[0.5, 0.3333], [1., 1.]]]
""" """
reduce_all, axis = _get_reduce_axis(axis) reduce_all, axis = _get_reduce_axis(axis, x)
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.amax(x, axis, keepdim) return _C_ops.amax(x, axis, keepdim)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
...@@ -2764,7 +2732,7 @@ def amin(x, axis=None, keepdim=False, name=None): ...@@ -2764,7 +2732,7 @@ def amin(x, axis=None, keepdim=False, name=None):
#[0.1., 0.1], [[[0., 0.3333], [0.5, 0.3333]], [[0.5, 0.3333], [1., 1.]]] #[0.1., 0.1], [[[0., 0.3333], [0.5, 0.3333]], [[0.5, 0.3333], [1., 1.]]]
""" """
reduce_all, axis = _get_reduce_axis(axis) reduce_all, axis = _get_reduce_axis(axis, x)
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.amin(x, axis, keepdim) return _C_ops.amin(x, axis, keepdim)
elif _in_legacy_dygraph(): elif _in_legacy_dygraph():
......
...@@ -20,9 +20,9 @@ from ..framework import core ...@@ -20,9 +20,9 @@ from ..framework import core
from paddle.fluid.framework import _in_legacy_dygraph, in_dygraph_mode from paddle.fluid.framework import _in_legacy_dygraph, in_dygraph_mode
from .search import where from .search import where
from ..fluid.data_feeder import check_type, check_variable_and_dtype from ..fluid.data_feeder import check_type, check_variable_and_dtype
from ..fluid.layers import utils
import paddle import paddle
from paddle import _C_ops, _legacy_C_ops from paddle import _C_ops, _legacy_C_ops
from .math import _get_reduce_axis_with_tensor
__all__ = [] __all__ = []
...@@ -80,22 +80,9 @@ def mean(x, axis=None, keepdim=False, name=None): ...@@ -80,22 +80,9 @@ def mean(x, axis=None, keepdim=False, name=None):
# [ 8.5 12.5 16.5] # [ 8.5 12.5 16.5]
""" """
if isinstance(axis, Variable): reduce_all, axis = _get_reduce_axis_with_tensor(axis, x)
reduce_all = True if axis.shape[0] == len(x.shape) else False
else:
if isinstance(axis, int):
axis = [axis]
reduce_all = (
True
if axis is None or len(axis) == 0 or len(axis) == len(x.shape)
else False
)
if axis is None or len(axis) == 0:
axis = [0]
if in_dygraph_mode(): if in_dygraph_mode():
if reduce_all:
axis = list(range(len(x.shape)))
return _C_ops.mean(x, axis, keepdim) return _C_ops.mean(x, axis, keepdim)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _legacy_C_ops.reduce_mean( return _legacy_C_ops.reduce_mean(
...@@ -122,8 +109,6 @@ def mean(x, axis=None, keepdim=False, name=None): ...@@ -122,8 +109,6 @@ def mean(x, axis=None, keepdim=False, name=None):
helper = LayerHelper('mean', **locals()) helper = LayerHelper('mean', **locals())
if not isinstance(axis, Variable) and utils._contain_var(axis):
axis = utils._convert_to_tensor_list(axis)
attrs = {'dim': axis, 'keep_dim': keepdim, 'reduce_all': reduce_all} attrs = {'dim': axis, 'keep_dim': keepdim, 'reduce_all': reduce_all}
out = helper.create_variable_for_type_inference(x.dtype) out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op( helper.append_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册