未验证 提交 f25dba0a 编写于 作者: Z Zhong Hui 提交者: GitHub

[PHI] Move arg min max to PHI. (#40222)

* move arg min max to phi.

* move infermeta.

* fix as reviews.
上级 1128db30
...@@ -15,23 +15,19 @@ limitations under the License. */ ...@@ -15,23 +15,19 @@ limitations under the License. */
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/arg_min_max_op_base.h" #include "paddle/fluid/operators/arg_min_max_op_base.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
DECLARE_INFER_SHAPE_FUNCTOR(arg_max, ArgMaxInferShapeFunctor,
PD_INFER_META(phi::ArgMinMaxInferMeta));
REGISTER_OPERATOR( REGISTER_OPERATOR(
arg_max, paddle::operators::ArgMinMaxOp, paddle::operators::ArgMaxOpMaker, arg_max, paddle::operators::ArgMinMaxOp, paddle::operators::ArgMaxOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ArgMaxInferShapeFunctor);
REGISTER_OP_CPU_KERNEL(
arg_max,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, float>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, double>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
int64_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
int32_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
int16_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
uint8_t>);
REGISTER_OP_VERSION(arg_max) REGISTER_OP_VERSION(arg_max)
.AddCheckpoint( .AddCheckpoint(
R"ROC( R"ROC(
......
...@@ -27,193 +27,9 @@ limitations under the License. */ ...@@ -27,193 +27,9 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
enum ArgMinMaxType { kArgMin, kArgMax };
template <typename DeviceContext, typename T, typename Tout, int64_t Rank,
ArgMinMaxType argMinMaxValue>
struct ArgMinMaxFunctor {};
#define DECLARE_ARG_MIN_MAX_FUNCTOR(eigen_op_type, enum_argminmax_value) \
template <typename DeviceContext, typename T, typename Tout, int64_t Rank> \
struct ArgMinMaxFunctor<DeviceContext, T, Tout, Rank, \
enum_argminmax_value> { \
void operator()(const DeviceContext& ctx, const framework::LoDTensor& in, \
framework::LoDTensor* out, framework::DDim x_dims, \
int64_t axis, bool keepdims) { \
auto in_eigen = framework::EigenTensor<T, Rank>::From(in, x_dims); \
if (keepdims) { \
auto out_eigen = framework::EigenTensor<Tout, Rank>::From(*out); \
out_eigen.device(*(ctx.eigen_device())) = \
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
} else { \
auto out_eigen = framework::EigenTensor<Tout, Rank - 1>::From(*out); \
out_eigen.device(*(ctx.eigen_device())) = \
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
} \
} \
}
DECLARE_ARG_MIN_MAX_FUNCTOR(argmin, ArgMinMaxType::kArgMin);
DECLARE_ARG_MIN_MAX_FUNCTOR(argmax, ArgMinMaxType::kArgMax);
template <typename DeviceContext, typename T, ArgMinMaxType EnumArgMinMaxValue>
struct VisitDataArgMinMaxFunctor {
const framework::ExecutionContext& ctx;
explicit VisitDataArgMinMaxFunctor(const framework::ExecutionContext& ctx)
: ctx(ctx) {}
template <typename Tout>
void apply() const {
auto& x = *(ctx.Input<framework::LoDTensor>("X"));
auto& out = *(ctx.Output<framework::LoDTensor>("Out"));
out.template mutable_data<Tout>(ctx.GetPlace());
auto axis = ctx.Attr<int64_t>("axis");
auto keepdims = ctx.Attr<bool>("keepdims");
const bool& flatten = ctx.Attr<bool>("flatten");
// paddle do not have the scalar tensor, just return the shape [1] tensor
if (flatten) keepdims = true;
// if flatten, will construct the new dims for the cacluate
framework::DDim x_dims;
if (flatten) {
x_dims = phi::make_ddim({x.numel()});
// if flatten, the axis just as 0
axis = 0;
} else {
x_dims = x.dims();
if (axis < 0) axis += x_dims.size();
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
#define CALL_ARG_MINMAX_FUNCTOR(rank) \
ArgMinMaxFunctor<DeviceContext, T, Tout, rank, EnumArgMinMaxValue> \
functor##rank; \
functor##rank(dev_ctx, x, &out, x_dims, axis, keepdims)
switch (x_dims.size()) {
case 1:
CALL_ARG_MINMAX_FUNCTOR(1);
break;
case 2:
CALL_ARG_MINMAX_FUNCTOR(2);
break;
case 3:
CALL_ARG_MINMAX_FUNCTOR(3);
break;
case 4:
CALL_ARG_MINMAX_FUNCTOR(4);
break;
case 5:
CALL_ARG_MINMAX_FUNCTOR(5);
break;
case 6:
CALL_ARG_MINMAX_FUNCTOR(6);
break;
default:
PADDLE_ENFORCE_LE(
x_dims.size(), 6,
platform::errors::InvalidArgument(
"%s operator doesn't supports tensors whose ranks are greater "
"than 6.",
(EnumArgMinMaxValue == kArgMin ? "argmin" : "argmax")));
break;
#undef CALL_ARG_MINMAX_FUNCTOR
}
}
};
template <typename DeviceContext, typename T, ArgMinMaxType EnumArgMinMaxValue>
class ArgMinMaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dtype = ctx.Attr<int>("dtype");
if (dtype < 0) {
framework::VisitDataTypeTiny(
static_cast<framework::proto::VarType::Type>(
framework::proto::VarType::INT64),
VisitDataArgMinMaxFunctor<DeviceContext, T, EnumArgMinMaxValue>(ctx));
return;
}
framework::VisitDataTypeTiny(
static_cast<framework::proto::VarType::Type>(dtype),
VisitDataArgMinMaxFunctor<DeviceContext, T, EnumArgMinMaxValue>(ctx));
}
};
template <typename DeviceContext, typename T>
using ArgMinKernel = ArgMinMaxKernel<DeviceContext, T, ArgMinMaxType::kArgMin>;
template <typename DeviceContext, typename T>
using ArgMaxKernel = ArgMinMaxKernel<DeviceContext, T, ArgMinMaxType::kArgMax>;
class ArgMinMaxOp : public framework::OperatorWithKernel { class ArgMinMaxOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "arg_min_max");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "arg_min_max");
const auto& x_dims = ctx->GetInputDim("X");
int64_t axis = ctx->Attrs().Get<int64_t>("axis");
bool keepdims = ctx->Attrs().Get<bool>("keepdims");
const bool& flatten = ctx->Attrs().Get<bool>("flatten");
PADDLE_ENFORCE_GE(axis, -x_dims.size(),
platform::errors::InvalidArgument(
"'axis'(%d) must be greater than or equal to"
" -Rank(X)(%d).",
axis, -x_dims.size()));
PADDLE_ENFORCE_LT(
axis, x_dims.size(),
platform::errors::InvalidArgument(
"'axis'(%d) must be less than Rank(X)(%d) of Input(X).", axis,
x_dims.size()));
const int& dtype = ctx->Attrs().Get<int>("dtype");
PADDLE_ENFORCE_EQ(
(dtype < 0 || dtype == 2 || dtype == 3), true,
platform::errors::InvalidArgument(
"The attribute of dtype in argmin/argmax must be [%s] or [%s], but "
"received [%s]",
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64),
paddle::framework::DataTypeToString(
static_cast<framework::proto::VarType::Type>(dtype))));
auto x_rank = x_dims.size();
if (axis < 0) axis += x_rank;
if (ctx->IsRuntime()) {
if (dtype == framework::proto::VarType::INT32) {
int64_t all_element_num = 0;
if (flatten) {
all_element_num = phi::product(x_dims);
} else {
all_element_num = x_dims[axis];
}
PADDLE_ENFORCE_LE(
all_element_num, INT_MAX,
platform::errors::InvalidArgument(
"The element num of the argmin/argmax input at axis is "
"%d, is larger than int32 maximum value:%d, you must "
"set the dtype of argmin/argmax to 'int64'.",
all_element_num, INT_MAX));
}
}
std::vector<int64_t> vec;
if (flatten) {
vec.emplace_back(static_cast<int64_t>(1));
} else {
for (int64_t i = 0; i < axis; i++) vec.emplace_back(x_dims[i]);
if (keepdims) {
vec.emplace_back(static_cast<int64_t>(1));
}
for (int64_t i = axis + 1; i < x_rank; i++) vec.emplace_back(x_dims[i]);
}
ctx->SetOutputDim("Out", phi::make_ddim(vec));
}
}; };
class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker { class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -12,26 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,26 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/arg_min_max_op_base.h" #include "paddle/fluid/operators/arg_min_max_op_base.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
DECLARE_INFER_SHAPE_FUNCTOR(arg_min, ArgMinInferShapeFunctor,
PD_INFER_META(phi::ArgMinMaxInferMeta));
REGISTER_OPERATOR( REGISTER_OPERATOR(
arg_min, paddle::operators::ArgMinMaxOp, paddle::operators::ArgMinOpMaker, arg_min, paddle::operators::ArgMinMaxOp, paddle::operators::ArgMinOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ArgMinInferShapeFunctor);
REGISTER_OP_CPU_KERNEL(
arg_min,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, float>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, double>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
int64_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
int32_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
int16_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
uint8_t>);
REGISTER_OP_VERSION(arg_min) REGISTER_OP_VERSION(arg_min)
.AddCheckpoint( .AddCheckpoint(
R"ROC( R"ROC(
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/arg_min_max_op_base.cu.h"
REGISTER_OP_CUDA_KERNEL(
arg_min, paddle::operators::ArgMinMaxOpCUDAKernel<float, cub::ArgMin>,
paddle::operators::ArgMinMaxOpCUDAKernel<double, cub::ArgMin>,
paddle::operators::ArgMinMaxOpCUDAKernel<int64_t, cub::ArgMin>,
paddle::operators::ArgMinMaxOpCUDAKernel<int32_t, cub::ArgMin>,
paddle::operators::ArgMinMaxOpCUDAKernel<int8_t, cub::ArgMin>);
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <set> #include <set>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/type_traits.h" #include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
...@@ -1014,6 +1015,82 @@ void DiagInferMeta(const MetaTensor& x, ...@@ -1014,6 +1015,82 @@ void DiagInferMeta(const MetaTensor& x,
} }
} }
void ArgMinMaxInferMeta(const MetaTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
int dtype,
MetaTensor* out,
MetaConfig config) {
const auto& x_dims = x.dims();
PADDLE_ENFORCE_GE(
axis,
-x_dims.size(),
phi::errors::InvalidArgument("'axis'(%d) must be greater than or equal to"
" -Rank(X)(%d).",
axis,
-x_dims.size()));
PADDLE_ENFORCE_LT(axis,
x_dims.size(),
phi::errors::InvalidArgument(
"'axis'(%d) must be less than Rank(X)(%d) of Input(X).",
axis,
x_dims.size()));
PADDLE_ENFORCE_EQ(
(dtype < 0 || dtype == 2 || dtype == 3),
true,
phi::errors::InvalidArgument(
"The attribute of dtype in argmin/argmax must be [%s] or [%s], but "
"received [%s]",
paddle::framework::DataTypeToString(
paddle::framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
paddle::framework::proto::VarType::INT64),
paddle::framework::DataTypeToString(
static_cast<paddle::framework::proto::VarType::Type>(dtype))));
auto x_rank = x_dims.size();
if (axis < 0) axis += x_rank;
if (config.is_runtime) {
if (dtype == paddle::framework::proto::VarType::INT32) {
int64_t all_element_num = 0;
if (flatten) {
all_element_num = phi::product(x_dims);
} else {
all_element_num = x_dims[axis];
}
PADDLE_ENFORCE_LE(
all_element_num,
INT_MAX,
phi::errors::InvalidArgument(
"The element num of the argmin/argmax input at axis is "
"%d, is larger than int32 maximum value:%d, you must "
"set the dtype of argmin/argmax to 'int64'.",
all_element_num,
INT_MAX));
}
}
std::vector<int64_t> vec;
if (flatten) {
vec.emplace_back(static_cast<int64_t>(1));
} else {
for (int64_t i = 0; i < axis; i++) vec.emplace_back(x_dims[i]);
if (keepdims) {
vec.emplace_back(static_cast<int64_t>(1));
}
for (int64_t i = axis + 1; i < x_rank; i++) vec.emplace_back(x_dims[i]);
}
out->set_dims(phi::make_ddim(vec));
if (dtype == 2) {
out->set_dtype(DataType::INT32);
} else if (dtype == 3) {
out->set_dtype(DataType::INT64);
}
}
void SizeInferMeta(const MetaTensor& input, MetaTensor* out) { void SizeInferMeta(const MetaTensor& input, MetaTensor* out) {
out->set_dtype(DataType::INT64); out->set_dtype(DataType::INT64);
out->set_dims({1}); out->set_dims({1});
......
...@@ -147,6 +147,14 @@ void DiagInferMeta(const MetaTensor& x, ...@@ -147,6 +147,14 @@ void DiagInferMeta(const MetaTensor& x,
float padding_value, float padding_value,
MetaTensor* out); MetaTensor* out);
void ArgMinMaxInferMeta(const MetaTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
int dtype,
MetaTensor* out,
MetaConfig config = MetaConfig());
void SizeInferMeta(const MetaTensor& input, MetaTensor* out); void SizeInferMeta(const MetaTensor& input, MetaTensor* out);
void DiagonalInferMeta( void DiagonalInferMeta(
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/arg_min_max_op_base.cu.h" #pragma once
REGISTER_OP_CUDA_KERNEL( #include "paddle/phi/core/dense_tensor.h"
arg_max, paddle::operators::ArgMinMaxOpCUDAKernel<float, cub::ArgMax>,
paddle::operators::ArgMinMaxOpCUDAKernel<double, cub::ArgMax>, namespace phi {
paddle::operators::ArgMinMaxOpCUDAKernel<int64_t, cub::ArgMax>,
paddle::operators::ArgMinMaxOpCUDAKernel<int32_t, cub::ArgMax>, template <typename T, typename Context>
paddle::operators::ArgMinMaxOpCUDAKernel<int8_t, cub::ArgMax>); void ArgMinKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
int dtype,
DenseTensor* out);
template <typename T, typename Context>
void ArgMaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
int dtype,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/arg_min_max_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
enum ArgMinMaxType { kArgMin, kArgMax };
template <typename Context,
typename T,
typename Tout,
int64_t Rank,
ArgMinMaxType argMinMaxValue>
struct ArgMinMaxFunctor {};
#define DECLARE_ARG_MIN_MAX_FUNCTOR(eigen_op_type, enum_argminmax_value) \
template <typename Context, typename T, typename Tout, int64_t Rank> \
struct ArgMinMaxFunctor<Context, T, Tout, Rank, enum_argminmax_value> { \
void operator()(const Context& dev_ctx, \
const DenseTensor& in, \
DenseTensor* out, \
phi::DDim x_dims, \
int64_t axis, \
bool keepdims) { \
auto in_eigen = EigenTensor<T, Rank>::From(in, x_dims); \
if (keepdims) { \
auto out_eigen = EigenTensor<Tout, Rank>::From(*out); \
out_eigen.device(*(dev_ctx.eigen_device())) = \
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
} else { \
auto out_eigen = EigenTensor<Tout, Rank - 1>::From(*out); \
out_eigen.device(*(dev_ctx.eigen_device())) = \
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
} \
} \
}
DECLARE_ARG_MIN_MAX_FUNCTOR(argmin, ArgMinMaxType::kArgMin);
DECLARE_ARG_MIN_MAX_FUNCTOR(argmax, ArgMinMaxType::kArgMax);
template <typename Context, typename T, ArgMinMaxType EnumArgMinMaxValue>
struct VisitDataArgMinMaxFunctor {
const Context& dev_ctx;
const DenseTensor& x;
int64_t axis;
bool keepdims;
bool flatten;
DenseTensor* out;
explicit VisitDataArgMinMaxFunctor(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
DenseTensor* out)
: dev_ctx(dev_ctx),
x(x),
axis(axis),
keepdims(keepdims),
flatten(flatten),
out(out) {}
template <typename Tout>
void apply() const {
dev_ctx.template Alloc<Tout>(out);
bool new_keepdims = keepdims;
if (flatten) new_keepdims = true;
// if flatten, will construct the new dims for the cacluate
phi::DDim x_dims;
int new_axis = axis;
if (flatten) {
x_dims = phi::make_ddim({x.numel()});
// if flatten, the axis just as 0
new_axis = 0;
} else {
x_dims = x.dims();
if (axis < 0) new_axis = axis + x_dims.size();
}
#define CALL_ARG_MINMAX_FUNCTOR(rank) \
ArgMinMaxFunctor<Context, T, Tout, rank, EnumArgMinMaxValue> functor##rank; \
functor##rank(dev_ctx, x, out, x_dims, new_axis, new_keepdims)
switch (x_dims.size()) {
case 1:
CALL_ARG_MINMAX_FUNCTOR(1);
break;
case 2:
CALL_ARG_MINMAX_FUNCTOR(2);
break;
case 3:
CALL_ARG_MINMAX_FUNCTOR(3);
break;
case 4:
CALL_ARG_MINMAX_FUNCTOR(4);
break;
case 5:
CALL_ARG_MINMAX_FUNCTOR(5);
break;
case 6:
CALL_ARG_MINMAX_FUNCTOR(6);
break;
default:
PADDLE_ENFORCE_LE(
x_dims.size(),
6,
phi::errors::InvalidArgument(
"%s operator doesn't supports tensors whose ranks are greater "
"than 6.",
(EnumArgMinMaxValue == kArgMin ? "argmin" : "argmax")));
break;
#undef CALL_ARG_MINMAX_FUNCTOR
}
}
};
template <typename Context, typename T, ArgMinMaxType EnumArgMinMaxValue>
void ArgMinMaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
int dtype,
DenseTensor* out) {
if (dtype < 0) {
paddle::framework::VisitDataTypeTiny(
static_cast<paddle::framework::proto::VarType::Type>(
paddle::framework::proto::VarType::INT64),
VisitDataArgMinMaxFunctor<Context, T, EnumArgMinMaxValue>(
dev_ctx, x, axis, keepdims, flatten, out));
return;
}
paddle::framework::VisitDataTypeTiny(
static_cast<paddle::framework::proto::VarType::Type>(dtype),
VisitDataArgMinMaxFunctor<Context, T, EnumArgMinMaxValue>(
dev_ctx, x, axis, keepdims, flatten, out));
}
template <typename T, typename Context>
void ArgMinKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
int dtype,
DenseTensor* out) {
ArgMinMaxKernel<Context, T, ArgMinMaxType::kArgMin>(
dev_ctx, x, axis, keepdims, flatten, dtype, out);
}
template <typename T, typename Context>
void ArgMaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
int dtype,
DenseTensor* out) {
ArgMinMaxKernel<Context, T, ArgMinMaxType::kArgMax>(
dev_ctx, x, axis, keepdims, flatten, dtype, out);
}
} // namespace phi
PD_REGISTER_KERNEL(arg_min,
CPU,
ALL_LAYOUT,
phi::ArgMinKernel,
float,
double,
int32_t,
int64_t,
int16_t,
uint8_t) {}
PD_REGISTER_KERNEL(arg_max,
CPU,
ALL_LAYOUT,
phi::ArgMaxKernel,
float,
double,
int32_t,
int64_t,
int16_t,
uint8_t) {}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
You may obtain a copy of the License at // You may obtain a copy of the License at
//
http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
//
Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
limitations under the License. */ // limitations under the License.
#pragma once #include "paddle/phi/kernels/arg_min_max_kernel.h"
#if defined(__NVCC__) || defined(__HIPCC__) #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#ifdef __NVCC__
#include "cub/cub.cuh" #if defined(__NVCC__) || defined(__HIPCC__)
#endif
#ifdef __HIPCC__ #ifdef __NVCC__
#include <hipcub/hipcub.hpp> #include "cub/cub.cuh"
namespace cub = hipcub; #endif
#endif #ifdef __HIPCC__
#include <limits> #include <hipcub/hipcub.hpp>
#include <string> namespace cub = hipcub;
#include <typeinfo> #endif
#include <vector> #include <limits>
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/transpose_op.h" #include "paddle/phi/core/ddim.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/core/ddim.h" namespace phi {
namespace paddle { namespace { // NOLINT
namespace operators { template <typename K, typename V>
using KeyValuePair = cub::KeyValuePair<K, V>;
namespace { // NOLINT
template <typename K, typename V> } // end namespace
using KeyValuePair = cub::KeyValuePair<K, V>;
using Tensor = framework::Tensor; #define FIXED_BLOCK_DIM_CASE_BASE(log2_block_dim, ...) \
case (1 << (log2_block_dim)): { \
} // end namespace constexpr auto kBlockDim = (1 << (log2_block_dim)); \
__VA_ARGS__; \
#define FIXED_BLOCK_DIM_CASE_BASE(log2_block_dim, ...) \ } break
case (1 << (log2_block_dim)): { \
constexpr auto kBlockDim = (1 << (log2_block_dim)); \ #define FIXED_BLOCK_DIM_CASE(...) \
__VA_ARGS__; \ FIXED_BLOCK_DIM_CASE_BASE(10, ##__VA_ARGS__); \
} break FIXED_BLOCK_DIM_CASE_BASE(9, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(8, ##__VA_ARGS__); \
#define FIXED_BLOCK_DIM_CASE(...) \ FIXED_BLOCK_DIM_CASE_BASE(7, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(10, ##__VA_ARGS__); \ FIXED_BLOCK_DIM_CASE_BASE(6, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(9, ##__VA_ARGS__); \ FIXED_BLOCK_DIM_CASE_BASE(5, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(8, ##__VA_ARGS__); \ FIXED_BLOCK_DIM_CASE_BASE(4, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(7, ##__VA_ARGS__); \ FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__);
FIXED_BLOCK_DIM_CASE_BASE(6, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(5, ##__VA_ARGS__); \ template <typename T, typename IndType, class Reducer, size_t BlockDim>
FIXED_BLOCK_DIM_CASE_BASE(4, ##__VA_ARGS__); \ __global__ void ArgCUDAKernel(const int64_t height, // n * h
FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__); const int64_t width, // c
const int64_t post_size, // h
template <typename T, typename IndType, class Reducer, size_t BlockDim> const Reducer reducer,
__global__ void ArgCUDAKernel(const int64_t height, // n * h const T init,
const int64_t width, // c const T* in,
const int64_t post_size, // h IndType* out) {
const Reducer reducer, const T init, const T* in, typedef cub::BlockReduce<KeyValuePair<int, T>, BlockDim> BlockReduce;
IndType* out) { __shared__ typename BlockReduce::TempStorage temp_storage;
typedef cub::BlockReduce<KeyValuePair<int, T>, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage; for (int idx = blockIdx.x; idx < height; idx += gridDim.x) {
KeyValuePair<int, T> kv_pair = {-1, init};
for (int idx = blockIdx.x; idx < height; idx += gridDim.x) { int h = idx / post_size;
KeyValuePair<int, T> kv_pair = {-1, init}; int w = idx % post_size;
int h = idx / post_size; for (int k = threadIdx.x; k < width; k += blockDim.x) {
int w = idx % post_size; kv_pair =
for (int k = threadIdx.x; k < width; k += blockDim.x) { reducer({k, in[h * width * post_size + k * post_size + w]}, kv_pair);
kv_pair = }
reducer({k, in[h * width * post_size + k * post_size + w]}, kv_pair); kv_pair = BlockReduce(temp_storage).Reduce(kv_pair, reducer);
} if (threadIdx.x == 0) {
kv_pair = BlockReduce(temp_storage).Reduce(kv_pair, reducer); out[idx] = static_cast<IndType>(kv_pair.key);
if (threadIdx.x == 0) { }
out[idx] = static_cast<IndType>(kv_pair.key); __syncthreads();
} }
__syncthreads(); }
}
} template <typename T, typename IndType, class Reducer>
void ComputeFullArg(const phi::GPUContext& dev_ctx,
template <typename T, typename IndType, class Reducer> const DenseTensor& input,
void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input, DenseTensor* indices,
Tensor* indices, const int64_t pre, const int64_t post, const int64_t pre,
const int64_t n) { const int64_t post,
auto cu_stream = ctx.stream(); const int64_t n) {
auto ComputeBlockSize = [](int64_t col) { auto cu_stream = dev_ctx.stream();
auto block_size = 8; auto ComputeBlockSize = [](int64_t col) {
if (col > 512) auto block_size = 8;
block_size = 1024; if (col > 512)
else if (col > 256) block_size = 1024;
block_size = 512; else if (col > 256)
else if (col > 128) block_size = 512;
block_size = 256; else if (col > 128)
else if (col > 64) block_size = 256;
block_size = 128; else if (col > 64)
else if (col > 32) block_size = 128;
block_size = 64; else if (col > 32)
else if (col > 16) block_size = 64;
block_size = 32; else if (col > 16)
else if (col > 8) block_size = 32;
block_size = 16; else if (col > 8)
#ifdef __HIPCC__ block_size = 16;
block_size = std::min(block_size, 256); #ifdef __HIPCC__
#endif block_size = std::min(block_size, 256);
return block_size; #endif
}; return block_size;
};
int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize()[0];
int64_t height = pre * post; int64_t max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize()[0];
int64_t width = n; int64_t height = pre * post;
int64_t grid_size = height < max_grid_dimx ? height : max_grid_dimx; int64_t width = n;
int64_t grid_size = height < max_grid_dimx ? height : max_grid_dimx;
const T* in_data = input.data<T>();
IndType* out_data = indices->mutable_data<IndType>(ctx.GetPlace()); const T* in_data = input.data<T>();
IndType* out_data = dev_ctx.template Alloc<IndType>(indices);
if (typeid(Reducer) == typeid(cub::ArgMax)) {
switch (ComputeBlockSize(width)) { if (typeid(Reducer) == typeid(cub::ArgMax)) {
FIXED_BLOCK_DIM_CASE( switch (ComputeBlockSize(width)) {
ArgCUDAKernel<T, IndType, Reducer, FIXED_BLOCK_DIM_CASE(
kBlockDim><<<grid_size, kBlockDim, 0, cu_stream>>>( ArgCUDAKernel<T,
height, width, post, Reducer(), std::numeric_limits<T>::lowest(), IndType,
in_data, out_data)); Reducer,
} kBlockDim><<<grid_size, kBlockDim, 0, cu_stream>>>(
} else { height,
switch (ComputeBlockSize(width)) { width,
FIXED_BLOCK_DIM_CASE( post,
ArgCUDAKernel<T, IndType, Reducer, Reducer(),
kBlockDim><<<grid_size, kBlockDim, 0, cu_stream>>>( std::numeric_limits<T>::lowest(),
height, width, post, Reducer(), std::numeric_limits<T>::max(), in_data,
in_data, out_data)); out_data));
} }
} } else {
} switch (ComputeBlockSize(width)) {
FIXED_BLOCK_DIM_CASE(
template <typename T, class Reducer> ArgCUDAKernel<T,
struct VisitDataCudaArgMinMaxFunctor { IndType,
const framework::ExecutionContext& ctx; Reducer,
kBlockDim><<<grid_size, kBlockDim, 0, cu_stream>>>(
explicit VisitDataCudaArgMinMaxFunctor(const framework::ExecutionContext& ctx) height,
: ctx(ctx) {} width,
template <typename IndType> post,
void apply() const { Reducer(),
auto* input = ctx.Input<Tensor>("X"); std::numeric_limits<T>::max(),
auto* output = ctx.Output<Tensor>("Out"); in_data,
int axis = ctx.Attr<int64_t>("axis"); out_data));
const bool& flatten = ctx.Attr<bool>("flatten"); }
}
framework::DDim input_dims; }
if (flatten) {
input_dims = phi::make_ddim({input->numel()}); template <typename Context, typename T, class Reducer>
// if flatten, the axis just as 0 struct VisitDataCudaArgMinMaxFunctor {
axis = 0; const Context& dev_ctx;
} else { const DenseTensor& x;
input_dims = input->dims(); int64_t axis;
if (axis < 0) axis += input->dims().size(); bool keepdims;
} bool flatten;
DenseTensor* out;
int64_t numel = input->numel();
int64_t groups = numel / input_dims[axis]; explicit VisitDataCudaArgMinMaxFunctor(const Context& dev_ctx,
int64_t pre = 1; const DenseTensor& x,
int64_t post = 1; int64_t axis,
int64_t n = input_dims[axis]; bool keepdims,
bool flatten,
for (int i = 0; i < axis; i++) { DenseTensor* out)
pre *= input_dims[i]; : dev_ctx(dev_ctx),
} x(x),
axis(axis),
for (int i = axis + 1; i < input_dims.size(); i++) { keepdims(keepdims),
post *= input_dims[i]; flatten(flatten),
} out(out) {}
const auto& dev_ctx = ctx.cuda_device_context(); template <typename IndType>
ComputeFullArg<T, IndType, Reducer>(dev_ctx, *input, output, pre, post, n); void apply() const {
} phi::DDim x_dims;
}; int new_axis = axis;
template <typename T, class Reducer> if (flatten) {
class ArgMinMaxOpCUDAKernel : public framework::OpKernel<T> { x_dims = phi::make_ddim({x.numel()});
public: // if flatten, the axis just as 0
void Compute(const framework::ExecutionContext& ctx) const override { new_axis = 0;
auto& dtype = ctx.Attr<int>("dtype"); } else {
if (dtype < 0) { x_dims = x.dims();
framework::VisitDataTypeTiny( if (axis < 0) new_axis = axis + x.dims().size();
static_cast<framework::proto::VarType::Type>( }
framework::proto::VarType::INT64),
VisitDataCudaArgMinMaxFunctor<T, Reducer>(ctx)); int64_t numel = x.numel();
return; int64_t groups = numel / x_dims[new_axis];
} int64_t pre = 1;
framework::VisitDataTypeTiny( int64_t post = 1;
static_cast<framework::proto::VarType::Type>(dtype), int64_t n = x_dims[new_axis];
VisitDataCudaArgMinMaxFunctor<T, Reducer>(ctx));
} for (int i = 0; i < new_axis; i++) {
}; pre *= x_dims[i];
}
#endif
for (int i = new_axis + 1; i < x_dims.size(); i++) {
} // namespace operators post *= x_dims[i];
} // namespace paddle }
ComputeFullArg<T, IndType, Reducer>(dev_ctx, x, out, pre, post, n);
}
};
template <typename Context, typename T, class Reducer>
void ArgMinMaxOpCUDAKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
int dtype,
DenseTensor* out) {
if (dtype < 0) {
paddle::framework::VisitDataTypeTiny(
static_cast<paddle::framework::proto::VarType::Type>(
paddle::framework::proto::VarType::INT64),
VisitDataCudaArgMinMaxFunctor<Context, T, Reducer>(
dev_ctx, x, axis, keepdims, flatten, out));
return;
}
paddle::framework::VisitDataTypeTiny(
static_cast<paddle::framework::proto::VarType::Type>(dtype),
VisitDataCudaArgMinMaxFunctor<Context, T, Reducer>(
dev_ctx, x, axis, keepdims, flatten, out));
}
template <typename T, typename Context>
void ArgMinKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
int dtype,
DenseTensor* out) {
ArgMinMaxOpCUDAKernel<Context, T, cub::ArgMin>(
dev_ctx, x, axis, keepdims, flatten, dtype, out);
}
template <typename T, typename Context>
void ArgMaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
int dtype,
DenseTensor* out) {
ArgMinMaxOpCUDAKernel<Context, T, cub::ArgMax>(
dev_ctx, x, axis, keepdims, flatten, dtype, out);
}
#endif
} // namespace phi
PD_REGISTER_KERNEL(arg_min,
GPU,
ALL_LAYOUT,
phi::ArgMinKernel,
float,
double,
int32_t,
int64_t,
int16_t,
uint8_t) {}
PD_REGISTER_KERNEL(arg_max,
GPU,
ALL_LAYOUT,
phi::ArgMaxKernel,
float,
double,
int32_t,
int64_t,
int16_t,
uint8_t) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册