diff --git a/paddle/fluid/operators/arg_max_op.cc b/paddle/fluid/operators/arg_max_op.cc index 859cccd1b2dfcc8b2278e11157468095fd6a396d..8174d3735859b1fac40cd4c07545f34874d31ab7 100644 --- a/paddle/fluid/operators/arg_max_op.cc +++ b/paddle/fluid/operators/arg_max_op.cc @@ -12,24 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/arg_max_op.h" +#include "paddle/fluid/operators/arg_min_max_op_base.h" -REGISTER_OPERATOR(arg_max, paddle::operators::ArgMaxOp, +REGISTER_OPERATOR(arg_max, paddle::operators::ArgMinMaxOp, paddle::operators::ArgMaxOpMaker, paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL( - arg_max, paddle::operators::ArgMaxKernel, - paddle::operators::ArgMaxKernel, + paddle::operators::ArgMaxKernel, + paddle::operators::ArgMaxKernel, - paddle::operators::ArgMaxKernel, - paddle::operators::ArgMaxKernel, - paddle::operators::ArgMaxKernel, - paddle::operators::ArgMaxKernel, - paddle::operators::ArgMaxKernel); + paddle::operators::ArgMaxKernel, + paddle::operators::ArgMaxKernel, + paddle::operators::ArgMaxKernel, + paddle::operators::ArgMaxKernel); diff --git a/paddle/fluid/operators/arg_max_op.cu b/paddle/fluid/operators/arg_max_op.cu index c9c102bdccf32b02b055cdfd800957c6c3723183..a147d77a9e9c577984028e1a6ed9582dda622069 100644 --- a/paddle/fluid/operators/arg_max_op.cu +++ b/paddle/fluid/operators/arg_max_op.cu @@ -12,21 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/arg_max_op.h" +#include "paddle/fluid/operators/arg_min_max_op_base.h" REGISTER_OP_CUDA_KERNEL( arg_max, - paddle::operators::ArgMaxKernel, - paddle::operators::ArgMaxKernel, + paddle::operators::ArgMaxKernel, + paddle::operators::ArgMaxKernel, paddle::operators::ArgMaxKernel, + int32_t>, paddle::operators::ArgMaxKernel, + int16_t>, paddle::operators::ArgMaxKernel, - paddle::operators::ArgMaxKernel, + size_t>, paddle::operators::ArgMaxKernel); + uint8_t>); diff --git a/paddle/fluid/operators/arg_max_op.h b/paddle/fluid/operators/arg_max_op.h deleted file mode 100644 index d232a856992fb8841a7be4ed6e1a881a1e4de6cd..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/arg_max_op.h +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/operators/arg_min_max_op_base.h" diff --git a/paddle/fluid/operators/arg_min_max_op_base.h b/paddle/fluid/operators/arg_min_max_op_base.h index 8c20461a345b88f7c7ff3722d970a23daf444041..6cbdaefeda099c36a864289ef8195c20d09c55e6 100644 --- a/paddle/fluid/operators/arg_min_max_op_base.h +++ b/paddle/fluid/operators/arg_min_max_op_base.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include #include #include "paddle/fluid/framework/ddim.h" @@ -37,9 +38,9 @@ struct ArgMinMaxFunctor {}; struct ArgMinMaxFunctor { \ void operator()(const DeviceContext& ctx, const framework::LoDTensor& in, \ - framework::LoDTensor& out, int64_t axis) { \ + framework::LoDTensor* out, int64_t axis) { \ auto in_eigen = framework::EigenTensor::From(in); \ - auto out_eigen = framework::EigenTensor::From(out); \ + auto out_eigen = framework::EigenTensor::From(*out); \ out_eigen.device(*(ctx.eigen_device())) = \ in_eigen.eigen_op_type(axis).template cast(); \ } \ @@ -62,7 +63,7 @@ class ArgMinMaxKernel : public framework::OpKernel { #define CALL_ARG_MINMAX_FUNCTOR(rank) \ ArgMinMaxFunctor \ functor##rank; \ - functor##rank(dev_ctx, x, out, axis) + functor##rank(dev_ctx, x, &out, axis) switch (x.dims().size()) { case 1: @@ -89,19 +90,20 @@ class ArgMinMaxKernel : public framework::OpKernel { "than 6.", (EnumArgMinMaxValue == kArgMin ? "argmin" : "argmax")); break; +#undef CALL_ARG_MINMAX_FUNCTOR } } }; -template +template using ArgMinKernel = - ArgMinMaxKernel; + ArgMinMaxKernel; -template +template using ArgMaxKernel = - ArgMinMaxKernel; + ArgMinMaxKernel; -typedef class BaseArgMinMaxOp : public framework::OperatorWithKernel { +class ArgMinMaxOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -121,7 +123,7 @@ typedef class BaseArgMinMaxOp : public framework::OperatorWithKernel { for (int64_t i = axis + 1; i < x_rank; i++) vec.push_back(x_dims[i]); ctx->SetOutputDim("Out", framework::make_ddim(vec)); } -} ArgMinOp, ArgMaxOp; +}; class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker { protected: @@ -133,12 +135,13 @@ class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "Input tensor."); AddOutput("Out", "Output tensor."); AddAttr("axis", "The axis in which to compute the arg indics."); - AddComment(::paddle::string::Sprintf(R"DOC( - %s Operator. + AddComment(string::Sprintf(R"DOC( + %s Operator. - Computes the indices of the %s elements of the input tensor's element along the provided axis. + Computes the indices of the %s elements of the input tensor's element + along the provided axis. )DOC", - OpName(), Name())); + OpName(), Name())); } }; diff --git a/paddle/fluid/operators/arg_min_op.cc b/paddle/fluid/operators/arg_min_op.cc index 18c0884a042604d35978d2bde57877a018668511..41f188029f17dbe8717afc0ca0760a39edc24b54 100644 --- a/paddle/fluid/operators/arg_min_op.cc +++ b/paddle/fluid/operators/arg_min_op.cc @@ -12,24 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/arg_min_op.h" +#include "paddle/fluid/operators/arg_min_max_op_base.h" -REGISTER_OPERATOR(arg_min, paddle::operators::ArgMinOp, +REGISTER_OPERATOR(arg_min, paddle::operators::ArgMinMaxOp, paddle::operators::ArgMinOpMaker, paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL( - arg_min, paddle::operators::ArgMinKernel, - paddle::operators::ArgMinKernel, + paddle::operators::ArgMinKernel, + paddle::operators::ArgMinKernel, - paddle::operators::ArgMinKernel, - paddle::operators::ArgMinKernel, - paddle::operators::ArgMinKernel, - paddle::operators::ArgMinKernel, - paddle::operators::ArgMinKernel); + paddle::operators::ArgMinKernel, + paddle::operators::ArgMinKernel, + paddle::operators::ArgMinKernel, + paddle::operators::ArgMinKernel); diff --git a/paddle/fluid/operators/arg_min_op.cu b/paddle/fluid/operators/arg_min_op.cu index 6d5aaa9596101f6cbd3e8177d516e24a43223ec8..4d020508505a6ebac8be41ce1e4f99d436b67ab5 100644 --- a/paddle/fluid/operators/arg_min_op.cu +++ b/paddle/fluid/operators/arg_min_op.cu @@ -12,21 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/arg_min_op.h" +#include "paddle/fluid/operators/arg_min_max_op_base.h" REGISTER_OP_CUDA_KERNEL( arg_min, - paddle::operators::ArgMinKernel, - paddle::operators::ArgMinKernel, + paddle::operators::ArgMinKernel, + paddle::operators::ArgMinKernel, paddle::operators::ArgMinKernel, + int32_t>, paddle::operators::ArgMinKernel, + int16_t>, paddle::operators::ArgMinKernel, - paddle::operators::ArgMinKernel, + size_t>, paddle::operators::ArgMinKernel); + uint8_t>); diff --git a/paddle/fluid/operators/arg_min_op.h b/paddle/fluid/operators/arg_min_op.h deleted file mode 100644 index d232a856992fb8841a7be4ed6e1a881a1e4de6cd..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/arg_min_op.h +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/operators/arg_min_max_op_base.h"