diff --git a/paddle/fluid/operators/arg_max_op.cc b/paddle/fluid/operators/arg_max_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..5603607357ab860bbf11928a1faa4a3ed94f2580 --- /dev/null +++ b/paddle/fluid/operators/arg_max_op.cc @@ -0,0 +1,40 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/arg_max_op.h" +/* +REGISTER_ARG_MINMAX_OP_WITHOUT_GRADIENT(arg_max, ArgMax); + +REGISTER_ARG_MINMAX_KERNEL(arg_max, ArgMax, CPU); +*/ + +REGISTER_OPERATOR(arg_max, paddle::operators::ArgMaxOp, + paddle::operators::ArgMaxOpMaker, + paddle::framework::EmptyGradOpMaker); + +REGISTER_OP_CPU_KERNEL( + arg_max, paddle::operators::ArgMaxKernel, + paddle::operators::ArgMaxKernel, + paddle::operators::ArgMaxKernel, + paddle::operators::ArgMaxKernel, + paddle::operators::ArgMaxKernel, + paddle::operators::ArgMaxKernel, + paddle::operators::ArgMaxKernel); diff --git a/paddle/fluid/operators/arg_max_op.cu b/paddle/fluid/operators/arg_max_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..8f57c63beb432e01992ea5325844b8b2bd7f6387 --- /dev/null +++ b/paddle/fluid/operators/arg_max_op.cu @@ -0,0 +1,34 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/arg_max_op.h" + +// REGISTER_ARG_MINMAX_KERNEL(arg_max, ArgMax, CUDA); + +REGISTER_OP_CUDA_KERNEL( + arg_max, + paddle::operators::ArgMaxKernel, + paddle::operators::ArgMaxKernel, + paddle::operators::ArgMaxKernel, + paddle::operators::ArgMaxKernel, + paddle::operators::ArgMaxKernel, + paddle::operators::ArgMaxKernel, + paddle::operators::ArgMaxKernel); diff --git a/paddle/fluid/operators/arg_max_op.h b/paddle/fluid/operators/arg_max_op.h new file mode 100644 index 0000000000000000000000000000000000000000..d232a856992fb8841a7be4ed6e1a881a1e4de6cd --- /dev/null +++ b/paddle/fluid/operators/arg_max_op.h @@ -0,0 +1,16 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/operators/arg_min_max_op_base.h" diff --git a/paddle/fluid/operators/arg_min_max_op_base.h b/paddle/fluid/operators/arg_min_max_op_base.h new file mode 100644 index 0000000000000000000000000000000000000000..8c20461a345b88f7c7ff3722d970a23daf444041 --- /dev/null +++ b/paddle/fluid/operators/arg_min_max_op_base.h @@ -0,0 +1,157 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/string/printf.h" + +namespace paddle { +namespace operators { + +enum ArgMinMaxType { kArgMin, kArgMax }; + +template +struct ArgMinMaxFunctor {}; + +#define DECLARE_ARG_MIN_MAX_FUNCTOR(eigen_op_type, enum_argminmax_value) \ + template \ + struct ArgMinMaxFunctor { \ + void operator()(const DeviceContext& ctx, const framework::LoDTensor& in, \ + framework::LoDTensor& out, int64_t axis) { \ + auto in_eigen = framework::EigenTensor::From(in); \ + auto out_eigen = framework::EigenTensor::From(out); \ + out_eigen.device(*(ctx.eigen_device())) = \ + in_eigen.eigen_op_type(axis).template cast(); \ + } \ + } + +DECLARE_ARG_MIN_MAX_FUNCTOR(argmin, ArgMinMaxType::kArgMin); +DECLARE_ARG_MIN_MAX_FUNCTOR(argmax, ArgMinMaxType::kArgMax); + +template +class ArgMinMaxKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& x = *(ctx.Input("X")); + auto& out = *(ctx.Output("Out")); + out.mutable_data(ctx.GetPlace()); + auto axis = ctx.Attr("axis"); + auto& dev_ctx = ctx.template device_context(); + +#define CALL_ARG_MINMAX_FUNCTOR(rank) \ + ArgMinMaxFunctor \ + functor##rank; \ + functor##rank(dev_ctx, x, out, axis) + + switch (x.dims().size()) { + case 1: + CALL_ARG_MINMAX_FUNCTOR(1); + break; + case 2: + CALL_ARG_MINMAX_FUNCTOR(2); + break; + case 3: + CALL_ARG_MINMAX_FUNCTOR(3); + break; + case 4: + CALL_ARG_MINMAX_FUNCTOR(4); + break; + case 5: + CALL_ARG_MINMAX_FUNCTOR(5); + break; + case 6: + CALL_ARG_MINMAX_FUNCTOR(6); + break; + default: + PADDLE_THROW( + "%s operator doesn't supports tensors whose ranks are greater " + "than 6.", + (EnumArgMinMaxValue == kArgMin ? "argmin" : "argmax")); + break; + } + } +}; + +template +using ArgMinKernel = + ArgMinMaxKernel; + +template +using ArgMaxKernel = + ArgMinMaxKernel; + +typedef class BaseArgMinMaxOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null"); + const auto& x_dims = ctx->GetInputDim("X"); + int64_t axis = ctx->Attrs().Get("axis"); + PADDLE_ENFORCE(axis >= -x_dims.size() && axis < x_dims.size(), + "'axis' must be inside [-Rank(X), Rank(X))"); + + auto x_rank = x_dims.size(); + if (axis < 0) axis += x_rank; + + std::vector vec; + for (int64_t i = 0; i < axis; i++) vec.push_back(x_dims[i]); + for (int64_t i = axis + 1; i < x_rank; i++) vec.push_back(x_dims[i]); + ctx->SetOutputDim("Out", framework::make_ddim(vec)); + } +} ArgMinOp, ArgMaxOp; + +class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker { + protected: + virtual const char* OpName() const = 0; + virtual const char* Name() const = 0; + + public: + void Make() override { + AddInput("X", "Input tensor."); + AddOutput("Out", "Output tensor."); + AddAttr("axis", "The axis in which to compute the arg indics."); + AddComment(::paddle::string::Sprintf(R"DOC( + %s Operator. + + Computes the indices of the %s elements of the input tensor's element along the provided axis. +)DOC", + OpName(), Name())); + } +}; + +class ArgMinOpMaker : public BaseArgMinMaxOpMaker { + protected: + const char* OpName() const override { return "ArgMin"; } + const char* Name() const override { return "min"; } +}; + +class ArgMaxOpMaker : public BaseArgMinMaxOpMaker { + protected: + const char* OpName() const override { return "ArgMax"; } + const char* Name() const override { return "max"; } +}; +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/arg_min_op.cc b/paddle/fluid/operators/arg_min_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..fe17ed711b2fa8ee4b432366f3d858756bb3ccc5 --- /dev/null +++ b/paddle/fluid/operators/arg_min_op.cc @@ -0,0 +1,40 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/arg_min_op.h" +/* +REGISTER_ARG_MINMAX_OP_WITHOUT_GRADIENT(arg_min, ArgMin); + +REGISTER_ARG_MINMAX_KERNEL(arg_min, ArgMin, CPU); +*/ + +REGISTER_OPERATOR(arg_min, paddle::operators::ArgMinOp, + paddle::operators::ArgMinOpMaker, + paddle::framework::EmptyGradOpMaker); + +REGISTER_OP_CPU_KERNEL( + arg_min, paddle::operators::ArgMinKernel, + paddle::operators::ArgMinKernel, + paddle::operators::ArgMinKernel, + paddle::operators::ArgMinKernel, + paddle::operators::ArgMinKernel, + paddle::operators::ArgMinKernel, + paddle::operators::ArgMinKernel); diff --git a/paddle/fluid/operators/arg_min_op.cu b/paddle/fluid/operators/arg_min_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..da9262044a72c58998f65ae4b1569c56a163e35f --- /dev/null +++ b/paddle/fluid/operators/arg_min_op.cu @@ -0,0 +1,34 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/arg_min_op.h" + +// REGISTER_ARG_MINMAX_KERNEL(arg_min, ArgMin, CUDA); + +REGISTER_OP_CUDA_KERNEL( + arg_min, + paddle::operators::ArgMinKernel, + paddle::operators::ArgMinKernel, + paddle::operators::ArgMinKernel, + paddle::operators::ArgMinKernel, + paddle::operators::ArgMinKernel, + paddle::operators::ArgMinKernel, + paddle::operators::ArgMinKernel); diff --git a/paddle/fluid/operators/arg_min_op.h b/paddle/fluid/operators/arg_min_op.h new file mode 100644 index 0000000000000000000000000000000000000000..d232a856992fb8841a7be4ed6e1a881a1e4de6cd --- /dev/null +++ b/paddle/fluid/operators/arg_min_op.h @@ -0,0 +1,16 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/operators/arg_min_max_op_base.h" diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 75d3bf879703a1db1108eae45d879164e0024156..3dfacfff6acf24329e710e005e3d56c4b8df97b0 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -30,6 +30,8 @@ __all__ = [ 'assign', 'fill_constant_batch_size_like', 'fill_constant', + 'argmin', + 'argmax', 'ones', 'zeros', ] @@ -315,6 +317,68 @@ def fill_constant_batch_size_like(input, return out +def argmin(x, axis=0): + """ + **argmin** + + This function computes the indices of the min elements + of the input tensor's element along the provided axis. + + Args: + x(Variable): The input to compute the indices of + the min elements. + axis(int): Axis to compute indices along. + + Returns: + Variable: The tensor variable storing the output + + Examples: + .. code-block:: python + + out = fluid.layers.argmin(x=in, axis=0) + out = fluid.layers.argmin(x=in, axis=-1) + """ + helper = LayerHelper("arg_min", **locals()) + out = helper.create_tmp_variable(VarDesc.VarType.INT64) + helper.append_op( + type='arg_min', + inputs={'X': x}, + outputs={'Out': [out]}, + attrs={'axis': axis}) + return out + + +def argmax(x, axis=0): + """ + **argmax** + + This function computes the indices of the max elements + of the input tensor's element along the provided axis. + + Args: + x(Variable): The input to compute the indices of + the max elements. + axis(int): Axis to compute indices along. + + Returns: + Variable: The tensor variable storing the output + + Examples: + .. code-block:: python + + out = fluid.layers.argmax(x=in, axis=0) + out = fluid.layers.argmax(x=in, axis=-1) + """ + helper = LayerHelper("arg_max", **locals()) + out = helper.create_tmp_variable(VarDesc.VarType.INT64) + helper.append_op( + type='arg_max', + inputs={'X': x}, + outputs={'Out': [out]}, + attrs={'axis': axis}) + return out + + def ones(shape, dtype, force_cpu=False): """ **ones** diff --git a/python/paddle/fluid/tests/unittests/test_arg_min_max_op.py b/python/paddle/fluid/tests/unittests/test_arg_min_max_op.py new file mode 100644 index 0000000000000000000000000000000000000000..e04412f809cdd75d07d28a60f0c2f19041a684f6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_arg_min_max_op.py @@ -0,0 +1,82 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +class BaseTestCase(OpTest): + def initTestCase(self): + self.op_type = 'arg_min' + self.dims = (3, 4, 5) + self.dtype = 'float32' + self.axis = 0 + + def setUp(self): + self.initTestCase() + self.x = (1000 * np.random.random(self.dims)).astype(self.dtype) + self.inputs = {'X': self.x} + self.attrs = {'axis': self.axis} + if self.op_type == "arg_min": + self.outputs = {'Out': np.argmin(self.x, axis=self.axis)} + else: + self.outputs = {'Out': np.argmax(self.x, axis=self.axis)} + + def test_check_output(self): + self.check_output() + + +class TestCase0(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4, 5) + self.dtype = 'float32' + self.axis = 0 + + +class TestCase1(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_min' + self.dims = (3, 4) + self.dtype = 'float64' + self.axis = 1 + + +class TestCase2(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4) + self.dtype = 'int64' + self.axis = 0 + + +class TestCase3(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, ) + self.dtype = 'int64' + self.axis = 0 + + +class TestCase4(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_min' + self.dims = (1, ) + self.dtype = 'int32' + self.axis = 0 + + +if __name__ == '__main__': + unittest.main()