From f00f982a026e323eecbdbd7028907355539ae94f Mon Sep 17 00:00:00 2001 From: Zhaolong Xing Date: Thu, 20 Aug 2020 14:42:47 +0800 Subject: [PATCH] add cub impl for arg max, min (#25941) test=develop --- paddle/fluid/operators/arg_max_op.cu | 51 +++--- .../fluid/operators/arg_min_max_op_base.cu.h | 163 ++++++++++++++++++ paddle/fluid/operators/arg_min_op.cu | 50 +++--- 3 files changed, 206 insertions(+), 58 deletions(-) create mode 100644 paddle/fluid/operators/arg_min_max_op_base.cu.h diff --git a/paddle/fluid/operators/arg_max_op.cu b/paddle/fluid/operators/arg_max_op.cu index 85e4f98173..14708c4df1 100644 --- a/paddle/fluid/operators/arg_max_op.cu +++ b/paddle/fluid/operators/arg_max_op.cu @@ -1,29 +1,22 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/arg_min_max_op_base.h" - -REGISTER_OP_CUDA_KERNEL( - arg_max, - paddle::operators::ArgMaxKernel, - paddle::operators::ArgMaxKernel, - paddle::operators::ArgMaxKernel, - paddle::operators::ArgMaxKernel, - paddle::operators::ArgMaxKernel, - paddle::operators::ArgMaxKernel); +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/arg_min_max_op_base.cu.h" + +REGISTER_OP_CUDA_KERNEL( + arg_max, paddle::operators::ArgMinMaxOpCUDAKernel, + paddle::operators::ArgMinMaxOpCUDAKernel, + paddle::operators::ArgMinMaxOpCUDAKernel, + paddle::operators::ArgMinMaxOpCUDAKernel, + paddle::operators::ArgMinMaxOpCUDAKernel); diff --git a/paddle/fluid/operators/arg_min_max_op_base.cu.h b/paddle/fluid/operators/arg_min_max_op_base.cu.h new file mode 100644 index 0000000000..e9b65946a5 --- /dev/null +++ b/paddle/fluid/operators/arg_min_max_op_base.cu.h @@ -0,0 +1,163 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifdef __NVCC__ + +#include +#include +#include +#include +#include +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/transpose_op.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace operators { + +namespace { // NOLINT +template +using KeyValuePair = cub::KeyValuePair; +using Tensor = framework::Tensor; + +} // end namespace + +#define FIXED_BLOCK_DIM_CASE_BASE(log2_block_dim, ...) \ + case (1 << (log2_block_dim)): { \ + constexpr auto kBlockDim = (1 << (log2_block_dim)); \ + __VA_ARGS__; \ + } break + +#define FIXED_BLOCK_DIM_CASE(...) \ + FIXED_BLOCK_DIM_CASE_BASE(10, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(9, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(8, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(7, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(6, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(5, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(4, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__); + +template +__global__ void ArgCUDAKernel(const IndType height, // n * h + const IndType width, // c + const IndType post_size, // h + const Reducer reducer, const T init, const T* in, + IndType* out) { + typedef cub::BlockReduce, BlockDim> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + for (int idx = blockIdx.x; idx < height; idx += gridDim.x) { + KeyValuePair kv_pair = {-1, init}; + int h = idx / post_size; + int w = idx % post_size; + for (int k = threadIdx.x; k < width; k += blockDim.x) { + kv_pair = + reducer({k, in[h * width * post_size + k * post_size + w]}, kv_pair); + } + kv_pair = BlockReduce(temp_storage).Reduce(kv_pair, reducer); + if (threadIdx.x == 0) { + out[idx] = static_cast(kv_pair.key); + } + __syncthreads(); + } +} + +template +void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input, + Tensor* indices, const IndType pre, const IndType post, + const IndType n) { + auto cu_stream = ctx.stream(); + auto ComputeBlockSize = [](IndType col) { + if (col > 512) + return 1024; + else if (col > 256) + return 512; + else if (col > 128) + return 256; + else if (col > 64) + return 128; + else if (col > 32) + return 64; + else if (col > 16) + return 32; + else if (col > 8) + return 16; + else + return 8; + }; + + int max_grid_dimx = ctx.GetCUDAMaxGridDimSize().x; + int height = pre * post; + int width = n; + int grid_size = height < max_grid_dimx ? height : max_grid_dimx; + + const T* in_data = input.data(); + IndType* out_data = indices->mutable_data(ctx.GetPlace()); + + if (typeid(Reducer) == typeid(cub::ArgMax)) { + switch (ComputeBlockSize(width)) { + FIXED_BLOCK_DIM_CASE( + ArgCUDAKernel<<>>( + height, width, post, Reducer(), std::numeric_limits::lowest(), + in_data, out_data)); + } + } else { + switch (ComputeBlockSize(width)) { + FIXED_BLOCK_DIM_CASE( + ArgCUDAKernel<<>>( + height, width, post, Reducer(), std::numeric_limits::max(), + in_data, out_data)); + } + } +} + +template +class ArgMinMaxOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + int axis = ctx.Attr("axis"); + auto in_dims = input->dims(); + axis = (axis < 0) ? (in_dims.size() + axis) : axis; + + int64_t numel = input->numel(); + int64_t groups = numel / in_dims[axis]; + int64_t pre = 1; + int64_t post = 1; + int64_t n = in_dims[axis]; + + for (int i = 0; i < axis; i++) { + pre *= in_dims[i]; + } + + for (int i = axis + 1; i < in_dims.size(); i++) { + post *= in_dims[i]; + } + + const auto& dev_ctx = ctx.cuda_device_context(); + ComputeFullArg(dev_ctx, *input, output, pre, post, n); + } +}; + +#endif + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/arg_min_op.cu b/paddle/fluid/operators/arg_min_op.cu index 47d7c8b122..23170bf008 100644 --- a/paddle/fluid/operators/arg_min_op.cu +++ b/paddle/fluid/operators/arg_min_op.cu @@ -1,29 +1,21 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/arg_min_max_op_base.h" - -REGISTER_OP_CUDA_KERNEL( - arg_min, - paddle::operators::ArgMinKernel, - paddle::operators::ArgMinKernel, - paddle::operators::ArgMinKernel, - paddle::operators::ArgMinKernel, - paddle::operators::ArgMinKernel, - paddle::operators::ArgMinKernel); +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/arg_min_max_op_base.cu.h" +REGISTER_OP_CUDA_KERNEL( + arg_min, paddle::operators::ArgMinMaxOpCUDAKernel, + paddle::operators::ArgMinMaxOpCUDAKernel, + paddle::operators::ArgMinMaxOpCUDAKernel, + paddle::operators::ArgMinMaxOpCUDAKernel, + paddle::operators::ArgMinMaxOpCUDAKernel); -- GitLab