/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include #include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/string/printf.h" namespace paddle { namespace operators { enum ArgMinMaxType { kArgMin, kArgMax }; template struct ArgMinMaxFunctor {}; #define DECLARE_ARG_MIN_MAX_FUNCTOR(eigen_op_type, enum_argminmax_value) \ template \ struct ArgMinMaxFunctor { \ void operator()(const DeviceContext& ctx, const framework::LoDTensor& in, \ framework::LoDTensor* out, int64_t axis) { \ auto in_eigen = framework::EigenTensor::From(in); \ auto out_eigen = framework::EigenTensor::From(*out); \ out_eigen.device(*(ctx.eigen_device())) = \ in_eigen.eigen_op_type(axis).template cast(); \ } \ } DECLARE_ARG_MIN_MAX_FUNCTOR(argmin, ArgMinMaxType::kArgMin); DECLARE_ARG_MIN_MAX_FUNCTOR(argmax, ArgMinMaxType::kArgMax); template class ArgMinMaxKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto& x = *(ctx.Input("X")); auto& out = *(ctx.Output("Out")); out.mutable_data(ctx.GetPlace()); auto axis = ctx.Attr("axis"); auto& dev_ctx = ctx.template device_context(); #define CALL_ARG_MINMAX_FUNCTOR(rank) \ ArgMinMaxFunctor \ functor##rank; \ functor##rank(dev_ctx, x, &out, axis) switch (x.dims().size()) { case 1: CALL_ARG_MINMAX_FUNCTOR(1); break; case 2: CALL_ARG_MINMAX_FUNCTOR(2); break; case 3: CALL_ARG_MINMAX_FUNCTOR(3); break; case 4: CALL_ARG_MINMAX_FUNCTOR(4); break; case 5: CALL_ARG_MINMAX_FUNCTOR(5); break; case 6: CALL_ARG_MINMAX_FUNCTOR(6); break; default: PADDLE_THROW( "%s operator doesn't supports tensors whose ranks are greater " "than 6.", (EnumArgMinMaxValue == kArgMin ? "argmin" : "argmax")); break; #undef CALL_ARG_MINMAX_FUNCTOR } } }; template using ArgMinKernel = ArgMinMaxKernel; template using ArgMaxKernel = ArgMinMaxKernel; class ArgMinMaxOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null"); const auto& x_dims = ctx->GetInputDim("X"); int64_t axis = ctx->Attrs().Get("axis"); PADDLE_ENFORCE(axis >= -x_dims.size() && axis < x_dims.size(), "'axis' must be inside [-Rank(X), Rank(X))"); auto x_rank = x_dims.size(); if (axis < 0) axis += x_rank; std::vector vec; for (int64_t i = 0; i < axis; i++) vec.push_back(x_dims[i]); for (int64_t i = axis + 1; i < x_rank; i++) vec.push_back(x_dims[i]); ctx->SetOutputDim("Out", framework::make_ddim(vec)); } }; class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker { protected: virtual const char* OpName() const = 0; virtual const char* Name() const = 0; public: void Make() override { AddInput("X", "Input tensor."); AddOutput("Out", "Output tensor."); AddAttr("axis", "The axis in which to compute the arg indics."); AddComment(string::Sprintf(R"DOC( %s Operator. Computes the indices of the %s elements of the input tensor's element along the provided axis. )DOC", OpName(), Name())); } }; class ArgMinOpMaker : public BaseArgMinMaxOpMaker { protected: const char* OpName() const override { return "ArgMin"; } const char* Name() const override { return "min"; } }; class ArgMaxOpMaker : public BaseArgMinMaxOpMaker { protected: const char* OpName() const override { return "ArgMax"; } const char* Name() const override { return "max"; } }; } // namespace operators } // namespace paddle