未验证 提交 f25dba0a 编写于 作者: Z Zhong Hui 提交者: GitHub

[PHI] Move arg min max to PHI. (#40222)

* move arg min max to phi.

* move infermeta.

* fix as reviews.
上级 1128db30
...@@ -15,23 +15,19 @@ limitations under the License. */ ...@@ -15,23 +15,19 @@ limitations under the License. */
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/arg_min_max_op_base.h" #include "paddle/fluid/operators/arg_min_max_op_base.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
DECLARE_INFER_SHAPE_FUNCTOR(arg_max, ArgMaxInferShapeFunctor,
PD_INFER_META(phi::ArgMinMaxInferMeta));
REGISTER_OPERATOR( REGISTER_OPERATOR(
arg_max, paddle::operators::ArgMinMaxOp, paddle::operators::ArgMaxOpMaker, arg_max, paddle::operators::ArgMinMaxOp, paddle::operators::ArgMaxOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ArgMaxInferShapeFunctor);
REGISTER_OP_CPU_KERNEL(
arg_max,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, float>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, double>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
int64_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
int32_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
int16_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
uint8_t>);
REGISTER_OP_VERSION(arg_max) REGISTER_OP_VERSION(arg_max)
.AddCheckpoint( .AddCheckpoint(
R"ROC( R"ROC(
......
...@@ -27,193 +27,9 @@ limitations under the License. */ ...@@ -27,193 +27,9 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
enum ArgMinMaxType { kArgMin, kArgMax };
template <typename DeviceContext, typename T, typename Tout, int64_t Rank,
ArgMinMaxType argMinMaxValue>
struct ArgMinMaxFunctor {};
#define DECLARE_ARG_MIN_MAX_FUNCTOR(eigen_op_type, enum_argminmax_value) \
template <typename DeviceContext, typename T, typename Tout, int64_t Rank> \
struct ArgMinMaxFunctor<DeviceContext, T, Tout, Rank, \
enum_argminmax_value> { \
void operator()(const DeviceContext& ctx, const framework::LoDTensor& in, \
framework::LoDTensor* out, framework::DDim x_dims, \
int64_t axis, bool keepdims) { \
auto in_eigen = framework::EigenTensor<T, Rank>::From(in, x_dims); \
if (keepdims) { \
auto out_eigen = framework::EigenTensor<Tout, Rank>::From(*out); \
out_eigen.device(*(ctx.eigen_device())) = \
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
} else { \
auto out_eigen = framework::EigenTensor<Tout, Rank - 1>::From(*out); \
out_eigen.device(*(ctx.eigen_device())) = \
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
} \
} \
}
DECLARE_ARG_MIN_MAX_FUNCTOR(argmin, ArgMinMaxType::kArgMin);
DECLARE_ARG_MIN_MAX_FUNCTOR(argmax, ArgMinMaxType::kArgMax);
template <typename DeviceContext, typename T, ArgMinMaxType EnumArgMinMaxValue>
struct VisitDataArgMinMaxFunctor {
const framework::ExecutionContext& ctx;
explicit VisitDataArgMinMaxFunctor(const framework::ExecutionContext& ctx)
: ctx(ctx) {}
template <typename Tout>
void apply() const {
auto& x = *(ctx.Input<framework::LoDTensor>("X"));
auto& out = *(ctx.Output<framework::LoDTensor>("Out"));
out.template mutable_data<Tout>(ctx.GetPlace());
auto axis = ctx.Attr<int64_t>("axis");
auto keepdims = ctx.Attr<bool>("keepdims");
const bool& flatten = ctx.Attr<bool>("flatten");
// paddle do not have the scalar tensor, just return the shape [1] tensor
if (flatten) keepdims = true;
// if flatten, will construct the new dims for the cacluate
framework::DDim x_dims;
if (flatten) {
x_dims = phi::make_ddim({x.numel()});
// if flatten, the axis just as 0
axis = 0;
} else {
x_dims = x.dims();
if (axis < 0) axis += x_dims.size();
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
#define CALL_ARG_MINMAX_FUNCTOR(rank) \
ArgMinMaxFunctor<DeviceContext, T, Tout, rank, EnumArgMinMaxValue> \
functor##rank; \
functor##rank(dev_ctx, x, &out, x_dims, axis, keepdims)
switch (x_dims.size()) {
case 1:
CALL_ARG_MINMAX_FUNCTOR(1);
break;
case 2:
CALL_ARG_MINMAX_FUNCTOR(2);
break;
case 3:
CALL_ARG_MINMAX_FUNCTOR(3);
break;
case 4:
CALL_ARG_MINMAX_FUNCTOR(4);
break;
case 5:
CALL_ARG_MINMAX_FUNCTOR(5);
break;
case 6:
CALL_ARG_MINMAX_FUNCTOR(6);
break;
default:
PADDLE_ENFORCE_LE(
x_dims.size(), 6,
platform::errors::InvalidArgument(
"%s operator doesn't supports tensors whose ranks are greater "
"than 6.",
(EnumArgMinMaxValue == kArgMin ? "argmin" : "argmax")));
break;
#undef CALL_ARG_MINMAX_FUNCTOR
}
}
};
template <typename DeviceContext, typename T, ArgMinMaxType EnumArgMinMaxValue>
class ArgMinMaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dtype = ctx.Attr<int>("dtype");
if (dtype < 0) {
framework::VisitDataTypeTiny(
static_cast<framework::proto::VarType::Type>(
framework::proto::VarType::INT64),
VisitDataArgMinMaxFunctor<DeviceContext, T, EnumArgMinMaxValue>(ctx));
return;
}
framework::VisitDataTypeTiny(
static_cast<framework::proto::VarType::Type>(dtype),
VisitDataArgMinMaxFunctor<DeviceContext, T, EnumArgMinMaxValue>(ctx));
}
};
template <typename DeviceContext, typename T>
using ArgMinKernel = ArgMinMaxKernel<DeviceContext, T, ArgMinMaxType::kArgMin>;
template <typename DeviceContext, typename T>
using ArgMaxKernel = ArgMinMaxKernel<DeviceContext, T, ArgMinMaxType::kArgMax>;
class ArgMinMaxOp : public framework::OperatorWithKernel { class ArgMinMaxOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "arg_min_max");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "arg_min_max");
const auto& x_dims = ctx->GetInputDim("X");
int64_t axis = ctx->Attrs().Get<int64_t>("axis");
bool keepdims = ctx->Attrs().Get<bool>("keepdims");
const bool& flatten = ctx->Attrs().Get<bool>("flatten");
PADDLE_ENFORCE_GE(axis, -x_dims.size(),
platform::errors::InvalidArgument(
"'axis'(%d) must be greater than or equal to"
" -Rank(X)(%d).",
axis, -x_dims.size()));
PADDLE_ENFORCE_LT(
axis, x_dims.size(),
platform::errors::InvalidArgument(
"'axis'(%d) must be less than Rank(X)(%d) of Input(X).", axis,
x_dims.size()));
const int& dtype = ctx->Attrs().Get<int>("dtype");
PADDLE_ENFORCE_EQ(
(dtype < 0 || dtype == 2 || dtype == 3), true,
platform::errors::InvalidArgument(
"The attribute of dtype in argmin/argmax must be [%s] or [%s], but "
"received [%s]",
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64),
paddle::framework::DataTypeToString(
static_cast<framework::proto::VarType::Type>(dtype))));
auto x_rank = x_dims.size();
if (axis < 0) axis += x_rank;
if (ctx->IsRuntime()) {
if (dtype == framework::proto::VarType::INT32) {
int64_t all_element_num = 0;
if (flatten) {
all_element_num = phi::product(x_dims);
} else {
all_element_num = x_dims[axis];
}
PADDLE_ENFORCE_LE(
all_element_num, INT_MAX,
platform::errors::InvalidArgument(
"The element num of the argmin/argmax input at axis is "
"%d, is larger than int32 maximum value:%d, you must "
"set the dtype of argmin/argmax to 'int64'.",
all_element_num, INT_MAX));
}
}
std::vector<int64_t> vec;
if (flatten) {
vec.emplace_back(static_cast<int64_t>(1));
} else {
for (int64_t i = 0; i < axis; i++) vec.emplace_back(x_dims[i]);
if (keepdims) {
vec.emplace_back(static_cast<int64_t>(1));
}
for (int64_t i = axis + 1; i < x_rank; i++) vec.emplace_back(x_dims[i]);
}
ctx->SetOutputDim("Out", phi::make_ddim(vec));
}
}; };
class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker { class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -12,26 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,26 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/arg_min_max_op_base.h" #include "paddle/fluid/operators/arg_min_max_op_base.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
DECLARE_INFER_SHAPE_FUNCTOR(arg_min, ArgMinInferShapeFunctor,
PD_INFER_META(phi::ArgMinMaxInferMeta));
REGISTER_OPERATOR( REGISTER_OPERATOR(
arg_min, paddle::operators::ArgMinMaxOp, paddle::operators::ArgMinOpMaker, arg_min, paddle::operators::ArgMinMaxOp, paddle::operators::ArgMinOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ArgMinInferShapeFunctor);
REGISTER_OP_CPU_KERNEL(
arg_min,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, float>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, double>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
int64_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
int32_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
int16_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
uint8_t>);
REGISTER_OP_VERSION(arg_min) REGISTER_OP_VERSION(arg_min)
.AddCheckpoint( .AddCheckpoint(
R"ROC( R"ROC(
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/arg_min_max_op_base.cu.h"
REGISTER_OP_CUDA_KERNEL(
arg_min, paddle::operators::ArgMinMaxOpCUDAKernel<float, cub::ArgMin>,
paddle::operators::ArgMinMaxOpCUDAKernel<double, cub::ArgMin>,
paddle::operators::ArgMinMaxOpCUDAKernel<int64_t, cub::ArgMin>,
paddle::operators::ArgMinMaxOpCUDAKernel<int32_t, cub::ArgMin>,
paddle::operators::ArgMinMaxOpCUDAKernel<int8_t, cub::ArgMin>);
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <set> #include <set>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/type_traits.h" #include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
...@@ -1014,6 +1015,82 @@ void DiagInferMeta(const MetaTensor& x, ...@@ -1014,6 +1015,82 @@ void DiagInferMeta(const MetaTensor& x,
} }
} }
void ArgMinMaxInferMeta(const MetaTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
int dtype,
MetaTensor* out,
MetaConfig config) {
const auto& x_dims = x.dims();
PADDLE_ENFORCE_GE(
axis,
-x_dims.size(),
phi::errors::InvalidArgument("'axis'(%d) must be greater than or equal to"
" -Rank(X)(%d).",
axis,
-x_dims.size()));
PADDLE_ENFORCE_LT(axis,
x_dims.size(),
phi::errors::InvalidArgument(
"'axis'(%d) must be less than Rank(X)(%d) of Input(X).",
axis,
x_dims.size()));
PADDLE_ENFORCE_EQ(
(dtype < 0 || dtype == 2 || dtype == 3),
true,
phi::errors::InvalidArgument(
"The attribute of dtype in argmin/argmax must be [%s] or [%s], but "
"received [%s]",
paddle::framework::DataTypeToString(
paddle::framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
paddle::framework::proto::VarType::INT64),
paddle::framework::DataTypeToString(
static_cast<paddle::framework::proto::VarType::Type>(dtype))));
auto x_rank = x_dims.size();
if (axis < 0) axis += x_rank;
if (config.is_runtime) {
if (dtype == paddle::framework::proto::VarType::INT32) {
int64_t all_element_num = 0;
if (flatten) {
all_element_num = phi::product(x_dims);
} else {
all_element_num = x_dims[axis];
}
PADDLE_ENFORCE_LE(
all_element_num,
INT_MAX,
phi::errors::InvalidArgument(
"The element num of the argmin/argmax input at axis is "
"%d, is larger than int32 maximum value:%d, you must "
"set the dtype of argmin/argmax to 'int64'.",
all_element_num,
INT_MAX));
}
}
std::vector<int64_t> vec;
if (flatten) {
vec.emplace_back(static_cast<int64_t>(1));
} else {
for (int64_t i = 0; i < axis; i++) vec.emplace_back(x_dims[i]);
if (keepdims) {
vec.emplace_back(static_cast<int64_t>(1));
}
for (int64_t i = axis + 1; i < x_rank; i++) vec.emplace_back(x_dims[i]);
}
out->set_dims(phi::make_ddim(vec));
if (dtype == 2) {
out->set_dtype(DataType::INT32);
} else if (dtype == 3) {
out->set_dtype(DataType::INT64);
}
}
void SizeInferMeta(const MetaTensor& input, MetaTensor* out) { void SizeInferMeta(const MetaTensor& input, MetaTensor* out) {
out->set_dtype(DataType::INT64); out->set_dtype(DataType::INT64);
out->set_dims({1}); out->set_dims({1});
......
...@@ -147,6 +147,14 @@ void DiagInferMeta(const MetaTensor& x, ...@@ -147,6 +147,14 @@ void DiagInferMeta(const MetaTensor& x,
float padding_value, float padding_value,
MetaTensor* out); MetaTensor* out);
void ArgMinMaxInferMeta(const MetaTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
int dtype,
MetaTensor* out,
MetaConfig config = MetaConfig());
void SizeInferMeta(const MetaTensor& input, MetaTensor* out); void SizeInferMeta(const MetaTensor& input, MetaTensor* out);
void DiagonalInferMeta( void DiagonalInferMeta(
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -12,11 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,11 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/arg_min_max_op_base.cu.h" #pragma once
REGISTER_OP_CUDA_KERNEL( #include "paddle/phi/core/dense_tensor.h"
arg_max, paddle::operators::ArgMinMaxOpCUDAKernel<float, cub::ArgMax>,
paddle::operators::ArgMinMaxOpCUDAKernel<double, cub::ArgMax>, namespace phi {
paddle::operators::ArgMinMaxOpCUDAKernel<int64_t, cub::ArgMax>,
paddle::operators::ArgMinMaxOpCUDAKernel<int32_t, cub::ArgMax>, template <typename T, typename Context>
paddle::operators::ArgMinMaxOpCUDAKernel<int8_t, cub::ArgMax>); void ArgMinKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
int dtype,
DenseTensor* out);
template <typename T, typename Context>
void ArgMaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
int dtype,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/arg_min_max_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
enum ArgMinMaxType { kArgMin, kArgMax };
template <typename Context,
typename T,
typename Tout,
int64_t Rank,
ArgMinMaxType argMinMaxValue>
struct ArgMinMaxFunctor {};
#define DECLARE_ARG_MIN_MAX_FUNCTOR(eigen_op_type, enum_argminmax_value) \
template <typename Context, typename T, typename Tout, int64_t Rank> \
struct ArgMinMaxFunctor<Context, T, Tout, Rank, enum_argminmax_value> { \
void operator()(const Context& dev_ctx, \
const DenseTensor& in, \
DenseTensor* out, \
phi::DDim x_dims, \
int64_t axis, \
bool keepdims) { \
auto in_eigen = EigenTensor<T, Rank>::From(in, x_dims); \
if (keepdims) { \
auto out_eigen = EigenTensor<Tout, Rank>::From(*out); \
out_eigen.device(*(dev_ctx.eigen_device())) = \
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
} else { \
auto out_eigen = EigenTensor<Tout, Rank - 1>::From(*out); \
out_eigen.device(*(dev_ctx.eigen_device())) = \
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
} \
} \
}
DECLARE_ARG_MIN_MAX_FUNCTOR(argmin, ArgMinMaxType::kArgMin);
DECLARE_ARG_MIN_MAX_FUNCTOR(argmax, ArgMinMaxType::kArgMax);
template <typename Context, typename T, ArgMinMaxType EnumArgMinMaxValue>
struct VisitDataArgMinMaxFunctor {
const Context& dev_ctx;
const DenseTensor& x;
int64_t axis;
bool keepdims;
bool flatten;
DenseTensor* out;
explicit VisitDataArgMinMaxFunctor(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
DenseTensor* out)
: dev_ctx(dev_ctx),
x(x),
axis(axis),
keepdims(keepdims),
flatten(flatten),
out(out) {}
template <typename Tout>
void apply() const {
dev_ctx.template Alloc<Tout>(out);
bool new_keepdims = keepdims;
if (flatten) new_keepdims = true;
// if flatten, will construct the new dims for the cacluate
phi::DDim x_dims;
int new_axis = axis;
if (flatten) {
x_dims = phi::make_ddim({x.numel()});
// if flatten, the axis just as 0
new_axis = 0;
} else {
x_dims = x.dims();
if (axis < 0) new_axis = axis + x_dims.size();
}
#define CALL_ARG_MINMAX_FUNCTOR(rank) \
ArgMinMaxFunctor<Context, T, Tout, rank, EnumArgMinMaxValue> functor##rank; \
functor##rank(dev_ctx, x, out, x_dims, new_axis, new_keepdims)
switch (x_dims.size()) {
case 1:
CALL_ARG_MINMAX_FUNCTOR(1);
break;
case 2:
CALL_ARG_MINMAX_FUNCTOR(2);
break;
case 3:
CALL_ARG_MINMAX_FUNCTOR(3);
break;
case 4:
CALL_ARG_MINMAX_FUNCTOR(4);
break;
case 5:
CALL_ARG_MINMAX_FUNCTOR(5);
break;
case 6:
CALL_ARG_MINMAX_FUNCTOR(6);
break;
default:
PADDLE_ENFORCE_LE(
x_dims.size(),
6,
phi::errors::InvalidArgument(
"%s operator doesn't supports tensors whose ranks are greater "
"than 6.",
(EnumArgMinMaxValue == kArgMin ? "argmin" : "argmax")));
break;
#undef CALL_ARG_MINMAX_FUNCTOR
}
}
};
template <typename Context, typename T, ArgMinMaxType EnumArgMinMaxValue>
void ArgMinMaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
int dtype,
DenseTensor* out) {
if (dtype < 0) {
paddle::framework::VisitDataTypeTiny(
static_cast<paddle::framework::proto::VarType::Type>(
paddle::framework::proto::VarType::INT64),
VisitDataArgMinMaxFunctor<Context, T, EnumArgMinMaxValue>(
dev_ctx, x, axis, keepdims, flatten, out));
return;
}
paddle::framework::VisitDataTypeTiny(
static_cast<paddle::framework::proto::VarType::Type>(dtype),
VisitDataArgMinMaxFunctor<Context, T, EnumArgMinMaxValue>(
dev_ctx, x, axis, keepdims, flatten, out));
}
template <typename T, typename Context>
void ArgMinKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
int dtype,
DenseTensor* out) {
ArgMinMaxKernel<Context, T, ArgMinMaxType::kArgMin>(
dev_ctx, x, axis, keepdims, flatten, dtype, out);
}
template <typename T, typename Context>
void ArgMaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
int dtype,
DenseTensor* out) {
ArgMinMaxKernel<Context, T, ArgMinMaxType::kArgMax>(
dev_ctx, x, axis, keepdims, flatten, dtype, out);
}
} // namespace phi
PD_REGISTER_KERNEL(arg_min,
CPU,
ALL_LAYOUT,
phi::ArgMinKernel,
float,
double,
int32_t,
int64_t,
int16_t,
uint8_t) {}
PD_REGISTER_KERNEL(arg_max,
CPU,
ALL_LAYOUT,
phi::ArgMaxKernel,
float,
double,
int32_t,
int64_t,
int16_t,
uint8_t) {}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Licensed under the Apache License, Version 2.0 (the "License"); #include "paddle/phi/kernels/arg_min_max_kernel.h"
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
...@@ -24,21 +27,14 @@ limitations under the License. */ ...@@ -24,21 +27,14 @@ limitations under the License. */
namespace cub = hipcub; namespace cub = hipcub;
#endif #endif
#include <limits> #include <limits>
#include <string> #include "paddle/fluid/framework/data_type.h"
#include <typeinfo>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
namespace paddle { namespace phi {
namespace operators {
namespace { // NOLINT namespace { // NOLINT
template <typename K, typename V> template <typename K, typename V>
using KeyValuePair = cub::KeyValuePair<K, V>; using KeyValuePair = cub::KeyValuePair<K, V>;
using Tensor = framework::Tensor;
} // end namespace } // end namespace
...@@ -62,7 +58,9 @@ template <typename T, typename IndType, class Reducer, size_t BlockDim> ...@@ -62,7 +58,9 @@ template <typename T, typename IndType, class Reducer, size_t BlockDim>
__global__ void ArgCUDAKernel(const int64_t height, // n * h __global__ void ArgCUDAKernel(const int64_t height, // n * h
const int64_t width, // c const int64_t width, // c
const int64_t post_size, // h const int64_t post_size, // h
const Reducer reducer, const T init, const T* in, const Reducer reducer,
const T init,
const T* in,
IndType* out) { IndType* out) {
typedef cub::BlockReduce<KeyValuePair<int, T>, BlockDim> BlockReduce; typedef cub::BlockReduce<KeyValuePair<int, T>, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
...@@ -84,10 +82,13 @@ __global__ void ArgCUDAKernel(const int64_t height, // n * h ...@@ -84,10 +82,13 @@ __global__ void ArgCUDAKernel(const int64_t height, // n * h
} }
template <typename T, typename IndType, class Reducer> template <typename T, typename IndType, class Reducer>
void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input, void ComputeFullArg(const phi::GPUContext& dev_ctx,
Tensor* indices, const int64_t pre, const int64_t post, const DenseTensor& input,
DenseTensor* indices,
const int64_t pre,
const int64_t post,
const int64_t n) { const int64_t n) {
auto cu_stream = ctx.stream(); auto cu_stream = dev_ctx.stream();
auto ComputeBlockSize = [](int64_t col) { auto ComputeBlockSize = [](int64_t col) {
auto block_size = 8; auto block_size = 8;
if (col > 512) if (col > 512)
...@@ -110,93 +111,168 @@ void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input, ...@@ -110,93 +111,168 @@ void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input,
return block_size; return block_size;
}; };
int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize()[0]; int64_t max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize()[0];
int64_t height = pre * post; int64_t height = pre * post;
int64_t width = n; int64_t width = n;
int64_t grid_size = height < max_grid_dimx ? height : max_grid_dimx; int64_t grid_size = height < max_grid_dimx ? height : max_grid_dimx;
const T* in_data = input.data<T>(); const T* in_data = input.data<T>();
IndType* out_data = indices->mutable_data<IndType>(ctx.GetPlace()); IndType* out_data = dev_ctx.template Alloc<IndType>(indices);
if (typeid(Reducer) == typeid(cub::ArgMax)) { if (typeid(Reducer) == typeid(cub::ArgMax)) {
switch (ComputeBlockSize(width)) { switch (ComputeBlockSize(width)) {
FIXED_BLOCK_DIM_CASE( FIXED_BLOCK_DIM_CASE(
ArgCUDAKernel<T, IndType, Reducer, ArgCUDAKernel<T,
IndType,
Reducer,
kBlockDim><<<grid_size, kBlockDim, 0, cu_stream>>>( kBlockDim><<<grid_size, kBlockDim, 0, cu_stream>>>(
height, width, post, Reducer(), std::numeric_limits<T>::lowest(), height,
in_data, out_data)); width,
post,
Reducer(),
std::numeric_limits<T>::lowest(),
in_data,
out_data));
} }
} else { } else {
switch (ComputeBlockSize(width)) { switch (ComputeBlockSize(width)) {
FIXED_BLOCK_DIM_CASE( FIXED_BLOCK_DIM_CASE(
ArgCUDAKernel<T, IndType, Reducer, ArgCUDAKernel<T,
IndType,
Reducer,
kBlockDim><<<grid_size, kBlockDim, 0, cu_stream>>>( kBlockDim><<<grid_size, kBlockDim, 0, cu_stream>>>(
height, width, post, Reducer(), std::numeric_limits<T>::max(), height,
in_data, out_data)); width,
post,
Reducer(),
std::numeric_limits<T>::max(),
in_data,
out_data));
} }
} }
} }
template <typename T, class Reducer> template <typename Context, typename T, class Reducer>
struct VisitDataCudaArgMinMaxFunctor { struct VisitDataCudaArgMinMaxFunctor {
const framework::ExecutionContext& ctx; const Context& dev_ctx;
const DenseTensor& x;
int64_t axis;
bool keepdims;
bool flatten;
DenseTensor* out;
explicit VisitDataCudaArgMinMaxFunctor(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
DenseTensor* out)
: dev_ctx(dev_ctx),
x(x),
axis(axis),
keepdims(keepdims),
flatten(flatten),
out(out) {}
explicit VisitDataCudaArgMinMaxFunctor(const framework::ExecutionContext& ctx)
: ctx(ctx) {}
template <typename IndType> template <typename IndType>
void apply() const { void apply() const {
auto* input = ctx.Input<Tensor>("X"); phi::DDim x_dims;
auto* output = ctx.Output<Tensor>("Out"); int new_axis = axis;
int axis = ctx.Attr<int64_t>("axis");
const bool& flatten = ctx.Attr<bool>("flatten");
framework::DDim input_dims;
if (flatten) { if (flatten) {
input_dims = phi::make_ddim({input->numel()}); x_dims = phi::make_ddim({x.numel()});
// if flatten, the axis just as 0 // if flatten, the axis just as 0
axis = 0; new_axis = 0;
} else { } else {
input_dims = input->dims(); x_dims = x.dims();
if (axis < 0) axis += input->dims().size(); if (axis < 0) new_axis = axis + x.dims().size();
} }
int64_t numel = input->numel(); int64_t numel = x.numel();
int64_t groups = numel / input_dims[axis]; int64_t groups = numel / x_dims[new_axis];
int64_t pre = 1; int64_t pre = 1;
int64_t post = 1; int64_t post = 1;
int64_t n = input_dims[axis]; int64_t n = x_dims[new_axis];
for (int i = 0; i < axis; i++) { for (int i = 0; i < new_axis; i++) {
pre *= input_dims[i]; pre *= x_dims[i];
} }
for (int i = axis + 1; i < input_dims.size(); i++) { for (int i = new_axis + 1; i < x_dims.size(); i++) {
post *= input_dims[i]; post *= x_dims[i];
} }
const auto& dev_ctx = ctx.cuda_device_context(); ComputeFullArg<T, IndType, Reducer>(dev_ctx, x, out, pre, post, n);
ComputeFullArg<T, IndType, Reducer>(dev_ctx, *input, output, pre, post, n);
} }
}; };
template <typename T, class Reducer>
class ArgMinMaxOpCUDAKernel : public framework::OpKernel<T> { template <typename Context, typename T, class Reducer>
public: void ArgMinMaxOpCUDAKernel(const Context& dev_ctx,
void Compute(const framework::ExecutionContext& ctx) const override { const DenseTensor& x,
auto& dtype = ctx.Attr<int>("dtype"); int64_t axis,
bool keepdims,
bool flatten,
int dtype,
DenseTensor* out) {
if (dtype < 0) { if (dtype < 0) {
framework::VisitDataTypeTiny( paddle::framework::VisitDataTypeTiny(
static_cast<framework::proto::VarType::Type>( static_cast<paddle::framework::proto::VarType::Type>(
framework::proto::VarType::INT64), paddle::framework::proto::VarType::INT64),
VisitDataCudaArgMinMaxFunctor<T, Reducer>(ctx)); VisitDataCudaArgMinMaxFunctor<Context, T, Reducer>(
dev_ctx, x, axis, keepdims, flatten, out));
return; return;
} }
framework::VisitDataTypeTiny( paddle::framework::VisitDataTypeTiny(
static_cast<framework::proto::VarType::Type>(dtype), static_cast<paddle::framework::proto::VarType::Type>(dtype),
VisitDataCudaArgMinMaxFunctor<T, Reducer>(ctx)); VisitDataCudaArgMinMaxFunctor<Context, T, Reducer>(
} dev_ctx, x, axis, keepdims, flatten, out));
}; }
template <typename T, typename Context>
void ArgMinKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
int dtype,
DenseTensor* out) {
ArgMinMaxOpCUDAKernel<Context, T, cub::ArgMin>(
dev_ctx, x, axis, keepdims, flatten, dtype, out);
}
template <typename T, typename Context>
void ArgMaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
int dtype,
DenseTensor* out) {
ArgMinMaxOpCUDAKernel<Context, T, cub::ArgMax>(
dev_ctx, x, axis, keepdims, flatten, dtype, out);
}
#endif #endif
} // namespace operators } // namespace phi
} // namespace paddle
PD_REGISTER_KERNEL(arg_min,
GPU,
ALL_LAYOUT,
phi::ArgMinKernel,
float,
double,
int32_t,
int64_t,
int16_t,
uint8_t) {}
PD_REGISTER_KERNEL(arg_max,
GPU,
ALL_LAYOUT,
phi::ArgMaxKernel,
float,
double,
int32_t,
int64_t,
int16_t,
uint8_t) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册