diff --git a/paddle/fluid/operators/arg_max_op.cc b/paddle/fluid/operators/arg_max_op.cc index 0f5c048b6be9c73ae98181685269592f409196cd..c5e4188ca2d6f749a06127c41da99490a7fb3ffc 100644 --- a/paddle/fluid/operators/arg_max_op.cc +++ b/paddle/fluid/operators/arg_max_op.cc @@ -15,23 +15,19 @@ limitations under the License. */ #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/operators/arg_min_max_op_base.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" + +DECLARE_INFER_SHAPE_FUNCTOR(arg_max, ArgMaxInferShapeFunctor, + PD_INFER_META(phi::ArgMinMaxInferMeta)); + REGISTER_OPERATOR( arg_max, paddle::operators::ArgMinMaxOp, paddle::operators::ArgMaxOpMaker, paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); - -REGISTER_OP_CPU_KERNEL( - arg_max, - paddle::operators::ArgMaxKernel, - paddle::operators::ArgMaxKernel, - paddle::operators::ArgMaxKernel, - paddle::operators::ArgMaxKernel, - paddle::operators::ArgMaxKernel, - paddle::operators::ArgMaxKernel); + paddle::framework::EmptyGradOpMaker, + ArgMaxInferShapeFunctor); + REGISTER_OP_VERSION(arg_max) .AddCheckpoint( R"ROC( diff --git a/paddle/fluid/operators/arg_max_op.cu b/paddle/fluid/operators/arg_max_op.cu deleted file mode 100644 index 14708c4df10f5160d0e72e7669e0015554d8215f..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/arg_max_op.cu +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/arg_min_max_op_base.cu.h" - -REGISTER_OP_CUDA_KERNEL( - arg_max, paddle::operators::ArgMinMaxOpCUDAKernel, - paddle::operators::ArgMinMaxOpCUDAKernel, - paddle::operators::ArgMinMaxOpCUDAKernel, - paddle::operators::ArgMinMaxOpCUDAKernel, - paddle::operators::ArgMinMaxOpCUDAKernel); diff --git a/paddle/fluid/operators/arg_min_max_op_base.cu.h b/paddle/fluid/operators/arg_min_max_op_base.cu.h deleted file mode 100644 index b77031f7fb4c9d94f30ed06333b9c8766fd2310d..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/arg_min_max_op_base.cu.h +++ /dev/null @@ -1,202 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#if defined(__NVCC__) || defined(__HIPCC__) - -#ifdef __NVCC__ -#include "cub/cub.cuh" -#endif -#ifdef __HIPCC__ -#include -namespace cub = hipcub; -#endif -#include -#include -#include -#include -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/operators/transpose_op.h" -#include "paddle/fluid/platform/device_context.h" -#include "paddle/phi/core/ddim.h" - -namespace paddle { -namespace operators { - -namespace { // NOLINT -template -using KeyValuePair = cub::KeyValuePair; -using Tensor = framework::Tensor; - -} // end namespace - -#define FIXED_BLOCK_DIM_CASE_BASE(log2_block_dim, ...) \ - case (1 << (log2_block_dim)): { \ - constexpr auto kBlockDim = (1 << (log2_block_dim)); \ - __VA_ARGS__; \ - } break - -#define FIXED_BLOCK_DIM_CASE(...) \ - FIXED_BLOCK_DIM_CASE_BASE(10, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_CASE_BASE(9, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_CASE_BASE(8, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_CASE_BASE(7, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_CASE_BASE(6, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_CASE_BASE(5, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_CASE_BASE(4, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__); - -template -__global__ void ArgCUDAKernel(const int64_t height, // n * h - const int64_t width, // c - const int64_t post_size, // h - const Reducer reducer, const T init, const T* in, - IndType* out) { - typedef cub::BlockReduce, BlockDim> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - - for (int idx = blockIdx.x; idx < height; idx += gridDim.x) { - KeyValuePair kv_pair = {-1, init}; - int h = idx / post_size; - int w = idx % post_size; - for (int k = threadIdx.x; k < width; k += blockDim.x) { - kv_pair = - reducer({k, in[h * width * post_size + k * post_size + w]}, kv_pair); - } - kv_pair = BlockReduce(temp_storage).Reduce(kv_pair, reducer); - if (threadIdx.x == 0) { - out[idx] = static_cast(kv_pair.key); - } - __syncthreads(); - } -} - -template -void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input, - Tensor* indices, const int64_t pre, const int64_t post, - const int64_t n) { - auto cu_stream = ctx.stream(); - auto ComputeBlockSize = [](int64_t col) { - auto block_size = 8; - if (col > 512) - block_size = 1024; - else if (col > 256) - block_size = 512; - else if (col > 128) - block_size = 256; - else if (col > 64) - block_size = 128; - else if (col > 32) - block_size = 64; - else if (col > 16) - block_size = 32; - else if (col > 8) - block_size = 16; -#ifdef __HIPCC__ - block_size = std::min(block_size, 256); -#endif - return block_size; - }; - - int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize()[0]; - int64_t height = pre * post; - int64_t width = n; - int64_t grid_size = height < max_grid_dimx ? height : max_grid_dimx; - - const T* in_data = input.data(); - IndType* out_data = indices->mutable_data(ctx.GetPlace()); - - if (typeid(Reducer) == typeid(cub::ArgMax)) { - switch (ComputeBlockSize(width)) { - FIXED_BLOCK_DIM_CASE( - ArgCUDAKernel<<>>( - height, width, post, Reducer(), std::numeric_limits::lowest(), - in_data, out_data)); - } - } else { - switch (ComputeBlockSize(width)) { - FIXED_BLOCK_DIM_CASE( - ArgCUDAKernel<<>>( - height, width, post, Reducer(), std::numeric_limits::max(), - in_data, out_data)); - } - } -} - -template -struct VisitDataCudaArgMinMaxFunctor { - const framework::ExecutionContext& ctx; - - explicit VisitDataCudaArgMinMaxFunctor(const framework::ExecutionContext& ctx) - : ctx(ctx) {} - template - void apply() const { - auto* input = ctx.Input("X"); - auto* output = ctx.Output("Out"); - int axis = ctx.Attr("axis"); - const bool& flatten = ctx.Attr("flatten"); - - framework::DDim input_dims; - if (flatten) { - input_dims = phi::make_ddim({input->numel()}); - // if flatten, the axis just as 0 - axis = 0; - } else { - input_dims = input->dims(); - if (axis < 0) axis += input->dims().size(); - } - - int64_t numel = input->numel(); - int64_t groups = numel / input_dims[axis]; - int64_t pre = 1; - int64_t post = 1; - int64_t n = input_dims[axis]; - - for (int i = 0; i < axis; i++) { - pre *= input_dims[i]; - } - - for (int i = axis + 1; i < input_dims.size(); i++) { - post *= input_dims[i]; - } - - const auto& dev_ctx = ctx.cuda_device_context(); - ComputeFullArg(dev_ctx, *input, output, pre, post, n); - } -}; -template -class ArgMinMaxOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto& dtype = ctx.Attr("dtype"); - if (dtype < 0) { - framework::VisitDataTypeTiny( - static_cast( - framework::proto::VarType::INT64), - VisitDataCudaArgMinMaxFunctor(ctx)); - return; - } - framework::VisitDataTypeTiny( - static_cast(dtype), - VisitDataCudaArgMinMaxFunctor(ctx)); - } -}; - -#endif - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/arg_min_max_op_base.h b/paddle/fluid/operators/arg_min_max_op_base.h index d3ce61d183a3d322e40966ce59f9a10320ceab4f..585341beea12c14fbd01a3a47af34ce57def0db5 100644 --- a/paddle/fluid/operators/arg_min_max_op_base.h +++ b/paddle/fluid/operators/arg_min_max_op_base.h @@ -27,193 +27,9 @@ limitations under the License. */ namespace paddle { namespace operators { -enum ArgMinMaxType { kArgMin, kArgMax }; - -template -struct ArgMinMaxFunctor {}; - -#define DECLARE_ARG_MIN_MAX_FUNCTOR(eigen_op_type, enum_argminmax_value) \ - template \ - struct ArgMinMaxFunctor { \ - void operator()(const DeviceContext& ctx, const framework::LoDTensor& in, \ - framework::LoDTensor* out, framework::DDim x_dims, \ - int64_t axis, bool keepdims) { \ - auto in_eigen = framework::EigenTensor::From(in, x_dims); \ - if (keepdims) { \ - auto out_eigen = framework::EigenTensor::From(*out); \ - out_eigen.device(*(ctx.eigen_device())) = \ - in_eigen.eigen_op_type(axis).template cast(); \ - } else { \ - auto out_eigen = framework::EigenTensor::From(*out); \ - out_eigen.device(*(ctx.eigen_device())) = \ - in_eigen.eigen_op_type(axis).template cast(); \ - } \ - } \ - } - -DECLARE_ARG_MIN_MAX_FUNCTOR(argmin, ArgMinMaxType::kArgMin); -DECLARE_ARG_MIN_MAX_FUNCTOR(argmax, ArgMinMaxType::kArgMax); - -template -struct VisitDataArgMinMaxFunctor { - const framework::ExecutionContext& ctx; - - explicit VisitDataArgMinMaxFunctor(const framework::ExecutionContext& ctx) - : ctx(ctx) {} - template - void apply() const { - auto& x = *(ctx.Input("X")); - auto& out = *(ctx.Output("Out")); - out.template mutable_data(ctx.GetPlace()); - auto axis = ctx.Attr("axis"); - auto keepdims = ctx.Attr("keepdims"); - const bool& flatten = ctx.Attr("flatten"); - // paddle do not have the scalar tensor, just return the shape [1] tensor - if (flatten) keepdims = true; - - // if flatten, will construct the new dims for the cacluate - framework::DDim x_dims; - if (flatten) { - x_dims = phi::make_ddim({x.numel()}); - // if flatten, the axis just as 0 - axis = 0; - } else { - x_dims = x.dims(); - if (axis < 0) axis += x_dims.size(); - } - auto& dev_ctx = ctx.template device_context(); - -#define CALL_ARG_MINMAX_FUNCTOR(rank) \ - ArgMinMaxFunctor \ - functor##rank; \ - functor##rank(dev_ctx, x, &out, x_dims, axis, keepdims) - - switch (x_dims.size()) { - case 1: - CALL_ARG_MINMAX_FUNCTOR(1); - break; - case 2: - CALL_ARG_MINMAX_FUNCTOR(2); - break; - case 3: - CALL_ARG_MINMAX_FUNCTOR(3); - break; - case 4: - CALL_ARG_MINMAX_FUNCTOR(4); - break; - case 5: - CALL_ARG_MINMAX_FUNCTOR(5); - break; - case 6: - CALL_ARG_MINMAX_FUNCTOR(6); - break; - default: - PADDLE_ENFORCE_LE( - x_dims.size(), 6, - platform::errors::InvalidArgument( - "%s operator doesn't supports tensors whose ranks are greater " - "than 6.", - (EnumArgMinMaxValue == kArgMin ? "argmin" : "argmax"))); - break; -#undef CALL_ARG_MINMAX_FUNCTOR - } - } -}; - -template -class ArgMinMaxKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto& dtype = ctx.Attr("dtype"); - if (dtype < 0) { - framework::VisitDataTypeTiny( - static_cast( - framework::proto::VarType::INT64), - VisitDataArgMinMaxFunctor(ctx)); - return; - } - framework::VisitDataTypeTiny( - static_cast(dtype), - VisitDataArgMinMaxFunctor(ctx)); - } -}; - -template -using ArgMinKernel = ArgMinMaxKernel; - -template -using ArgMaxKernel = ArgMinMaxKernel; - class ArgMinMaxOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "arg_min_max"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "arg_min_max"); - const auto& x_dims = ctx->GetInputDim("X"); - int64_t axis = ctx->Attrs().Get("axis"); - bool keepdims = ctx->Attrs().Get("keepdims"); - const bool& flatten = ctx->Attrs().Get("flatten"); - - PADDLE_ENFORCE_GE(axis, -x_dims.size(), - platform::errors::InvalidArgument( - "'axis'(%d) must be greater than or equal to" - " -Rank(X)(%d).", - axis, -x_dims.size())); - PADDLE_ENFORCE_LT( - axis, x_dims.size(), - platform::errors::InvalidArgument( - "'axis'(%d) must be less than Rank(X)(%d) of Input(X).", axis, - x_dims.size())); - - const int& dtype = ctx->Attrs().Get("dtype"); - PADDLE_ENFORCE_EQ( - (dtype < 0 || dtype == 2 || dtype == 3), true, - platform::errors::InvalidArgument( - "The attribute of dtype in argmin/argmax must be [%s] or [%s], but " - "received [%s]", - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64), - paddle::framework::DataTypeToString( - static_cast(dtype)))); - - auto x_rank = x_dims.size(); - if (axis < 0) axis += x_rank; - if (ctx->IsRuntime()) { - if (dtype == framework::proto::VarType::INT32) { - int64_t all_element_num = 0; - if (flatten) { - all_element_num = phi::product(x_dims); - - } else { - all_element_num = x_dims[axis]; - } - PADDLE_ENFORCE_LE( - all_element_num, INT_MAX, - platform::errors::InvalidArgument( - "The element num of the argmin/argmax input at axis is " - "%d, is larger than int32 maximum value:%d, you must " - "set the dtype of argmin/argmax to 'int64'.", - all_element_num, INT_MAX)); - } - } - std::vector vec; - if (flatten) { - vec.emplace_back(static_cast(1)); - } else { - for (int64_t i = 0; i < axis; i++) vec.emplace_back(x_dims[i]); - if (keepdims) { - vec.emplace_back(static_cast(1)); - } - for (int64_t i = axis + 1; i < x_rank; i++) vec.emplace_back(x_dims[i]); - } - ctx->SetOutputDim("Out", phi::make_ddim(vec)); - } }; class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/arg_min_op.cc b/paddle/fluid/operators/arg_min_op.cc index 0a4ba6fb0bfdfccfc4eae99da730e96fe5f0a540..fb3abd01af8c396d764f9f1d247f24c41bd15959 100644 --- a/paddle/fluid/operators/arg_min_op.cc +++ b/paddle/fluid/operators/arg_min_op.cc @@ -12,26 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/operators/arg_min_max_op_base.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" + +DECLARE_INFER_SHAPE_FUNCTOR(arg_min, ArgMinInferShapeFunctor, + PD_INFER_META(phi::ArgMinMaxInferMeta)); REGISTER_OPERATOR( arg_min, paddle::operators::ArgMinMaxOp, paddle::operators::ArgMinOpMaker, paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); + paddle::framework::EmptyGradOpMaker, + ArgMinInferShapeFunctor); -REGISTER_OP_CPU_KERNEL( - arg_min, - paddle::operators::ArgMinKernel, - paddle::operators::ArgMinKernel, - paddle::operators::ArgMinKernel, - paddle::operators::ArgMinKernel, - paddle::operators::ArgMinKernel, - paddle::operators::ArgMinKernel); REGISTER_OP_VERSION(arg_min) .AddCheckpoint( R"ROC( diff --git a/paddle/fluid/operators/arg_min_op.cu b/paddle/fluid/operators/arg_min_op.cu deleted file mode 100644 index 23170bf0087906d752767051ce58874cb3584ee5..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/arg_min_op.cu +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/arg_min_max_op_base.cu.h" -REGISTER_OP_CUDA_KERNEL( - arg_min, paddle::operators::ArgMinMaxOpCUDAKernel, - paddle::operators::ArgMinMaxOpCUDAKernel, - paddle::operators::ArgMinMaxOpCUDAKernel, - paddle::operators::ArgMinMaxOpCUDAKernel, - paddle::operators::ArgMinMaxOpCUDAKernel); diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 544a5593014f41924eb544563c9379d020504af8..8c2707e1d2369b544c0748b9eb909635d7908326 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/convert_utils.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/type_traits.h" #include "paddle/phi/core/enforce.h" @@ -1014,6 +1015,82 @@ void DiagInferMeta(const MetaTensor& x, } } +void ArgMinMaxInferMeta(const MetaTensor& x, + int64_t axis, + bool keepdims, + bool flatten, + int dtype, + MetaTensor* out, + MetaConfig config) { + const auto& x_dims = x.dims(); + + PADDLE_ENFORCE_GE( + axis, + -x_dims.size(), + phi::errors::InvalidArgument("'axis'(%d) must be greater than or equal to" + " -Rank(X)(%d).", + axis, + -x_dims.size())); + PADDLE_ENFORCE_LT(axis, + x_dims.size(), + phi::errors::InvalidArgument( + "'axis'(%d) must be less than Rank(X)(%d) of Input(X).", + axis, + x_dims.size())); + + PADDLE_ENFORCE_EQ( + (dtype < 0 || dtype == 2 || dtype == 3), + true, + phi::errors::InvalidArgument( + "The attribute of dtype in argmin/argmax must be [%s] or [%s], but " + "received [%s]", + paddle::framework::DataTypeToString( + paddle::framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + paddle::framework::proto::VarType::INT64), + paddle::framework::DataTypeToString( + static_cast(dtype)))); + + auto x_rank = x_dims.size(); + if (axis < 0) axis += x_rank; + if (config.is_runtime) { + if (dtype == paddle::framework::proto::VarType::INT32) { + int64_t all_element_num = 0; + if (flatten) { + all_element_num = phi::product(x_dims); + + } else { + all_element_num = x_dims[axis]; + } + PADDLE_ENFORCE_LE( + all_element_num, + INT_MAX, + phi::errors::InvalidArgument( + "The element num of the argmin/argmax input at axis is " + "%d, is larger than int32 maximum value:%d, you must " + "set the dtype of argmin/argmax to 'int64'.", + all_element_num, + INT_MAX)); + } + } + std::vector vec; + if (flatten) { + vec.emplace_back(static_cast(1)); + } else { + for (int64_t i = 0; i < axis; i++) vec.emplace_back(x_dims[i]); + if (keepdims) { + vec.emplace_back(static_cast(1)); + } + for (int64_t i = axis + 1; i < x_rank; i++) vec.emplace_back(x_dims[i]); + } + out->set_dims(phi::make_ddim(vec)); + if (dtype == 2) { + out->set_dtype(DataType::INT32); + } else if (dtype == 3) { + out->set_dtype(DataType::INT64); + } +} + void SizeInferMeta(const MetaTensor& input, MetaTensor* out) { out->set_dtype(DataType::INT64); out->set_dims({1}); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index c57e1bdec8da84c5cccc456212b1cfda3c476a7a..df9258644ac1210d680bc5848ffa934118b068bb 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -147,6 +147,14 @@ void DiagInferMeta(const MetaTensor& x, float padding_value, MetaTensor* out); +void ArgMinMaxInferMeta(const MetaTensor& x, + int64_t axis, + bool keepdims, + bool flatten, + int dtype, + MetaTensor* out, + MetaConfig config = MetaConfig()); + void SizeInferMeta(const MetaTensor& input, MetaTensor* out); void DiagonalInferMeta( diff --git a/paddle/phi/kernels/arg_min_max_kernel.h b/paddle/phi/kernels/arg_min_max_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..917babeef07e99c9cdd6ab71e772a4e0cb1f9e12 --- /dev/null +++ b/paddle/phi/kernels/arg_min_max_kernel.h @@ -0,0 +1,39 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void ArgMinKernel(const Context& dev_ctx, + const DenseTensor& x, + int64_t axis, + bool keepdims, + bool flatten, + int dtype, + DenseTensor* out); + +template +void ArgMaxKernel(const Context& dev_ctx, + const DenseTensor& x, + int64_t axis, + bool keepdims, + bool flatten, + int dtype, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/arg_min_max_kernel.cc b/paddle/phi/kernels/cpu/arg_min_max_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..f4ad830e149321f417392f6e83bbc8cc06ad3876 --- /dev/null +++ b/paddle/phi/kernels/cpu/arg_min_max_kernel.cc @@ -0,0 +1,203 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/arg_min_max_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +enum ArgMinMaxType { kArgMin, kArgMax }; + +template +struct ArgMinMaxFunctor {}; + +#define DECLARE_ARG_MIN_MAX_FUNCTOR(eigen_op_type, enum_argminmax_value) \ + template \ + struct ArgMinMaxFunctor { \ + void operator()(const Context& dev_ctx, \ + const DenseTensor& in, \ + DenseTensor* out, \ + phi::DDim x_dims, \ + int64_t axis, \ + bool keepdims) { \ + auto in_eigen = EigenTensor::From(in, x_dims); \ + if (keepdims) { \ + auto out_eigen = EigenTensor::From(*out); \ + out_eigen.device(*(dev_ctx.eigen_device())) = \ + in_eigen.eigen_op_type(axis).template cast(); \ + } else { \ + auto out_eigen = EigenTensor::From(*out); \ + out_eigen.device(*(dev_ctx.eigen_device())) = \ + in_eigen.eigen_op_type(axis).template cast(); \ + } \ + } \ + } + +DECLARE_ARG_MIN_MAX_FUNCTOR(argmin, ArgMinMaxType::kArgMin); +DECLARE_ARG_MIN_MAX_FUNCTOR(argmax, ArgMinMaxType::kArgMax); + +template +struct VisitDataArgMinMaxFunctor { + const Context& dev_ctx; + const DenseTensor& x; + int64_t axis; + bool keepdims; + bool flatten; + DenseTensor* out; + + explicit VisitDataArgMinMaxFunctor(const Context& dev_ctx, + const DenseTensor& x, + int64_t axis, + bool keepdims, + bool flatten, + DenseTensor* out) + : dev_ctx(dev_ctx), + x(x), + axis(axis), + keepdims(keepdims), + flatten(flatten), + out(out) {} + template + void apply() const { + dev_ctx.template Alloc(out); + bool new_keepdims = keepdims; + if (flatten) new_keepdims = true; + + // if flatten, will construct the new dims for the cacluate + phi::DDim x_dims; + int new_axis = axis; + if (flatten) { + x_dims = phi::make_ddim({x.numel()}); + // if flatten, the axis just as 0 + new_axis = 0; + } else { + x_dims = x.dims(); + if (axis < 0) new_axis = axis + x_dims.size(); + } + +#define CALL_ARG_MINMAX_FUNCTOR(rank) \ + ArgMinMaxFunctor functor##rank; \ + functor##rank(dev_ctx, x, out, x_dims, new_axis, new_keepdims) + + switch (x_dims.size()) { + case 1: + CALL_ARG_MINMAX_FUNCTOR(1); + break; + case 2: + CALL_ARG_MINMAX_FUNCTOR(2); + break; + case 3: + CALL_ARG_MINMAX_FUNCTOR(3); + break; + case 4: + CALL_ARG_MINMAX_FUNCTOR(4); + break; + case 5: + CALL_ARG_MINMAX_FUNCTOR(5); + break; + case 6: + CALL_ARG_MINMAX_FUNCTOR(6); + break; + default: + PADDLE_ENFORCE_LE( + x_dims.size(), + 6, + phi::errors::InvalidArgument( + "%s operator doesn't supports tensors whose ranks are greater " + "than 6.", + (EnumArgMinMaxValue == kArgMin ? "argmin" : "argmax"))); + break; +#undef CALL_ARG_MINMAX_FUNCTOR + } + } +}; + +template +void ArgMinMaxKernel(const Context& dev_ctx, + const DenseTensor& x, + int64_t axis, + bool keepdims, + bool flatten, + int dtype, + DenseTensor* out) { + if (dtype < 0) { + paddle::framework::VisitDataTypeTiny( + static_cast( + paddle::framework::proto::VarType::INT64), + VisitDataArgMinMaxFunctor( + dev_ctx, x, axis, keepdims, flatten, out)); + return; + } + paddle::framework::VisitDataTypeTiny( + static_cast(dtype), + VisitDataArgMinMaxFunctor( + dev_ctx, x, axis, keepdims, flatten, out)); +} + +template +void ArgMinKernel(const Context& dev_ctx, + const DenseTensor& x, + int64_t axis, + bool keepdims, + bool flatten, + int dtype, + DenseTensor* out) { + ArgMinMaxKernel( + dev_ctx, x, axis, keepdims, flatten, dtype, out); +} + +template +void ArgMaxKernel(const Context& dev_ctx, + const DenseTensor& x, + int64_t axis, + bool keepdims, + bool flatten, + int dtype, + DenseTensor* out) { + ArgMinMaxKernel( + dev_ctx, x, axis, keepdims, flatten, dtype, out); +} + +} // namespace phi + +PD_REGISTER_KERNEL(arg_min, + CPU, + ALL_LAYOUT, + phi::ArgMinKernel, + float, + double, + int32_t, + int64_t, + int16_t, + uint8_t) {} + +PD_REGISTER_KERNEL(arg_max, + CPU, + ALL_LAYOUT, + phi::ArgMaxKernel, + float, + double, + int32_t, + int64_t, + int16_t, + uint8_t) {} diff --git a/paddle/phi/kernels/gpu/arg_min_max_kernel.cu b/paddle/phi/kernels/gpu/arg_min_max_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..6feee512cc9f4ec411167d1dc26feed1d766787d --- /dev/null +++ b/paddle/phi/kernels/gpu/arg_min_max_kernel.cu @@ -0,0 +1,278 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/arg_min_max_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#if defined(__NVCC__) || defined(__HIPCC__) + +#ifdef __NVCC__ +#include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif +#include +#include "paddle/fluid/framework/data_type.h" +#include "paddle/phi/core/ddim.h" + +namespace phi { + +namespace { // NOLINT +template +using KeyValuePair = cub::KeyValuePair; + +} // end namespace + +#define FIXED_BLOCK_DIM_CASE_BASE(log2_block_dim, ...) \ + case (1 << (log2_block_dim)): { \ + constexpr auto kBlockDim = (1 << (log2_block_dim)); \ + __VA_ARGS__; \ + } break + +#define FIXED_BLOCK_DIM_CASE(...) \ + FIXED_BLOCK_DIM_CASE_BASE(10, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(9, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(8, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(7, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(6, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(5, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(4, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__); + +template +__global__ void ArgCUDAKernel(const int64_t height, // n * h + const int64_t width, // c + const int64_t post_size, // h + const Reducer reducer, + const T init, + const T* in, + IndType* out) { + typedef cub::BlockReduce, BlockDim> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + for (int idx = blockIdx.x; idx < height; idx += gridDim.x) { + KeyValuePair kv_pair = {-1, init}; + int h = idx / post_size; + int w = idx % post_size; + for (int k = threadIdx.x; k < width; k += blockDim.x) { + kv_pair = + reducer({k, in[h * width * post_size + k * post_size + w]}, kv_pair); + } + kv_pair = BlockReduce(temp_storage).Reduce(kv_pair, reducer); + if (threadIdx.x == 0) { + out[idx] = static_cast(kv_pair.key); + } + __syncthreads(); + } +} + +template +void ComputeFullArg(const phi::GPUContext& dev_ctx, + const DenseTensor& input, + DenseTensor* indices, + const int64_t pre, + const int64_t post, + const int64_t n) { + auto cu_stream = dev_ctx.stream(); + auto ComputeBlockSize = [](int64_t col) { + auto block_size = 8; + if (col > 512) + block_size = 1024; + else if (col > 256) + block_size = 512; + else if (col > 128) + block_size = 256; + else if (col > 64) + block_size = 128; + else if (col > 32) + block_size = 64; + else if (col > 16) + block_size = 32; + else if (col > 8) + block_size = 16; +#ifdef __HIPCC__ + block_size = std::min(block_size, 256); +#endif + return block_size; + }; + + int64_t max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize()[0]; + int64_t height = pre * post; + int64_t width = n; + int64_t grid_size = height < max_grid_dimx ? height : max_grid_dimx; + + const T* in_data = input.data(); + IndType* out_data = dev_ctx.template Alloc(indices); + + if (typeid(Reducer) == typeid(cub::ArgMax)) { + switch (ComputeBlockSize(width)) { + FIXED_BLOCK_DIM_CASE( + ArgCUDAKernel<<>>( + height, + width, + post, + Reducer(), + std::numeric_limits::lowest(), + in_data, + out_data)); + } + } else { + switch (ComputeBlockSize(width)) { + FIXED_BLOCK_DIM_CASE( + ArgCUDAKernel<<>>( + height, + width, + post, + Reducer(), + std::numeric_limits::max(), + in_data, + out_data)); + } + } +} + +template +struct VisitDataCudaArgMinMaxFunctor { + const Context& dev_ctx; + const DenseTensor& x; + int64_t axis; + bool keepdims; + bool flatten; + DenseTensor* out; + + explicit VisitDataCudaArgMinMaxFunctor(const Context& dev_ctx, + const DenseTensor& x, + int64_t axis, + bool keepdims, + bool flatten, + DenseTensor* out) + : dev_ctx(dev_ctx), + x(x), + axis(axis), + keepdims(keepdims), + flatten(flatten), + out(out) {} + + template + void apply() const { + phi::DDim x_dims; + int new_axis = axis; + if (flatten) { + x_dims = phi::make_ddim({x.numel()}); + // if flatten, the axis just as 0 + new_axis = 0; + } else { + x_dims = x.dims(); + if (axis < 0) new_axis = axis + x.dims().size(); + } + + int64_t numel = x.numel(); + int64_t groups = numel / x_dims[new_axis]; + int64_t pre = 1; + int64_t post = 1; + int64_t n = x_dims[new_axis]; + + for (int i = 0; i < new_axis; i++) { + pre *= x_dims[i]; + } + + for (int i = new_axis + 1; i < x_dims.size(); i++) { + post *= x_dims[i]; + } + + ComputeFullArg(dev_ctx, x, out, pre, post, n); + } +}; + +template +void ArgMinMaxOpCUDAKernel(const Context& dev_ctx, + const DenseTensor& x, + int64_t axis, + bool keepdims, + bool flatten, + int dtype, + DenseTensor* out) { + if (dtype < 0) { + paddle::framework::VisitDataTypeTiny( + static_cast( + paddle::framework::proto::VarType::INT64), + VisitDataCudaArgMinMaxFunctor( + dev_ctx, x, axis, keepdims, flatten, out)); + return; + } + paddle::framework::VisitDataTypeTiny( + static_cast(dtype), + VisitDataCudaArgMinMaxFunctor( + dev_ctx, x, axis, keepdims, flatten, out)); +} + +template +void ArgMinKernel(const Context& dev_ctx, + const DenseTensor& x, + int64_t axis, + bool keepdims, + bool flatten, + int dtype, + DenseTensor* out) { + ArgMinMaxOpCUDAKernel( + dev_ctx, x, axis, keepdims, flatten, dtype, out); +} + +template +void ArgMaxKernel(const Context& dev_ctx, + const DenseTensor& x, + int64_t axis, + bool keepdims, + bool flatten, + int dtype, + DenseTensor* out) { + ArgMinMaxOpCUDAKernel( + dev_ctx, x, axis, keepdims, flatten, dtype, out); +} + +#endif + +} // namespace phi + +PD_REGISTER_KERNEL(arg_min, + GPU, + ALL_LAYOUT, + phi::ArgMinKernel, + float, + double, + int32_t, + int64_t, + int16_t, + uint8_t) {} + +PD_REGISTER_KERNEL(arg_max, + GPU, + ALL_LAYOUT, + phi::ArgMaxKernel, + float, + double, + int32_t, + int64_t, + int16_t, + uint8_t) {}