From 2402c742c876c77740abb21a7a7019ce47ee4cc5 Mon Sep 17 00:00:00 2001 From: Wilber Date: Tue, 23 Jun 2020 16:48:03 +0800 Subject: [PATCH] add topk_pooling kernel. test=develop (#3813) --- lite/backends/cuda/cuda_utils.h | 2 + lite/kernels/cuda/CMakeLists.txt | 2 + lite/kernels/cuda/topk_pooling_compute.cu | 200 ++++++++++++++++++ lite/kernels/cuda/topk_pooling_compute.h | 45 ++++ .../kernels/cuda/topk_pooling_compute_test.cc | 145 +++++++++++++ lite/operators/CMakeLists.txt | 1 + lite/operators/op_params.h | 9 + lite/operators/topk_pooling_op.cc | 55 +++++ lite/operators/topk_pooling_op.h | 46 ++++ 9 files changed, 505 insertions(+) create mode 100644 lite/kernels/cuda/topk_pooling_compute.cu create mode 100644 lite/kernels/cuda/topk_pooling_compute.h create mode 100644 lite/kernels/cuda/topk_pooling_compute_test.cc create mode 100644 lite/operators/topk_pooling_op.cc create mode 100644 lite/operators/topk_pooling_op.h diff --git a/lite/backends/cuda/cuda_utils.h b/lite/backends/cuda/cuda_utils.h index f52acac731..012004a65f 100644 --- a/lite/backends/cuda/cuda_utils.h +++ b/lite/backends/cuda/cuda_utils.h @@ -41,6 +41,8 @@ << "CUDA: " << cudaGetErrorString(e); \ } +#define CUDA_POST_KERNEL_CHECK CUDA_CALL(cudaPeekAtLastError()) + #define CUBLAS_CALL(func) \ { \ auto e = (func); \ diff --git a/lite/kernels/cuda/CMakeLists.txt b/lite/kernels/cuda/CMakeLists.txt index 1a58a51c36..a70d7e8004 100644 --- a/lite/kernels/cuda/CMakeLists.txt +++ b/lite/kernels/cuda/CMakeLists.txt @@ -44,6 +44,7 @@ add_kernel(match_matrix_tensor_compute_cuda CUDA extra SRCS match_matrix_tensor_ add_kernel(search_aligned_mat_mul_compute_cuda CUDA extra SRCS search_aligned_mat_mul_compute.cc DEPS ${lite_kernel_deps} cuda_batched_gemm) add_kernel(search_seq_fc_compute_cuda CUDA extra SRCS search_seq_fc_compute.cu DEPS ${lite_kernel_deps} cuda_gemm) add_kernel(var_conv_2d_compute_cuda CUDA extra SRCS var_conv_2d_compute.cu DEPS ${lite_kernel_deps} ${math_cuda}) +add_kernel(topk_pooling_compute_cuda CUDA extra SRCS topk_pooling_compute.cu DEPS ${lite_kernel_deps}) # unit test lite_cc_test(calib_compute_cuda_test SRCS calib_compute_cuda_test.cc DEPS calib_compute_cuda) @@ -79,4 +80,5 @@ if(LITE_BUILD_EXTRA) #nv_test(attention_padding_mask_compute_cuda_test SRCS attention_padding_mask_compute_test.cc DEPS attention_padding_mask_compute_cuda) nv_test(sequence_arithmetic_compute_cuda_test SRCS sequence_arithmetic_compute_test.cc DEPS sequence_arithmetic_compute_cuda) #nv_test(search_fc_cuda_test SRCS search_fc_compute_test.cc DEPS search_fc_compute_cuda) + nv_test(topk_pooling_compute_cuda_test SRCS topk_pooling_compute_test.cc DEPS topk_pooling_compute_cuda) endif() diff --git a/lite/kernels/cuda/topk_pooling_compute.cu b/lite/kernels/cuda/topk_pooling_compute.cu new file mode 100644 index 0000000000..bb4499b637 --- /dev/null +++ b/lite/kernels/cuda/topk_pooling_compute.cu @@ -0,0 +1,200 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/topk_pooling_compute.h" + +#include +#include + +#include "lite/backends/cuda/target_wrapper.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +template +__global__ void top_k_pooling_batch_kernel_reduction(Dtype *output_data, + const Dtype *input, + const int *height_offset, + const int *width_offset, + const int batch_size, + const int channel_num, + const int height_stride, + const int width_stride, + const int k) { + const Dtype *input_start = + input + + (blockIdx.x * channel_num + blockIdx.y) * height_stride * width_stride; + Dtype *output_start = + output_data + (blockIdx.x * channel_num + blockIdx.y) * k; + + int width = width_offset[blockIdx.x + 1] - width_offset[blockIdx.x]; + int height = height_offset[blockIdx.x + 1] - height_offset[blockIdx.x]; + int real_k = k < height * width ? k : height * width; + + extern __shared__ Dtype smem[]; + + Dtype min_val = -100000.0f; + for (int j = threadIdx.x; j < height * width; j += blockDim.x) { + int index_tmp = (j / width) * width_stride + j % width; + smem[j] = input_start[index_tmp]; + } + __syncthreads(); + + // get max val + int t = 0; + for (; t < real_k; ++t) { + // reduction + for (int gap = height * width; gap > 1;) { + if (threadIdx.x == 0) { // edge cond + if (gap % 2 != 0) { + Dtype value_first = smem[0]; + Dtype value_gap = smem[gap - 1]; + if (value_first < value_gap) { + smem[0] = value_gap; + smem[gap - 1] = value_first; + } + } + } + gap >>= 1; + for (int j = threadIdx.x; j < gap; j += blockDim.x) { + Dtype value_first = smem[j]; + Dtype value_gap = smem[j + gap]; + if (value_first < value_gap) { + smem[j] = value_gap; + smem[j + gap] = value_first; + } + } + __syncthreads(); + } + if (threadIdx.x == 0) { + output_start[t] = smem[0]; + smem[0] = min_val; + } + __syncthreads(); + } + for (int i = threadIdx.x; i < (k - t); i += blockDim.x) { + // output_start[t + i] = 0.0f; + } +} + +template +void TopkPoolingCompute::PrepareForRun() { + int device_id = lite::TargetWrapperCuda::GetCurDevice(); + cudaDeviceProp deviceProp; + CUDA_CALL(cudaGetDeviceProperties(&deviceProp, device_id)); + _shared_mem_size = deviceProp.sharedMemPerBlock; +} + +template +void TopkPoolingCompute::Run() { + auto ¶m = this->Param(); + auto &ctx = this->ctx_->template As(); + auto cuda_stream = ctx.exec_stream(); + + CHECK(param.X->lod().size() > 0 && param.X->lod()[0].size() > 0) + << "X sequence offset is not valid"; + CHECK(param.Y->lod().size() > 0 && param.Y->lod()[0].size() > 0) + << "Y sequence offset is not valid"; + + int width_offset_len = param.X->lod()[0].size(); + lite::DDim width_offset_shape(std::vector{width_offset_len}); + _width_offset.Resize(width_offset_shape); + std::vector width_lod_0(width_offset_len, 0); + for (size_t i = 0; i < param.X->lod()[0].size(); ++i) { + width_lod_0[i] = static_cast(param.X->lod()[0][i]); + } + lite::TargetWrapperCuda::MemcpyAsync( + _width_offset.mutable_data(TARGET(kCUDA)), + width_lod_0.data(), + sizeof(int) * width_offset_len, + lite::IoDirection::HtoD, + cuda_stream); + + int height_offset_len = param.Y->lod()[0].size(); + lite::DDim height_offset_shape(std::vector{height_offset_len}); + _height_offset.Resize(height_offset_shape); + std::vector height_lod_0(height_offset_len, 0); + for (size_t i = 0; i < param.Y->lod()[0].size(); ++i) { + height_lod_0[i] = static_cast(param.Y->lod()[0][i]); + } + lite::TargetWrapperCuda::MemcpyAsync( + _height_offset.mutable_data(TARGET(kCUDA)), + height_lod_0.data(), + sizeof(int) * height_offset_len, + lite::IoDirection::HtoD, + cuda_stream); + + const Tensor *x_tensor = param.X; + Tensor *out_tensor = param.Out; + const T *in_data = x_tensor->data(); + T *out_data = out_tensor->mutable_data(TARGET(kCUDA)); + + int num = x_tensor->dims()[0]; + int channel = x_tensor->dims()[1]; + int height = x_tensor->dims()[2]; + int width = x_tensor->dims()[3]; + + const int *height_offset = _height_offset.data(); + const int *width_offset = _width_offset.data(); + + int feat_map_size = height * width; + + if (feat_map_size * sizeof(T) <= _shared_mem_size) { + dim3 blocks(num, channel); + dim3 threads(32, 1); + + top_k_pooling_batch_kernel_reduction< + T><<>>( + out_data, + in_data, + height_offset, + width_offset, + num, + channel, + height, + width, + param.top_k); + } else { + LOG(FATAL) << "Not implemented. Exceeded the shared memory limit."; + } + CUDA_POST_KERNEL_CHECK; +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(topk_pooling, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::TopkPoolingCompute, + def) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindInput("Y", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .Finalize(); diff --git a/lite/kernels/cuda/topk_pooling_compute.h b/lite/kernels/cuda/topk_pooling_compute.h new file mode 100644 index 0000000000..abf1616381 --- /dev/null +++ b/lite/kernels/cuda/topk_pooling_compute.h @@ -0,0 +1,45 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/backends/cuda/cuda_utils.h" +#include "lite/core/kernel.h" +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +template +class TopkPoolingCompute + : public KernelLite { + public: + using param_t = operators::TopkPoolingParam; + + void Run() override; + + void PrepareForRun() override; + + virtual ~TopkPoolingCompute() = default; + + protected: + lite::Tensor _height_offset; + lite::Tensor _width_offset; + int _shared_mem_size; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/topk_pooling_compute_test.cc b/lite/kernels/cuda/topk_pooling_compute_test.cc new file mode 100644 index 0000000000..0fb5c29f25 --- /dev/null +++ b/lite/kernels/cuda/topk_pooling_compute_test.cc @@ -0,0 +1,145 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/topk_pooling_compute.h" + +#include + +#include +#include +#include +#include + +#include "lite/api/test_helper.h" +#include "lite/utils/float16.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class TopkPooingTest : public ::testing::Test { + protected: + TopkPooingTest() + : num(2), + channels(4), + height(4), + width(4), + top_k(2), + feat_map_num(height * width), + x_lod({{0, 4, 7}}), + y_lod({{0, 4, 7}}), + x_shape({num, channels, height, width}), + out_shape({num, channels * top_k}) { + CHECK_EQ(x_lod[0].size(), num + 1) << "invalid input."; + for (size_t i = 1; i < x_lod[0].size(); ++i) { + CHECK_LE(x_lod[0][i] - x_lod[0][i - 1], height) << "invalid input."; + } + + X_gpu.Resize(lite::DDim(x_shape)); + X_ref.Resize(lite::DDim(x_shape)); + X_ref.set_lod(x_lod); + Y_gpu.Resize(lite::DDim(x_shape)); + Y_ref.Resize(lite::DDim(x_shape)); + Y_ref.set_lod(y_lod); + auto x_ref_data = X_ref.mutable_data(); + auto y_ref_data = Y_ref.mutable_data(); + + // prepare input + for (int64_t i = 0; i < X_ref.numel(); i++) { + x_ref_data[i] = static_cast(i % 16); + } + for (int64_t i = 0; i < Y_ref.numel(); i++) { + y_ref_data[i] = static_cast(i % 16); + } + + Out_ref.Resize(lite::DDim(out_shape)); + Out_gpu.Resize(lite::DDim(out_shape)); + Out_cpu.Resize(lite::DDim(out_shape)); + + device_init(); + } + + void device_init() { + ctx.reset(new KernelContext); + cudaStreamCreate(&stream); + param.X = &X_gpu; + param.Y = &Y_gpu; + param.Out = &Out_gpu; + param.top_k = top_k; + param.feat_map_num = feat_map_num; + } + + void float_data_init() { + X_gpu.Assign(X_ref.data(), + X_gpu.dims()); + X_gpu.set_lod(X_ref.lod()); + Y_gpu.Assign(Y_ref.data(), + Y_gpu.dims()); + Y_gpu.set_lod(Y_ref.lod()); + } + + void half_data_init() {} + + void cpu_base(const lite::Tensor* X, + const lite::Tensor* Y, + lite::Tensor* Out) {} + + int num, channels, height, width; + int top_k, feat_map_num; + std::vector> x_lod, y_lod; + std::vector x_shape, out_shape; + lite::Tensor X_ref, Y_ref, Out_ref; + lite::Tensor X_gpu, Y_gpu; + lite::Tensor Out_cpu, Out_gpu; + + operators::TopkPoolingParam param; + std::unique_ptr ctx; + cudaStream_t stream; +}; + +TEST_F(TopkPooingTest, fp32) { + float_data_init(); + auto& context = ctx->As(); + context.SetExecStream(stream); + TopkPoolingCompute kernel; + kernel.SetParam(param); + kernel.SetContext(std::move(ctx)); + + for (int i = 0; i < FLAGS_warmup; ++i) { + kernel.Launch(); + cudaDeviceSynchronize(); + } + + auto start = GetCurrentUS(); + kernel.PrepareForRun(); + for (int i = 0; i < FLAGS_repeats; ++i) { + kernel.Run(); + } + cudaDeviceSynchronize(); + auto duration = (GetCurrentUS() - start) / 1000.0; + LOG(INFO) << "fp32, warmup: " << FLAGS_warmup + << ", repeats: " << FLAGS_repeats << ", spend " + << duration / FLAGS_repeats << " ms in average."; + + CopySync(Out_cpu.mutable_data(), + Out_gpu.data(), + sizeof(float) * Out_gpu.numel(), + IoDirection::DtoH); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 76f223c3d5..dac3c3f7dd 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -147,6 +147,7 @@ add_operator(search_seq_fc_op extra SRCS search_seq_fc_op.cc DEPS ${op_DEPS}) add_operator(sequence_topk_avg_pooling_op basic SRCS sequence_topk_avg_pooling_op.cc DEPS ${op_DEPS}) add_operator(search_fc_op basic SRCS search_fc_op.cc DEPS ${op_DEPS}) add_operator(lstm_op extra SRCS lstm_op.cc DEPS ${op_DEPS}) +add_operator(topk_pooling_op extra SRCS topk_pooling_op.cc DEPS ${op_DEPS}) # for deformable-convNet add_operator(deformable_conv_op extra SRCS deformable_conv_op.cc DEPS ${op_DEPS}) diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 98099a1a1e..4254ded3b5 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -1344,6 +1344,15 @@ struct SequenceTopkAvgPoolingParam : ParamBase { std::vector topks{}; }; +/// --------------- topk_pooling operators ------------------ +struct TopkPoolingParam : ParamBase { + const lite::Tensor* X{}; + const lite::Tensor* Y{}; + lite::Tensor* Out{}; + int top_k{1}; + int feat_map_num{1}; +}; + /// --------------- search_fc operators ------------------ struct SearchFcParam : ParamBase { const lite::Tensor* X{}; diff --git a/lite/operators/topk_pooling_op.cc b/lite/operators/topk_pooling_op.cc new file mode 100644 index 0000000000..76634d216a --- /dev/null +++ b/lite/operators/topk_pooling_op.cc @@ -0,0 +1,55 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/topk_pooling_op.h" +#include "lite/core/op_registry.h" +namespace paddle { +namespace lite { +namespace operators { + +bool TopkPoolingOp::CheckShape() const { + CHECK_OR_FALSE(param_.X); + CHECK_OR_FALSE(param_.Y); + CHECK_OR_FALSE(param_.Out); + return true; +} + +bool TopkPoolingOp::InferShapeImpl() const { + auto out_dims = param_.X->dims(); + out_dims[1] *= param_.top_k; + auto out = param_.Out; + out->Resize(out_dims); + out->set_lod(param_.X->lod()); + + return true; +} + +bool TopkPoolingOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { + auto x = op_desc.Input("X").front(); + auto y = op_desc.Input("Y").front(); + param_.X = scope->FindTensor(x); + param_.Y = scope->FindTensor(y); + auto output = op_desc.Output("Out").front(); + param_.Out = scope->FindMutableTensor(output); + param_.top_k = op_desc.GetAttr("top_k"); + param_.feat_map_num = op_desc.GetAttr("feat_map_num"); + + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(topk_pooling, paddle::lite::operators::TopkPoolingOp); diff --git a/lite/operators/topk_pooling_op.h b/lite/operators/topk_pooling_op.h new file mode 100644 index 0000000000..ec48c476ca --- /dev/null +++ b/lite/operators/topk_pooling_op.h @@ -0,0 +1,46 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class TopkPoolingOp : public OpLite { + public: + TopkPoolingOp() {} + explicit TopkPoolingOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShapeImpl() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "topk_pooling"; } + + private: + mutable TopkPoolingParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle -- GitLab