未验证 提交 a7509ce3 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] support input 0D Tensor for min/max/amin/amax/prod/logsumexp/all/any (#47501)

上级 ef7d966a
......@@ -1805,84 +1805,14 @@ void LogsumexpInferMeta(const MetaTensor& input,
bool keepdim,
bool reduce_all,
MetaTensor* out) {
auto x_dims = input.dims();
auto x_rank = x_dims.size();
std::vector<int64_t> formated_axis = axis;
PADDLE_ENFORCE_LE(x_rank,
auto input_rank = input.dims().size();
// only supoort 0~4D, due to eigen template compile slow
PADDLE_ENFORCE_LE(
input_rank,
4,
errors::InvalidArgument(
"The input tensor X's dimensions of logsumexp "
"should be less or equal than 4. But received X's "
"dimensions = %d, X's shape = [%s].",
x_rank,
x_dims));
PADDLE_ENFORCE_GT(
axis.size(),
0,
errors::InvalidArgument(
"The size of axis of logsumexp "
"should be greater than 0. But received the size of axis "
"of logsumexp is %d.",
axis.size()));
for (size_t i = 0; i < axis.size(); i++) {
PADDLE_ENFORCE_LT(axis[i],
x_rank,
errors::InvalidArgument(
"axis[%d] should be in the "
"range [-D, D), where D is the dimensions of X and "
"D is %d. But received axis[%d] = %d.",
i,
x_rank,
i,
axis[i]));
PADDLE_ENFORCE_GE(axis[i],
-x_rank,
errors::InvalidArgument(
"axis[%d] should be in the "
"range [-D, D), where D is the dimensions of X and "
"D is %d. But received axis[%d] = %d.",
i,
x_rank,
i,
axis[i]));
if (axis[i] < 0) {
formated_axis[i] += x_rank;
}
}
auto dims_vector = vectorize(x_dims);
if (reduce_all) {
if (keepdim)
out->set_dims(phi::make_ddim(std::vector<int64_t>(x_rank, 1)));
else
out->set_dims({1});
} else {
auto dims_vector = vectorize(x_dims);
if (keepdim) {
for (size_t i = 0; i < formated_axis.size(); ++i) {
dims_vector[formated_axis[i]] = 1;
}
} else {
const int kDelFlag = -1;
for (size_t i = 0; i < formated_axis.size(); ++i) {
dims_vector[formated_axis[i]] = kDelFlag;
}
dims_vector.erase(
std::remove(dims_vector.begin(), dims_vector.end(), kDelFlag),
dims_vector.end());
}
if (!keepdim && dims_vector.size() == 0) {
dims_vector.push_back(1);
}
auto out_dims = phi::make_ddim(dims_vector);
out->set_dims(out_dims);
if (formated_axis.size() > 0 && formated_axis[0] != 0) {
// Only pass LoD when not reducing on the first dim.
out->share_lod(input);
}
}
out->set_dtype(input.dtype());
errors::InvalidArgument("The input tensor X's dimensions of logsumexp "
"should be less or equal than 4. "));
ReduceInferMetaBase(input, axis, keepdim, reduce_all, out);
}
void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out) {
......
......@@ -60,8 +60,9 @@ void LogsumexpGradKernel(const Context& dev_ctx,
DenseTensor* in_grad) {
dev_ctx.template Alloc<T>(in_grad);
const auto input_dim_size = in.dims().size();
reduce_all |= (static_cast<int>(axis.size()) == input_dim_size);
if (axis.size() == 0 || static_cast<int>(axis.size()) == in.dims().size()) {
reduce_all = true;
}
if (reduce_all) {
auto x = phi::EigenVector<T>::Flatten(in);
......
......@@ -69,9 +69,9 @@ void LogsumexpKernel(const Context& dev_ctx,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
const auto& input_dim_size = x.dims().size();
// The dims has full dim, set the reduce_all is True
reduce_all |= (static_cast<int>(axis.size()) == input_dim_size);
if (axis.size() == 0 || static_cast<int>(axis.size()) == x.dims().size()) {
reduce_all = true;
}
if (reduce_all) {
// Flatten and reduce 1-D tensor
......@@ -81,7 +81,7 @@ void LogsumexpKernel(const Context& dev_ctx,
auto reduce_dim = Eigen::array<int, 1>({{0}});
LogsumexpFunctor<T>()(place, &input, &output, reduce_dim);
} else {
int ndim = input_dim_size;
int ndim = x.dims().size();
int rdim = axis.size();
if (ndim > 4) {
PADDLE_THROW(phi::errors::Unimplemented(
......
......@@ -26,6 +26,9 @@ void AnyKernel(const Context& dev_ctx,
bool keep_dim,
DenseTensor* out) {
bool reduce_all = false;
if (dims.size() == 0 || static_cast<int>(dims.size()) == x.dims().size()) {
reduce_all = true;
}
AnyRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
}
......
......@@ -103,6 +103,12 @@ class TestLogsumexp(OpTest):
return dy * np.exp(x - y)
class TestLogsumexp_ZeroDim(TestLogsumexp):
def set_attrs(self):
self.shape = []
self.axis = []
class TestLogsumexp_shape(TestLogsumexp):
def set_attrs(self):
self.shape = [4, 5, 6]
......
......@@ -136,6 +136,15 @@ class TestMaxMinAmaxAminAPI(unittest.TestCase):
# test two minimum or maximum elements
class TestMaxMinAmaxAminAPI_ZeroDim(TestMaxMinAmaxAminAPI):
def init_case(self):
self.x_np = np.array(0.5)
self.shape = []
self.dtype = 'float64'
self.axis = None
self.keepdim = False
class TestMaxMinAmaxAminAPI2(TestMaxMinAmaxAminAPI):
def init_case(self):
self.x_np = np.array([[0.2, 0.3, 0.9, 0.9], [0.1, 0.1, 0.6, 0.7]])
......
......@@ -217,6 +217,22 @@ class TestMaxOp(OpTest):
self.check_output(check_eager=True)
class TestMaxOp_ZeroDim(OpTest):
"""Remove Max with subgradient from gradient check to confirm the success of CI."""
def setUp(self):
self.op_type = "reduce_max"
self.python_api = paddle.max
self.inputs = {'X': np.random.random([]).astype("float64")}
self.attrs = {'dim': []}
self.outputs = {
'Out': self.inputs['X'].max(axis=tuple(self.attrs['dim']))
}
def test_check_output(self):
self.check_output(check_eager=True)
@skip_check_grad_ci(
reason="reduce_min is discontinuous non-derivable function,"
" its gradient check is not supported by unittest framework."
......@@ -237,6 +253,22 @@ class TestMinOp(OpTest):
self.check_output(check_eager=True)
class TestMinOp_ZeroDim(OpTest):
"""Remove Min with subgradient from gradient check to confirm the success of CI."""
def setUp(self):
self.op_type = "reduce_min"
self.python_api = paddle.min
self.inputs = {'X': np.random.random([]).astype("float64")}
self.attrs = {'dim': []}
self.outputs = {
'Out': self.inputs['X'].min(axis=tuple(self.attrs['dim']))
}
def test_check_output(self):
self.check_output(check_eager=True)
class TestMin6DOp(OpTest):
"""Remove Min with subgradient from gradient check to confirm the success of CI."""
......@@ -297,6 +329,21 @@ class TestProdOp(OpTest):
self.check_grad(['X'], 'Out', check_eager=True)
class TestProdOp_ZeroDim(OpTest):
def setUp(self):
self.python_api = paddle.prod
self.op_type = "reduce_prod"
self.inputs = {'X': np.random.random([]).astype("float64")}
self.outputs = {'Out': self.inputs['X'].prod()}
self.attrs = {'dim': [], 'reduce_all': True}
def test_check_output(self):
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=True)
class TestProd6DOp(OpTest):
def setUp(self):
self.op_type = "reduce_prod"
......@@ -361,6 +408,18 @@ class TestAllOp(OpTest):
self.check_output(check_eager=True)
class TestAllOp_ZeroDim(OpTest):
def setUp(self):
self.python_api = paddle.all
self.op_type = "reduce_all"
self.inputs = {'X': np.random.randint(0, 2, []).astype("bool")}
self.outputs = {'Out': self.inputs['X'].all()}
self.attrs = {'dim': [], 'reduce_all': True}
def test_check_output(self):
self.check_output(check_eager=True)
class TestAll8DOp(OpTest):
def setUp(self):
self.op_type = "reduce_all"
......@@ -464,6 +523,18 @@ class TestAnyOp(OpTest):
self.check_output(check_eager=True)
class TestAnyOp_ZeroDim(OpTest):
def setUp(self):
self.python_api = paddle.any
self.op_type = "reduce_any"
self.inputs = {'X': np.random.randint(0, 2, []).astype("bool")}
self.outputs = {'Out': self.inputs['X'].any()}
self.attrs = {'dim': [], 'reduce_all': True}
def test_check_output(self):
self.check_output(check_eager=True)
class TestAny8DOp(OpTest):
def setUp(self):
self.op_type = "reduce_any"
......
......@@ -165,6 +165,14 @@ reduce_api_list = [
paddle.mean,
paddle.nansum,
paddle.nanmean,
paddle.min,
paddle.max,
paddle.amin,
paddle.amax,
paddle.prod,
paddle.logsumexp,
paddle.all,
paddle.any,
]
......@@ -173,6 +181,12 @@ class TestReduceAPI(unittest.TestCase):
paddle.disable_static()
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
for api in reduce_api_list:
if api in [paddle.all, paddle.any]:
x = paddle.randint(0, 2, []).astype('bool')
out = api(x, None)
self.assertEqual(x.shape, [])
self.assertEqual(out.shape, [])
else:
x = paddle.rand([])
x.stop_gradient = False
out = api(x, None)
......@@ -190,11 +204,13 @@ class TestReduceAPI(unittest.TestCase):
for api in reduce_api_list:
main_prog = fluid.Program()
with fluid.program_guard(main_prog, fluid.Program()):
if api in [paddle.all, paddle.any]:
x = paddle.randint(0, 2, []).astype('bool')
else:
x = paddle.rand([])
x.stop_gradient = False
out = api(x, None)
fluid.backward.append_backward(out)
# Test compile shape, grad is always [1]
self.assertEqual(x.shape, ())
......
......@@ -95,6 +95,44 @@ _supported_float_dtype_ = [
]
def _get_reduce_axis(axis, x):
"""
Internal function for max, min, amax and amin.
It computes the attribute reduce_all value based on axis.
"""
if axis is not None and not isinstance(axis, list):
if isinstance(axis, (tuple, range)):
axis = list(axis)
elif isinstance(axis, int):
axis = [axis]
else:
raise TypeError(
"The type of axis must be int, list or tuple, but received {}".format(
type(axis)
)
)
if axis is None:
axis = []
if axis == [] or len(axis) == len(x.shape):
reduce_all = True
else:
reduce_all = False
return reduce_all, axis
def _get_reduce_axis_with_tensor(axis, x):
if isinstance(axis, Variable):
if axis.shape[0] == len(x.shape):
reduce_all = True
else:
reduce_all = False
else:
reduce_all, axis = _get_reduce_axis(axis, x)
if utils._contain_var(axis):
axis = utils._convert_to_tensor_list(axis)
return reduce_all, axis
def log(x, name=None):
r"""
Calculates the natural log of the given input Tensor, element-wise.
......@@ -2204,19 +2242,9 @@ def logsumexp(x, axis=None, keepdim=False, name=None):
out2 = paddle.logsumexp(x, 1) # [2.15317821, 3.15684602]
"""
if isinstance(axis, int):
axis = [axis]
reduce_all = (
True
if axis is None or len(axis) == 0 or len(axis) == len(x.shape)
else False
)
if axis is None or len(axis) == 0:
axis = [0]
reduce_all, axis = _get_reduce_axis(axis, x)
if in_dygraph_mode():
if reduce_all:
axis = range(len(x.shape))
return _C_ops.logsumexp(x, axis, keepdim, reduce_all)
if _in_legacy_dygraph():
return _legacy_C_ops.logsumexp(
......@@ -2284,44 +2312,6 @@ def inverse(x, name=None):
return out
def _get_reduce_axis(axis, x):
"""
Internal function for max, min, amax and amin.
It computes the attribute reduce_all value based on axis.
"""
if axis is not None and not isinstance(axis, list):
if isinstance(axis, (tuple, range)):
axis = list(axis)
elif isinstance(axis, int):
axis = [axis]
else:
raise TypeError(
"The type of axis must be int, list or tuple, but received {}".format(
type(axis)
)
)
if axis is None:
axis = []
if axis == [] or len(axis) == len(x.shape):
reduce_all = True
else:
reduce_all = False
return reduce_all, axis
def _get_reduce_axis_with_tensor(axis, x):
if isinstance(axis, Variable):
if axis.shape[0] == len(x.shape):
reduce_all = True
else:
reduce_all = False
else:
reduce_all, axis = _get_reduce_axis(axis, x)
if utils._contain_var(axis):
axis = utils._convert_to_tensor_list(axis)
return reduce_all, axis
def max(x, axis=None, keepdim=False, name=None):
"""
......@@ -2515,8 +2505,6 @@ def min(x, axis=None, keepdim=False, name=None):
check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'int32', 'int64'], 'min'
)
if not isinstance(axis, Variable) and utils._contain_var(axis):
axis = utils._convert_to_tensor_list(axis)
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
......@@ -3681,35 +3669,13 @@ def prod(x, axis=None, keepdim=False, dtype=None, name=None):
if x.dtype != convert_np_dtype_to_dtype_(dtype):
x = cast(x, dtype)
dim = axis
if isinstance(dim, Variable):
reduce_all = True if axis.shape[0] == len(x.shape) else False
else:
if dim is not None and not isinstance(dim, list):
if isinstance(dim, tuple):
dim = list(dim)
elif isinstance(dim, int):
dim = [dim]
else:
raise TypeError(
"The type of axis must be int, list or tuple, but received {}".format(
type(dim)
)
)
reduce_all = (
True
if dim is None or len(dim) == 0 or len(dim) == len(x.shape)
else False
)
if dim is None or len(dim) == 0:
dim = [0]
reduce_all, axis = _get_reduce_axis_with_tensor(axis, x)
if in_dygraph_mode():
return _C_ops.prod(x, dim, keepdim, reduce_all)
return _C_ops.prod(x, axis, keepdim, reduce_all)
if _in_legacy_dygraph():
return _legacy_C_ops.reduce_prod(
x, 'dim', dim, 'keep_dim', keepdim, 'reduce_all', reduce_all
x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all
)
helper = LayerHelper('reduce_prod', **locals())
......@@ -3717,13 +3683,11 @@ def prod(x, axis=None, keepdim=False, dtype=None, name=None):
x, 'x/input', ['float32', 'float64', 'int32', 'int64'], 'reduce_prod'
)
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
if not isinstance(dim, Variable) and utils._contain_var(dim):
dim = utils._convert_to_tensor_list(dim)
helper.append_op(
type='reduce_prod',
inputs={'X': x},
outputs={'Out': out},
attrs={'dim': dim, 'keep_dim': keepdim, 'reduce_all': reduce_all},
attrs={'dim': axis, 'keep_dim': keepdim, 'reduce_all': reduce_all},
)
return out
......@@ -3904,32 +3868,20 @@ def all(x, axis=None, keepdim=False, name=None):
print(out4)
"""
if axis is not None and not isinstance(axis, (list, tuple)):
axis = [axis]
if not axis:
reduce_all_flag = True
else:
if len(axis) == len(x.shape):
reduce_all_flag = True
else:
reduce_all_flag = False
reduce_all, axis = _get_reduce_axis(axis, x)
if in_dygraph_mode():
if reduce_all_flag:
axis = range(len(x.shape))
return _C_ops.all(x, axis, keepdim)
if _in_legacy_dygraph():
axis = axis if axis is not None and axis != [] else [0]
return _legacy_C_ops.reduce_all(
x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all_flag
x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all
)
attrs = {
'dim': axis if axis is not None and axis != [] and axis != () else [0],
'dim': axis,
'keep_dim': keepdim,
'reduce_all': reduce_all_flag,
'reduce_all': reduce_all,
}
check_variable_and_dtype(x, 'x', ['bool'], 'all')
......@@ -3993,32 +3945,20 @@ def any(x, axis=None, keepdim=False, name=None):
print(out4)
"""
if axis is not None and not isinstance(axis, (list, tuple)):
axis = [axis]
if not axis:
reduce_all_flag = True
else:
if len(axis) == len(x.shape):
reduce_all_flag = True
else:
reduce_all_flag = False
reduce_all, axis = _get_reduce_axis(axis, x)
if in_dygraph_mode():
if reduce_all_flag:
axis = range(len(x.shape))
return _C_ops.any(x, axis, keepdim)
if _in_legacy_dygraph():
axis = axis if axis is not None and axis != [] else [0]
return _legacy_C_ops.reduce_any(
x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all_flag
x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all
)
attrs = {
'dim': axis if axis is not None and axis != [] and axis != () else [0],
'dim': axis,
'keep_dim': keepdim,
'reduce_all': reduce_all_flag,
'reduce_all': reduce_all,
}
check_variable_and_dtype(x, 'x', ['bool'], 'any')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册