diff --git a/paddle/fluid/operators/arg_min_max_op_base.cu.h b/paddle/fluid/operators/arg_min_max_op_base.cu.h index e9b65946a5b1e1e69932332126566427affebe8d..73581dac4e419ca9c970db4414ff54d4cbd3fd70 100644 --- a/paddle/fluid/operators/arg_min_max_op_base.cu.h +++ b/paddle/fluid/operators/arg_min_max_op_base.cu.h @@ -53,9 +53,9 @@ using Tensor = framework::Tensor; FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__); template -__global__ void ArgCUDAKernel(const IndType height, // n * h - const IndType width, // c - const IndType post_size, // h +__global__ void ArgCUDAKernel(const int64_t height, // n * h + const int64_t width, // c + const int64_t post_size, // h const Reducer reducer, const T init, const T* in, IndType* out) { typedef cub::BlockReduce, BlockDim> BlockReduce; @@ -79,10 +79,10 @@ __global__ void ArgCUDAKernel(const IndType height, // n * h template void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input, - Tensor* indices, const IndType pre, const IndType post, - const IndType n) { + Tensor* indices, const int64_t pre, const int64_t post, + const int64_t n) { auto cu_stream = ctx.stream(); - auto ComputeBlockSize = [](IndType col) { + auto ComputeBlockSize = [](int64_t col) { if (col > 512) return 1024; else if (col > 256) @@ -101,10 +101,10 @@ void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input, return 8; }; - int max_grid_dimx = ctx.GetCUDAMaxGridDimSize().x; - int height = pre * post; - int width = n; - int grid_size = height < max_grid_dimx ? height : max_grid_dimx; + int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize().x; + int64_t height = pre * post; + int64_t width = n; + int64_t grid_size = height < max_grid_dimx ? height : max_grid_dimx; const T* in_data = input.data(); IndType* out_data = indices->mutable_data(ctx.GetPlace()); @@ -129,31 +129,60 @@ void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input, } template -class ArgMinMaxOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { +struct VisitDataCudaArgMinMaxFunctor { + const framework::ExecutionContext& ctx; + + explicit VisitDataCudaArgMinMaxFunctor(const framework::ExecutionContext& ctx) + : ctx(ctx) {} + template + void apply() const { auto* input = ctx.Input("X"); auto* output = ctx.Output("Out"); int axis = ctx.Attr("axis"); - auto in_dims = input->dims(); - axis = (axis < 0) ? (in_dims.size() + axis) : axis; + const bool& flatten = ctx.Attr("flatten"); + + framework::DDim input_dims; + if (flatten) { + input_dims = framework::make_ddim({input->numel()}); + // if flatten, the axis just as 0 + axis = 0; + } else { + input_dims = input->dims(); + if (axis < 0) axis += input->dims().size(); + } int64_t numel = input->numel(); - int64_t groups = numel / in_dims[axis]; + int64_t groups = numel / input_dims[axis]; int64_t pre = 1; int64_t post = 1; - int64_t n = in_dims[axis]; + int64_t n = input_dims[axis]; for (int i = 0; i < axis; i++) { - pre *= in_dims[i]; + pre *= input_dims[i]; } - for (int i = axis + 1; i < in_dims.size(); i++) { - post *= in_dims[i]; + for (int i = axis + 1; i < input_dims.size(); i++) { + post *= input_dims[i]; } const auto& dev_ctx = ctx.cuda_device_context(); - ComputeFullArg(dev_ctx, *input, output, pre, post, n); + ComputeFullArg(dev_ctx, *input, output, pre, post, n); + } +}; +template +class ArgMinMaxOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dtype = ctx.Attr("dtype"); + if (dtype < 0) { + framework::VisitDataType(static_cast( + framework::proto::VarType::INT64), + VisitDataCudaArgMinMaxFunctor(ctx)); + return; + } + framework::VisitDataType( + static_cast(dtype), + VisitDataCudaArgMinMaxFunctor(ctx)); } }; diff --git a/paddle/fluid/operators/arg_min_max_op_base.h b/paddle/fluid/operators/arg_min_max_op_base.h index 0fc7b47c62ea9d7da805b797fcf5e4db4e39328d..ae3637f6f99783d70bd57a3935a979b0387692de 100644 --- a/paddle/fluid/operators/arg_min_max_op_base.h +++ b/paddle/fluid/operators/arg_min_max_op_base.h @@ -38,8 +38,9 @@ struct ArgMinMaxFunctor {}; struct ArgMinMaxFunctor { \ void operator()(const DeviceContext& ctx, const framework::LoDTensor& in, \ - framework::LoDTensor* out, int64_t axis, bool keepdims) { \ - auto in_eigen = framework::EigenTensor::From(in); \ + framework::LoDTensor* out, framework::DDim x_dims, \ + int64_t axis, bool keepdims) { \ + auto in_eigen = framework::EigenTensor::From(in, x_dims); \ if (keepdims) { \ auto out_eigen = framework::EigenTensor::From(*out); \ out_eigen.device(*(ctx.eigen_device())) = \ @@ -68,16 +69,26 @@ struct VisitDataArgMinMaxFunctor { out.template mutable_data(ctx.GetPlace()); auto axis = ctx.Attr("axis"); auto keepdims = ctx.Attr("keepdims"); - auto x_rank = x.dims().size(); - if (axis < 0) axis += x_rank; + const bool& flatten = ctx.Attr("flatten"); + + // if flatten, will construct the new dims for the cacluate + framework::DDim x_dims; + if (flatten) { + x_dims = framework::make_ddim({x.numel()}); + // if flatten, the axis just as 0 + axis = 0; + } else { + x_dims = x.dims(); + if (axis < 0) axis += x_dims.size(); + } auto& dev_ctx = ctx.template device_context(); #define CALL_ARG_MINMAX_FUNCTOR(rank) \ ArgMinMaxFunctor \ functor##rank; \ - functor##rank(dev_ctx, x, &out, axis, keepdims) + functor##rank(dev_ctx, x, &out, x_dims, axis, keepdims) - switch (x.dims().size()) { + switch (x_dims.size()) { case 1: CALL_ARG_MINMAX_FUNCTOR(1); break; @@ -141,6 +152,7 @@ class ArgMinMaxOp : public framework::OperatorWithKernel { const auto& x_dims = ctx->GetInputDim("X"); int64_t axis = ctx->Attrs().Get("axis"); bool keepdims = ctx->Attrs().Get("keepdims"); + const bool& flatten = ctx->Attrs().Get("flatten"); PADDLE_ENFORCE_GE(axis, -x_dims.size(), platform::errors::InvalidArgument( @@ -152,14 +164,21 @@ class ArgMinMaxOp : public framework::OperatorWithKernel { platform::errors::InvalidArgument( "'axis'(%d) must be less than Rank(X)(%d).", axis, x_dims.size())); - auto x_rank = x_dims.size(); - if (axis < 0) axis += x_rank; std::vector vec; - for (int64_t i = 0; i < axis; i++) vec.push_back(x_dims[i]); - if (keepdims) { - vec.push_back(static_cast(1)); + if (flatten) { + // if is flatten, will return the only on element + if (keepdims) { + vec.emplace_back(static_cast(1)); + } + } else { + auto x_rank = x_dims.size(); + if (axis < 0) axis += x_rank; + for (int64_t i = 0; i < axis; i++) vec.emplace_back(x_dims[i]); + if (keepdims) { + vec.emplace_back(static_cast(1)); + } + for (int64_t i = axis + 1; i < x_rank; i++) vec.emplace_back(x_dims[i]); } - for (int64_t i = axis + 1; i < x_rank; i++) vec.push_back(x_dims[i]); ctx->SetOutputDim("Out", framework::make_ddim(vec)); } }; @@ -176,6 +195,9 @@ class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("axis", "The axis in which to compute the arg indics."); AddAttr("keepdims", "Keep the dim that to reduce.").SetDefault(false); AddAttr("dtype", "Keep the dim that to reduce.").SetDefault(-1); + AddAttr("flatten", + "Flatten the input value, and search the min or max indices") + .SetDefault(false); AddComment(string::Sprintf(R"DOC( %s Operator. diff --git a/python/paddle/fluid/tests/unittests/test_arg_min_max_op.py b/python/paddle/fluid/tests/unittests/test_arg_min_max_op.py index 0201f0635a5afeb285cdbca3e8d526a1ff5032f2..3639c4dea0a3a12aa46d2875affeebd4c623a4dd 100644 --- a/python/paddle/fluid/tests/unittests/test_arg_min_max_op.py +++ b/python/paddle/fluid/tests/unittests/test_arg_min_max_op.py @@ -201,107 +201,5 @@ class BaseTestComplex2_2(OpTest): } -class APT_ArgMaxTest(unittest.TestCase): - def test_output_result(self): - with fluid.program_guard(fluid.Program()): - data1 = fluid.data(name="X", shape=[3, 4], dtype="float32") - data2 = fluid.data(name="Y", shape=[3], dtype="int64") - out = paddle.argmax(input=data1, out=data2) - - place = fluid.CPUPlace() - exe = fluid.Executor(place) - result = exe.run( - feed={"X": np.random.rand(3, 4).astype("float32")}, - fetch_list=[data2, out]) - self.assertEqual((result[0] == result[1]).all(), True) - - def test_basic(self): - with fluid.program_guard(fluid.Program()): - data = fluid.data(name="X", shape=[3, 4], dtype="float32") - out = paddle.argmax(input=data) - - place = fluid.CPUPlace() - exe = fluid.Executor(place) - np_input = np.random.rand(3, 4).astype("float32") - expected_result = np.argmax(np_input, axis=1) - - result, = exe.run(feed={"X": np_input}, fetch_list=[out]) - self.assertEqual((result == expected_result).all(), True) - - with fluid.program_guard(fluid.Program()): - data = fluid.data(name="X", shape=[3, 4], dtype="float32") - out = paddle.argmax(input=data, axis=0) - - place = fluid.CPUPlace() - exe = fluid.Executor(place) - np_input = np.random.rand(3, 4).astype("float32") - expected_result = np.argmax(np_input, axis=0) - - result = exe.run(feed={"X": np_input}, fetch_list=[out]) - self.assertEqual((result == expected_result).all(), True) - - with fluid.program_guard(fluid.Program()): - data = fluid.data(name="X", shape=[3, 4], dtype="float32") - out = paddle.argmax(input=data, dtype="int32") - - place = fluid.CPUPlace() - exe = fluid.Executor(place) - np_input = np.random.rand(3, 4).astype("float32") - expected_result = np.argmax(np_input, axis=1).astype(np.int32) - - result = exe.run(feed={"X": np_input}, fetch_list=[out]) - self.assertEqual((result == expected_result).all(), True) - - with fluid.program_guard(fluid.Program()): - data1 = fluid.data(name="X", shape=[3, 4], dtype="float32") - data2 = fluid.data(name="Y", shape=[3], dtype="int64") - out = paddle.argmax(input=data, out=data2) - - place = fluid.CPUPlace() - exe = fluid.Executor(place) - result = exe.run( - feed={"X": np.random.rand(3, 4).astype("float32")}, - fetch_list=[data2, out]) - self.assertEqual((result[0] == result[1]).all(), True) - - def test_name(self): - with fluid.program_guard(fluid.Program()): - x = fluid.data(name="x", shape=[100], dtype="float32") - y_1 = paddle.argmax(x, name='arg_max_res') - self.assertEqual(('arg_max_res' in y_1.name), True) - - def test_errors(self): - def test_dtype1(): - with fluid.program_guard(fluid.Program(), fluid.Program()): - data = fluid.data(name="data", shape=[10], dtype="float32") - paddle.argmax(data, dtype="float32") - - self.assertRaises(TypeError, test_dtype1) - - def test_dtype2(): - with fluid.program_guard(fluid.Program(), fluid.Program()): - data = fluid.data(name="data", shape=[10], dtype="float64") - paddle.argmax(data, dtype="float32") - - self.assertRaises(TypeError, test_dtype2) - - -class TestArgMinMaxOpError(unittest.TestCase): - def test_errors(self): - with program_guard(Program(), Program()): - - def test_argmax_x_type(): - x1 = [1, 2, 3] - output = fluid.layers.argmax(x=x1) - - self.assertRaises(TypeError, test_argmax_x_type) - - def test_argmin_x_type(): - x2 = [1, 2, 3] - output = fluid.layers.argmin(x=x2) - - self.assertRaises(TypeError, test_argmin_x_type) - - if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py b/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py new file mode 100644 index 0000000000000000000000000000000000000000..7c1f9d802c31ac2c3b244541936ba25018e1487a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py @@ -0,0 +1,313 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid import Program, program_guard + + +def create_kernel_case(op_type, numpy_op_type): + class ArgMinMaxKernelBaseCase(OpTest): + def initTestCase(self): + self.op_type = op_type + self.numpy_op_type = numpy_op_type + self.axis = 0 + + def setUp(self): + np.random.seed(123) + self.initTestCase() + self.dims = (4, 5, 6) + self.dtype = "float64" + self.x = (1000 * np.random.random(self.dims).astype(self.dtype)) + self.inputs = {'X': self.x} + self.attrs = {"axis": self.axis} + self.numpy_op = eval("np.%s" % (numpy_op_type)) + self.outputs = {'Out': self.numpy_op(self.x, axis=self.axis)} + + def test_check_output(self): + paddle.enable_static() + self.check_output() + + class ArgMinMaxKernelCase0(ArgMinMaxKernelBaseCase): + def initTestCase(self): + self.op_type = op_type + self.numpy_op_type = numpy_op_type + self.axis = 1 + + class ArgMinMaxKernelCase1(ArgMinMaxKernelBaseCase): + def initTestCase(self): + self.op_type = op_type + self.numpy_op_type = numpy_op_type + self.axis = 2 + + class ArgMinMaxKernelCase2(ArgMinMaxKernelBaseCase): + def initTestCase(self): + self.op_type = op_type + self.numpy_op_type = numpy_op_type + self.axis = -1 + + class ArgMinMaxKernelCase3(ArgMinMaxKernelBaseCase): + def initTestCase(self): + self.op_type = op_type + self.numpy_op_type = numpy_op_type + self.axis = -2 + + class ArgMinMaxKernelCase4(ArgMinMaxKernelBaseCase): + def setUp(self): + self.initTestCase() + self.dims = (4, 5, 6) + self.dtype = "float64" + self.x = (1000 * np.random.random(self.dims).astype(self.dtype)) + self.inputs = {'X': self.x} + self.attrs = {"axis": self.axis, "keepdims": True} + self.numpy_op = eval("np.%s" % (numpy_op_type)) + self.outputs = { + 'Out': self.numpy_op( + self.x, axis=self.axis).reshape((1, 5, 6)) + } + + class ArgMinMaxKernelCase5(ArgMinMaxKernelBaseCase): + def setUp(self): + self.initTestCase() + self.dims = (4) + self.dtype = "float64" + self.x = (1000 * np.random.random(self.dims).astype(self.dtype)) + self.inputs = {'X': self.x} + self.attrs = {"axis": self.axis, "flatten": True} + self.numpy_op = eval("np.%s" % (numpy_op_type)) + self.outputs = { + 'Out': self.numpy_op( + self.x.flatten(), axis=self.axis) + } + + class ArgMinMaxKernelCase6(ArgMinMaxKernelBaseCase): + def setUp(self): + self.initTestCase() + self.dims = (4) + self.dtype = "float64" + self.x = (1000 * np.random.random(self.dims).astype(self.dtype)) + self.inputs = {'X': self.x} + self.attrs = {"axis": self.axis, "flatten": True, "keepdims": True} + self.numpy_op = eval("np.%s" % (numpy_op_type)) + self.outputs = { + 'Out': + np.array(self.numpy_op( + self.x.flatten(), axis=self.axis)) + } + + cls_name = "ArgMinMaxKernelBaseCase_%s" % (op_type) + ArgMinMaxKernelBaseCase.__name__ = cls_name + globals()[cls_name] = ArgMinMaxKernelBaseCase + + cls_name = "ArgMinMaxKernelCase0_%s" % (op_type) + ArgMinMaxKernelCase0.__name__ = cls_name + globals()[cls_name] = ArgMinMaxKernelCase0 + + cls_name = "ArgMinMaxKernelCase1_%s" % (op_type) + ArgMinMaxKernelCase1.__name__ = cls_name + globals()[cls_name] = ArgMinMaxKernelCase1 + + cls_name = "ArgMinMaxKernelCase2_%s" % (op_type) + ArgMinMaxKernelCase2.__name__ = cls_name + globals()[cls_name] = ArgMinMaxKernelCase2 + + cls_name = "ArgMinMaxKernelCase3_%s" % (op_type) + ArgMinMaxKernelCase3.__name__ = cls_name + globals()[cls_name] = ArgMinMaxKernelCase3 + + cls_name = "ArgMinMaxKernelCase4_%s" % (op_type) + ArgMinMaxKernelCase4.__name__ = cls_name + globals()[cls_name] = ArgMinMaxKernelCase4 + + cls_name = "ArgMinMaxKernelCase5_%s" % (op_type) + ArgMinMaxKernelCase5.__name__ = cls_name + globals()[cls_name] = ArgMinMaxKernelCase5 + + cls_name = "ArgMinMaxKernelCase6_%s" % (op_type) + ArgMinMaxKernelCase6.__name__ = cls_name + globals()[cls_name] = ArgMinMaxKernelCase6 + + +for op_type, numpy_op_type in zip(['arg_max', 'arg_min'], ['argmax', 'argmin']): + create_kernel_case(op_type, numpy_op_type) + + +def create_test_case(op_type): + class ArgMaxMinTestCase(unittest.TestCase): + def setUp(self): + np.random.seed(123) + self.input_data = np.random.rand(10, 10).astype("float32") + self.places = [] + self.places.append(fluid.CPUPlace()) + if core.is_compiled_with_cuda(): + self.places.append(paddle.CUDAPlace(0)) + self.op = eval("paddle.%s" % (op_type)) + self.numpy_op = eval("np.%s" % (op_type)) + + def run_static(self, place): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + data_var = paddle.static.data( + name="data", shape=[10, 10], dtype="float32") + op = eval("paddle.%s" % (op_type)) + result = op(data_var) + exe = paddle.static.Executor(place) + result_data = exe.run(feed={"data": self.input_data}, + fetch_list=[result]) + expected_data = self.numpy_op(self.input_data) + self.assertTrue((result_data == np.array(expected_data)).all(), + True) + + with paddle.static.program_guard(paddle.static.Program()): + data_var = paddle.static.data( + name="data", shape=[10, 10], dtype="float32") + op = eval("paddle.%s" % (op_type)) + result = op(data_var, axis=1) + exe = paddle.static.Executor(place) + result_data = exe.run(feed={"data": self.input_data}, + fetch_list=[result]) + expected_data = self.numpy_op(self.input_data, axis=1) + self.assertTrue((result_data == expected_data).all(), True) + + with paddle.static.program_guard(paddle.static.Program()): + data_var = paddle.static.data( + name="data", shape=[10, 10], dtype="float32") + op = eval("paddle.%s" % (op_type)) + result = op(data_var, axis=-1) + exe = paddle.static.Executor(place) + result_data = exe.run(feed={"data": self.input_data}, + fetch_list=[result]) + expected_data = self.numpy_op(self.input_data, axis=-1) + self.assertTrue((result_data == expected_data).all(), True) + + with paddle.static.program_guard(paddle.static.Program()): + data_var = paddle.static.data( + name="data", shape=[10, 10], dtype="float32") + + op = eval("paddle.%s" % (op_type)) + result = op(data_var, axis=-1, keepdim=True) + exe = paddle.static.Executor(place) + result_data = exe.run(feed={"data": self.input_data}, + fetch_list=[result]) + expected_data = self.numpy_op( + self.input_data, axis=-1).reshape((10, 1)) + self.assertTrue((result_data == expected_data).all(), True) + + with paddle.static.program_guard(paddle.static.Program()): + op = eval("paddle.%s" % (op_type)) + data_var = paddle.static.data( + name="data", shape=[10, 10], dtype="float32") + result = op(data_var, axis=-1, name="test_arg_api") + self.assertTrue("test_arg_api" in result.name) + + def run_dygraph(self, place): + paddle.disable_static() + op = eval("paddle.%s" % (op_type)) + data_tensor = paddle.to_tensor(self.input_data) + + #case 1 + result_data = op(data_tensor) + excepted_data = self.numpy_op(self.input_data) + self.assertTrue((result_data.numpy() == excepted_data).all(), True) + + #case 2 + result_data = op(data_tensor, axis=1) + excepted_data = self.numpy_op(self.input_data, axis=1) + self.assertTrue((result_data.numpy() == excepted_data).all(), True) + + #case 3 + result_data = op(data_tensor, axis=-1) + excepted_data = self.numpy_op(self.input_data, axis=-1) + self.assertTrue((result_data.numpy() == excepted_data).all(), True) + + #case 4 + result_data = op(data_tensor, axis=-1, keepdim=True) + excepted_data = self.numpy_op(self.input_data, axis=-1) + excepted_data = excepted_data.reshape((10)) + self.assertTrue((result_data.numpy() == excepted_data).all(), True) + + #case 5 + result_data = op(data_tensor, axis=-1, keepdim=True, dtype="int32") + self.assertTrue(result_data.numpy().dtype == np.int32) + + # case for dim 4, 5, 6, for test case coverage + input_data = np.random.rand(5, 5, 5, 5) + excepted_data = self.numpy_op(input_data, axis=0) + result_data = op(paddle.to_tensor(input_data), axis=0) + self.assertTrue((result_data.numpy() == excepted_data).all(), True) + + input_data = np.random.rand(4, 4, 4, 4, 4) + excepted_data = self.numpy_op(input_data, axis=0) + result_data = op(paddle.to_tensor(input_data), axis=0) + self.assertTrue((result_data.numpy() == excepted_data).all(), True) + + input_data = np.random.rand(3, 3, 3, 3, 3, 3) + excepted_data = self.numpy_op(input_data, axis=0) + result_data = op(paddle.to_tensor(input_data), axis=0) + self.assertTrue((result_data.numpy() == excepted_data).all(), True) + + def test_case(self): + for place in self.places: + self.run_static(place) + self.run_dygraph(place) + + cls_name = "ArgMaxMinTestCase_{}".format(op_type) + ArgMaxMinTestCase.__name__ = cls_name + globals()[cls_name] = ArgMaxMinTestCase + + +for op_type in ['argmin', 'argmax']: + create_test_case(op_type) + + +class TestArgMinMaxOpError(unittest.TestCase): + def test_errors(self): + paddle.enable_static() + with program_guard(Program(), Program()): + + def test_argmax_x_type(): + x1 = [1, 2, 3] + output = paddle.argmax(x=x1) + + self.assertRaises(TypeError, test_argmax_x_type) + + def test_argmin_x_type(): + x2 = [1, 2, 3] + output = paddle.argmin(x=x2) + + self.assertRaises(TypeError, test_argmin_x_type) + + def test_argmax_attr_type(): + data = paddle.static.data( + name="test_argmax", shape=[10], dtype="float32") + output = paddle.argmax(x=data, dtype="float32") + + self.assertRaises(ValueError, test_argmax_attr_type) + + def test_argmin_attr_type(): + data = paddle.static.data( + name="test_argmax", shape=[10], dtype="float32") + output = paddle.argmin(x=data, dtype="float32") + + self.assertRaises(ValueError, test_argmin_attr_type) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 840621199952447846136a4142d9e3bd2d9066ae..91ad3bfa9cc1babd22cf9419ede33ae26f0dc900 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -125,95 +125,168 @@ def argsort(x, axis=-1, descending=False, name=None): return ids -def argmax(input, axis=None, dtype=None, out=None, keepdims=False, name=None): +def argmax(x, axis=None, dtype=None, keepdim=False, name=None): """ - :alias_main: paddle.argmax - :alias: paddle.argmax,paddle.tensor.argmax,paddle.tensor.search.argmax - This OP computes the indices of the max elements of the input tensor's element along the provided axis. Args: - input(Variable): An input N-D Tensor with type float32, float64, int16, + x(Tensor): An input N-D Tensor with type float32, float64, int16, int32, int64, uint8. axis(int, optional): Axis to compute indices along. The effective range - is [-R, R), where R is Rank(input). when axis<0, it works the same way - as axis+R. Default is None, it will use the last dim to select indices of max value. - dtype(np.dtype|core.VarDesc.VarType|str): Data type of the output tensor which can + is [-R, R), where R is x.ndim. when axis < 0, it works the same way + as axis + R. Default is None, the input `x` will be into the flatten tensor, and selecting the min value index. + dtype(str): Data type of the output tensor which can be int32, int64. The default value is None, and it will return the int64 indices. - out(Variable, optional): Optional output which can be any created - Variable that meets the requirements to store the result of operation. - if out is None, a new Varibale will be create to store the result. Defalut is None. - keepdims(bool, optional): Keep the axis that do the select max. + keepdim(bool, optional): Keep the axis that selecting max. The defalut value is False. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: - Variable: A Tensor with data type int64. + Tensor, return the tensor of `int32` if set :attr:`dtype` is `int32`, otherwise return the tensor of `int64` Examples: .. code-block:: python - import paddle - import paddle.fluid as fluid import numpy as np + import paddle - in1 = np.array([[[5,8,9,5], - [0,0,1,7], - [6,9,2,4]], - [[5,2,4,2], - [4,7,7,9], - [1,7,0,6]]]) - with fluid.dygraph.guard(): - x = fluid.dygraph.to_variable(in1) - out1 = paddle.argmax(input=x, axis=-1) - out2 = paddle.argmax(input=x, axis=0) - out3 = paddle.argmax(input=x, axis=1) - out4 = paddle.argmax(input=x, axis=2) - out5 = paddle.argmax(input=x, axis=2, keepdims=True) - print(out1.numpy()) - # [[2 3 1] - # [0 3 1]] - print(out2.numpy()) - # [[0 0 0 0] - # [1 1 1 1] - # [0 0 0 1]] - print(out3.numpy()) - # [[2 2 0 1] - # [0 1 1 1]] - print(out4.numpy()) - # [[2 3 1] - # [0 3 1]] - print(out5.numpy()) - #array([[[2], - # [3], - # [1]], - # [[0], - # [3], - # [1]]]) + paddle.disable_static() + data = np.array([[5,8,9,5], + [0,0,1,7], + [6,9,2,4]]) + x = paddle.to_variable(data) + out1 = paddle.argmax(x) + print(out1.numpy()) # 2 + out2 = paddle.argmax(x, axis=1) + print(out2.numpy()) + # [2 3 1] + out3 = paddle.argmax(x, axis=-1) + print(out3.numpy()) + # [2 3 1] """ - helper = LayerHelper("arg_max", **locals()) + flatten = False + if axis is None: + flatten = True + axis = 0 + + if in_dygraph_mode(): + if dtype != None: + var_dtype = convert_np_dtype_to_dtype_(dtype) + out = core.ops.arg_max(x, 'axis', axis, 'dtype', var_dtype, + 'keepdim', keepdim, 'flatten', flatten) + else: + out = core.ops.arg_max(x, 'axis', axis, 'keepdim', keepdim, + 'flatten', flatten) + return out + + helper = LayerHelper("argmax", **locals()) + check_variable_and_dtype( + x, 'x', ['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'], + 'paddle.argmax') var_dtype = None attrs = {} if dtype is not None: - check_dtype(dtype, 'create data type', ['int32', 'int64'], 'arg_max') + if dtype not in ['int32', 'int64']: + raise ValueError( + "The value of 'dtype' in argmax op must be int32, int64, but received of {}". + format(dtype)) var_dtype = convert_np_dtype_to_dtype_(dtype) attrs["dtype"] = var_dtype else: var_dtype = VarDesc.VarType.INT64 - if out is None: - out = helper.create_variable_for_type_inference(var_dtype) + + out = helper.create_variable_for_type_inference(var_dtype) + attrs['keepdims'] = keepdim + attrs['axis'] = axis + attrs['flatten'] = flatten + helper.append_op( + type='arg_max', inputs={'X': x}, outputs={'Out': [out]}, attrs=attrs) + out.stop_gradient = True + return out + + +def argmin(x, axis=None, dtype=None, keepdim=False, name=None): + """ + This OP computes the indices of the min elements of the input tensor's + element along the provided axis. + + Args: + x(Tensor): An input N-D Tensor with type float32, float64, int16, + int32, int64, uint8. + axis(int, optional): Axis to compute indices along. The effective range + is [-R, R), where R is x.ndim. when axis < 0, it works the same way + as axis + R. Default is None, the input `x` will be into the flatten tensor, and selecting the min value index. + dtype(str): Data type of the output tensor which can + be int32, int64. The default value is None, and it will + return the int64 indices. + keepdim(bool, optional): Keep the axis that selecting min. The defalut value is False. + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. + + Returns: + Tensor, return the tensor of `int32` if set :attr:`dtype` is `int32`, otherwise return the tensor of `int64` + + Examples: + .. code-block:: python + + import numpy as np + import paddle + + paddle.disable_static() + data = np.array([[5,8,9,5], + [0,0,1,7], + [6,9,2,4]]) + x = paddle.to_variable(data) + out1 = paddle.argmin(x) + print(out1.numpy()) # 4 + out2 = paddle.argmin(x, axis=1) + print(out2.numpy()) + # [0 0 2] + out3 = paddle.argmin(x, axis=-1) + print(out3.numpy()) + # [0 0 2] + """ + flatten = False if axis is None: - axis = -1 - attrs['keepdims'] = keepdims + flatten = True + axis = 0 + + if in_dygraph_mode(): + if dtype != None: + var_dtype = convert_np_dtype_to_dtype_(dtype) + out = core.ops.arg_min(x, 'axis', axis, 'dtype', var_dtype, + 'keepdim', keepdim, 'flatten', flatten) + else: + out = core.ops.arg_min(x, 'axis', axis, 'keepdim', keepdim, + 'flatten', flatten) + return out + + helper = LayerHelper("argmin", **locals()) + check_variable_and_dtype( + x, 'x', ['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'], + 'paddle.argmin') + var_dtype = None + attrs = {} + if dtype is not None: + if dtype not in ['int32', 'int64']: + raise ValueError( + "The value of 'dtype' in argmin op must be int32, int64, but received of {}". + format(dtype)) + var_dtype = convert_np_dtype_to_dtype_(dtype) + attrs["dtype"] = var_dtype + else: + var_dtype = VarDesc.VarType.INT64 + + out = helper.create_variable_for_type_inference(var_dtype) + attrs['keepdims'] = keepdim attrs['axis'] = axis + attrs['flatten'] = flatten helper.append_op( - type='arg_max', - inputs={'X': input}, - outputs={'Out': [out]}, - attrs=attrs) + type='arg_min', inputs={'X': x}, outputs={'Out': [out]}, attrs=attrs) out.stop_gradient = True return out