From 3c6b6d2dbe7c70d92b6a48a25f68b20029d5e59e Mon Sep 17 00:00:00 2001 From: Pei Yang Date: Mon, 18 Nov 2019 14:19:53 +0800 Subject: [PATCH] add sequence_pool cuda kernel, test=develop (#2430) add sequence_pool cuda kernel --- lite/kernels/cuda/CMakeLists.txt | 2 + lite/kernels/cuda/sequence_pool_compute.cu | 265 ++++++++++++++++++ lite/kernels/cuda/sequence_pool_compute.h | 35 +++ .../cuda/sequence_pool_compute_test.cc | 134 +++++++++ 4 files changed, 436 insertions(+) create mode 100644 lite/kernels/cuda/sequence_pool_compute.cu create mode 100644 lite/kernels/cuda/sequence_pool_compute.h create mode 100644 lite/kernels/cuda/sequence_pool_compute_test.cc diff --git a/lite/kernels/cuda/CMakeLists.txt b/lite/kernels/cuda/CMakeLists.txt index 979e7f2730..026f391933 100644 --- a/lite/kernels/cuda/CMakeLists.txt +++ b/lite/kernels/cuda/CMakeLists.txt @@ -9,6 +9,7 @@ add_kernel(io_copy_compute_cuda CUDA basic SRCS io_copy_compute.cc DEPS ${lite_k add_kernel(leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps}) add_kernel(relu_compute_cuda CUDA basic SRCS relu_compute.cu DEPS ${lite_kernel_deps}) add_kernel(yolo_box_compute_cuda CUDA basic SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps}) +add_kernel(sequence_pool_compute_cuda CUDA extra SRCS sequence_pool_compute.cu DEPS ${lite_kernel_deps}) add_kernel(transpose_compute_cuda CUDA basic SRCS transpose_compute.cu DEPS ${lite_kernel_deps} ${math_cuda} cuda_transpose) add_kernel(nearest_interp_compute_cuda CUDA basic SRCS nearest_interp_compute.cu DEPS ${lite_kernel_deps}) add_kernel(conv2d_cuda CUDA basic SRCS conv_compute.cc DEPS ${lite_kernel_deps} ${math_cuda}) @@ -38,6 +39,7 @@ nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_c nv_test(transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda) nv_test(concat_compute_cuda_test SRCS concat_compute_test.cc DEPS concat_compute_cuda) nv_test(elementwise_add_compute_cuda_test SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_cuda) +nv_test(sequence_pool_compute_cuda_test SRCS sequence_pool_compute_test.cc DEPS sequence_pool_compute_cuda sequence_pooling) nv_test(softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_compute_cuda) #nv_test(layout_cuda_test SRCS layout_compute_test.cc DEPS layout_compute_cuda) nv_test(mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda) diff --git a/lite/kernels/cuda/sequence_pool_compute.cu b/lite/kernels/cuda/sequence_pool_compute.cu new file mode 100644 index 0000000000..34853adf92 --- /dev/null +++ b/lite/kernels/cuda/sequence_pool_compute.cu @@ -0,0 +1,265 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/core/op_registry.h" +#include "lite/core/target_wrapper.h" +#include "lite/kernels/cuda/sequence_pool_compute.h" + +const int CUDA_NUM_THREADS = 512; +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +/// CUDA: number of blocks for threads. +inline int CUDA_GET_BLOCKS(const int N) { + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} +inline int CUDA_GET_BLOCKS(const int N, const int base) { + return (N + base - 1) / base; +} + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +template +__global__ void seq_pool_average_kernel(Dtype* dst, + const Dtype* src_in, + const int batch_size, + const uint64_t* seq_offset, + const int slice_size) { + int total = slice_size * batch_size; + CUDA_KERNEL_LOOP(tid, total) { + int out_batch_id = tid / slice_size; + int out_id = tid % slice_size; + int in_slice_num = static_cast(seq_offset[out_batch_id + 1] - + seq_offset[out_batch_id]); + int in_offset = static_cast(seq_offset[out_batch_id] * slice_size); + src_in += in_offset + out_id; + Dtype sum = (Dtype)0; + for (int i = 0; i < in_slice_num; ++i) { + sum += src_in[i * slice_size]; + } + dst[out_batch_id * slice_size + out_id] = sum / in_slice_num; + } +} + +template +__global__ void seq_pool_sum_kernel(Dtype* dst, + const Dtype* src_in, + const int batch_size, + const uint64_t* seq_offset, + const int slice_size) { + int total = slice_size * batch_size; + CUDA_KERNEL_LOOP(tid, total) { + int out_batch_id = tid / slice_size; + int out_id = tid % slice_size; + int in_slice_num = static_cast(seq_offset[out_batch_id + 1] - + seq_offset[out_batch_id]); + int in_offset = static_cast(seq_offset[out_batch_id] * slice_size); + src_in += in_offset + out_id; + Dtype sum = (Dtype)0; + for (int i = 0; i < in_slice_num; ++i) { + sum += src_in[i * slice_size]; + } + dst[out_batch_id * slice_size + out_id] = sum; + } +} + +template +__global__ void seq_pool_sqrt_kernel(Dtype* dst, + const Dtype* src_in, + const int batch_size, + const uint64_t* seq_offset, + const int slice_size) { + int total = slice_size * batch_size; + CUDA_KERNEL_LOOP(tid, total) { + int out_batch_id = tid / slice_size; + int out_id = tid % slice_size; + int in_slice_num = static_cast(seq_offset[out_batch_id + 1] - + seq_offset[out_batch_id]); + int in_offset = static_cast(seq_offset[out_batch_id] * slice_size); + src_in += in_offset + out_id; + Dtype sum = (Dtype)0; + for (int i = 0; i < in_slice_num; ++i) { + sum += src_in[i * slice_size]; + } + dst[out_batch_id * slice_size + out_id] = sum * rsqrtf(in_slice_num); + } +} + +template +__global__ void seq_pool_max_kernel(Dtype* dst, + const Dtype* src_in, + const int batch_size, + const uint64_t* seq_offset, + const int slice_size) { + int total = slice_size * batch_size; + CUDA_KERNEL_LOOP(tid, total) { + int out_batch_id = tid / slice_size; + int out_id = tid % slice_size; + int in_slice_num = static_cast(seq_offset[out_batch_id + 1] - + seq_offset[out_batch_id]); + int in_offset = static_cast(seq_offset[out_batch_id] * slice_size); + src_in += in_offset + out_id; + Dtype max = src_in[0]; + for (int i = 1; i < in_slice_num; ++i) { + Dtype val = src_in[i * slice_size]; + if (val > max) { + max = val; + } + } + dst[out_batch_id * slice_size + out_id] = max; + } +} + +template +__global__ void seq_pool_last_kernel(Dtype* dst, + const Dtype* src_in, + const int batch_size, + const uint64_t* seq_offset, + const int slice_size) { + int total = slice_size * batch_size; + CUDA_KERNEL_LOOP(tid, total) { + int out_batch_id = tid / slice_size; + int out_id = tid % slice_size; + int in_offset = + (static_cast(seq_offset[out_batch_id + 1]) - 1) * slice_size; + dst[tid] = src_in[in_offset + out_id]; + } +} + +template +__global__ void seq_pool_first_kernel(Dtype* dst, + const Dtype* src_in, + const int batch_size, + const uint64_t* seq_offset, + const int slice_size) { + int total = slice_size * batch_size; + CUDA_KERNEL_LOOP(tid, total) { + int out_batch_id = tid / slice_size; + int out_id = tid % slice_size; + int in_offset = static_cast(seq_offset[out_batch_id] * slice_size); + dst[tid] = src_in[in_offset + out_id]; + } +} + +void SequencePoolCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + + std::vector seq_offset = param.X->lod()[0]; + int slice_size = + param.Out->dims()[1] * param.Out->dims()[2] * param.Out->dims()[3]; + + float* out_data = param.Out->mutable_data(TARGET(kCUDA)); + const float* in_data = param.X->data(); + int batch_size = param.X->lod().size() - 1; + + lite::Tensor seq_offset_D; + seq_offset_D.Resize({static_cast(seq_offset.size())}); + TargetWrapperCuda::MemcpyAsync(seq_offset_D.mutable_data(), + seq_offset.data(), + sizeof(uint64_t) * seq_offset.size(), + IoDirection::HtoD, + stream); + + if (param.pool_type == "MAX") { + seq_pool_max_kernel<<>>(out_data, + in_data, + batch_size, + seq_offset_D.data(), + slice_size); + } else if (param.pool_type == "AVERAGE ") { + seq_pool_average_kernel<<>>(out_data, + in_data, + batch_size, + seq_offset_D.data(), + slice_size); + } else if (param.pool_type == "SUM") { + seq_pool_sum_kernel<<>>(out_data, + in_data, + batch_size, + seq_offset_D.data(), + slice_size); + } else if (param.pool_type == "SQRT") { + seq_pool_sqrt_kernel<<>>(out_data, + in_data, + batch_size, + seq_offset_D.data(), + slice_size); + } else if (param.pool_type == "FIRST") { + seq_pool_first_kernel<<>>(out_data, + in_data, + batch_size, + seq_offset_D.data(), + slice_size); + } else if (param.pool_type == "LAST") { + seq_pool_last_kernel<<>>(out_data, + in_data, + batch_size, + seq_offset_D.data(), + slice_size); + } else { + LOG(ERROR) << "pool type " << param.pool_type << " is not supoorted."; + } + + std::vector offset_new(static_cast(batch_size + 1)); + + for (int i = 0; i <= batch_size; ++i) { + offset_new[i] = i; + } + std::vector> voffset_new; + voffset_new.push_back(offset_new); + param.Out->set_lod(voffset_new); + + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(sequence_pool, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::SequencePoolCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); diff --git a/lite/kernels/cuda/sequence_pool_compute.h b/lite/kernels/cuda/sequence_pool_compute.h new file mode 100644 index 0000000000..9309454d18 --- /dev/null +++ b/lite/kernels/cuda/sequence_pool_compute.h @@ -0,0 +1,35 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class SequencePoolCompute + : public KernelLite { + public: + using param_t = operators::SequencePoolParam; + + void Run() override; + virtual ~SequencePoolCompute() = default; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/sequence_pool_compute_test.cc b/lite/kernels/cuda/sequence_pool_compute_test.cc new file mode 100644 index 0000000000..faced90c6c --- /dev/null +++ b/lite/kernels/cuda/sequence_pool_compute_test.cc @@ -0,0 +1,134 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/sequence_pool_compute.h" +#include +#include +#include +#include +#include "lite/backends/x86/math/sequence_pooling.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +namespace { + +static void sequence_pool_ref(const operators::SequencePoolParam& param, ) { + auto* x = param.X; + auto* out = param.Out; + auto dims = x->dims(); + auto lod = x->lod(); + CHECK_EQ(lod.size(), 1UL); + CHECK_GE(dims[0], static_cast(lod[0].size() - 1)); + + dims[0] = lod[0].size() - 1; + out->Resize({dims}); + out->mutable_data(); + lite::Tensor* index = nullptr; + + const bool is_test = true; + float pad_value = 0.0; + + lite::x86::math::SequencePoolFunctor pool; + pool(context, param.pool_type, pad_value, *x, out, is_test, index); +} + +#define PREPARE_INPUT_DATA(name) \ + name.Resize({name##_lod_len, feature_len}); \ + name##_cpu.Resize({name##_lod_len, feature_len}); \ + name##_ref.Resize({name##_lod_len, feature_len}); \ + name.set_lod(lod_info_##name); \ + name##_cpu.set_lod(lod_info_##name); \ + name##_ref.set_lod(lod_info_##name); \ + float* name##_cpu_data = name##_cpu.mutable_data(); \ + float* name##_ref_data = name##_ref.mutable_data(); \ + for (int i = 0; i < name##_cpu.numel(); ++i) { \ + name##_cpu_data[i] = (i - 2.0) * 1.0; \ + name##_ref_data[i] = (i - 2.0) * 1.0; \ + } \ + name.Assign(name##_cpu_data, \ + name##_cpu.dims()); + +#define PREPARE_OUTPUT_INFO(name) \ + name##_cpu.Resize({y_lod_len, feature_len}); \ + name##_ref.Resize({y_lod_len, feature_len}); \ + name.Resize({y_lod_len, feature_len}); \ + float* name##_cpu_data = name##_cpu.mutable_data(); + +} // namespace + +TEST(sequence_pool_cuda, normal) { + SequencePoolCompute seq_kernel; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + std::unique_ptr ctx_ref(new KernelContext); + auto& context_ref = ctx_ref->As(); + + operators::SequencePoolParam param; + lite::Tensor x1, x2, x3, x1_cpu, x2_cpu, x3_cpu, x1_ref, x2_ref, x3_ref; + lite::Tensor y, y_cpu, y_ref; + + int32_t x1_lod_len = 10, feature_len = 4; + int32_t x2_lod_len = 4, x3_lod_len = 8; + int32_t y_lod_len = x1_lod_len + x2_lod_len + x3_lod_len; + LoD lod_info_x1{{0, 3, 5, 6, 10}}; + LoD lod_info_x2{{0, 1, 2, 3, 4}}; + LoD lod_info_x3{{0, 2, 4, 6, 8}}; + LoD lod_info_y{{0, 0, 0, 0, 0}}; + for (size_t i = 0; i < lod_info_x1[0].size(); ++i) { + lod_info_y[0][i] = + lod_info_x1[0][i] + lod_info_x2[0][i] + lod_info_x3[0][i]; + } + + PREPARE_INPUT_DATA(x1); + PREPARE_INPUT_DATA(x2); + PREPARE_INPUT_DATA(x3); + PREPARE_OUTPUT_INFO(y); + + param.X = &x1; + param.Out = &y; + param.pool_type = "AVERAGE"; + seq_kernel.SetParam(param); + + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + seq_kernel.SetContext(std::move(ctx)); + seq_kernel.Run(); + cudaDeviceSynchronize(); + + auto* y_data = y.mutable_data(TARGET(kCUDA)); + CopySync( + y_cpu_data, y_data, sizeof(float) * y.numel(), IoDirection::DtoH); + + param.X = &x1_ref; + param.Out = &y_ref; + sequence_pool_ref(param); + + lite::x86::math::SequencePoolFunctor pool; + pool(context, param.pool_type, pad_value, *x, out, is_test, index); + + float* y_ref_data = y_ref.mutable_data(); + for (int i = 0; i < y.numel(); i++) { + EXPECT_NEAR(y_cpu_data[i], y_ref_data[i], 1e-5); + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle -- GitLab