未验证 提交 f25dba0a 编写于 作者: Z Zhong Hui 提交者: GitHub

[PHI] Move arg min max to PHI. (#40222)

* move arg min max to phi.

* move infermeta.

* fix as reviews.
上级 1128db30
......@@ -15,23 +15,19 @@ limitations under the License. */
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/arg_min_max_op_base.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
DECLARE_INFER_SHAPE_FUNCTOR(arg_max, ArgMaxInferShapeFunctor,
PD_INFER_META(phi::ArgMinMaxInferMeta));
REGISTER_OPERATOR(
arg_max, paddle::operators::ArgMinMaxOp, paddle::operators::ArgMaxOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
arg_max,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, float>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, double>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
int64_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
int32_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
int16_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
uint8_t>);
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ArgMaxInferShapeFunctor);
REGISTER_OP_VERSION(arg_max)
.AddCheckpoint(
R"ROC(
......
......@@ -27,193 +27,9 @@ limitations under the License. */
namespace paddle {
namespace operators {
enum ArgMinMaxType { kArgMin, kArgMax };
template <typename DeviceContext, typename T, typename Tout, int64_t Rank,
ArgMinMaxType argMinMaxValue>
struct ArgMinMaxFunctor {};
#define DECLARE_ARG_MIN_MAX_FUNCTOR(eigen_op_type, enum_argminmax_value) \
template <typename DeviceContext, typename T, typename Tout, int64_t Rank> \
struct ArgMinMaxFunctor<DeviceContext, T, Tout, Rank, \
enum_argminmax_value> { \
void operator()(const DeviceContext& ctx, const framework::LoDTensor& in, \
framework::LoDTensor* out, framework::DDim x_dims, \
int64_t axis, bool keepdims) { \
auto in_eigen = framework::EigenTensor<T, Rank>::From(in, x_dims); \
if (keepdims) { \
auto out_eigen = framework::EigenTensor<Tout, Rank>::From(*out); \
out_eigen.device(*(ctx.eigen_device())) = \
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
} else { \
auto out_eigen = framework::EigenTensor<Tout, Rank - 1>::From(*out); \
out_eigen.device(*(ctx.eigen_device())) = \
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
} \
} \
}
DECLARE_ARG_MIN_MAX_FUNCTOR(argmin, ArgMinMaxType::kArgMin);
DECLARE_ARG_MIN_MAX_FUNCTOR(argmax, ArgMinMaxType::kArgMax);
template <typename DeviceContext, typename T, ArgMinMaxType EnumArgMinMaxValue>
struct VisitDataArgMinMaxFunctor {
const framework::ExecutionContext& ctx;
explicit VisitDataArgMinMaxFunctor(const framework::ExecutionContext& ctx)
: ctx(ctx) {}
template <typename Tout>
void apply() const {
auto& x = *(ctx.Input<framework::LoDTensor>("X"));
auto& out = *(ctx.Output<framework::LoDTensor>("Out"));
out.template mutable_data<Tout>(ctx.GetPlace());
auto axis = ctx.Attr<int64_t>("axis");
auto keepdims = ctx.Attr<bool>("keepdims");
const bool& flatten = ctx.Attr<bool>("flatten");
// paddle do not have the scalar tensor, just return the shape [1] tensor
if (flatten) keepdims = true;
// if flatten, will construct the new dims for the cacluate
framework::DDim x_dims;
if (flatten) {
x_dims = phi::make_ddim({x.numel()});
// if flatten, the axis just as 0
axis = 0;
} else {
x_dims = x.dims();
if (axis < 0) axis += x_dims.size();
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
#define CALL_ARG_MINMAX_FUNCTOR(rank) \
ArgMinMaxFunctor<DeviceContext, T, Tout, rank, EnumArgMinMaxValue> \
functor##rank; \
functor##rank(dev_ctx, x, &out, x_dims, axis, keepdims)
switch (x_dims.size()) {
case 1:
CALL_ARG_MINMAX_FUNCTOR(1);
break;
case 2:
CALL_ARG_MINMAX_FUNCTOR(2);
break;
case 3:
CALL_ARG_MINMAX_FUNCTOR(3);
break;
case 4:
CALL_ARG_MINMAX_FUNCTOR(4);
break;
case 5:
CALL_ARG_MINMAX_FUNCTOR(5);
break;
case 6:
CALL_ARG_MINMAX_FUNCTOR(6);
break;
default:
PADDLE_ENFORCE_LE(
x_dims.size(), 6,
platform::errors::InvalidArgument(
"%s operator doesn't supports tensors whose ranks are greater "
"than 6.",
(EnumArgMinMaxValue == kArgMin ? "argmin" : "argmax")));
break;
#undef CALL_ARG_MINMAX_FUNCTOR
}
}
};
template <typename DeviceContext, typename T, ArgMinMaxType EnumArgMinMaxValue>
class ArgMinMaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dtype = ctx.Attr<int>("dtype");
if (dtype < 0) {
framework::VisitDataTypeTiny(
static_cast<framework::proto::VarType::Type>(
framework::proto::VarType::INT64),
VisitDataArgMinMaxFunctor<DeviceContext, T, EnumArgMinMaxValue>(ctx));
return;
}
framework::VisitDataTypeTiny(
static_cast<framework::proto::VarType::Type>(dtype),
VisitDataArgMinMaxFunctor<DeviceContext, T, EnumArgMinMaxValue>(ctx));
}
};
template <typename DeviceContext, typename T>
using ArgMinKernel = ArgMinMaxKernel<DeviceContext, T, ArgMinMaxType::kArgMin>;
template <typename DeviceContext, typename T>
using ArgMaxKernel = ArgMinMaxKernel<DeviceContext, T, ArgMinMaxType::kArgMax>;
class ArgMinMaxOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "arg_min_max");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "arg_min_max");
const auto& x_dims = ctx->GetInputDim("X");
int64_t axis = ctx->Attrs().Get<int64_t>("axis");
bool keepdims = ctx->Attrs().Get<bool>("keepdims");
const bool& flatten = ctx->Attrs().Get<bool>("flatten");
PADDLE_ENFORCE_GE(axis, -x_dims.size(),
platform::errors::InvalidArgument(
"'axis'(%d) must be greater than or equal to"
" -Rank(X)(%d).",
axis, -x_dims.size()));
PADDLE_ENFORCE_LT(
axis, x_dims.size(),
platform::errors::InvalidArgument(
"'axis'(%d) must be less than Rank(X)(%d) of Input(X).", axis,
x_dims.size()));
const int& dtype = ctx->Attrs().Get<int>("dtype");
PADDLE_ENFORCE_EQ(
(dtype < 0 || dtype == 2 || dtype == 3), true,
platform::errors::InvalidArgument(
"The attribute of dtype in argmin/argmax must be [%s] or [%s], but "
"received [%s]",
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64),
paddle::framework::DataTypeToString(
static_cast<framework::proto::VarType::Type>(dtype))));
auto x_rank = x_dims.size();
if (axis < 0) axis += x_rank;
if (ctx->IsRuntime()) {
if (dtype == framework::proto::VarType::INT32) {
int64_t all_element_num = 0;
if (flatten) {
all_element_num = phi::product(x_dims);
} else {
all_element_num = x_dims[axis];
}
PADDLE_ENFORCE_LE(
all_element_num, INT_MAX,
platform::errors::InvalidArgument(
"The element num of the argmin/argmax input at axis is "
"%d, is larger than int32 maximum value:%d, you must "
"set the dtype of argmin/argmax to 'int64'.",
all_element_num, INT_MAX));
}
}
std::vector<int64_t> vec;
if (flatten) {
vec.emplace_back(static_cast<int64_t>(1));
} else {
for (int64_t i = 0; i < axis; i++) vec.emplace_back(x_dims[i]);
if (keepdims) {
vec.emplace_back(static_cast<int64_t>(1));
}
for (int64_t i = axis + 1; i < x_rank; i++) vec.emplace_back(x_dims[i]);
}
ctx->SetOutputDim("Out", phi::make_ddim(vec));
}
};
class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker {
......
......@@ -12,26 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/arg_min_max_op_base.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
DECLARE_INFER_SHAPE_FUNCTOR(arg_min, ArgMinInferShapeFunctor,
PD_INFER_META(phi::ArgMinMaxInferMeta));
REGISTER_OPERATOR(
arg_min, paddle::operators::ArgMinMaxOp, paddle::operators::ArgMinOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ArgMinInferShapeFunctor);
REGISTER_OP_CPU_KERNEL(
arg_min,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, float>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, double>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
int64_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
int32_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
int16_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
uint8_t>);
REGISTER_OP_VERSION(arg_min)
.AddCheckpoint(
R"ROC(
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/arg_min_max_op_base.cu.h"
REGISTER_OP_CUDA_KERNEL(
arg_min, paddle::operators::ArgMinMaxOpCUDAKernel<float, cub::ArgMin>,
paddle::operators::ArgMinMaxOpCUDAKernel<double, cub::ArgMin>,
paddle::operators::ArgMinMaxOpCUDAKernel<int64_t, cub::ArgMin>,
paddle::operators::ArgMinMaxOpCUDAKernel<int32_t, cub::ArgMin>,
paddle::operators::ArgMinMaxOpCUDAKernel<int8_t, cub::ArgMin>);
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <algorithm>
#include <set>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/enforce.h"
......@@ -1014,6 +1015,82 @@ void DiagInferMeta(const MetaTensor& x,
}
}
void ArgMinMaxInferMeta(const MetaTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
int dtype,
MetaTensor* out,
MetaConfig config) {
const auto& x_dims = x.dims();
PADDLE_ENFORCE_GE(
axis,
-x_dims.size(),
phi::errors::InvalidArgument("'axis'(%d) must be greater than or equal to"
" -Rank(X)(%d).",
axis,
-x_dims.size()));
PADDLE_ENFORCE_LT(axis,
x_dims.size(),
phi::errors::InvalidArgument(
"'axis'(%d) must be less than Rank(X)(%d) of Input(X).",
axis,
x_dims.size()));
PADDLE_ENFORCE_EQ(
(dtype < 0 || dtype == 2 || dtype == 3),
true,
phi::errors::InvalidArgument(
"The attribute of dtype in argmin/argmax must be [%s] or [%s], but "
"received [%s]",
paddle::framework::DataTypeToString(
paddle::framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
paddle::framework::proto::VarType::INT64),
paddle::framework::DataTypeToString(
static_cast<paddle::framework::proto::VarType::Type>(dtype))));
auto x_rank = x_dims.size();
if (axis < 0) axis += x_rank;
if (config.is_runtime) {
if (dtype == paddle::framework::proto::VarType::INT32) {
int64_t all_element_num = 0;
if (flatten) {
all_element_num = phi::product(x_dims);
} else {
all_element_num = x_dims[axis];
}
PADDLE_ENFORCE_LE(
all_element_num,
INT_MAX,
phi::errors::InvalidArgument(
"The element num of the argmin/argmax input at axis is "
"%d, is larger than int32 maximum value:%d, you must "
"set the dtype of argmin/argmax to 'int64'.",
all_element_num,
INT_MAX));
}
}
std::vector<int64_t> vec;
if (flatten) {
vec.emplace_back(static_cast<int64_t>(1));
} else {
for (int64_t i = 0; i < axis; i++) vec.emplace_back(x_dims[i]);
if (keepdims) {
vec.emplace_back(static_cast<int64_t>(1));
}
for (int64_t i = axis + 1; i < x_rank; i++) vec.emplace_back(x_dims[i]);
}
out->set_dims(phi::make_ddim(vec));
if (dtype == 2) {
out->set_dtype(DataType::INT32);
} else if (dtype == 3) {
out->set_dtype(DataType::INT64);
}
}
void SizeInferMeta(const MetaTensor& input, MetaTensor* out) {
out->set_dtype(DataType::INT64);
out->set_dims({1});
......
......@@ -147,6 +147,14 @@ void DiagInferMeta(const MetaTensor& x,
float padding_value,
MetaTensor* out);
void ArgMinMaxInferMeta(const MetaTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
int dtype,
MetaTensor* out,
MetaConfig config = MetaConfig());
void SizeInferMeta(const MetaTensor& input, MetaTensor* out);
void DiagonalInferMeta(
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/arg_min_max_op_base.cu.h"
REGISTER_OP_CUDA_KERNEL(
arg_max, paddle::operators::ArgMinMaxOpCUDAKernel<float, cub::ArgMax>,
paddle::operators::ArgMinMaxOpCUDAKernel<double, cub::ArgMax>,
paddle::operators::ArgMinMaxOpCUDAKernel<int64_t, cub::ArgMax>,
paddle::operators::ArgMinMaxOpCUDAKernel<int32_t, cub::ArgMax>,
paddle::operators::ArgMinMaxOpCUDAKernel<int8_t, cub::ArgMax>);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void ArgMinKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
int dtype,
DenseTensor* out);
template <typename T, typename Context>
void ArgMaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
int dtype,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/arg_min_max_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
enum ArgMinMaxType { kArgMin, kArgMax };
template <typename Context,
typename T,
typename Tout,
int64_t Rank,
ArgMinMaxType argMinMaxValue>
struct ArgMinMaxFunctor {};
#define DECLARE_ARG_MIN_MAX_FUNCTOR(eigen_op_type, enum_argminmax_value) \
template <typename Context, typename T, typename Tout, int64_t Rank> \
struct ArgMinMaxFunctor<Context, T, Tout, Rank, enum_argminmax_value> { \
void operator()(const Context& dev_ctx, \
const DenseTensor& in, \
DenseTensor* out, \
phi::DDim x_dims, \
int64_t axis, \
bool keepdims) { \
auto in_eigen = EigenTensor<T, Rank>::From(in, x_dims); \
if (keepdims) { \
auto out_eigen = EigenTensor<Tout, Rank>::From(*out); \
out_eigen.device(*(dev_ctx.eigen_device())) = \
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
} else { \
auto out_eigen = EigenTensor<Tout, Rank - 1>::From(*out); \
out_eigen.device(*(dev_ctx.eigen_device())) = \
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
} \
} \
}
DECLARE_ARG_MIN_MAX_FUNCTOR(argmin, ArgMinMaxType::kArgMin);
DECLARE_ARG_MIN_MAX_FUNCTOR(argmax, ArgMinMaxType::kArgMax);
template <typename Context, typename T, ArgMinMaxType EnumArgMinMaxValue>
struct VisitDataArgMinMaxFunctor {
const Context& dev_ctx;
const DenseTensor& x;
int64_t axis;
bool keepdims;
bool flatten;
DenseTensor* out;
explicit VisitDataArgMinMaxFunctor(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
DenseTensor* out)
: dev_ctx(dev_ctx),
x(x),
axis(axis),
keepdims(keepdims),
flatten(flatten),
out(out) {}
template <typename Tout>
void apply() const {
dev_ctx.template Alloc<Tout>(out);
bool new_keepdims = keepdims;
if (flatten) new_keepdims = true;
// if flatten, will construct the new dims for the cacluate
phi::DDim x_dims;
int new_axis = axis;
if (flatten) {
x_dims = phi::make_ddim({x.numel()});
// if flatten, the axis just as 0
new_axis = 0;
} else {
x_dims = x.dims();
if (axis < 0) new_axis = axis + x_dims.size();
}
#define CALL_ARG_MINMAX_FUNCTOR(rank) \
ArgMinMaxFunctor<Context, T, Tout, rank, EnumArgMinMaxValue> functor##rank; \
functor##rank(dev_ctx, x, out, x_dims, new_axis, new_keepdims)
switch (x_dims.size()) {
case 1:
CALL_ARG_MINMAX_FUNCTOR(1);
break;
case 2:
CALL_ARG_MINMAX_FUNCTOR(2);
break;
case 3:
CALL_ARG_MINMAX_FUNCTOR(3);
break;
case 4:
CALL_ARG_MINMAX_FUNCTOR(4);
break;
case 5:
CALL_ARG_MINMAX_FUNCTOR(5);
break;
case 6:
CALL_ARG_MINMAX_FUNCTOR(6);
break;
default:
PADDLE_ENFORCE_LE(
x_dims.size(),
6,
phi::errors::InvalidArgument(
"%s operator doesn't supports tensors whose ranks are greater "
"than 6.",
(EnumArgMinMaxValue == kArgMin ? "argmin" : "argmax")));
break;
#undef CALL_ARG_MINMAX_FUNCTOR
}
}
};
template <typename Context, typename T, ArgMinMaxType EnumArgMinMaxValue>
void ArgMinMaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
int dtype,
DenseTensor* out) {
if (dtype < 0) {
paddle::framework::VisitDataTypeTiny(
static_cast<paddle::framework::proto::VarType::Type>(
paddle::framework::proto::VarType::INT64),
VisitDataArgMinMaxFunctor<Context, T, EnumArgMinMaxValue>(
dev_ctx, x, axis, keepdims, flatten, out));
return;
}
paddle::framework::VisitDataTypeTiny(
static_cast<paddle::framework::proto::VarType::Type>(dtype),
VisitDataArgMinMaxFunctor<Context, T, EnumArgMinMaxValue>(
dev_ctx, x, axis, keepdims, flatten, out));
}
template <typename T, typename Context>
void ArgMinKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
int dtype,
DenseTensor* out) {
ArgMinMaxKernel<Context, T, ArgMinMaxType::kArgMin>(
dev_ctx, x, axis, keepdims, flatten, dtype, out);
}
template <typename T, typename Context>
void ArgMaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
int dtype,
DenseTensor* out) {
ArgMinMaxKernel<Context, T, ArgMinMaxType::kArgMax>(
dev_ctx, x, axis, keepdims, flatten, dtype, out);
}
} // namespace phi
PD_REGISTER_KERNEL(arg_min,
CPU,
ALL_LAYOUT,
phi::ArgMinKernel,
float,
double,
int32_t,
int64_t,
int16_t,
uint8_t) {}
PD_REGISTER_KERNEL(arg_max,
CPU,
ALL_LAYOUT,
phi::ArgMaxKernel,
float,
double,
int32_t,
int64_t,
int16_t,
uint8_t) {}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#if defined(__NVCC__) || defined(__HIPCC__)
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include <limits>
#include <string>
#include <typeinfo>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/core/ddim.h"
namespace paddle {
namespace operators {
namespace { // NOLINT
template <typename K, typename V>
using KeyValuePair = cub::KeyValuePair<K, V>;
using Tensor = framework::Tensor;
} // end namespace
#define FIXED_BLOCK_DIM_CASE_BASE(log2_block_dim, ...) \
case (1 << (log2_block_dim)): { \
constexpr auto kBlockDim = (1 << (log2_block_dim)); \
__VA_ARGS__; \
} break
#define FIXED_BLOCK_DIM_CASE(...) \
FIXED_BLOCK_DIM_CASE_BASE(10, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(9, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(8, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(7, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(6, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(5, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(4, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__);
template <typename T, typename IndType, class Reducer, size_t BlockDim>
__global__ void ArgCUDAKernel(const int64_t height, // n * h
const int64_t width, // c
const int64_t post_size, // h
const Reducer reducer, const T init, const T* in,
IndType* out) {
typedef cub::BlockReduce<KeyValuePair<int, T>, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
for (int idx = blockIdx.x; idx < height; idx += gridDim.x) {
KeyValuePair<int, T> kv_pair = {-1, init};
int h = idx / post_size;
int w = idx % post_size;
for (int k = threadIdx.x; k < width; k += blockDim.x) {
kv_pair =
reducer({k, in[h * width * post_size + k * post_size + w]}, kv_pair);
}
kv_pair = BlockReduce(temp_storage).Reduce(kv_pair, reducer);
if (threadIdx.x == 0) {
out[idx] = static_cast<IndType>(kv_pair.key);
}
__syncthreads();
}
}
template <typename T, typename IndType, class Reducer>
void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input,
Tensor* indices, const int64_t pre, const int64_t post,
const int64_t n) {
auto cu_stream = ctx.stream();
auto ComputeBlockSize = [](int64_t col) {
auto block_size = 8;
if (col > 512)
block_size = 1024;
else if (col > 256)
block_size = 512;
else if (col > 128)
block_size = 256;
else if (col > 64)
block_size = 128;
else if (col > 32)
block_size = 64;
else if (col > 16)
block_size = 32;
else if (col > 8)
block_size = 16;
#ifdef __HIPCC__
block_size = std::min(block_size, 256);
#endif
return block_size;
};
int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize()[0];
int64_t height = pre * post;
int64_t width = n;
int64_t grid_size = height < max_grid_dimx ? height : max_grid_dimx;
const T* in_data = input.data<T>();
IndType* out_data = indices->mutable_data<IndType>(ctx.GetPlace());
if (typeid(Reducer) == typeid(cub::ArgMax)) {
switch (ComputeBlockSize(width)) {
FIXED_BLOCK_DIM_CASE(
ArgCUDAKernel<T, IndType, Reducer,
kBlockDim><<<grid_size, kBlockDim, 0, cu_stream>>>(
height, width, post, Reducer(), std::numeric_limits<T>::lowest(),
in_data, out_data));
}
} else {
switch (ComputeBlockSize(width)) {
FIXED_BLOCK_DIM_CASE(
ArgCUDAKernel<T, IndType, Reducer,
kBlockDim><<<grid_size, kBlockDim, 0, cu_stream>>>(
height, width, post, Reducer(), std::numeric_limits<T>::max(),
in_data, out_data));
}
}
}
template <typename T, class Reducer>
struct VisitDataCudaArgMinMaxFunctor {
const framework::ExecutionContext& ctx;
explicit VisitDataCudaArgMinMaxFunctor(const framework::ExecutionContext& ctx)
: ctx(ctx) {}
template <typename IndType>
void apply() const {
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
int axis = ctx.Attr<int64_t>("axis");
const bool& flatten = ctx.Attr<bool>("flatten");
framework::DDim input_dims;
if (flatten) {
input_dims = phi::make_ddim({input->numel()});
// if flatten, the axis just as 0
axis = 0;
} else {
input_dims = input->dims();
if (axis < 0) axis += input->dims().size();
}
int64_t numel = input->numel();
int64_t groups = numel / input_dims[axis];
int64_t pre = 1;
int64_t post = 1;
int64_t n = input_dims[axis];
for (int i = 0; i < axis; i++) {
pre *= input_dims[i];
}
for (int i = axis + 1; i < input_dims.size(); i++) {
post *= input_dims[i];
}
const auto& dev_ctx = ctx.cuda_device_context();
ComputeFullArg<T, IndType, Reducer>(dev_ctx, *input, output, pre, post, n);
}
};
template <typename T, class Reducer>
class ArgMinMaxOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dtype = ctx.Attr<int>("dtype");
if (dtype < 0) {
framework::VisitDataTypeTiny(
static_cast<framework::proto::VarType::Type>(
framework::proto::VarType::INT64),
VisitDataCudaArgMinMaxFunctor<T, Reducer>(ctx));
return;
}
framework::VisitDataTypeTiny(
static_cast<framework::proto::VarType::Type>(dtype),
VisitDataCudaArgMinMaxFunctor<T, Reducer>(ctx));
}
};
#endif
} // namespace operators
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/arg_min_max_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include <limits>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/core/ddim.h"
namespace phi {
namespace { // NOLINT
template <typename K, typename V>
using KeyValuePair = cub::KeyValuePair<K, V>;
} // end namespace
#define FIXED_BLOCK_DIM_CASE_BASE(log2_block_dim, ...) \
case (1 << (log2_block_dim)): { \
constexpr auto kBlockDim = (1 << (log2_block_dim)); \
__VA_ARGS__; \
} break
#define FIXED_BLOCK_DIM_CASE(...) \
FIXED_BLOCK_DIM_CASE_BASE(10, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(9, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(8, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(7, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(6, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(5, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(4, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__);
template <typename T, typename IndType, class Reducer, size_t BlockDim>
__global__ void ArgCUDAKernel(const int64_t height, // n * h
const int64_t width, // c
const int64_t post_size, // h
const Reducer reducer,
const T init,
const T* in,
IndType* out) {
typedef cub::BlockReduce<KeyValuePair<int, T>, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
for (int idx = blockIdx.x; idx < height; idx += gridDim.x) {
KeyValuePair<int, T> kv_pair = {-1, init};
int h = idx / post_size;
int w = idx % post_size;
for (int k = threadIdx.x; k < width; k += blockDim.x) {
kv_pair =
reducer({k, in[h * width * post_size + k * post_size + w]}, kv_pair);
}
kv_pair = BlockReduce(temp_storage).Reduce(kv_pair, reducer);
if (threadIdx.x == 0) {
out[idx] = static_cast<IndType>(kv_pair.key);
}
__syncthreads();
}
}
template <typename T, typename IndType, class Reducer>
void ComputeFullArg(const phi::GPUContext& dev_ctx,
const DenseTensor& input,
DenseTensor* indices,
const int64_t pre,
const int64_t post,
const int64_t n) {
auto cu_stream = dev_ctx.stream();
auto ComputeBlockSize = [](int64_t col) {
auto block_size = 8;
if (col > 512)
block_size = 1024;
else if (col > 256)
block_size = 512;
else if (col > 128)
block_size = 256;
else if (col > 64)
block_size = 128;
else if (col > 32)
block_size = 64;
else if (col > 16)
block_size = 32;
else if (col > 8)
block_size = 16;
#ifdef __HIPCC__
block_size = std::min(block_size, 256);
#endif
return block_size;
};
int64_t max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize()[0];
int64_t height = pre * post;
int64_t width = n;
int64_t grid_size = height < max_grid_dimx ? height : max_grid_dimx;
const T* in_data = input.data<T>();
IndType* out_data = dev_ctx.template Alloc<IndType>(indices);
if (typeid(Reducer) == typeid(cub::ArgMax)) {
switch (ComputeBlockSize(width)) {
FIXED_BLOCK_DIM_CASE(
ArgCUDAKernel<T,
IndType,
Reducer,
kBlockDim><<<grid_size, kBlockDim, 0, cu_stream>>>(
height,
width,
post,
Reducer(),
std::numeric_limits<T>::lowest(),
in_data,
out_data));
}
} else {
switch (ComputeBlockSize(width)) {
FIXED_BLOCK_DIM_CASE(
ArgCUDAKernel<T,
IndType,
Reducer,
kBlockDim><<<grid_size, kBlockDim, 0, cu_stream>>>(
height,
width,
post,
Reducer(),
std::numeric_limits<T>::max(),
in_data,
out_data));
}
}
}
template <typename Context, typename T, class Reducer>
struct VisitDataCudaArgMinMaxFunctor {
const Context& dev_ctx;
const DenseTensor& x;
int64_t axis;
bool keepdims;
bool flatten;
DenseTensor* out;
explicit VisitDataCudaArgMinMaxFunctor(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
DenseTensor* out)
: dev_ctx(dev_ctx),
x(x),
axis(axis),
keepdims(keepdims),
flatten(flatten),
out(out) {}
template <typename IndType>
void apply() const {
phi::DDim x_dims;
int new_axis = axis;
if (flatten) {
x_dims = phi::make_ddim({x.numel()});
// if flatten, the axis just as 0
new_axis = 0;
} else {
x_dims = x.dims();
if (axis < 0) new_axis = axis + x.dims().size();
}
int64_t numel = x.numel();
int64_t groups = numel / x_dims[new_axis];
int64_t pre = 1;
int64_t post = 1;
int64_t n = x_dims[new_axis];
for (int i = 0; i < new_axis; i++) {
pre *= x_dims[i];
}
for (int i = new_axis + 1; i < x_dims.size(); i++) {
post *= x_dims[i];
}
ComputeFullArg<T, IndType, Reducer>(dev_ctx, x, out, pre, post, n);
}
};
template <typename Context, typename T, class Reducer>
void ArgMinMaxOpCUDAKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
int dtype,
DenseTensor* out) {
if (dtype < 0) {
paddle::framework::VisitDataTypeTiny(
static_cast<paddle::framework::proto::VarType::Type>(
paddle::framework::proto::VarType::INT64),
VisitDataCudaArgMinMaxFunctor<Context, T, Reducer>(
dev_ctx, x, axis, keepdims, flatten, out));
return;
}
paddle::framework::VisitDataTypeTiny(
static_cast<paddle::framework::proto::VarType::Type>(dtype),
VisitDataCudaArgMinMaxFunctor<Context, T, Reducer>(
dev_ctx, x, axis, keepdims, flatten, out));
}
template <typename T, typename Context>
void ArgMinKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
int dtype,
DenseTensor* out) {
ArgMinMaxOpCUDAKernel<Context, T, cub::ArgMin>(
dev_ctx, x, axis, keepdims, flatten, dtype, out);
}
template <typename T, typename Context>
void ArgMaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
int dtype,
DenseTensor* out) {
ArgMinMaxOpCUDAKernel<Context, T, cub::ArgMax>(
dev_ctx, x, axis, keepdims, flatten, dtype, out);
}
#endif
} // namespace phi
PD_REGISTER_KERNEL(arg_min,
GPU,
ALL_LAYOUT,
phi::ArgMinKernel,
float,
double,
int32_t,
int64_t,
int16_t,
uint8_t) {}
PD_REGISTER_KERNEL(arg_max,
GPU,
ALL_LAYOUT,
phi::ArgMaxKernel,
float,
double,
int32_t,
int64_t,
int16_t,
uint8_t) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册