未验证 提交 bcf75132 编写于 作者: W wanghuancoder 提交者: GitHub

do not calc reduce_all in eager mode (#48199)

* do not calc reduce_all in eager mode

* refine python c cast list

* refine

* refine

* refine

* refine

* refine

* refine

* refine

* refine

* refine
上级 0a9c1f59
...@@ -1462,7 +1462,7 @@ static PyObject* tensor_method_set_string_list(TensorObject* self, ...@@ -1462,7 +1462,7 @@ static PyObject* tensor_method_set_string_list(TensorObject* self,
PyObject* kwargs) { PyObject* kwargs) {
EAGER_TRY EAGER_TRY
using Strings = std::vector<std::string>; using Strings = std::vector<std::string>;
auto strings = CastPyArg2Strings(PyTuple_GET_ITEM(args, 0), 0); auto strings = CastPyArg2VectorOfString(PyTuple_GET_ITEM(args, 0), 0);
auto var_tensor = std::make_shared<egr::VariableCompatTensor>(); auto var_tensor = std::make_shared<egr::VariableCompatTensor>();
*var_tensor->GetMutable<Strings>() = strings; *var_tensor->GetMutable<Strings>() = strings;
self->tensor.set_impl(var_tensor); self->tensor.set_impl(var_tensor);
......
...@@ -289,6 +289,9 @@ std::vector<paddle::experimental::Tensor> CastPyArg2VectorOfTensor( ...@@ -289,6 +289,9 @@ std::vector<paddle::experimental::Tensor> CastPyArg2VectorOfTensor(
} }
} else if (obj == Py_None) { } else if (obj == Py_None) {
return {}; return {};
} else if (PyObject_IsInstance(obj,
reinterpret_cast<PyObject*>(p_tensor_type))) {
return {reinterpret_cast<TensorObject*>(obj)->tensor};
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"argument (position %d) must be " "argument (position %d) must be "
...@@ -335,6 +338,56 @@ std::vector<int> CastPyArg2VectorOfInt(PyObject* obj, size_t arg_pos) { ...@@ -335,6 +338,56 @@ std::vector<int> CastPyArg2VectorOfInt(PyObject* obj, size_t arg_pos) {
} }
} else if (obj == Py_None) { } else if (obj == Py_None) {
return {}; return {};
} else if (PyObject_CheckLongOrConvertToLong(&obj)) {
return {static_cast<int>(PyLong_AsLong(obj))};
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"argument (position %d) must be "
"list or tuple, but got %s",
arg_pos + 1,
reinterpret_cast<PyTypeObject*>(obj->ob_type)->tp_name));
}
return result;
}
std::vector<int64_t> CastPyArg2VectorOfInt64(PyObject* obj, size_t arg_pos) {
std::vector<int64_t> result;
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GET_ITEM(obj, i);
if (PyObject_CheckLongOrConvertToLong(&item)) {
result.emplace_back(static_cast<int64_t>(PyLong_AsLong(item)));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"argument (position %d) must be "
"list of int, but got %s at pos %d",
arg_pos + 1,
reinterpret_cast<PyTypeObject*>(item->ob_type)->tp_name,
i));
}
}
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GET_ITEM(obj, i);
if (PyObject_CheckLongOrConvertToLong(&item)) {
result.emplace_back(static_cast<int64_t>(PyLong_AsLong(item)));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"argument (position %d) must be "
"list of int, but got %s at pos %d",
arg_pos + 1,
reinterpret_cast<PyTypeObject*>(item->ob_type)->tp_name,
i));
}
}
} else if (obj == Py_None) {
return {};
} else if (PyObject_CheckLongOrConvertToLong(&obj)) {
return {static_cast<int64_t>(PyLong_AsLong(obj))};
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"argument (position %d) must be " "argument (position %d) must be "
...@@ -363,10 +416,30 @@ std::vector<size_t> CastPyArg2VectorOfSize_t(PyObject* obj, size_t arg_pos) { ...@@ -363,10 +416,30 @@ std::vector<size_t> CastPyArg2VectorOfSize_t(PyObject* obj, size_t arg_pos) {
i)); i));
} }
} }
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GET_ITEM(obj, i);
if (PyObject_CheckLongOrConvertToLong(&item)) {
result.emplace_back(PyLong_AsSize_t(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"argument (position %d) must be "
"list of size_t, but got %s at pos %d",
arg_pos + 1,
reinterpret_cast<PyTypeObject*>(item->ob_type)->tp_name,
i));
}
}
} else if (obj == Py_None) {
return {};
} else if (PyObject_CheckLongOrConvertToLong(&obj)) {
return {PyLong_AsSize_t(obj)};
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"argument (position %d) must be " "argument (position %d) must be "
"list, but got %s", "list of size_t, but got %s",
arg_pos + 1, arg_pos + 1,
reinterpret_cast<PyTypeObject*>(obj->ob_type)->tp_name)); reinterpret_cast<PyTypeObject*>(obj->ob_type)->tp_name));
} }
...@@ -487,6 +560,9 @@ std::vector<phi::DenseTensor> CastPyArg2VectorOfTensorBase(PyObject* obj, ...@@ -487,6 +560,9 @@ std::vector<phi::DenseTensor> CastPyArg2VectorOfTensorBase(PyObject* obj,
} }
} else if (obj == Py_None) { } else if (obj == Py_None) {
return {}; return {};
} else if (PyObject_IsInstance(
obj, reinterpret_cast<PyObject*>(g_framework_tensor_pytype))) {
return {::pybind11::handle(obj).cast<phi::DenseTensor>()};
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"argument (position %d) must be " "argument (position %d) must be "
...@@ -527,7 +603,8 @@ std::unordered_map<std::wstring, int> CastPyArg2Vocab(PyObject* obj, ...@@ -527,7 +603,8 @@ std::unordered_map<std::wstring, int> CastPyArg2Vocab(PyObject* obj,
} }
} }
std::vector<std::string> CastPyArg2Strings(PyObject* obj, ssize_t arg_pos) { std::vector<std::string> CastPyArg2VectorOfString(PyObject* obj,
ssize_t arg_pos) {
if (PyList_Check(obj)) { if (PyList_Check(obj)) {
return ::pybind11::handle(obj).cast<std::vector<std::string>>(); return ::pybind11::handle(obj).cast<std::vector<std::string>>();
} else { } else {
...@@ -1385,16 +1462,8 @@ std::vector<phi::Scalar> CastPyArg2ScalarArray(PyObject* obj, ...@@ -1385,16 +1462,8 @@ std::vector<phi::Scalar> CastPyArg2ScalarArray(PyObject* obj,
paddle::experimental::IntArray CastPyArg2IntArray(PyObject* obj, paddle::experimental::IntArray CastPyArg2IntArray(PyObject* obj,
const std::string& op_type, const std::string& op_type,
ssize_t arg_pos) { ssize_t arg_pos) {
// In case of IntArray, only two possible PyObjects:
// 1. list of int
// 2. Tensor
if (obj == Py_None) { if (obj == Py_None) {
PADDLE_THROW(platform::errors::InvalidArgument( return paddle::experimental::IntArray({});
"%s(): argument (position %d) must be "
"list or Tensor, but got %s",
op_type,
arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
} }
// obj could be: int, float, bool, paddle.Tensor // obj could be: int, float, bool, paddle.Tensor
...@@ -1408,10 +1477,13 @@ paddle::experimental::IntArray CastPyArg2IntArray(PyObject* obj, ...@@ -1408,10 +1477,13 @@ paddle::experimental::IntArray CastPyArg2IntArray(PyObject* obj,
paddle::experimental::Tensor& value = GetTensorFromPyObject( paddle::experimental::Tensor& value = GetTensorFromPyObject(
op_type, "" /*arg_name*/, obj, arg_pos, false /*dispensable*/); op_type, "" /*arg_name*/, obj, arg_pos, false /*dispensable*/);
return paddle::experimental::IntArray(value); return paddle::experimental::IntArray(value);
} else if (PyObject_CheckLongOrConvertToLong(&obj)) {
return paddle::experimental::IntArray(
{static_cast<int64_t>(PyLong_AsLong(obj))});
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be " "%s(): argument (position %d) must be "
"list or Tensor, but got %s", "list or int, but got %s",
op_type, op_type,
arg_pos + 1, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
......
...@@ -68,6 +68,7 @@ phi::DenseTensor CastPyArg2FrameworkTensor(PyObject* obj, ssize_t arg_pos); ...@@ -68,6 +68,7 @@ phi::DenseTensor CastPyArg2FrameworkTensor(PyObject* obj, ssize_t arg_pos);
std::vector<phi::DenseTensor> CastPyArg2VectorOfTensorBase(PyObject* obj, std::vector<phi::DenseTensor> CastPyArg2VectorOfTensorBase(PyObject* obj,
ssize_t arg_pos); ssize_t arg_pos);
std::vector<int> CastPyArg2VectorOfInt(PyObject* obj, size_t arg_pos); std::vector<int> CastPyArg2VectorOfInt(PyObject* obj, size_t arg_pos);
std::vector<int64_t> CastPyArg2VectorOfInt64(PyObject* obj, size_t arg_pos);
std::vector<size_t> CastPyArg2VectorOfSize_t(PyObject* obj, size_t arg_pos); std::vector<size_t> CastPyArg2VectorOfSize_t(PyObject* obj, size_t arg_pos);
std::vector<std::vector<size_t>> CastPyArg2VectorOfVectorOfSize_t( std::vector<std::vector<size_t>> CastPyArg2VectorOfVectorOfSize_t(
PyObject* obj, size_t arg_pos); PyObject* obj, size_t arg_pos);
...@@ -75,7 +76,8 @@ framework::proto::VarType::Type CastPyArg2ProtoType(PyObject* obj, ...@@ -75,7 +76,8 @@ framework::proto::VarType::Type CastPyArg2ProtoType(PyObject* obj,
ssize_t arg_pos); ssize_t arg_pos);
std::unordered_map<std::wstring, int> CastPyArg2Vocab(PyObject* obj, std::unordered_map<std::wstring, int> CastPyArg2Vocab(PyObject* obj,
ssize_t arg_pos); ssize_t arg_pos);
std::vector<std::string> CastPyArg2Strings(PyObject* obj, ssize_t arg_pos); std::vector<std::string> CastPyArg2VectorOfString(PyObject* obj,
ssize_t arg_pos);
std::shared_ptr<jit::Function> CastPyArg2JitFunction(PyObject* obj, std::shared_ptr<jit::Function> CastPyArg2JitFunction(PyObject* obj,
ssize_t arg_pos); ssize_t arg_pos);
......
...@@ -461,7 +461,11 @@ std::vector<int64_t> CastPyArg2Longs(PyObject* obj, ...@@ -461,7 +461,11 @@ std::vector<int64_t> CastPyArg2Longs(PyObject* obj,
i)); i));
} }
} }
} else if ((PyObject*)obj != Py_None) { // NOLINT } else if (obj == Py_None) {
return {};
} else if (PyObject_CheckLongOrToLong(&obj)) {
return {static_cast<int64_t>(PyLong_AsLong(obj))};
} else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be " "%s(): argument (position %d) must be "
"list or tuple, but got %s", "list or tuple, but got %s",
......
...@@ -25,7 +25,7 @@ void ProdKernel(const Context& dev_ctx, ...@@ -25,7 +25,7 @@ void ProdKernel(const Context& dev_ctx,
const IntArray& dims, const IntArray& dims,
bool keep_dim, bool keep_dim,
DenseTensor* out) { DenseTensor* out) {
bool reduce_all = false; // recompute_reduce_all(x, dims); bool reduce_all = recompute_reduce_all(x, dims);
ProdRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out); ProdRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
} }
......
...@@ -29,13 +29,13 @@ np.random.seed(10) ...@@ -29,13 +29,13 @@ np.random.seed(10)
def mean_wrapper(x, axis=None, keepdim=False, reduce_all=False): def mean_wrapper(x, axis=None, keepdim=False, reduce_all=False):
if reduce_all: if reduce_all:
return paddle.mean(x, range(len(x.shape)), keepdim) return paddle.mean(x, list(range(len(x.shape))), keepdim)
return paddle.mean(x, axis, keepdim) return paddle.mean(x, axis, keepdim)
def reduce_mean_wrapper(x, axis=0, keepdim=False, reduce_all=False): def reduce_mean_wrapper(x, axis=0, keepdim=False, reduce_all=False):
if reduce_all: if reduce_all:
return paddle.mean(x, range(len(x.shape)), keepdim) return paddle.mean(x, list(range(len(x.shape))), keepdim)
return paddle.mean(x, axis, keepdim) return paddle.mean(x, axis, keepdim)
......
...@@ -465,12 +465,6 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None): ...@@ -465,12 +465,6 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None):
): ):
if in_dygraph_mode(): if in_dygraph_mode():
out = _C_ops.abs(input) out = _C_ops.abs(input)
reduce_all = (
True if axis is None or axis == [] or asvector else False
)
axis = axis if axis is not None and axis != [] else [0]
if reduce_all:
assert (axis == []) or (axis is None)
if porder == np.float64('inf'): if porder == np.float64('inf'):
return _C_ops.max(out, axis, keepdim) return _C_ops.max(out, axis, keepdim)
else: else:
...@@ -844,27 +838,25 @@ def cond(x, p=None, name=None): ...@@ -844,27 +838,25 @@ def cond(x, p=None, name=None):
Calculate the matrix norm of a square matrix or batches of square matrices, Calculate the matrix norm of a square matrix or batches of square matrices,
when porder is in (1, -1, inf, -inf) when porder is in (1, -1, inf, -inf)
""" """
reduce_all = True if axis is None or axis == [] else False
axis = axis if axis is not None and axis != [] else [0]
keepdim = False
if in_dygraph_mode(): if in_dygraph_mode():
abs_out = _C_ops.abs(input) abs_out = _C_ops.abs(input)
sum_out = _C_ops.sum(abs_out, axis, None, keepdim) sum_out = _C_ops.sum(abs_out, axis, None, False)
if porder == 1 or porder == np.inf: if porder == 1 or porder == np.inf:
return _C_ops.max(sum_out, [-1], keepdim) return _C_ops.max(sum_out, [-1], False)
if porder == -1 or porder == -np.inf: if porder == -1 or porder == -np.inf:
return _C_ops.min(sum_out, [-1], keepdim) return _C_ops.min(sum_out, [-1], False)
elif _in_legacy_dygraph(): elif _in_legacy_dygraph():
reduce_all = True if axis is None or axis == [] else False
axis = axis if axis is not None and axis != [] else [0]
abs_out = _legacy_C_ops.abs(input) abs_out = _legacy_C_ops.abs(input)
sum_out = _legacy_C_ops.reduce_sum( sum_out = _legacy_C_ops.reduce_sum(
abs_out, abs_out,
'dim', 'dim',
axis, axis,
'keepdim', 'keepdim',
keepdim, False,
'reduce_all', 'reduce_all',
reduce_all, reduce_all,
) )
...@@ -874,7 +866,7 @@ def cond(x, p=None, name=None): ...@@ -874,7 +866,7 @@ def cond(x, p=None, name=None):
'dim', 'dim',
[-1], [-1],
'keepdim', 'keepdim',
keepdim, False,
'reduce_all', 'reduce_all',
reduce_all, reduce_all,
) )
...@@ -884,11 +876,13 @@ def cond(x, p=None, name=None): ...@@ -884,11 +876,13 @@ def cond(x, p=None, name=None):
'dim', 'dim',
[-1], [-1],
'keepdim', 'keepdim',
keepdim, False,
'reduce_all', 'reduce_all',
reduce_all, reduce_all,
) )
else: else:
reduce_all = True if axis is None or axis == [] else False
axis = axis if axis is not None and axis != [] else [0]
block = LayerHelper('norm', **locals()) block = LayerHelper('norm', **locals())
abs_out = block.create_variable_for_type_inference( abs_out = block.create_variable_for_type_inference(
dtype=block.input_dtype() dtype=block.input_dtype()
...@@ -908,7 +902,7 @@ def cond(x, p=None, name=None): ...@@ -908,7 +902,7 @@ def cond(x, p=None, name=None):
outputs={'Out': sum_out}, outputs={'Out': sum_out},
attrs={ attrs={
'dim': axis, 'dim': axis,
'keep_dim': keepdim, 'keep_dim': False,
'reduce_all': reduce_all, 'reduce_all': reduce_all,
}, },
) )
...@@ -919,7 +913,7 @@ def cond(x, p=None, name=None): ...@@ -919,7 +913,7 @@ def cond(x, p=None, name=None):
outputs={'Out': out}, outputs={'Out': out},
attrs={ attrs={
'dim': [-1], 'dim': [-1],
'keep_dim': keepdim, 'keep_dim': False,
'reduce_all': reduce_all, 'reduce_all': reduce_all,
}, },
) )
...@@ -930,7 +924,7 @@ def cond(x, p=None, name=None): ...@@ -930,7 +924,7 @@ def cond(x, p=None, name=None):
outputs={'Out': out}, outputs={'Out': out},
attrs={ attrs={
'dim': [-1], 'dim': [-1],
'keep_dim': keepdim, 'keep_dim': False,
'reduce_all': reduce_all, 'reduce_all': reduce_all,
}, },
) )
...@@ -941,22 +935,20 @@ def cond(x, p=None, name=None): ...@@ -941,22 +935,20 @@ def cond(x, p=None, name=None):
NOTE: NOTE:
Calculate the frobenius norm of a square matrix or batches of square matrices. Calculate the frobenius norm of a square matrix or batches of square matrices.
""" """
reduce_all = True if axis is None or axis == [] else False
keepdim = False
if in_dygraph_mode(): if in_dygraph_mode():
pow_out = _C_ops.pow(input, porder) pow_out = _C_ops.pow(input, porder)
sum_out_1 = _C_ops.sum(pow_out, axis, None, keepdim) sum_out_1 = _C_ops.sum(pow_out, axis, None, False)
sum_out_2 = _C_ops.sum(sum_out_1, axis, None, keepdim) sum_out_2 = _C_ops.sum(sum_out_1, axis, None, False)
return _C_ops.pow(sum_out_2, float(1.0 / porder)) return _C_ops.pow(sum_out_2, float(1.0 / porder))
elif paddle.in_dynamic_mode(): elif paddle.in_dynamic_mode():
reduce_all = True if axis is None or axis == [] else False
pow_out = _legacy_C_ops.pow(input, 'factor', porder) pow_out = _legacy_C_ops.pow(input, 'factor', porder)
sum_out_1 = _legacy_C_ops.reduce_sum( sum_out_1 = _legacy_C_ops.reduce_sum(
pow_out, pow_out,
'dim', 'dim',
axis, axis,
'keepdim', 'keepdim',
keepdim, False,
'reduce_all', 'reduce_all',
reduce_all, reduce_all,
) )
...@@ -965,12 +957,13 @@ def cond(x, p=None, name=None): ...@@ -965,12 +957,13 @@ def cond(x, p=None, name=None):
'dim', 'dim',
axis, axis,
'keepdim', 'keepdim',
keepdim, False,
'reduce_all', 'reduce_all',
reduce_all, reduce_all,
) )
return _legacy_C_ops.pow(sum_out_2, 'factor', float(1.0 / porder)) return _legacy_C_ops.pow(sum_out_2, 'factor', float(1.0 / porder))
reduce_all = True if axis is None or axis == [] else False
block = LayerHelper('norm', **locals()) block = LayerHelper('norm', **locals())
pow_out = block.create_variable_for_type_inference( pow_out = block.create_variable_for_type_inference(
dtype=block.input_dtype() dtype=block.input_dtype()
...@@ -994,13 +987,13 @@ def cond(x, p=None, name=None): ...@@ -994,13 +987,13 @@ def cond(x, p=None, name=None):
type='reduce_sum', type='reduce_sum',
inputs={'X': pow_out}, inputs={'X': pow_out},
outputs={'Out': sum_out_1}, outputs={'Out': sum_out_1},
attrs={'dim': axis, 'keep_dim': keepdim, 'reduce_all': reduce_all}, attrs={'dim': axis, 'keep_dim': False, 'reduce_all': reduce_all},
) )
block.append_op( block.append_op(
type='reduce_sum', type='reduce_sum',
inputs={'X': sum_out_1}, inputs={'X': sum_out_1},
outputs={'Out': sum_out_2}, outputs={'Out': sum_out_2},
attrs={'dim': axis, 'keep_dim': keepdim, 'reduce_all': reduce_all}, attrs={'dim': axis, 'keep_dim': False, 'reduce_all': reduce_all},
) )
block.append_op( block.append_op(
type='pow', type='pow',
...@@ -1016,28 +1009,27 @@ def cond(x, p=None, name=None): ...@@ -1016,28 +1009,27 @@ def cond(x, p=None, name=None):
Calculate the matrix norm, which is related to singular values, of a matrix Calculate the matrix norm, which is related to singular values, of a matrix
or batches of matrices, including nuclear norm, 2-norm and (-2)-norm. or batches of matrices, including nuclear norm, 2-norm and (-2)-norm.
""" """
reduce_all = True if axis is None or axis == [] else False if not in_dygraph_mode():
keepdim = False reduce_all = True if axis is None or axis == [] else False
u, s, vh = svd(input, full_matrices=False) u, s, vh = svd(input, full_matrices=False)
if _non_static_mode(): if _non_static_mode():
if porder == "nuc": if porder == "nuc":
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.sum(s, axis, None, keepdim) return _C_ops.sum(s, axis, None, False)
else: else:
return _legacy_C_ops.reduce_sum( return _legacy_C_ops.reduce_sum(
s, s,
'dim', 'dim',
axis, axis,
'keepdim', 'keepdim',
keepdim, False,
'reduce_all', 'reduce_all',
reduce_all, reduce_all,
) )
if in_dygraph_mode(): if in_dygraph_mode():
max_out = _C_ops.max(s, axis, keepdim) max_out = _C_ops.max(s, axis, False)
min_out = _C_ops.min(s, axis, keepdim) min_out = _C_ops.min(s, axis, False)
if porder == 2: if porder == 2:
return _C_ops.divide(max_out, min_out) return _C_ops.divide(max_out, min_out)
if porder == -2: if porder == -2:
...@@ -1045,10 +1037,10 @@ def cond(x, p=None, name=None): ...@@ -1045,10 +1037,10 @@ def cond(x, p=None, name=None):
else: else:
max_out = _legacy_C_ops.reduce_max( max_out = _legacy_C_ops.reduce_max(
s, 'dim', axis, 'keepdim', keepdim, 'reduce_all', reduce_all s, 'dim', axis, 'keepdim', False, 'reduce_all', reduce_all
) )
min_out = _legacy_C_ops.reduce_min( min_out = _legacy_C_ops.reduce_min(
s, 'dim', axis, 'keepdim', keepdim, 'reduce_all', reduce_all s, 'dim', axis, 'keepdim', False, 'reduce_all', reduce_all
) )
if porder == 2: if porder == 2:
return _legacy_C_ops.elementwise_div( return _legacy_C_ops.elementwise_div(
...@@ -1070,7 +1062,7 @@ def cond(x, p=None, name=None): ...@@ -1070,7 +1062,7 @@ def cond(x, p=None, name=None):
outputs={'Out': out}, outputs={'Out': out},
attrs={ attrs={
'dim': axis, 'dim': axis,
'keep_dim': keepdim, 'keep_dim': False,
'reduce_all': reduce_all, 'reduce_all': reduce_all,
}, },
) )
...@@ -1085,13 +1077,13 @@ def cond(x, p=None, name=None): ...@@ -1085,13 +1077,13 @@ def cond(x, p=None, name=None):
type='reduce_max', type='reduce_max',
inputs={'X': s}, inputs={'X': s},
outputs={'Out': max_out}, outputs={'Out': max_out},
attrs={'dim': axis, 'keep_dim': keepdim, 'reduce_all': reduce_all}, attrs={'dim': axis, 'keep_dim': False, 'reduce_all': reduce_all},
) )
block.append_op( block.append_op(
type='reduce_min', type='reduce_min',
inputs={'X': s}, inputs={'X': s},
outputs={'Out': min_out}, outputs={'Out': min_out},
attrs={'dim': axis, 'keep_dim': keepdim, 'reduce_all': reduce_all}, attrs={'dim': axis, 'keep_dim': False, 'reduce_all': reduce_all},
) )
if porder == 2: if porder == 2:
block.append_op( block.append_op(
......
...@@ -1303,7 +1303,6 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None): ...@@ -1303,7 +1303,6 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
out8 = paddle.sum(x, axis=0) # [1, 1, 1, 1] out8 = paddle.sum(x, axis=0) # [1, 1, 1, 1]
out9 = paddle.sum(x, axis=1) # [4, 0] out9 = paddle.sum(x, axis=1) # [4, 0]
""" """
reduce_all, axis = _get_reduce_axis_with_tensor(axis, x)
dtype_flag = False dtype_flag = False
if dtype is not None: if dtype is not None:
...@@ -1313,6 +1312,8 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None): ...@@ -1313,6 +1312,8 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.sum(x, axis, dtype, keepdim) return _C_ops.sum(x, axis, dtype, keepdim)
reduce_all, axis = _get_reduce_axis_with_tensor(axis, x)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
if dtype_flag: if dtype_flag:
return _legacy_C_ops.reduce_sum( return _legacy_C_ops.reduce_sum(
...@@ -2382,9 +2383,9 @@ def max(x, axis=None, keepdim=False, name=None): ...@@ -2382,9 +2383,9 @@ def max(x, axis=None, keepdim=False, name=None):
#[7., 8.], [[[0., 0.], [0., 0.]], [[0., 0.], [1., 1.]]] #[7., 8.], [[[0., 0.], [0., 0.]], [[0., 0.], [1., 1.]]]
""" """
reduce_all, axis = _get_reduce_axis_with_tensor(axis, x)
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.max(x, axis, keepdim) return _C_ops.max(x, axis, keepdim)
reduce_all, axis = _get_reduce_axis_with_tensor(axis, x)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _legacy_C_ops.reduce_max( return _legacy_C_ops.reduce_max(
x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all
...@@ -2484,10 +2485,10 @@ def min(x, axis=None, keepdim=False, name=None): ...@@ -2484,10 +2485,10 @@ def min(x, axis=None, keepdim=False, name=None):
#[1., 2.], [[[1., 1.], [0., 0.]], [[0., 0.], [0., 0.]]] #[1., 2.], [[[1., 1.], [0., 0.]], [[0., 0.], [0., 0.]]]
""" """
reduce_all, axis = _get_reduce_axis_with_tensor(axis, x)
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.min(x, axis, keepdim) return _C_ops.min(x, axis, keepdim)
reduce_all, axis = _get_reduce_axis_with_tensor(axis, x)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _legacy_C_ops.reduce_min( return _legacy_C_ops.reduce_min(
x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all
...@@ -2597,10 +2598,10 @@ def amax(x, axis=None, keepdim=False, name=None): ...@@ -2597,10 +2598,10 @@ def amax(x, axis=None, keepdim=False, name=None):
print(result6, y.grad) print(result6, y.grad)
#[0.9., 0.9], [[[0., 0.3333], [0.5, 0.3333]], [[0.5, 0.3333], [1., 1.]]] #[0.9., 0.9], [[[0., 0.3333], [0.5, 0.3333]], [[0.5, 0.3333], [1., 1.]]]
""" """
reduce_all, axis = _get_reduce_axis(axis, x)
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.amax(x, axis, keepdim) return _C_ops.amax(x, axis, keepdim)
reduce_all, axis = _get_reduce_axis(axis, x)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _legacy_C_ops.reduce_amax( return _legacy_C_ops.reduce_amax(
x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all
...@@ -2711,11 +2712,11 @@ def amin(x, axis=None, keepdim=False, name=None): ...@@ -2711,11 +2712,11 @@ def amin(x, axis=None, keepdim=False, name=None):
print(result6, y.grad) print(result6, y.grad)
#[0.1., 0.1], [[[0., 0.3333], [0.5, 0.3333]], [[0.5, 0.3333], [1., 1.]]] #[0.1., 0.1], [[[0., 0.3333], [0.5, 0.3333]], [[0.5, 0.3333], [1., 1.]]]
""" """
reduce_all, axis = _get_reduce_axis(axis, x)
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.amin(x, axis, keepdim) return _C_ops.amin(x, axis, keepdim)
elif _in_legacy_dygraph():
reduce_all, axis = _get_reduce_axis(axis, x)
if _in_legacy_dygraph():
return _legacy_C_ops.reduce_amin( return _legacy_C_ops.reduce_amin(
x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all
) )
...@@ -3860,11 +3861,10 @@ def all(x, axis=None, keepdim=False, name=None): ...@@ -3860,11 +3861,10 @@ def all(x, axis=None, keepdim=False, name=None):
print(out4) print(out4)
""" """
reduce_all, axis = _get_reduce_axis(axis, x)
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.all(x, axis, keepdim) return _C_ops.all(x, axis, keepdim)
reduce_all, axis = _get_reduce_axis(axis, x)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _legacy_C_ops.reduce_all( return _legacy_C_ops.reduce_all(
x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all
...@@ -3937,11 +3937,10 @@ def any(x, axis=None, keepdim=False, name=None): ...@@ -3937,11 +3937,10 @@ def any(x, axis=None, keepdim=False, name=None):
print(out4) print(out4)
""" """
reduce_all, axis = _get_reduce_axis(axis, x)
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.any(x, axis, keepdim) return _C_ops.any(x, axis, keepdim)
reduce_all, axis = _get_reduce_axis(axis, x)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _legacy_C_ops.reduce_any( return _legacy_C_ops.reduce_any(
x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all
......
...@@ -79,11 +79,10 @@ def mean(x, axis=None, keepdim=False, name=None): ...@@ -79,11 +79,10 @@ def mean(x, axis=None, keepdim=False, name=None):
out4 = paddle.mean(x, axis=[0, 2]) out4 = paddle.mean(x, axis=[0, 2])
# [ 8.5 12.5 16.5] # [ 8.5 12.5 16.5]
""" """
reduce_all, axis = _get_reduce_axis_with_tensor(axis, x)
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.mean(x, axis, keepdim) return _C_ops.mean(x, axis, keepdim)
reduce_all, axis = _get_reduce_axis_with_tensor(axis, x)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _legacy_C_ops.reduce_mean( return _legacy_C_ops.reduce_mean(
x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册