From 126691f01f96aa0dbe34417165cb0cb09b7e557a Mon Sep 17 00:00:00 2001 From: Wilber Date: Tue, 30 Jun 2020 21:42:59 +0800 Subject: [PATCH] [Kernels] [CUDA] Add sequence_pad & unpad cuda kernel. (#3858) --- lite/backends/cuda/math/CMakeLists.txt | 2 + lite/backends/cuda/math/sequence_padding.cu | 148 +++++++++++++++ lite/backends/cuda/math/sequence_padding.h | 51 ++++++ lite/kernels/cuda/CMakeLists.txt | 4 + lite/kernels/cuda/sequence_pad_compute.cu | 93 ++++++++++ lite/kernels/cuda/sequence_pad_compute.h | 41 +++++ .../kernels/cuda/sequence_pad_compute_test.cc | 170 ++++++++++++++++++ lite/kernels/cuda/sequence_unpad_compute.cu | 81 +++++++++ lite/kernels/cuda/sequence_unpad_compute.h | 40 +++++ .../cuda/sequence_unpad_compute_test.cc | 153 ++++++++++++++++ lite/operators/CMakeLists.txt | 1 + lite/operators/op_params.h | 8 + lite/operators/sequence_pad_op.cc | 102 +++++++++++ lite/operators/sequence_pad_op.h | 45 +++++ 14 files changed, 939 insertions(+) create mode 100644 lite/backends/cuda/math/sequence_padding.cu create mode 100644 lite/backends/cuda/math/sequence_padding.h create mode 100644 lite/kernels/cuda/sequence_pad_compute.cu create mode 100644 lite/kernels/cuda/sequence_pad_compute.h create mode 100644 lite/kernels/cuda/sequence_pad_compute_test.cc create mode 100644 lite/kernels/cuda/sequence_unpad_compute.cu create mode 100644 lite/kernels/cuda/sequence_unpad_compute.h create mode 100644 lite/kernels/cuda/sequence_unpad_compute_test.cc create mode 100644 lite/operators/sequence_pad_op.cc create mode 100644 lite/operators/sequence_pad_op.h diff --git a/lite/backends/cuda/math/CMakeLists.txt b/lite/backends/cuda/math/CMakeLists.txt index 9e33d38fee..9f82d11eac 100644 --- a/lite/backends/cuda/math/CMakeLists.txt +++ b/lite/backends/cuda/math/CMakeLists.txt @@ -13,6 +13,7 @@ nv_library(cuda_elementwise SRCS elementwise.cu DEPS ${cuda_static_deps}) nv_library(cudnn_pool SRCS cudnn_pool.cc DEPS ${cuda_static_deps}) nv_library(cuda_gemm SRCS gemm.cc DEPS ${cuda_static_deps}) nv_library(cuda_batched_gemm SRCS batched_gemm.cc DEPS ${cuda_static_deps}) +nv_library(cuda_sequence_padding SRCS sequence_padding.cu DEPS ${cuda_static_deps}) set ( math_cuda @@ -25,6 +26,7 @@ set ( cudnn_pool cuda_gemm cuda_batched_gemm + cuda_sequence_padding ) set(math_cuda "${math_cuda}" CACHE GLOBAL "math cuda") diff --git a/lite/backends/cuda/math/sequence_padding.cu b/lite/backends/cuda/math/sequence_padding.cu new file mode 100644 index 0000000000..6d38adcd48 --- /dev/null +++ b/lite/backends/cuda/math/sequence_padding.cu @@ -0,0 +1,148 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/backends/cuda/cuda_utils.h" +#include "lite/backends/cuda/math/sequence_padding.h" +#include "lite/backends/cuda/math/utils.h" + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +enum CopyType { kSeqToPad, kPadToSeq }; + +template +__global__ void SequencePadKernel(T* dst, + const T* src, + const T* pad_value, + bool is_constant_pad, + const size_t* seq_offsets, + const int seq_num, + const int pad_seq_len, + const int step_width) { + size_t seq_idx = blockIdx.y; + size_t seq_len = seq_offsets[seq_idx + 1] - seq_offsets[seq_idx]; + + size_t step_idx = blockIdx.x * blockDim.y + threadIdx.y; + size_t seq_data_offset = (seq_offsets[seq_idx] + step_idx) * step_width; + size_t pad_data_offset = (seq_idx * pad_seq_len + step_idx) * step_width; + T* dst_data = dst + (Type == kSeqToPad ? pad_data_offset : seq_data_offset); + const T* src_data = + src + (Type == kSeqToPad ? seq_data_offset : pad_data_offset); + + if (step_idx < seq_len) { + for (size_t i = threadIdx.x; i < step_width; i += blockDim.x) { + dst_data[i] = src_data[i]; + } + } else if (step_idx < pad_seq_len && Type == kSeqToPad) { + for (size_t i = threadIdx.x; i < step_width; i += blockDim.x) { + dst_data[i] = is_constant_pad ? pad_value[0] : pad_value[i]; + } + } +} + +template +void SequencePadding(T* pad_data, + const T* seq_data, + const T* pad_value_data, + bool is_constant_pad, + const size_t* seq_offsets_data, + int seq_num, + int pad_seq_len, + int step_width, + cudaStream_t* stream) { + const int kBlockSize = 512; + /* At least use 32 threads to copy sequence_width elements, + * and at least 8 elements for each thread. + */ + size_t block_dim_x = + std::min(((((step_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize); + size_t block_dim_y = kBlockSize / block_dim_x; + dim3 threads(block_dim_x, block_dim_y); + + size_t grid_dim_x = (pad_seq_len + block_dim_y - 1) / block_dim_y; + size_t grid_dim_y = seq_num; + dim3 grid(grid_dim_x, grid_dim_y); + + SequencePadKernel<<>>( + pad_data, + seq_data, + pad_value_data, + is_constant_pad, + seq_offsets_data, + seq_num, + pad_seq_len, + step_width); + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error); +} + +template +void SequenceUnpadding(T* seq_data, + const T* pad_data, + const size_t* seq_offsets_data, + int seq_num, + int pad_seq_len, + int step_width, + cudaStream_t* stream) { + const int kBlockSize = 512; + /* At least use 32 threads to copy sequence_width elements, + * and at least 8 elements for each thread. + */ + size_t block_dim_x = + std::min(((((step_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize); + size_t block_dim_y = kBlockSize / block_dim_x; + dim3 threads(block_dim_x, block_dim_y); + + size_t grid_dim_x = (pad_seq_len + block_dim_y - 1) / block_dim_y; + size_t grid_dim_y = seq_num; + dim3 grid(grid_dim_x, grid_dim_y); + + SequencePadKernel<<>>( + seq_data, + pad_data, + nullptr, + false, + seq_offsets_data, + seq_num, + pad_seq_len, + step_width); + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error); +} + +template void SequencePadding(float* pad_data, + const float* seq_data, + const float* pad_value_data, + bool is_constant_pad, + const size_t* seq_offsets_data, + int seq_num, + int pad_seq_len, + int step_width, + cudaStream_t* stream); + +template void SequenceUnpadding(float* seq_data, + const float* pad_data, + const size_t* seq_offsets_data, + int seq_num, + int pad_seq_len, + int step_width, + cudaStream_t* stream); + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/math/sequence_padding.h b/lite/backends/cuda/math/sequence_padding.h new file mode 100644 index 0000000000..cfbac9b5bc --- /dev/null +++ b/lite/backends/cuda/math/sequence_padding.h @@ -0,0 +1,51 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include "lite/core/context.h" +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +template +void SequenceUnpadding(T* seq_data, + const T* pad_data, + const size_t* seq_offsets_data, + int seq_num, + int pad_seq_len, + int step_width, + cudaStream_t* stream); + +template +void SequencePadding(T* pad_data, + const T* seq_data, + const T* pad_value_data, + bool is_constant_pad, + const size_t* seq_offsets_data, + int seq_num, + int pad_seq_len, + int step_width, + cudaStream_t* stream); + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/CMakeLists.txt b/lite/kernels/cuda/CMakeLists.txt index a70d7e8004..3e92a69c54 100644 --- a/lite/kernels/cuda/CMakeLists.txt +++ b/lite/kernels/cuda/CMakeLists.txt @@ -34,6 +34,8 @@ add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute. add_kernel(search_seq_depadding_compute_cuda CUDA extra SRCS search_seq_depadding_compute.cu DEPS ${lite_kernel_deps}) add_kernel(search_grnn_compute_cuda CUDA extra SRCS search_grnn_compute.cu DEPS ${lite_kernel_deps} cuda_gemm ${math_cuda}) add_kernel(sequence_reverse_compute_cuda CUDA extra SRCS sequence_reverse_compute.cu DEPS ${lite_kernel_deps}) +add_kernel(sequence_pad_compute_cuda CUDA extra SRCS sequence_pad_compute.cu DEPS ${lite_kernel_deps} ${math_cuda}) +add_kernel(sequence_unpad_compute_cuda CUDA extra SRCS sequence_unpad_compute.cu DEPS ${lite_kernel_deps} ${math_cuda}) add_kernel(sequence_concat_compute_cuda CUDA extra SRCS sequence_concat_compute.cu DEPS ${lite_kernel_deps}) add_kernel(sequence_arithmetic_compute_cuda CUDA extra SRCS sequence_arithmetic_compute.cu DEPS ${lite_kernel_deps}) add_kernel(lookup_table_compute_cuda CUDA extra SRCS lookup_table_compute.cu DEPS ${lite_kernel_deps}) @@ -75,6 +77,8 @@ if(LITE_BUILD_EXTRA) nv_test(search_aligned_mat_mul_compute_cuda_test SRCS search_aligned_mat_mul_compute_test.cc DEPS search_aligned_mat_mul_compute_cuda) nv_test(search_seq_fc_compute_cuda_test SRCS search_seq_fc_compute_test.cc DEPS search_seq_fc_compute_cuda) nv_test(sequence_reverse_compute_cuda_test SRCS sequence_reverse_compute_test.cc DEPS sequence_reverse_compute_cuda) + nv_test(sequence_pad_compute_cuda_test SRCS sequence_pad_compute_test.cc DEPS sequence_pad_compute_cuda) + nv_test(sequence_unpad_compute_cuda_test SRCS sequence_unpad_compute_test.cc DEPS sequence_unpad_compute_cuda) nv_test(var_conv_2d_compute_cuda_test SRCS var_conv_2d_compute_test.cc DEPS var_conv_2d_compute_cuda) #nv_test(sequence_concat_compute_cuda_test SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_cuda) #nv_test(attention_padding_mask_compute_cuda_test SRCS attention_padding_mask_compute_test.cc DEPS attention_padding_mask_compute_cuda) diff --git a/lite/kernels/cuda/sequence_pad_compute.cu b/lite/kernels/cuda/sequence_pad_compute.cu new file mode 100644 index 0000000000..a4dede84ef --- /dev/null +++ b/lite/kernels/cuda/sequence_pad_compute.cu @@ -0,0 +1,93 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/cuda/math/sequence_padding.h" +#include "lite/core/op_registry.h" +#include "lite/core/target_wrapper.h" +#include "lite/kernels/cuda/sequence_pad_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +template +void SequencePadCompute::Run() { + auto& param = this->template Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + + const auto* x = param.X; + const auto* pad_value = param.PadValue; + auto* out = param.Out; + auto* len_t = param.Length; + int padded_length = param.padded_length; + + int seq_num = x->lod()[0].size() - 1; + int max_seq_len = 0; + int step_width = x->numel() / x->dims()[0]; + + // calc for param.Lenght + seq_len_.resize(seq_num); + seq_offsets_vec_.resize(x->lod()[0].size()); + for (size_t i = 0; i < seq_num; ++i) { + max_seq_len = std::max( + max_seq_len, static_cast(x->lod()[0][i + 1] - x->lod()[0][i])); + seq_len_[i] = x->lod()[0][i + 1] - x->lod()[0][i]; + seq_offsets_vec_[i] = x->lod()[0][i]; + } + seq_offsets_vec_[seq_num] = x->lod()[0][seq_num]; + TargetWrapperCuda::MemcpyAsync( + len_t->template mutable_data(TARGET(kCUDA)), + seq_len_.data(), + sizeof(int64_t) * seq_len_.size(), + IoDirection::HtoD, + stream); + seq_offsets_.Resize({static_cast(x->lod()[0].size())}); + TargetWrapperCuda::MemcpyAsync( + seq_offsets_.mutable_data(TARGET(kCUDA)), + seq_offsets_vec_.data(), + sizeof(size_t) * seq_offsets_vec_.size(), + IoDirection::HtoD, + stream); + + const T* seq_data = x->template data(); + T* pad_data = out->template mutable_data(TARGET(kCUDA)); + const T* pad_value_data = pad_value->template data(); + + lite::cuda::math::SequencePadding(pad_data, + seq_data, + pad_value_data, + pad_value->numel() == 1, + seq_offsets_.data(), + seq_num, + padded_length, + step_width, + &stream); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +using SeqPadFp32 = + paddle::lite::kernels::cuda::SequencePadCompute; + +REGISTER_LITE_KERNEL(sequence_pad, kCUDA, kFloat, kNCHW, SeqPadFp32, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("PadValue", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Length", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); diff --git a/lite/kernels/cuda/sequence_pad_compute.h b/lite/kernels/cuda/sequence_pad_compute.h new file mode 100644 index 0000000000..c494fe127d --- /dev/null +++ b/lite/kernels/cuda/sequence_pad_compute.h @@ -0,0 +1,41 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +template +class SequencePadCompute : public KernelLite { + public: + using param_t = operators::SequencePadParam; + + void Run() override; + virtual ~SequencePadCompute() = default; + + private: + lite::Tensor seq_offsets_; + std::vector seq_len_; + std::vector seq_offsets_vec_; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/sequence_pad_compute_test.cc b/lite/kernels/cuda/sequence_pad_compute_test.cc new file mode 100644 index 0000000000..ba168939ab --- /dev/null +++ b/lite/kernels/cuda/sequence_pad_compute_test.cc @@ -0,0 +1,170 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/sequence_pad_compute.h" + +#include + +#include +#include +#include +#include + +#include "lite/api/test_helper.h" +#include "lite/backends/cuda/cuda_utils.h" +// #include "lite/utils/float16.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class SequencePadTest : public ::testing::Test { + protected: + SequencePadTest() + : batch(5), + features(2), + padded_length(3), + x_lod({{0, 2, 5}}), + x_shape({batch, features}), + pad_value_shape({features}), + out_shape({static_cast(x_lod[0].size() - 1), + padded_length, + features}) { + X_ref.Resize(lite::DDim(x_shape)); + X_ref.set_lod(x_lod); + X_gpu.Resize(X_ref.dims()); + + PadValue_ref.Resize(lite::DDim(pad_value_shape)); + PadValue_gpu.Resize(PadValue_ref.dims()); + + Length_ref.Resize(lite::DDim({static_cast(x_lod[0].size() - 1)})); + Length_gpu.Resize(Length_ref.dims()); + + auto x_ref_data = X_ref.mutable_data(); + auto pad_value_ref_data = PadValue_ref.mutable_data(); + + // prepare input + for (int64_t i = 0; i < X_ref.numel(); i++) { + x_ref_data[i] = static_cast(i); + } + for (int64_t i = 0; i < PadValue_ref.numel(); i++) { + pad_value_ref_data[i] = static_cast(i); + } + + Out_ref.Resize(lite::DDim(out_shape)); + Out_gpu.Resize(Out_ref.dims()); + Out_cpu.Resize(Out_ref.dims()); + cpu_base(&X_ref, &PadValue_ref, &Out_ref, &Length_ref); + + device_init(); + } + + void device_init() { + ctx.reset(new KernelContext); + cudaStreamCreate(&stream); + param.X = &X_gpu; + param.PadValue = &PadValue_gpu; + param.Length = &Length_gpu; + param.Out = &Out_gpu; + param.padded_length = padded_length; + } + + void float_data_init() { + X_gpu.Assign(X_ref.data(), + X_gpu.dims()); + X_gpu.set_lod(X_ref.lod()); + PadValue_gpu.Assign( + PadValue_ref.data(), PadValue_gpu.dims()); + } + + void half_data_init() {} + + void cpu_base(const lite::Tensor* X, + const lite::Tensor* PadValue, + lite::Tensor* Out, + lite::Tensor* Length) { + auto* length_data = Length->mutable_data(); + auto* out_data = Out->mutable_data(); + length_data[0] = 2; + length_data[1] = 3; + + for (size_t i = 0; i < 4; ++i) { + out_data[i] = i; + } + out_data[4] = 0; + out_data[5] = 1; + for (size_t i = 4; i < 10; ++i) { + out_data[2 + i] = i; + } + } + + int batch, features, padded_length; + LoD x_lod; + std::vector x_shape, pad_value_shape, out_shape; + + lite::Tensor X_ref, PadValue_ref, Out_ref, Length_ref; + lite::Tensor X_gpu, PadValue_gpu, Out_gpu, Length_gpu; + lite::Tensor Out_cpu, Length_cpu; + + operators::SequencePadParam param; + std::unique_ptr ctx; + cudaStream_t stream; +}; + +TEST_F(SequencePadTest, fp32) { + float_data_init(); + auto& context = ctx->As(); + context.SetExecStream(stream); + SequencePadCompute kernel; + kernel.SetParam(param); + kernel.SetContext(std::move(ctx)); + + for (int i = 0; i < FLAGS_warmup; ++i) { + kernel.Launch(); + cudaDeviceSynchronize(); + } + + auto start = GetCurrentUS(); + kernel.PrepareForRun(); + for (int i = 0; i < FLAGS_repeats; ++i) { + kernel.Run(); + } + cudaDeviceSynchronize(); + auto duration = (GetCurrentUS() - start) / 1000.0; + LOG(INFO) << "fp32, warmup: " << FLAGS_warmup + << ", repeats: " << FLAGS_repeats << ", spend " + << duration / FLAGS_repeats << " ms in average."; + + CopySync(Out_cpu.mutable_data(), + Out_gpu.data(), + sizeof(float) * Out_gpu.numel(), + IoDirection::DtoH); + CopySync(Length_cpu.mutable_data(), + Length_gpu.data(), + sizeof(int64_t) * Length_gpu.numel(), + IoDirection::DtoH); + for (int i = 0; i < Out_gpu.numel(); ++i) { + EXPECT_NEAR(Out_cpu.data()[i], Out_ref.data()[i], 1e-5); + } + for (int i = 0; i < Length_gpu.numel(); ++i) { + EXPECT_NEAR( + Length_cpu.data()[i], Length_ref.data()[i], 1e-5); + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/sequence_unpad_compute.cu b/lite/kernels/cuda/sequence_unpad_compute.cu new file mode 100644 index 0000000000..7b0d95bc12 --- /dev/null +++ b/lite/kernels/cuda/sequence_unpad_compute.cu @@ -0,0 +1,81 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/backends/cuda/math/sequence_padding.h" +#include "lite/core/op_registry.h" +#include "lite/core/target_wrapper.h" +#include "lite/kernels/cuda/sequence_unpad_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +template +void SequenceUnpadCompute::Run() { + auto& param = this->template Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + + const auto* pad_tensor = param.X; + const auto* len_t = param.Length; + auto* seq_tensor = param.Out; + + int padded_length = pad_tensor->dims()[1]; + int seq_num = seq_tensor->lod()[0].size() - 1; + int max_seq_len = 0; + int step_width = seq_tensor->numel() / seq_tensor->dims()[0]; + + seq_offsets_vec_.resize(seq_tensor->lod()[0].size()); + for (size_t i = 0; i < seq_num; ++i) { + max_seq_len = std::max(max_seq_len, + static_cast(seq_tensor->lod()[0][i + 1] - + seq_tensor->lod()[0][i])); + seq_offsets_vec_[i] = seq_tensor->lod()[0][i]; + } + seq_offsets_vec_[seq_num] = seq_tensor->lod()[0][seq_num]; + seq_offsets_.Resize({static_cast(seq_tensor->lod()[0].size())}); + TargetWrapperCuda::MemcpyAsync( + seq_offsets_.mutable_data(TARGET(kCUDA)), + seq_offsets_vec_.data(), + sizeof(size_t) * seq_offsets_vec_.size(), + IoDirection::HtoD, + stream); + + const T* pad_data = pad_tensor->template data(); + T* seq_data = seq_tensor->template mutable_data(TARGET(kCUDA)); + + lite::cuda::math::SequenceUnpadding(seq_data, + pad_data, + seq_offsets_.data(), + seq_num, + padded_length, + step_width, + &stream); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +using SeqUnadFp32 = + paddle::lite::kernels::cuda::SequenceUnpadCompute; + +REGISTER_LITE_KERNEL(sequence_unpad, kCUDA, kFloat, kNCHW, SeqUnadFp32, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("Length", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); diff --git a/lite/kernels/cuda/sequence_unpad_compute.h b/lite/kernels/cuda/sequence_unpad_compute.h new file mode 100644 index 0000000000..f36520ea15 --- /dev/null +++ b/lite/kernels/cuda/sequence_unpad_compute.h @@ -0,0 +1,40 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +template +class SequenceUnpadCompute : public KernelLite { + public: + using param_t = operators::SequenceUnpadParam; + + void Run() override; + virtual ~SequenceUnpadCompute() = default; + + private: + lite::Tensor seq_offsets_; + std::vector seq_offsets_vec_; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/sequence_unpad_compute_test.cc b/lite/kernels/cuda/sequence_unpad_compute_test.cc new file mode 100644 index 0000000000..a76f9e5af2 --- /dev/null +++ b/lite/kernels/cuda/sequence_unpad_compute_test.cc @@ -0,0 +1,153 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/sequence_unpad_compute.h" + +#include + +#include +#include +#include +#include + +#include "lite/api/test_helper.h" +#include "lite/backends/cuda/cuda_utils.h" +// #include "lite/utils/float16.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class SequenceUnpadTest : public ::testing::Test { + protected: + SequenceUnpadTest() + : batch(5), + features(2), + padded_length(3), + out_lod({{0, 2, 5}}), + x_shape({static_cast(out_lod[0].size() - 1), + padded_length, + features}), + out_shape({batch, features}) { + X_ref.Resize(lite::DDim(x_shape)); + X_gpu.Resize(X_ref.dims()); + + Length_ref.Resize( + lite::DDim({static_cast(out_lod[0].size() - 1)})); + Length_gpu.Resize(Length_ref.dims()); + + auto* x_ref_data = X_ref.mutable_data(); + auto* length_ref_data = Length_ref.mutable_data(); + + // prepare input + for (int64_t i = 0; i < X_ref.numel(); i++) { + x_ref_data[i] = static_cast(i); + } + for (size_t i = 0; i < out_lod[0].size() - 1; ++i) { + length_ref_data[i] = out_lod[0][i + 1] - out_lod[0][i]; + } + + Out_ref.Resize(lite::DDim(out_shape)); + Out_ref.set_lod(out_lod); + Out_gpu.Resize(Out_ref.dims()); + Out_gpu.set_lod(Out_ref.lod()); + Out_cpu.Resize(Out_ref.dims()); + Out_cpu.set_lod(Out_ref.lod()); + + cpu_base(&X_ref, &Length_ref, &Out_ref); + + device_init(); + } + + void device_init() { + ctx.reset(new KernelContext); + cudaStreamCreate(&stream); + param.X = &X_gpu; + param.Length = &Length_gpu; + param.Out = &Out_gpu; + } + + void float_data_init() { + X_gpu.Assign(X_ref.data(), + X_gpu.dims()); + Length_gpu.Assign( + Length_ref.data(), Length_gpu.dims()); + } + + void half_data_init() {} + + void cpu_base(const lite::Tensor* X, + const lite::Tensor* Length, + lite::Tensor* Out) { + auto* out_data = Out->mutable_data(); + + for (size_t i = 0; i < 4; ++i) { + out_data[i] = i; + } + for (size_t i = 6; i < 12; ++i) { + out_data[i - 2] = i; + } + } + + int batch, features, padded_length; + LoD out_lod; + std::vector x_shape, out_shape; + + lite::Tensor X_ref, Out_ref, Length_ref; + lite::Tensor X_gpu, Out_gpu, Length_gpu; + lite::Tensor Out_cpu, Length_cpu; + + operators::SequencePadParam param; + std::unique_ptr ctx; + cudaStream_t stream; +}; + +TEST_F(SequenceUnpadTest, fp32) { + float_data_init(); + auto& context = ctx->As(); + context.SetExecStream(stream); + SequenceUnpadCompute kernel; + kernel.SetParam(param); + kernel.SetContext(std::move(ctx)); + + for (int i = 0; i < FLAGS_warmup; ++i) { + kernel.Launch(); + cudaDeviceSynchronize(); + } + + auto start = GetCurrentUS(); + kernel.PrepareForRun(); + for (int i = 0; i < FLAGS_repeats; ++i) { + kernel.Run(); + } + cudaDeviceSynchronize(); + auto duration = (GetCurrentUS() - start) / 1000.0; + LOG(INFO) << "fp32, warmup: " << FLAGS_warmup + << ", repeats: " << FLAGS_repeats << ", spend " + << duration / FLAGS_repeats << " ms in average."; + + CopySync(Out_cpu.mutable_data(), + Out_gpu.data(), + sizeof(float) * Out_gpu.numel(), + IoDirection::DtoH); + for (int i = 0; i < Out_gpu.numel(); ++i) { + EXPECT_NEAR(Out_cpu.data()[i], Out_ref.data()[i], 1e-5); + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 2fc2c58593..1186858eb8 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -77,6 +77,7 @@ add_operator(reduce_max_op_lite extra SRCS reduce_max_op.cc DEPS ${op_DEPS}) add_operator(shape_op_lite extra SRCS shape_op.cc DEPS ${op_DEPS}) add_operator(sequence_expand_op_lite extra SRCS sequence_expand_op.cc DEPS ${op_DEPS}) add_operator(sequence_unpad_op_lite extra SRCS sequence_unpad_op.cc DEPS ${op_DEPS}) +add_operator(sequence_pad_op_lite extra SRCS sequence_pad_op.cc DEPS ${op_DEPS}) add_operator(im2sequence_op extra SRCS im2sequence_op.cc DEPS ${op_DEPS}) add_operator(gather_op extra SRCS gather_op.cc DEPS ${op_DEPS}) add_operator(anchor_generator_op extra SRCS anchor_generator_op.cc DEPS ${op_DEPS}) diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 457abaffeb..9c97ce4ed2 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -1031,6 +1031,14 @@ struct SequenceExpandParam : ParamBase { int ref_level{-1}; }; +struct SequencePadParam : ParamBase { + const lite::Tensor* X{}; + const lite::Tensor* PadValue{}; + lite::Tensor* Out{}; + lite::Tensor* Length{}; + int padded_length{-1}; +}; + struct SequenceUnpadParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Length{}; diff --git a/lite/operators/sequence_pad_op.cc b/lite/operators/sequence_pad_op.cc new file mode 100644 index 0000000000..687c4a1989 --- /dev/null +++ b/lite/operators/sequence_pad_op.cc @@ -0,0 +1,102 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/sequence_pad_op.h" +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool SequencePadOp::CheckShape() const { + CHECK_OR_FALSE(param_.X); + CHECK_OR_FALSE(param_.PadValue); + CHECK_OR_FALSE(param_.Out); + CHECK_OR_FALSE(param_.Length); + + return true; +} + +bool SequencePadOp::InferShapeImpl() const { + auto x_dims = param_.X->dims(); + CHECK_GE(x_dims.size(), 2) << "The rank of SequencePad OP Input(x) can't be " + "less than 2. But the rank we received is " + << x_dims.size(); + auto time_step_dims = x_dims.Slice(1, x_dims.size()); + auto pad_value_dims = param_.PadValue->dims(); + CHECK_EQ((pad_value_dims == DDim({1})) || (pad_value_dims == time_step_dims), + true) + << "The SequencePad OP Input(PadValue) must be a scalar or a tensor " + "whiose shape equals to time steps in sequences"; + + auto x_lod = param_.X->lod(); + CHECK_EQ(x_lod.empty(), false) + << "The SequencePad OP Input(X) must hold lod info."; + const auto &x_lod_0 = x_lod[0]; + CHECK_GE(x_lod_0.size(), 2) + << "The size of SequencePadOp Input(X)'s lod info can't be less than 2. " + "But the size we received is " + << x_lod_0.size(); + CHECK_EQ(x_dims[0], static_cast(x_lod_0.back())) + << "The SequencePadOp Input(X)'s lod info mismatches the actual tensor " + "shape. The 1st dimension of Input(X)'s lod info is " + << x_dims[0] << ", the 1st dimension of actual tensor shape is " + << static_cast(x_lod_0.back()); + + int seq_num = x_lod_0.size() - 1; + int max_seq_len = 0; + for (int i = 0; i < seq_num; ++i) { + max_seq_len = + std::max(max_seq_len, static_cast(x_lod_0[i + 1] - x_lod_0[i])); + } + if (param_.padded_length == -1) { + param_.padded_length = max_seq_len; + } + CHECK_GE(param_.padded_length, max_seq_len) + << "The SequencePadOp Attr(padded_length) should be greater than or " + "equal to the length of the longest original sequence. But the " + "padded_length we received is " + << param_.padded_length + << ", the length of the longest original sequence is " << max_seq_len; + + int out_dim_0 = seq_num; + std::vector out_dims_vec{out_dim_0, param_.padded_length}; + std::vector len_dims_vec{out_dim_0}; + auto time_step_dims_vec = time_step_dims.Vectorize(); + out_dims_vec.insert( + out_dims_vec.end(), time_step_dims_vec.begin(), time_step_dims_vec.end()); + param_.Out->Resize(out_dims_vec); + param_.Length->Resize(len_dims_vec); + return true; +} + +bool SequencePadOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { + param_.X = const_cast( + &scope->FindVar(opdesc.Input("X").front())->Get()); + param_.PadValue = const_cast( + &scope->FindVar(opdesc.Input("PadValue").front())->Get()); + param_.Length = scope->FindVar(opdesc.Input("Length").front()) + ->GetMutable(); + param_.Out = + scope->FindVar(opdesc.Output("Out").front())->GetMutable(); + param_.padded_length = opdesc.GetAttr("padded_length"); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(sequence_pad, paddle::lite::operators::SequencePadOp); diff --git a/lite/operators/sequence_pad_op.h b/lite/operators/sequence_pad_op.h new file mode 100644 index 0000000000..bd5d732a5d --- /dev/null +++ b/lite/operators/sequence_pad_op.h @@ -0,0 +1,45 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" + +namespace paddle { +namespace lite { +namespace operators { + +class SequencePadOp : public OpLite { + public: + SequencePadOp() {} + explicit SequencePadOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShapeImpl() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "sequence_pad"; } + + private: + mutable SequencePadParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle -- GitLab