From 0b092d05f1dfd68fe4afc575d232bd745f373791 Mon Sep 17 00:00:00 2001 From: wawltor Date: Sat, 4 Apr 2020 17:34:15 +0800 Subject: [PATCH] =?UTF-8?q?Add=20the=20argmax=20op=20to=20API=202.0?= =?UTF-8?q?=EF=BC=8C=20and=20update=20some=20parameters?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add the argmax op to API 2.0, test=develop * Fix the compiler problem in arg_max op, test=develop * Fix the import meesage in common_ops_import, test=develop * Fix the default dtype of arg_min_max, test=develop --- paddle/fluid/operators/arg_min_max_op_base.h | 64 ++++-- python/paddle/__init__.py | 2 +- .../tests/unittests/test_arg_min_max_op.py | 197 ++++++++++++++++++ python/paddle/tensor/__init__.py | 2 +- python/paddle/tensor/search.py | 91 +++++++- 5 files changed, 337 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/operators/arg_min_max_op_base.h b/paddle/fluid/operators/arg_min_max_op_base.h index bf7b83bb7a..76c8426123 100644 --- a/paddle/fluid/operators/arg_min_max_op_base.h +++ b/paddle/fluid/operators/arg_min_max_op_base.h @@ -38,26 +38,36 @@ struct ArgMinMaxFunctor {}; struct ArgMinMaxFunctor { \ void operator()(const DeviceContext& ctx, const framework::LoDTensor& in, \ - framework::LoDTensor* out, int64_t axis) { \ + framework::LoDTensor* out, int64_t axis, bool keepdims) { \ auto in_eigen = framework::EigenTensor::From(in); \ - auto out_eigen = framework::EigenTensor::From(*out); \ - out_eigen.device(*(ctx.eigen_device())) = \ - in_eigen.eigen_op_type(axis).template cast(); \ + if (keepdims) { \ + auto out_eigen = framework::EigenTensor::From(*out); \ + out_eigen.device(*(ctx.eigen_device())) = \ + in_eigen.eigen_op_type(axis).template cast(); \ + } else { \ + auto out_eigen = framework::EigenTensor::From(*out); \ + out_eigen.device(*(ctx.eigen_device())) = \ + in_eigen.eigen_op_type(axis).template cast(); \ + } \ } \ } DECLARE_ARG_MIN_MAX_FUNCTOR(argmin, ArgMinMaxType::kArgMin); DECLARE_ARG_MIN_MAX_FUNCTOR(argmax, ArgMinMaxType::kArgMax); -template -class ArgMinMaxKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { +template +struct VisitDataArgMinMaxFunctor { + const framework::ExecutionContext& ctx; + + explicit VisitDataArgMinMaxFunctor(const framework::ExecutionContext& ctx) + : ctx(ctx) {} + template + void apply() const { auto& x = *(ctx.Input("X")); auto& out = *(ctx.Output("Out")); - out.mutable_data(ctx.GetPlace()); + out.template mutable_data(ctx.GetPlace()); auto axis = ctx.Attr("axis"); + auto keepdims = ctx.Attr("keepdims"); auto x_rank = x.dims().size(); if (axis < 0) axis += x_rank; auto& dev_ctx = ctx.template device_context(); @@ -65,7 +75,7 @@ class ArgMinMaxKernel : public framework::OpKernel { #define CALL_ARG_MINMAX_FUNCTOR(rank) \ ArgMinMaxFunctor \ functor##rank; \ - functor##rank(dev_ctx, x, &out, axis) + functor##rank(dev_ctx, x, &out, axis, keepdims) switch (x.dims().size()) { case 1: @@ -97,13 +107,29 @@ class ArgMinMaxKernel : public framework::OpKernel { } }; +template +class ArgMinMaxKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dtype = ctx.Attr("dtype"); + if (dtype < 0) { + framework::VisitDataType( + static_cast( + framework::proto::VarType::INT64), + VisitDataArgMinMaxFunctor(ctx)); + return; + } + framework::VisitDataType( + static_cast(dtype), + VisitDataArgMinMaxFunctor(ctx)); + } +}; + template -using ArgMinKernel = - ArgMinMaxKernel; +using ArgMinKernel = ArgMinMaxKernel; template -using ArgMaxKernel = - ArgMinMaxKernel; +using ArgMaxKernel = ArgMinMaxKernel; class ArgMinMaxOp : public framework::OperatorWithKernel { public: @@ -114,14 +140,18 @@ class ArgMinMaxOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null"); const auto& x_dims = ctx->GetInputDim("X"); int64_t axis = ctx->Attrs().Get("axis"); + bool keepdims = ctx->Attrs().Get("keepdims"); + PADDLE_ENFORCE(axis >= -x_dims.size() && axis < x_dims.size(), "'axis' must be inside [-Rank(X), Rank(X))"); auto x_rank = x_dims.size(); if (axis < 0) axis += x_rank; - std::vector vec; for (int64_t i = 0; i < axis; i++) vec.push_back(x_dims[i]); + if (keepdims) { + vec.push_back(static_cast(1)); + } for (int64_t i = axis + 1; i < x_rank; i++) vec.push_back(x_dims[i]); ctx->SetOutputDim("Out", framework::make_ddim(vec)); } @@ -137,6 +167,8 @@ class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "Input tensor."); AddOutput("Out", "Output tensor."); AddAttr("axis", "The axis in which to compute the arg indics."); + AddAttr("keepdims", "Keep the dim that to reduce.").SetDefault(false); + AddAttr("dtype", "Keep the dim that to reduce.").SetDefault(-1); AddComment(string::Sprintf(R"DOC( %s Operator. diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 8a0ec03986..a1a8f30c3f 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -182,7 +182,7 @@ from .tensor.math import tanh #DEFINE_ALIAS from .tensor.manipulation import flip #DEFINE_ALIAS # from .tensor.manipulation import unbind #DEFINE_ALIAS # from .tensor.manipulation import roll #DEFINE_ALIAS -# from .tensor.search import argmax #DEFINE_ALIAS +from .tensor.search import argmax #DEFINE_ALIAS # from .tensor.search import argmin #DEFINE_ALIAS # from .tensor.search import argsort #DEFINE_ALIAS # from .tensor.search import has_inf #DEFINE_ALIAS diff --git a/python/paddle/fluid/tests/unittests/test_arg_min_max_op.py b/python/paddle/fluid/tests/unittests/test_arg_min_max_op.py index 4f9f1ec225..3741adf8bb 100644 --- a/python/paddle/fluid/tests/unittests/test_arg_min_max_op.py +++ b/python/paddle/fluid/tests/unittests/test_arg_min_max_op.py @@ -17,6 +17,9 @@ from __future__ import print_function import unittest import numpy as np from op_test import OpTest +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core class BaseTestCase(OpTest): @@ -88,5 +91,199 @@ class TestCase4(BaseTestCase): self.axis = 0 +class TestCase3(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, ) + self.axis = 0 + + +class BaseTestComplex1_1(OpTest): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (4, 5, 6) + self.dtype = 'int32' + self.axis = 2 + + def setUp(self): + self.initTestCase() + self.x = (np.random.random(self.dims)).astype(self.dtype) + self.inputs = {'X': self.x} + self.attrs = {'axis': self.axis} + self.attrs = {'dtype': int(core.VarDesc.VarType.INT32)} + if self.op_type == "arg_min": + self.outputs = { + 'Out': np.argmin( + self.x, axis=self.axis).asdtype("int32") + } + else: + self.outputs = { + 'Out': np.argmax( + self.x, axis=self.axis).asdtype("int32") + } + + +class BaseTestComplex1_2(OpTest): + def initTestCase(self): + self.op_type = 'arg_min' + self.dims = (4, 5, 6) + self.dtype = 'int32' + self.axis = 2 + + def setUp(self): + self.initTestCase() + self.x = (np.random.random(self.dims)).astype(self.dtype) + self.inputs = {'X': self.x} + self.attrs = {'axis': self.axis} + self.attrs = {'dtype': int(core.VarDesc.VarType.INT32)} + if self.op_type == "arg_min": + self.outputs = { + 'Out': np.argmin( + self.x, axis=self.axis).asdtype("int32") + } + else: + self.outputs = { + 'Out': np.argmax( + self.x, axis=self.axis).asdtype("int32") + } + + +class BaseTestComplex2_1(OpTest): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (4, 5, 6) + self.dtype = 'int32' + self.axis = 2 + + def setUp(self): + self.initTestCase() + self.x = (np.random.random(self.dims)).astype(self.dtype) + self.inputs = {'X': self.x} + self.attrs = {'axis': self.axis} + self.attrs = {'dtype': int(core.VarDesc.VarType.INT32)} + self.attrs = {'keep_dims': True} + if self.op_type == "arg_min": + self.outputs = { + 'Out': np.argmin( + self.x, axis=self.axis).asdtype("int32").reshape(4, 5, 1) + } + else: + self.outputs = { + 'Out': np.argmax( + self.x, axis=self.axis).asdtype("int32").reshape(4, 5, 1) + } + + +class BaseTestComplex2_2(OpTest): + def initTestCase(self): + self.op_type = 'arg_min' + self.dims = (4, 5, 6) + self.dtype = 'int32' + self.axis = 2 + + def setUp(self): + self.initTestCase() + self.x = (np.random.random(self.dims)).astype(self.dtype) + self.inputs = {'X': self.x} + self.attrs = {'axis': self.axis} + self.attrs = {'dtype': int(core.VarDesc.VarType.INT32)} + self.attrs = {'keep_dims': True} + if self.op_type == "arg_min": + self.outputs = { + 'Out': np.argmin( + self.x, axis=self.axis).asdtype("int32").reshape(4, 5, 1) + } + else: + self.outputs = { + 'Out': np.argmax( + self.x, axis=self.axis).asdtype("int32").reshape(4, 5, 1) + } + + +class APT_ArgMaxTest(unittest.TestCase): + def test_output_result(self): + with fluid.program_guard(fluid.Program()): + data1 = fluid.data(name="X", shape=[3, 4], dtype="float32") + data2 = fluid.data(name="Y", shape=[3], dtype="int64") + out = paddle.argmax(input=data1, out=data2) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + result = exe.run( + feed={"X": np.random.rand(3, 4).astype("float32")}, + fetch_list=[data2, out]) + self.assertEqual((result[0] == result[1]).all(), True) + + def test_basic(self): + with fluid.program_guard(fluid.Program()): + data = fluid.data(name="X", shape=[3, 4], dtype="float32") + out = paddle.argmax(input=data) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + np_input = np.random.rand(3, 4).astype("float32") + expected_result = np.argmax(np_input, axis=1) + + result, = exe.run(feed={"X": np_input}, fetch_list=[out]) + self.assertEqual((result == expected_result).all(), True) + + with fluid.program_guard(fluid.Program()): + data = fluid.data(name="X", shape=[3, 4], dtype="float32") + out = paddle.argmax(input=data, axis=0) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + np_input = np.random.rand(3, 4).astype("float32") + expected_result = np.argmax(np_input, axis=0) + + result = exe.run(feed={"X": np_input}, fetch_list=[out]) + self.assertEqual((result == expected_result).all(), True) + + with fluid.program_guard(fluid.Program()): + data = fluid.data(name="X", shape=[3, 4], dtype="float32") + out = paddle.argmax(input=data, dtype="int32") + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + np_input = np.random.rand(3, 4).astype("float32") + expected_result = np.argmax(np_input, axis=1).astype(np.int32) + + result = exe.run(feed={"X": np_input}, fetch_list=[out]) + self.assertEqual((result == expected_result).all(), True) + + with fluid.program_guard(fluid.Program()): + data1 = fluid.data(name="X", shape=[3, 4], dtype="float32") + data2 = fluid.data(name="Y", shape=[3], dtype="int64") + out = paddle.argmax(input=data, out=data2) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + result = exe.run( + feed={"X": np.random.rand(3, 4).astype("float32")}, + fetch_list=[data2, out]) + self.assertEqual((result[0] == result[1]).all(), True) + + def test_name(self): + with fluid.program_guard(fluid.Program()): + x = fluid.data(name="x", shape=[100], dtype="float32") + y_1 = paddle.argmax(x, name='arg_max_res') + self.assertEqual(('arg_max_res' in y_1.name), True) + + def test_errors(self): + def test_dtype1(): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data = fluid.data(name="data", shape=[10], dtype="float32") + paddle.argmax(data, dtype="float32") + + self.assertRaises(TypeError, test_dtype1) + + def test_dtype2(): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data = fluid.data(name="data", shape=[10], dtype="float64") + paddle.argmax(data, dtype="float32") + + self.assertRaises(TypeError, test_dtype2) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 51cf5da6df..4c42d143ce 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -158,7 +158,7 @@ from .math import tanh #DEFINE_ALIAS from .manipulation import flip #DEFINE_ALIAS # from .manipulation import unbind #DEFINE_ALIAS # from .manipulation import roll #DEFINE_ALIAS -# from .search import argmax #DEFINE_ALIAS +from .search import argmax #DEFINE_ALIAS # from .search import argmin #DEFINE_ALIAS # from .search import argsort #DEFINE_ALIAS # from .search import has_inf #DEFINE_ALIAS diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 59c89797bb..5d75e56dde 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -28,7 +28,96 @@ __all__ = [ ] from paddle.common_ops_import import * -import warnings + + +def argmax(input, axis=None, dtype=None, out=None, keepdims=False, name=None): + """ + This OP computes the indices of the max elements of the input tensor's + element along the provided axis. + + Args: + input(Variable): An input N-D Tensor with type float32, float64, int16, + int32, int64, uint8. + axis(int, optional): Axis to compute indices along. The effective range + is [-R, R), where R is Rank(input). when axis<0, it works the same way + as axis+R. Default is None, it will use the last dim to select indices of max value. + dtype(np.dtype|core.VarDesc.VarType|str): Data type of the output tensor which can + be int32, int64. The default value is None, and it will + return the int64 indices. + out(Variable, optional): Optional output which can be any created + Variable that meets the requirements to store the result of operation. + if out is None, a new Varibale will be create to store the result. Defalut is None. + keepdims(bool, optional): Keep the axis that do the select max. + name(str, optional): The name of output variable, normally there is no need for user to set this this property. + Default value is None, the framework set the name of output variable. + + + Returns: + Variable: A Tensor with data type int64. + + Examples: + .. code-block:: python + + import paddle + import paddle.fluid as fluid + import numpy as np + + in1 = np.array([[[5,8,9,5], + [0,0,1,7], + [6,9,2,4]], + [[5,2,4,2], + [4,7,7,9], + [1,7,0,6]]]) + with fluid.dygraph.guard(): + x = fluid.dygraph.to_variable(in1) + out1 = paddle.argmax(input=x, axis=-1) + out2 = paddle.argmax(input=x, axis=0) + out3 = paddle.argmax(input=x, axis=1) + out4 = paddle.argmax(input=x, axis=2) + out5 = paddle.argmax(input=x, axis=2, keepdims=True) + print(out1.numpy()) + # [[2 3 1] + # [0 3 1]] + print(out2.numpy()) + # [[0 0 0 0] + # [1 1 1 1] + # [0 0 0 1]] + print(out3.numpy()) + # [[2 2 0 1] + # [0 1 1 1]] + print(out4.numpy()) + # [[2 3 1] + # [0 3 1]] + print(out5.numpy()) + #array([[[2], + # [3], + # [1]], + # [[0], + # [3], + # [1]]]) + """ + helper = LayerHelper("arg_max", **locals()) + var_dtype = None + attrs = {} + if dtype is not None: + check_dtype(dtype, 'create data type', ['int32', 'int64'], 'arg_max') + var_dtype = convert_np_dtype_to_dtype_(dtype) + attrs["dtype"] = var_dtype + else: + var_dtype = VarDesc.VarType.INT64 + if out is None: + out = helper.create_variable_for_type_inference(var_dtype) + if axis is None: + axis = -1 + attrs['keepdims'] = keepdims + attrs['axis'] = axis + helper.append_op( + type='arg_max', + inputs={'X': input}, + outputs={'Out': [out]}, + attrs=attrs) + out.stop_gradient = True + return out def sort(input, axis=-1, descending=False, out=None, name=None): -- GitLab