未验证 提交 bcf75132 编写于 作者: W wanghuancoder 提交者: GitHub

do not calc reduce_all in eager mode (#48199)

* do not calc reduce_all in eager mode

* refine python c cast list

* refine

* refine

* refine

* refine

* refine

* refine

* refine

* refine

* refine
上级 0a9c1f59
......@@ -1462,7 +1462,7 @@ static PyObject* tensor_method_set_string_list(TensorObject* self,
PyObject* kwargs) {
EAGER_TRY
using Strings = std::vector<std::string>;
auto strings = CastPyArg2Strings(PyTuple_GET_ITEM(args, 0), 0);
auto strings = CastPyArg2VectorOfString(PyTuple_GET_ITEM(args, 0), 0);
auto var_tensor = std::make_shared<egr::VariableCompatTensor>();
*var_tensor->GetMutable<Strings>() = strings;
self->tensor.set_impl(var_tensor);
......
......@@ -289,6 +289,9 @@ std::vector<paddle::experimental::Tensor> CastPyArg2VectorOfTensor(
}
} else if (obj == Py_None) {
return {};
} else if (PyObject_IsInstance(obj,
reinterpret_cast<PyObject*>(p_tensor_type))) {
return {reinterpret_cast<TensorObject*>(obj)->tensor};
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"argument (position %d) must be "
......@@ -335,6 +338,56 @@ std::vector<int> CastPyArg2VectorOfInt(PyObject* obj, size_t arg_pos) {
}
} else if (obj == Py_None) {
return {};
} else if (PyObject_CheckLongOrConvertToLong(&obj)) {
return {static_cast<int>(PyLong_AsLong(obj))};
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"argument (position %d) must be "
"list or tuple, but got %s",
arg_pos + 1,
reinterpret_cast<PyTypeObject*>(obj->ob_type)->tp_name));
}
return result;
}
std::vector<int64_t> CastPyArg2VectorOfInt64(PyObject* obj, size_t arg_pos) {
std::vector<int64_t> result;
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GET_ITEM(obj, i);
if (PyObject_CheckLongOrConvertToLong(&item)) {
result.emplace_back(static_cast<int64_t>(PyLong_AsLong(item)));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"argument (position %d) must be "
"list of int, but got %s at pos %d",
arg_pos + 1,
reinterpret_cast<PyTypeObject*>(item->ob_type)->tp_name,
i));
}
}
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GET_ITEM(obj, i);
if (PyObject_CheckLongOrConvertToLong(&item)) {
result.emplace_back(static_cast<int64_t>(PyLong_AsLong(item)));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"argument (position %d) must be "
"list of int, but got %s at pos %d",
arg_pos + 1,
reinterpret_cast<PyTypeObject*>(item->ob_type)->tp_name,
i));
}
}
} else if (obj == Py_None) {
return {};
} else if (PyObject_CheckLongOrConvertToLong(&obj)) {
return {static_cast<int64_t>(PyLong_AsLong(obj))};
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"argument (position %d) must be "
......@@ -363,10 +416,30 @@ std::vector<size_t> CastPyArg2VectorOfSize_t(PyObject* obj, size_t arg_pos) {
i));
}
}
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GET_ITEM(obj, i);
if (PyObject_CheckLongOrConvertToLong(&item)) {
result.emplace_back(PyLong_AsSize_t(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"argument (position %d) must be "
"list of size_t, but got %s at pos %d",
arg_pos + 1,
reinterpret_cast<PyTypeObject*>(item->ob_type)->tp_name,
i));
}
}
} else if (obj == Py_None) {
return {};
} else if (PyObject_CheckLongOrConvertToLong(&obj)) {
return {PyLong_AsSize_t(obj)};
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"argument (position %d) must be "
"list, but got %s",
"list of size_t, but got %s",
arg_pos + 1,
reinterpret_cast<PyTypeObject*>(obj->ob_type)->tp_name));
}
......@@ -487,6 +560,9 @@ std::vector<phi::DenseTensor> CastPyArg2VectorOfTensorBase(PyObject* obj,
}
} else if (obj == Py_None) {
return {};
} else if (PyObject_IsInstance(
obj, reinterpret_cast<PyObject*>(g_framework_tensor_pytype))) {
return {::pybind11::handle(obj).cast<phi::DenseTensor>()};
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"argument (position %d) must be "
......@@ -527,7 +603,8 @@ std::unordered_map<std::wstring, int> CastPyArg2Vocab(PyObject* obj,
}
}
std::vector<std::string> CastPyArg2Strings(PyObject* obj, ssize_t arg_pos) {
std::vector<std::string> CastPyArg2VectorOfString(PyObject* obj,
ssize_t arg_pos) {
if (PyList_Check(obj)) {
return ::pybind11::handle(obj).cast<std::vector<std::string>>();
} else {
......@@ -1385,16 +1462,8 @@ std::vector<phi::Scalar> CastPyArg2ScalarArray(PyObject* obj,
paddle::experimental::IntArray CastPyArg2IntArray(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos) {
// In case of IntArray, only two possible PyObjects:
// 1. list of int
// 2. Tensor
if (obj == Py_None) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list or Tensor, but got %s",
op_type,
arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
return paddle::experimental::IntArray({});
}
// obj could be: int, float, bool, paddle.Tensor
......@@ -1408,10 +1477,13 @@ paddle::experimental::IntArray CastPyArg2IntArray(PyObject* obj,
paddle::experimental::Tensor& value = GetTensorFromPyObject(
op_type, "" /*arg_name*/, obj, arg_pos, false /*dispensable*/);
return paddle::experimental::IntArray(value);
} else if (PyObject_CheckLongOrConvertToLong(&obj)) {
return paddle::experimental::IntArray(
{static_cast<int64_t>(PyLong_AsLong(obj))});
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list or Tensor, but got %s",
"list or int, but got %s",
op_type,
arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
......
......@@ -68,6 +68,7 @@ phi::DenseTensor CastPyArg2FrameworkTensor(PyObject* obj, ssize_t arg_pos);
std::vector<phi::DenseTensor> CastPyArg2VectorOfTensorBase(PyObject* obj,
ssize_t arg_pos);
std::vector<int> CastPyArg2VectorOfInt(PyObject* obj, size_t arg_pos);
std::vector<int64_t> CastPyArg2VectorOfInt64(PyObject* obj, size_t arg_pos);
std::vector<size_t> CastPyArg2VectorOfSize_t(PyObject* obj, size_t arg_pos);
std::vector<std::vector<size_t>> CastPyArg2VectorOfVectorOfSize_t(
PyObject* obj, size_t arg_pos);
......@@ -75,7 +76,8 @@ framework::proto::VarType::Type CastPyArg2ProtoType(PyObject* obj,
ssize_t arg_pos);
std::unordered_map<std::wstring, int> CastPyArg2Vocab(PyObject* obj,
ssize_t arg_pos);
std::vector<std::string> CastPyArg2Strings(PyObject* obj, ssize_t arg_pos);
std::vector<std::string> CastPyArg2VectorOfString(PyObject* obj,
ssize_t arg_pos);
std::shared_ptr<jit::Function> CastPyArg2JitFunction(PyObject* obj,
ssize_t arg_pos);
......
......@@ -461,7 +461,11 @@ std::vector<int64_t> CastPyArg2Longs(PyObject* obj,
i));
}
}
} else if ((PyObject*)obj != Py_None) { // NOLINT
} else if (obj == Py_None) {
return {};
} else if (PyObject_CheckLongOrToLong(&obj)) {
return {static_cast<int64_t>(PyLong_AsLong(obj))};
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list or tuple, but got %s",
......
......@@ -25,7 +25,7 @@ void ProdKernel(const Context& dev_ctx,
const IntArray& dims,
bool keep_dim,
DenseTensor* out) {
bool reduce_all = false; // recompute_reduce_all(x, dims);
bool reduce_all = recompute_reduce_all(x, dims);
ProdRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
}
......
......@@ -29,13 +29,13 @@ np.random.seed(10)
def mean_wrapper(x, axis=None, keepdim=False, reduce_all=False):
if reduce_all:
return paddle.mean(x, range(len(x.shape)), keepdim)
return paddle.mean(x, list(range(len(x.shape))), keepdim)
return paddle.mean(x, axis, keepdim)
def reduce_mean_wrapper(x, axis=0, keepdim=False, reduce_all=False):
if reduce_all:
return paddle.mean(x, range(len(x.shape)), keepdim)
return paddle.mean(x, list(range(len(x.shape))), keepdim)
return paddle.mean(x, axis, keepdim)
......
......@@ -465,12 +465,6 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None):
):
if in_dygraph_mode():
out = _C_ops.abs(input)
reduce_all = (
True if axis is None or axis == [] or asvector else False
)
axis = axis if axis is not None and axis != [] else [0]
if reduce_all:
assert (axis == []) or (axis is None)
if porder == np.float64('inf'):
return _C_ops.max(out, axis, keepdim)
else:
......@@ -844,27 +838,25 @@ def cond(x, p=None, name=None):
Calculate the matrix norm of a square matrix or batches of square matrices,
when porder is in (1, -1, inf, -inf)
"""
reduce_all = True if axis is None or axis == [] else False
axis = axis if axis is not None and axis != [] else [0]
keepdim = False
if in_dygraph_mode():
abs_out = _C_ops.abs(input)
sum_out = _C_ops.sum(abs_out, axis, None, keepdim)
sum_out = _C_ops.sum(abs_out, axis, None, False)
if porder == 1 or porder == np.inf:
return _C_ops.max(sum_out, [-1], keepdim)
return _C_ops.max(sum_out, [-1], False)
if porder == -1 or porder == -np.inf:
return _C_ops.min(sum_out, [-1], keepdim)
return _C_ops.min(sum_out, [-1], False)
elif _in_legacy_dygraph():
reduce_all = True if axis is None or axis == [] else False
axis = axis if axis is not None and axis != [] else [0]
abs_out = _legacy_C_ops.abs(input)
sum_out = _legacy_C_ops.reduce_sum(
abs_out,
'dim',
axis,
'keepdim',
keepdim,
False,
'reduce_all',
reduce_all,
)
......@@ -874,7 +866,7 @@ def cond(x, p=None, name=None):
'dim',
[-1],
'keepdim',
keepdim,
False,
'reduce_all',
reduce_all,
)
......@@ -884,11 +876,13 @@ def cond(x, p=None, name=None):
'dim',
[-1],
'keepdim',
keepdim,
False,
'reduce_all',
reduce_all,
)
else:
reduce_all = True if axis is None or axis == [] else False
axis = axis if axis is not None and axis != [] else [0]
block = LayerHelper('norm', **locals())
abs_out = block.create_variable_for_type_inference(
dtype=block.input_dtype()
......@@ -908,7 +902,7 @@ def cond(x, p=None, name=None):
outputs={'Out': sum_out},
attrs={
'dim': axis,
'keep_dim': keepdim,
'keep_dim': False,
'reduce_all': reduce_all,
},
)
......@@ -919,7 +913,7 @@ def cond(x, p=None, name=None):
outputs={'Out': out},
attrs={
'dim': [-1],
'keep_dim': keepdim,
'keep_dim': False,
'reduce_all': reduce_all,
},
)
......@@ -930,7 +924,7 @@ def cond(x, p=None, name=None):
outputs={'Out': out},
attrs={
'dim': [-1],
'keep_dim': keepdim,
'keep_dim': False,
'reduce_all': reduce_all,
},
)
......@@ -941,22 +935,20 @@ def cond(x, p=None, name=None):
NOTE:
Calculate the frobenius norm of a square matrix or batches of square matrices.
"""
reduce_all = True if axis is None or axis == [] else False
keepdim = False
if in_dygraph_mode():
pow_out = _C_ops.pow(input, porder)
sum_out_1 = _C_ops.sum(pow_out, axis, None, keepdim)
sum_out_2 = _C_ops.sum(sum_out_1, axis, None, keepdim)
sum_out_1 = _C_ops.sum(pow_out, axis, None, False)
sum_out_2 = _C_ops.sum(sum_out_1, axis, None, False)
return _C_ops.pow(sum_out_2, float(1.0 / porder))
elif paddle.in_dynamic_mode():
reduce_all = True if axis is None or axis == [] else False
pow_out = _legacy_C_ops.pow(input, 'factor', porder)
sum_out_1 = _legacy_C_ops.reduce_sum(
pow_out,
'dim',
axis,
'keepdim',
keepdim,
False,
'reduce_all',
reduce_all,
)
......@@ -965,12 +957,13 @@ def cond(x, p=None, name=None):
'dim',
axis,
'keepdim',
keepdim,
False,
'reduce_all',
reduce_all,
)
return _legacy_C_ops.pow(sum_out_2, 'factor', float(1.0 / porder))
reduce_all = True if axis is None or axis == [] else False
block = LayerHelper('norm', **locals())
pow_out = block.create_variable_for_type_inference(
dtype=block.input_dtype()
......@@ -994,13 +987,13 @@ def cond(x, p=None, name=None):
type='reduce_sum',
inputs={'X': pow_out},
outputs={'Out': sum_out_1},
attrs={'dim': axis, 'keep_dim': keepdim, 'reduce_all': reduce_all},
attrs={'dim': axis, 'keep_dim': False, 'reduce_all': reduce_all},
)
block.append_op(
type='reduce_sum',
inputs={'X': sum_out_1},
outputs={'Out': sum_out_2},
attrs={'dim': axis, 'keep_dim': keepdim, 'reduce_all': reduce_all},
attrs={'dim': axis, 'keep_dim': False, 'reduce_all': reduce_all},
)
block.append_op(
type='pow',
......@@ -1016,28 +1009,27 @@ def cond(x, p=None, name=None):
Calculate the matrix norm, which is related to singular values, of a matrix
or batches of matrices, including nuclear norm, 2-norm and (-2)-norm.
"""
reduce_all = True if axis is None or axis == [] else False
keepdim = False
if not in_dygraph_mode():
reduce_all = True if axis is None or axis == [] else False
u, s, vh = svd(input, full_matrices=False)
if _non_static_mode():
if porder == "nuc":
if in_dygraph_mode():
return _C_ops.sum(s, axis, None, keepdim)
return _C_ops.sum(s, axis, None, False)
else:
return _legacy_C_ops.reduce_sum(
s,
'dim',
axis,
'keepdim',
keepdim,
False,
'reduce_all',
reduce_all,
)
if in_dygraph_mode():
max_out = _C_ops.max(s, axis, keepdim)
min_out = _C_ops.min(s, axis, keepdim)
max_out = _C_ops.max(s, axis, False)
min_out = _C_ops.min(s, axis, False)
if porder == 2:
return _C_ops.divide(max_out, min_out)
if porder == -2:
......@@ -1045,10 +1037,10 @@ def cond(x, p=None, name=None):
else:
max_out = _legacy_C_ops.reduce_max(
s, 'dim', axis, 'keepdim', keepdim, 'reduce_all', reduce_all
s, 'dim', axis, 'keepdim', False, 'reduce_all', reduce_all
)
min_out = _legacy_C_ops.reduce_min(
s, 'dim', axis, 'keepdim', keepdim, 'reduce_all', reduce_all
s, 'dim', axis, 'keepdim', False, 'reduce_all', reduce_all
)
if porder == 2:
return _legacy_C_ops.elementwise_div(
......@@ -1070,7 +1062,7 @@ def cond(x, p=None, name=None):
outputs={'Out': out},
attrs={
'dim': axis,
'keep_dim': keepdim,
'keep_dim': False,
'reduce_all': reduce_all,
},
)
......@@ -1085,13 +1077,13 @@ def cond(x, p=None, name=None):
type='reduce_max',
inputs={'X': s},
outputs={'Out': max_out},
attrs={'dim': axis, 'keep_dim': keepdim, 'reduce_all': reduce_all},
attrs={'dim': axis, 'keep_dim': False, 'reduce_all': reduce_all},
)
block.append_op(
type='reduce_min',
inputs={'X': s},
outputs={'Out': min_out},
attrs={'dim': axis, 'keep_dim': keepdim, 'reduce_all': reduce_all},
attrs={'dim': axis, 'keep_dim': False, 'reduce_all': reduce_all},
)
if porder == 2:
block.append_op(
......
......@@ -1303,7 +1303,6 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
out8 = paddle.sum(x, axis=0) # [1, 1, 1, 1]
out9 = paddle.sum(x, axis=1) # [4, 0]
"""
reduce_all, axis = _get_reduce_axis_with_tensor(axis, x)
dtype_flag = False
if dtype is not None:
......@@ -1313,6 +1312,8 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
if in_dygraph_mode():
return _C_ops.sum(x, axis, dtype, keepdim)
reduce_all, axis = _get_reduce_axis_with_tensor(axis, x)
if _in_legacy_dygraph():
if dtype_flag:
return _legacy_C_ops.reduce_sum(
......@@ -2382,9 +2383,9 @@ def max(x, axis=None, keepdim=False, name=None):
#[7., 8.], [[[0., 0.], [0., 0.]], [[0., 0.], [1., 1.]]]
"""
reduce_all, axis = _get_reduce_axis_with_tensor(axis, x)
if in_dygraph_mode():
return _C_ops.max(x, axis, keepdim)
reduce_all, axis = _get_reduce_axis_with_tensor(axis, x)
if _in_legacy_dygraph():
return _legacy_C_ops.reduce_max(
x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all
......@@ -2484,10 +2485,10 @@ def min(x, axis=None, keepdim=False, name=None):
#[1., 2.], [[[1., 1.], [0., 0.]], [[0., 0.], [0., 0.]]]
"""
reduce_all, axis = _get_reduce_axis_with_tensor(axis, x)
if in_dygraph_mode():
return _C_ops.min(x, axis, keepdim)
reduce_all, axis = _get_reduce_axis_with_tensor(axis, x)
if _in_legacy_dygraph():
return _legacy_C_ops.reduce_min(
x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all
......@@ -2597,10 +2598,10 @@ def amax(x, axis=None, keepdim=False, name=None):
print(result6, y.grad)
#[0.9., 0.9], [[[0., 0.3333], [0.5, 0.3333]], [[0.5, 0.3333], [1., 1.]]]
"""
reduce_all, axis = _get_reduce_axis(axis, x)
if in_dygraph_mode():
return _C_ops.amax(x, axis, keepdim)
reduce_all, axis = _get_reduce_axis(axis, x)
if _in_legacy_dygraph():
return _legacy_C_ops.reduce_amax(
x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all
......@@ -2711,11 +2712,11 @@ def amin(x, axis=None, keepdim=False, name=None):
print(result6, y.grad)
#[0.1., 0.1], [[[0., 0.3333], [0.5, 0.3333]], [[0.5, 0.3333], [1., 1.]]]
"""
reduce_all, axis = _get_reduce_axis(axis, x)
if in_dygraph_mode():
return _C_ops.amin(x, axis, keepdim)
elif _in_legacy_dygraph():
reduce_all, axis = _get_reduce_axis(axis, x)
if _in_legacy_dygraph():
return _legacy_C_ops.reduce_amin(
x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all
)
......@@ -3860,11 +3861,10 @@ def all(x, axis=None, keepdim=False, name=None):
print(out4)
"""
reduce_all, axis = _get_reduce_axis(axis, x)
if in_dygraph_mode():
return _C_ops.all(x, axis, keepdim)
reduce_all, axis = _get_reduce_axis(axis, x)
if _in_legacy_dygraph():
return _legacy_C_ops.reduce_all(
x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all
......@@ -3937,11 +3937,10 @@ def any(x, axis=None, keepdim=False, name=None):
print(out4)
"""
reduce_all, axis = _get_reduce_axis(axis, x)
if in_dygraph_mode():
return _C_ops.any(x, axis, keepdim)
reduce_all, axis = _get_reduce_axis(axis, x)
if _in_legacy_dygraph():
return _legacy_C_ops.reduce_any(
x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all
......
......@@ -79,11 +79,10 @@ def mean(x, axis=None, keepdim=False, name=None):
out4 = paddle.mean(x, axis=[0, 2])
# [ 8.5 12.5 16.5]
"""
reduce_all, axis = _get_reduce_axis_with_tensor(axis, x)
if in_dygraph_mode():
return _C_ops.mean(x, axis, keepdim)
reduce_all, axis = _get_reduce_axis_with_tensor(axis, x)
if _in_legacy_dygraph():
return _legacy_C_ops.reduce_mean(
x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册