From 3a8818612bab3f3b8859f3cb3c5156bd5acb9f1a Mon Sep 17 00:00:00 2001 From: huzhiqiang <912790387@qq.com> Date: Thu, 21 Nov 2019 08:07:08 -0600 Subject: [PATCH] add cuda kernel for sequence_topk_avg_pooling and search_fc (#2451) * cuda kernel for sequence_topk_avg_pooling and search_fc test=develop --- lite/kernels/cuda/CMakeLists.txt | 3 + lite/kernels/cuda/search_fc_compute.cu | 176 ++++++++++++++++ lite/kernels/cuda/search_fc_compute.h | 52 +++++ lite/kernels/cuda/search_fc_compute_test.cc | 110 ++++++++++ .../cuda/sequence_topk_avg_pooling_compute.cu | 196 ++++++++++++++++++ .../cuda/sequence_topk_avg_pooling_compute.h | 43 ++++ 6 files changed, 580 insertions(+) create mode 100644 lite/kernels/cuda/search_fc_compute.cu create mode 100644 lite/kernels/cuda/search_fc_compute.h create mode 100644 lite/kernels/cuda/search_fc_compute_test.cc create mode 100644 lite/kernels/cuda/sequence_topk_avg_pooling_compute.cu create mode 100644 lite/kernels/cuda/sequence_topk_avg_pooling_compute.h diff --git a/lite/kernels/cuda/CMakeLists.txt b/lite/kernels/cuda/CMakeLists.txt index 34a192b040..e7bef76d05 100644 --- a/lite/kernels/cuda/CMakeLists.txt +++ b/lite/kernels/cuda/CMakeLists.txt @@ -30,6 +30,8 @@ add_kernel(sequence_concat_compute_cuda CUDA basic SRCS sequence_concat_compute. add_kernel(sequence_arithmetic_compute_cuda CUDA basic SRCS sequence_arithmetic_compute.cu DEPS ${lite_kernel_deps}) add_kernel(lookup_table_compute_cuda CUDA extra SRCS lookup_table_compute.cu DEPS ${lite_kernel_deps}) add_kernel(attention_padding_mask_compute_cuda CUDA extra SRCS attention_padding_mask_compute.cu DEPS ${lite_kernel_deps}) +add_kernel(search_fc_compute_cuda CUDA basic SRCS search_fc_compute.cu DEPS ${lite_kernel_deps} ${math_cuda}) +add_kernel(sequence_topk_avg_pooling_compute_cuda CUDA basic SRCS sequence_topk_avg_pooling_compute.cu DEPS ${lite_kernel_deps}) add_kernel(match_matrix_tensor_compute_cuda CUDA extra SRCS match_matrix_tensor_compute.cu DEPS ${lite_kernel_deps} cuda_gemm) add_kernel(search_aligned_mat_mul_compute_cuda CUDA extra SRCS search_aligned_mat_mul_compute.cc DEPS ${lite_kernel_deps} cuda_batched_gemm) add_kernel(search_seq_fc_compute_cuda CUDA extra SRCS search_seq_fc_compute.cu DEPS ${lite_kernel_deps} cuda_gemm) @@ -53,6 +55,7 @@ nv_test(sequence_reverse_compute_cuda_test SRCS sequence_reverse_compute_test.cc nv_test(sequence_concat_compute_cuda_test SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_cuda) nv_test(attention_padding_mask_compute_cuda_test SRCS attention_padding_mask_compute_test.cc DEPS attention_padding_mask_compute_cuda) nv_test(sequence_arithmetic_compute_cuda_test SRCS sequence_arithmetic_compute_test.cc DEPS sequence_arithmetic_compute_cuda) +nv_test(search_fc_test SRCS search_fc_compute_test.cc DEPS search_fc_compute_cuda sequence_topk_avg_pooling_compute_cuda) nv_test(var_conv_2d_compute_cuda_test SRCS var_conv_2d_compute_test.cc DEPS var_conv_2d_compute_cuda) if(LITE_BUILD_EXTRA) diff --git a/lite/kernels/cuda/search_fc_compute.cu b/lite/kernels/cuda/search_fc_compute.cu new file mode 100644 index 0000000000..b634bc933d --- /dev/null +++ b/lite/kernels/cuda/search_fc_compute.cu @@ -0,0 +1,176 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "lite/core/op_registry.h" +#include "lite/kernels/cuda/search_fc_compute.h" +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { +template +static void anakin_NV_gemv(cublasHandle_t handle, + const bool TransA, + const int M, + const int N, + const T alpha, + const T* A, + const T* x, + const T beta, + T* y); +template <> +void anakin_NV_gemv(cublasHandle_t handle, + const bool TransA, + const int M, + const int N, + const float alpha, + const float* A, + const float* x, + const float beta, + float* y) { + LOG(INFO) << "1"; + cublasOperation_t cuTransA = (TransA == false) ? CUBLAS_OP_T : CUBLAS_OP_N; + CUBLAS_CHECK( + cublasSgemv(handle, cuTransA, N, M, &alpha, A, N, x, 1, &beta, y, 1)); +} +template +static void anakin_NV_gemm(cublasHandle_t handle, + const bool TransA, + const bool TransB, + const int M, + const int N, + const int K, + const T alpha, + const T* A, + const T* B, + const T beta, + T* C); + +template <> +void anakin_NV_gemm(cublasHandle_t handle, + const bool TransA, + const bool TransB, + const int M, + const int N, + const int K, + const float alpha, + const float* A, + const float* B, + const float beta, + float* C) { + LOG(INFO) << "1"; + // Note that cublas follows fortran order. + int lda = (!TransA /* == CblasNoTrans*/) ? K : M; + int ldb = (!TransB /* == CblasNoTrans*/) ? N : K; + LOG(INFO) << "1"; + cublasOperation_t cuTransA = + (!TransA /* == CblasNoTrans*/) ? CUBLAS_OP_N : CUBLAS_OP_T; + LOG(INFO) << "1"; + cublasOperation_t cuTransB = + (!TransB /* == CblasNoTrans*/) ? CUBLAS_OP_N : CUBLAS_OP_T; + LOG(INFO) << "1"; + CUBLAS_CHECK(cublasSgemm(handle, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + ldb, + A, + lda, + &beta, + C, + N)); + LOG(INFO) << "1"; +} + +template <> +void anakin_NV_gemm(cublasHandle_t handle, + const bool TransA, + const bool TransB, + const int M, + const int N, + const int K, + const char alpha, + const char* A, + const char* B, + const char beta, + char* C) { + LOG(FATAL) << "int8 gemm is not implemented"; +} + +template +static __global__ void add_bias(int n, + int output_size, + const T* bias, + T* dout) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int bias_index = index % output_size; + if (index < n) { + dout[index] = dout[index] + bias[bias_index]; + } +} + +template +void SearchFcCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + const Tensor* x_tensor = param.X; + param.Out->Resize({x_tensor->dims()[0], param.out_size}); + _M = x_tensor->dims().count(0, 1); + _K = x_tensor->dims().count(1, x_tensor->numel()); + _N = param.out_size; + const T* din = x_tensor->data(); + Tensor* out_tensor = param.Out; + T* dout = out_tensor->mutable_data(TARGET(kCUDA)); + const Tensor* w_tensor = param.W; + const T* weight = w_tensor->data(); + const Tensor* b_tensor = param.b; + const T* bias = b_tensor->data(); + cublasCreate(&_handle); + if (_M == 1 && _K > 50000) { + anakin_NV_gemv(_handle, false, _N, _K, (T)1, weight, din, (T)0, dout); + } else { + anakin_NV_gemm(_handle, + false, + !_flag_trans_weights, + _M, + _N, + _K, + (T)1, + din, + weight, + (T)0, + dout); + } + int total_size = _M * _N; + add_bias<<>>( + total_size, _N, bias, dout); +} +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(search_fc, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::SearchFcCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("W", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("b", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); diff --git a/lite/kernels/cuda/search_fc_compute.h b/lite/kernels/cuda/search_fc_compute.h new file mode 100644 index 0000000000..db09362734 --- /dev/null +++ b/lite/kernels/cuda/search_fc_compute.h @@ -0,0 +1,52 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/backends/cuda/cuda_utils.h" +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +const int CUDA_NUM_THREADS = 512; +inline int CUDA_GET_BLOCKS(const int N) { + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} +inline int CUDA_GET_BLOCKS(const int N, const int base) { + return (N + base - 1) / base; +} + +template +class SearchFcCompute : public KernelLite { + public: + using param_t = operators::SearchFcParam; + void Run() override; + virtual ~SearchFcCompute() = default; + + private: + bool _flag_trans_weights{false}; + int _M; + int _K; + int _N; + cublasHandle_t _handle; + bool _is_continue_buf{true}; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/search_fc_compute_test.cc b/lite/kernels/cuda/search_fc_compute_test.cc new file mode 100644 index 0000000000..f06028fbe1 --- /dev/null +++ b/lite/kernels/cuda/search_fc_compute_test.cc @@ -0,0 +1,110 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/search_fc_compute.h" +#include +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +void fc_cpu_base(const lite::Tensor* X, + const lite::Tensor* W, + const lite::Tensor* b, + int out_size, + lite::Tensor* Out) { + const float* data_in = X->data(); + const float* bias = b->data(); + const float* weights = W->data(); + float* data_out = Out->mutable_data(); + int out_rows = X->dims()[0]; + int in_cols = X->numel() / out_rows; + int out_cols = W->numel() / in_cols; + int index_out; + + for (int i = 0; i < out_rows; i++) { + for (int j = 0; j < out_cols; j++) { + index_out = i * out_cols + j; + data_out[index_out] = bias ? bias[j] : 0; + + for (int k = 0; k < in_cols; k++) { + data_out[index_out] += + data_in[i * in_cols + k] * weights[j * in_cols + k]; + } + } + } +} + +TEST(search_fc, normal) { + SearchFcCompute search_fc_kernel; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + operators::SearchFcParam param; + lite::Tensor X, X_gpu, W, W_gpu, b, b_gpu; + lite::Tensor Out, Out_cpu, out_ref; + std::vector x_shape{1, 4}; + X.Resize(lite::DDim(x_shape)); + std::vector w_shape{3, 4}; + W.Resize(lite::DDim(w_shape)); + std::vector b_shape{3}; + b.Resize(lite::DDim(b_shape)); + std::vector out_shape{1, 4}; + Out.Resize(lite::DDim(out_shape)); + out_ref.Resize(lite::DDim(out_shape)); + auto x_data = X.mutable_data(); + auto w_data = W.mutable_data(); + auto b_data = b.mutable_data(); + auto out_data_ref = out_ref.mutable_data(); + for (int64_t i = 0; i < X.dims().production(); i++) { + x_data[i] = static_cast(i); + } + for (int64_t i = 0; i < W.dims().production(); i++) { + w_data[i] = static_cast(i); + } + for (int64_t i = 0; i < b.dims().production(); i++) { + b_data[i] = static_cast(i); + } + X_gpu.Assign(x_data, X.dims()); + W_gpu.Assign(w_data, W.dims()); + b_gpu.Assign(b_data, b.dims()); + param.X = &X_gpu; + param.W = &W_gpu; + param.b = &b_gpu; + param.out_size = 4; + param.Out = &Out; + search_fc_kernel.SetParam(param); + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + search_fc_kernel.SetContext(std::move(ctx)); + search_fc_kernel.Run(); + fc_cpu_base(&X, &W, &b, 4, &out_ref); + cudaDeviceSynchronize(); + const float* out_data = Out.data(); + float* out_cpu_data = Out_cpu.mutable_data(); + CopySync( + out_cpu_data, out_data, sizeof(float) * Out.numel(), IoDirection::DtoH); + for (int i = 0; i < Out.numel(); ++i) { + EXPECT_NEAR(out_cpu_data[i], out_data_ref[i], 1e-5); + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/sequence_topk_avg_pooling_compute.cu b/lite/kernels/cuda/sequence_topk_avg_pooling_compute.cu new file mode 100644 index 0000000000..7f4500c158 --- /dev/null +++ b/lite/kernels/cuda/sequence_topk_avg_pooling_compute.cu @@ -0,0 +1,196 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "lite/core/op_registry.h" +#include "lite/kernels/cuda/sequence_topk_avg_pooling_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +template +__global__ void topk_avg_pooling_kernel_by_row_improve( + Dtype *output_data, + const Dtype *input, + const int *gpu_input_offset_l, + const int *gpu_input_offset_r, + const int row_max, + const int col_max, + const int topk_size, + const int *topks, + const int feat_map_num) { + int row = + gpu_input_offset_l[blockIdx.x + 1] - gpu_input_offset_l[blockIdx.x]; // 8 + int col = gpu_input_offset_r[blockIdx.x + 1] - + gpu_input_offset_r[blockIdx.x]; // 30 + + int max_k = topks[topk_size - 1]; + max_k = max_k < col ? max_k : col; + + extern __shared__ Dtype smem[]; // H*W + + const Dtype *fm_row_in_data = input + + blockIdx.x * row_max * feat_map_num * col_max + + blockIdx.y * row_max * col_max; + + for (int i = threadIdx.x; i < row * col_max; i += blockDim.x) { + smem[i] = fm_row_in_data[i]; + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < row; idx += blockDim.x) { + Dtype *fm_row_out_data = + output_data + + (gpu_input_offset_l[blockIdx.x] + idx) * feat_map_num * topk_size + + blockIdx.y * topk_size; + + Dtype *smem_start_col = smem + idx * col_max; + + int counter = max_k; // topk_size; + Dtype last_max_val = -20000.0; + while (counter) { + Dtype max_val = -10000.0; + int max_pos = 0; + int m = 0; + for (; m < col; m++) { + Dtype cur_data = smem_start_col[m]; + if (cur_data > max_val) { + max_val = cur_data; + max_pos = m; + last_max_val = max_val; + } + } + if (max_val < -9999.0) { // == -10000.0 + max_val = last_max_val; + } + smem_start_col[max_pos] = 10000000.0; + int i = max_k - counter; + for (int c = 0; c < topk_size; c++) { + if (i <= topks[c] - 1) { + fm_row_out_data[c] += max_val; + } + } + counter--; + } + __syncthreads(); + // compute avg + for (int i = 0; i < topk_size; i++) { + fm_row_out_data[i] = fm_row_out_data[i] / topks[i]; + } + } +} + +template +void SequenceTopkAvgPoolingCompute::Run() { + auto ¶m = this->Param(); + auto &ctx = this->ctx_->template As(); + auto cuda_stream = ctx.exec_stream(); + + int topk_num = param.topks.size(); + lite::DDim top_ks_shape(std::vector{topk_num, 1, 1, 1}); + _top_ks.Resize(top_ks_shape); + cudaMemcpyAsync(_top_ks.mutable_data(TARGET(kCUDA)), + ¶m.topks[0], + sizeof(int) * topk_num, + cudaMemcpyHostToDevice, + cuda_stream); + + int width_offset_len = param.X->lod()[0].size(); + lite::DDim width_offset_shape( + std::vector{width_offset_len, 1, 1, 1}); + _width_offset.Resize(width_offset_shape); + cudaMemcpyAsync(_width_offset.mutable_data(TARGET(kCUDA)), + &(param.X->lod()[0][0]), + sizeof(int) * width_offset_len, + cudaMemcpyHostToDevice, + cuda_stream); + + int height_offset_len = param.ROW->lod()[0].size(); + lite::DDim height_offset_shape( + std::vector{height_offset_len, 1, 1, 1}); + _height_offset.Resize(height_offset_shape); + cudaMemcpyAsync(_height_offset.mutable_data(TARGET(kCUDA)), + &(param.ROW->lod()[0][0]), + sizeof(int) * height_offset_len, + cudaMemcpyHostToDevice, + cuda_stream); + + const Tensor *x_tensor = param.X; + Tensor *out_tensor = param.Out; + const T *in_data = x_tensor->data(); + T *out_data = out_tensor->mutable_data(TARGET(kCUDA)); + TargetWrapperCuda::MemsetAsync(out_tensor->mutable_data(TARGET(kCUDA)), + 0, + sizeof(T) * out_tensor->numel(), + cuda_stream); + + auto x_dims = x_tensor->dims(); + int num = x_dims[0]; + int channel = x_dims[1]; + int height = x_dims[2]; + int width = x_dims[3]; + + const int *height_offset = _height_offset.data(); + const int *width_offset = _width_offset.data(); + + int feat_map_size = height * width; + dim3 blocks(num, channel); + dim3 threads(32, 1); + topk_avg_pooling_kernel_by_row_improve< + T><<>>( + out_data, + in_data, + height_offset, + width_offset, + height, + width, + param.topks.size(), + _top_ks.data(), + param.channel_num); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + sequence_topk_avg_pooling, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::SequenceTopkAvgPoolingCompute, + def) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindInput("ROW", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindInput("COLUMN", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("pos", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .Finalize(); diff --git a/lite/kernels/cuda/sequence_topk_avg_pooling_compute.h b/lite/kernels/cuda/sequence_topk_avg_pooling_compute.h new file mode 100644 index 0000000000..321ec9cfce --- /dev/null +++ b/lite/kernels/cuda/sequence_topk_avg_pooling_compute.h @@ -0,0 +1,43 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/backends/cuda/cuda_utils.h" +#include "lite/core/kernel.h" +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +template +class SequenceTopkAvgPoolingCompute + : public KernelLite { + public: + using param_t = operators::SequenceTopkAvgPoolingParam; + + void Run() override; + + virtual ~SequenceTopkAvgPoolingCompute() = default; + + protected: + lite::Tensor _height_offset; + lite::Tensor _width_offset; + lite::Tensor _top_ks; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle -- GitLab