未验证 提交 2d6cc0b1 编写于 作者: W wawltor 提交者: GitHub

support the tuple for attribute of axis in min, max for api2.0

Update the code for the min,max, test=develop
上级 68203566
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include <set>
#include <string>
#include <vector>
......@@ -98,6 +99,18 @@ class ReduceKernel : public framework::OpKernel<T> {
int out_dtype = context.Attr<int>("out_dtype");
framework::proto::VarType::Type cast_out_dtype;
// The dims has full dim, set the reduce_all is True
const auto& input_dim_size = context.Input<Tensor>("X")->dims().size();
std::set<int> dims_set(dims.begin(), dims.end());
bool full_dim = true;
for (auto i = 0; i < input_dim_size; i++) {
if (dims_set.find(i) == dims_set.end()) {
full_dim = false;
break;
}
}
reduce_all = (reduce_all || full_dim);
if (out_dtype < 0) {
auto* cast_input = context.Input<Tensor>("X");
cast_out_dtype =
......@@ -137,6 +150,18 @@ class BoolReduceKernel : public framework::OpKernel<OutT> {
auto dims = context.Attr<std::vector<int>>("dim");
bool keep_dim = context.Attr<bool>("keep_dim");
// The dims has full dim, set the reduce_all is True
const auto& input_dim_size = context.Input<Tensor>("X")->dims().size();
std::set<int> dims_set(dims.begin(), dims.end());
bool full_dim = true;
for (auto i = 0; i < input_dim_size; i++) {
if (dims_set.find(i) == dims_set.end()) {
full_dim = false;
break;
}
}
reduce_all = (reduce_all || full_dim);
if (reduce_all) {
// Flatten and reduce 1-D tensor
auto x = EigenVector<OutT>::Flatten(*input);
......@@ -183,6 +208,17 @@ class ReduceGradKernel : public framework::OpKernel<T> {
auto* output = context.Output<Tensor>(framework::GradVarName("X"));
output->mutable_data<T>(context.GetPlace());
// The dims has full dim, set the reduce_all is True
const auto& input_dim_size = context.Input<Tensor>("X")->dims().size();
std::set<int> dims_set(dims.begin(), dims.end());
bool full_dim = true;
for (auto i = 0; i < input_dim_size; i++) {
if (dims_set.find(i) == dims_set.end()) {
full_dim = false;
break;
}
}
reduce_all = (reduce_all || full_dim);
// NOTE: EigenTensor::From() uses tensor->data()
// if op has NoNeedBufferVarsInferer, the corresponding kNoNeedBufferX or
// kNoNeedBufferY should set true
......
......@@ -48,6 +48,15 @@ class ApiMaxTest(unittest.TestCase):
res, = exe.run(feed={"data": input_data}, fetch_list=[result_max])
self.assertEqual((res == np.max(input_data, axis=0)).all(), True)
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data = paddle.nn.data("data", shape=[10, 10], dtype="int64")
result_max = paddle.max(x=data, axis=(0, 1))
exe = paddle.static.Executor(self.place)
input_data = np.random.randint(10, size=(10, 10)).astype(np.int64)
res, = exe.run(feed={"data": input_data}, fetch_list=[result_max])
self.assertEqual((res == np.max(input_data, axis=(0, 1))).all(), True)
def test_errors(self):
paddle.enable_static()
......@@ -59,6 +68,15 @@ class ApiMaxTest(unittest.TestCase):
self.assertRaises(TypeError, test_input_type)
def test_axis_type():
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data = paddle.nn.data("data", shape=[10, 10], dtype="int64")
axis = paddle.nn.data("axis", shape=[10, 10], dtype="int64")
result_min = paddle.min(data, axis)
self.assertRaises(TypeError, test_axis_type)
def test_imperative_api(self):
paddle.disable_static()
np_x = np.array([10, 10]).astype('float64')
......
......@@ -48,6 +48,15 @@ class ApiMinTest(unittest.TestCase):
res, = exe.run(feed={"data": input_data}, fetch_list=[result_min])
self.assertEqual((res == np.min(input_data, axis=0)).all(), True)
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data = paddle.nn.data("data", shape=[10, 10], dtype="int64")
result_min = paddle.min(x=data, axis=(0, 1))
exe = paddle.static.Executor(self.place)
input_data = np.random.randint(10, size=(10, 10)).astype(np.int64)
res, = exe.run(feed={"data": input_data}, fetch_list=[result_min])
self.assertEqual((res == np.min(input_data, axis=(0, 1))).all(), True)
def test_errors(self):
paddle.enable_static()
......@@ -59,6 +68,15 @@ class ApiMinTest(unittest.TestCase):
self.assertRaises(TypeError, test_input_type)
def test_axis_type():
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data = paddle.nn.data("data", shape=[10, 10], dtype="int64")
axis = paddle.nn.data("axis", shape=[10, 10], dtype="int64")
result_min = paddle.min(data, axis)
self.assertRaises(TypeError, test_axis_type)
def test_imperative_api(self):
paddle.disable_static()
np_x = np.array([10, 10]).astype('float64')
......
......@@ -1177,19 +1177,19 @@ def max(x, axis=None, keepdim=False, name=None):
float64, int32, int64.
axis(list|int, optional): The axis along which the maximum is computed.
If :attr:`None`, compute the maximum over all elements of
:attr:`input` and return a Tensor variable with a single element,
`x` and return a Tensor variable with a single element,
otherwise must be in the range :math:`[-x.ndim(x), x.ndim(x))`.
If :math:`axis[i] < 0`, the axis to reduce is :math:`x.ndim + axis[i]`.
keepdim(bool, optional): Whether to reserve the reduced dimension in the
output Tensor. The result tensor will have one fewer dimension
than the :attr:`input` unless :attr:`keepdim` is true, default
than the `x` unless :attr:`keepdim` is true, default
value is False.
name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`
Returns:
Tensor, results of maximum on the specified axis of input tensor,
it's data type is the same as input's Tensor.
it's data type is the same as `x`.
Examples:
.. code-block:: python
......@@ -1232,7 +1232,14 @@ def max(x, axis=None, keepdim=False, name=None):
"""
if axis is not None and not isinstance(axis, list):
axis = [axis]
if isinstance(axis, tuple):
axis = list(axis)
elif isinstance(axis, int):
axis= [axis]
else:
raise TypeError(
"The type of axis must be int, list or tuple, but received {}".format(type(axis)))
reduce_all = True if axis == None or axis == [] else False
axis = axis if axis != None and axis != [] else [0]
if in_dygraph_mode():
......@@ -1265,12 +1272,12 @@ def min(x, axis=None, keepdim=False, name=None):
x(Tensor): A tensor, the data type is float32, float64, int32, int64.
axis(list|int, optional): The axis along which the minimum is computed.
If :attr:`None`, compute the minimum over all elements of
:attr:`input` and return a Tensor variable with a single element,
`x` and return a Tensor variable with a single element,
otherwise must be in the range :math:`[-x.ndim, x.ndim)`.
If :math:`axis[i] < 0`, the axis to reduce is :math:`x.ndim + axis[i]`.
keepdim(bool, optional): Whether to reserve the reduced dimension in the
output Tensor. The result tensor will have one fewer dimension
than the :attr:`input` unless :attr:`keepdim` is true, default
than the `x` unless :attr:`keepdim` is true, default
value is False.
name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`
......@@ -1320,7 +1327,13 @@ def min(x, axis=None, keepdim=False, name=None):
"""
if axis is not None and not isinstance(axis, list):
axis= [axis]
if isinstance(axis, tuple):
axis = list(axis)
elif isinstance(axis, int):
axis= [axis]
else:
raise TypeError(
"The type of axis must be int, list or tuple, but received {}".format(type(axis)))
reduce_all = True if axis == None or axis == [] else False
axis = axis if axis != None and axis != [] else [0]
if in_dygraph_mode():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册