提交 9b43edea 编写于 作者: Y yuyang18

Polish arg_min_max_op

* Remove unused arg_max/min_op.h
* Remove reference parameter. Use pointer insteaded.
* undef macro
* Always set OutT as int64_t.
上级 6d32e960
......@@ -12,24 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/arg_max_op.h"
#include "paddle/fluid/operators/arg_min_max_op_base.h"
REGISTER_OPERATOR(arg_max, paddle::operators::ArgMaxOp,
REGISTER_OPERATOR(arg_max, paddle::operators::ArgMinMaxOp,
paddle::operators::ArgMaxOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(
arg_max, paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
float, int64_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, double,
arg_max,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, float>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, double>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
int64_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, int64_t,
int64_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, int32_t,
int64_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, int16_t,
int64_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, size_t,
int64_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, uint8_t,
int64_t>);
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
int32_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
int16_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, size_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
uint8_t>);
......@@ -12,21 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/arg_max_op.h"
#include "paddle/fluid/operators/arg_min_max_op_base.h"
REGISTER_OP_CUDA_KERNEL(
arg_max,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext, float,
int64_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext, double,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext, float>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
double>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
int64_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
int64_t, int64_t>,
int32_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
int32_t, int64_t>,
int16_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
int16_t, int64_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext, size_t,
int64_t>,
size_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
uint8_t, int64_t>);
uint8_t>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/arg_min_max_op_base.h"
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <type_traits>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
......@@ -37,9 +38,9 @@ struct ArgMinMaxFunctor {};
struct ArgMinMaxFunctor<DeviceContext, T, Tout, Rank, \
enum_argminmax_value> { \
void operator()(const DeviceContext& ctx, const framework::LoDTensor& in, \
framework::LoDTensor& out, int64_t axis) { \
framework::LoDTensor* out, int64_t axis) { \
auto in_eigen = framework::EigenTensor<T, Rank>::From(in); \
auto out_eigen = framework::EigenTensor<Tout, Rank - 1>::From(out); \
auto out_eigen = framework::EigenTensor<Tout, Rank - 1>::From(*out); \
out_eigen.device(*(ctx.eigen_device())) = \
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
} \
......@@ -62,7 +63,7 @@ class ArgMinMaxKernel : public framework::OpKernel<T> {
#define CALL_ARG_MINMAX_FUNCTOR(rank) \
ArgMinMaxFunctor<DeviceContext, T, Tout, rank, EnumArgMinMaxValue> \
functor##rank; \
functor##rank(dev_ctx, x, out, axis)
functor##rank(dev_ctx, x, &out, axis)
switch (x.dims().size()) {
case 1:
......@@ -89,19 +90,20 @@ class ArgMinMaxKernel : public framework::OpKernel<T> {
"than 6.",
(EnumArgMinMaxValue == kArgMin ? "argmin" : "argmax"));
break;
#undef CALL_ARG_MINMAX_FUNCTOR
}
}
};
template <typename DeviceContext, typename T, typename Tout>
template <typename DeviceContext, typename T>
using ArgMinKernel =
ArgMinMaxKernel<DeviceContext, T, Tout, ArgMinMaxType::kArgMin>;
ArgMinMaxKernel<DeviceContext, T, int64_t, ArgMinMaxType::kArgMin>;
template <typename DeviceContext, typename T, typename Tout>
template <typename DeviceContext, typename T>
using ArgMaxKernel =
ArgMinMaxKernel<DeviceContext, T, Tout, ArgMinMaxType::kArgMax>;
ArgMinMaxKernel<DeviceContext, T, int64_t, ArgMinMaxType::kArgMax>;
typedef class BaseArgMinMaxOp : public framework::OperatorWithKernel {
class ArgMinMaxOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -121,7 +123,7 @@ typedef class BaseArgMinMaxOp : public framework::OperatorWithKernel {
for (int64_t i = axis + 1; i < x_rank; i++) vec.push_back(x_dims[i]);
ctx->SetOutputDim("Out", framework::make_ddim(vec));
}
} ArgMinOp, ArgMaxOp;
};
class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker {
protected:
......@@ -133,10 +135,11 @@ class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "Input tensor.");
AddOutput("Out", "Output tensor.");
AddAttr<int64_t>("axis", "The axis in which to compute the arg indics.");
AddComment(::paddle::string::Sprintf(R"DOC(
AddComment(string::Sprintf(R"DOC(
%s Operator.
Computes the indices of the %s elements of the input tensor's element along the provided axis.
Computes the indices of the %s elements of the input tensor's element
along the provided axis.
)DOC",
OpName(), Name()));
}
......
......@@ -12,24 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/arg_min_op.h"
#include "paddle/fluid/operators/arg_min_max_op_base.h"
REGISTER_OPERATOR(arg_min, paddle::operators::ArgMinOp,
REGISTER_OPERATOR(arg_min, paddle::operators::ArgMinMaxOp,
paddle::operators::ArgMinOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(
arg_min, paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
float, int64_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, double,
arg_min,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, float>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, double>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
int64_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, int64_t,
int64_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, int32_t,
int64_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, int16_t,
int64_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, size_t,
int64_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, uint8_t,
int64_t>);
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
int32_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
int16_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, size_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
uint8_t>);
......@@ -12,21 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/arg_min_op.h"
#include "paddle/fluid/operators/arg_min_max_op_base.h"
REGISTER_OP_CUDA_KERNEL(
arg_min,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext, float,
int64_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext, double,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext, float>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
double>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
int64_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
int64_t, int64_t>,
int32_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
int32_t, int64_t>,
int16_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
int16_t, int64_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext, size_t,
int64_t>,
size_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
uint8_t, int64_t>);
uint8_t>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/arg_min_max_op_base.h"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册