未验证 提交 126691f0 编写于 作者: W Wilber 提交者: GitHub

[Kernels] [CUDA] Add sequence_pad & unpad cuda kernel. (#3858)

上级 ab8af5c4
......@@ -13,6 +13,7 @@ nv_library(cuda_elementwise SRCS elementwise.cu DEPS ${cuda_static_deps})
nv_library(cudnn_pool SRCS cudnn_pool.cc DEPS ${cuda_static_deps})
nv_library(cuda_gemm SRCS gemm.cc DEPS ${cuda_static_deps})
nv_library(cuda_batched_gemm SRCS batched_gemm.cc DEPS ${cuda_static_deps})
nv_library(cuda_sequence_padding SRCS sequence_padding.cu DEPS ${cuda_static_deps})
set (
math_cuda
......@@ -25,6 +26,7 @@ set (
cudnn_pool
cuda_gemm
cuda_batched_gemm
cuda_sequence_padding
)
set(math_cuda "${math_cuda}" CACHE GLOBAL "math cuda")
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/backends/cuda/math/sequence_padding.h"
#include "lite/backends/cuda/math/utils.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
enum CopyType { kSeqToPad, kPadToSeq };
template <typename T, CopyType Type>
__global__ void SequencePadKernel(T* dst,
const T* src,
const T* pad_value,
bool is_constant_pad,
const size_t* seq_offsets,
const int seq_num,
const int pad_seq_len,
const int step_width) {
size_t seq_idx = blockIdx.y;
size_t seq_len = seq_offsets[seq_idx + 1] - seq_offsets[seq_idx];
size_t step_idx = blockIdx.x * blockDim.y + threadIdx.y;
size_t seq_data_offset = (seq_offsets[seq_idx] + step_idx) * step_width;
size_t pad_data_offset = (seq_idx * pad_seq_len + step_idx) * step_width;
T* dst_data = dst + (Type == kSeqToPad ? pad_data_offset : seq_data_offset);
const T* src_data =
src + (Type == kSeqToPad ? seq_data_offset : pad_data_offset);
if (step_idx < seq_len) {
for (size_t i = threadIdx.x; i < step_width; i += blockDim.x) {
dst_data[i] = src_data[i];
}
} else if (step_idx < pad_seq_len && Type == kSeqToPad) {
for (size_t i = threadIdx.x; i < step_width; i += blockDim.x) {
dst_data[i] = is_constant_pad ? pad_value[0] : pad_value[i];
}
}
}
template <typename T>
void SequencePadding(T* pad_data,
const T* seq_data,
const T* pad_value_data,
bool is_constant_pad,
const size_t* seq_offsets_data,
int seq_num,
int pad_seq_len,
int step_width,
cudaStream_t* stream) {
const int kBlockSize = 512;
/* At least use 32 threads to copy sequence_width elements,
* and at least 8 elements for each thread.
*/
size_t block_dim_x =
std::min(((((step_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
size_t block_dim_y = kBlockSize / block_dim_x;
dim3 threads(block_dim_x, block_dim_y);
size_t grid_dim_x = (pad_seq_len + block_dim_y - 1) / block_dim_y;
size_t grid_dim_y = seq_num;
dim3 grid(grid_dim_x, grid_dim_y);
SequencePadKernel<T, kSeqToPad><<<grid, threads, 0, *stream>>>(
pad_data,
seq_data,
pad_value_data,
is_constant_pad,
seq_offsets_data,
seq_num,
pad_seq_len,
step_width);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error);
}
template <typename T>
void SequenceUnpadding(T* seq_data,
const T* pad_data,
const size_t* seq_offsets_data,
int seq_num,
int pad_seq_len,
int step_width,
cudaStream_t* stream) {
const int kBlockSize = 512;
/* At least use 32 threads to copy sequence_width elements,
* and at least 8 elements for each thread.
*/
size_t block_dim_x =
std::min(((((step_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
size_t block_dim_y = kBlockSize / block_dim_x;
dim3 threads(block_dim_x, block_dim_y);
size_t grid_dim_x = (pad_seq_len + block_dim_y - 1) / block_dim_y;
size_t grid_dim_y = seq_num;
dim3 grid(grid_dim_x, grid_dim_y);
SequencePadKernel<T, kPadToSeq><<<grid, threads, 0, *stream>>>(
seq_data,
pad_data,
nullptr,
false,
seq_offsets_data,
seq_num,
pad_seq_len,
step_width);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error);
}
template void SequencePadding(float* pad_data,
const float* seq_data,
const float* pad_value_data,
bool is_constant_pad,
const size_t* seq_offsets_data,
int seq_num,
int pad_seq_len,
int step_width,
cudaStream_t* stream);
template void SequenceUnpadding(float* seq_data,
const float* pad_data,
const size_t* seq_offsets_data,
int seq_num,
int pad_seq_len,
int step_width,
cudaStream_t* stream);
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include <string>
#include <vector>
#include "lite/core/context.h"
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename T>
void SequenceUnpadding(T* seq_data,
const T* pad_data,
const size_t* seq_offsets_data,
int seq_num,
int pad_seq_len,
int step_width,
cudaStream_t* stream);
template <typename T>
void SequencePadding(T* pad_data,
const T* seq_data,
const T* pad_value_data,
bool is_constant_pad,
const size_t* seq_offsets_data,
int seq_num,
int pad_seq_len,
int step_width,
cudaStream_t* stream);
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
......@@ -34,6 +34,8 @@ add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.
add_kernel(search_seq_depadding_compute_cuda CUDA extra SRCS search_seq_depadding_compute.cu DEPS ${lite_kernel_deps})
add_kernel(search_grnn_compute_cuda CUDA extra SRCS search_grnn_compute.cu DEPS ${lite_kernel_deps} cuda_gemm ${math_cuda})
add_kernel(sequence_reverse_compute_cuda CUDA extra SRCS sequence_reverse_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_pad_compute_cuda CUDA extra SRCS sequence_pad_compute.cu DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(sequence_unpad_compute_cuda CUDA extra SRCS sequence_unpad_compute.cu DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(sequence_concat_compute_cuda CUDA extra SRCS sequence_concat_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_arithmetic_compute_cuda CUDA extra SRCS sequence_arithmetic_compute.cu DEPS ${lite_kernel_deps})
add_kernel(lookup_table_compute_cuda CUDA extra SRCS lookup_table_compute.cu DEPS ${lite_kernel_deps})
......@@ -75,6 +77,8 @@ if(LITE_BUILD_EXTRA)
nv_test(search_aligned_mat_mul_compute_cuda_test SRCS search_aligned_mat_mul_compute_test.cc DEPS search_aligned_mat_mul_compute_cuda)
nv_test(search_seq_fc_compute_cuda_test SRCS search_seq_fc_compute_test.cc DEPS search_seq_fc_compute_cuda)
nv_test(sequence_reverse_compute_cuda_test SRCS sequence_reverse_compute_test.cc DEPS sequence_reverse_compute_cuda)
nv_test(sequence_pad_compute_cuda_test SRCS sequence_pad_compute_test.cc DEPS sequence_pad_compute_cuda)
nv_test(sequence_unpad_compute_cuda_test SRCS sequence_unpad_compute_test.cc DEPS sequence_unpad_compute_cuda)
nv_test(var_conv_2d_compute_cuda_test SRCS var_conv_2d_compute_test.cc DEPS var_conv_2d_compute_cuda)
#nv_test(sequence_concat_compute_cuda_test SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_cuda)
#nv_test(attention_padding_mask_compute_cuda_test SRCS attention_padding_mask_compute_test.cc DEPS attention_padding_mask_compute_cuda)
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/cuda/math/sequence_padding.h"
#include "lite/core/op_registry.h"
#include "lite/core/target_wrapper.h"
#include "lite/kernels/cuda/sequence_pad_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T, PrecisionType Ptype>
void SequencePadCompute<T, Ptype>::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
const auto* x = param.X;
const auto* pad_value = param.PadValue;
auto* out = param.Out;
auto* len_t = param.Length;
int padded_length = param.padded_length;
int seq_num = x->lod()[0].size() - 1;
int max_seq_len = 0;
int step_width = x->numel() / x->dims()[0];
// calc for param.Lenght
seq_len_.resize(seq_num);
seq_offsets_vec_.resize(x->lod()[0].size());
for (size_t i = 0; i < seq_num; ++i) {
max_seq_len = std::max(
max_seq_len, static_cast<int>(x->lod()[0][i + 1] - x->lod()[0][i]));
seq_len_[i] = x->lod()[0][i + 1] - x->lod()[0][i];
seq_offsets_vec_[i] = x->lod()[0][i];
}
seq_offsets_vec_[seq_num] = x->lod()[0][seq_num];
TargetWrapperCuda::MemcpyAsync(
len_t->template mutable_data<int64_t>(TARGET(kCUDA)),
seq_len_.data(),
sizeof(int64_t) * seq_len_.size(),
IoDirection::HtoD,
stream);
seq_offsets_.Resize({static_cast<int64_t>(x->lod()[0].size())});
TargetWrapperCuda::MemcpyAsync(
seq_offsets_.mutable_data<size_t>(TARGET(kCUDA)),
seq_offsets_vec_.data(),
sizeof(size_t) * seq_offsets_vec_.size(),
IoDirection::HtoD,
stream);
const T* seq_data = x->template data<T>();
T* pad_data = out->template mutable_data<T>(TARGET(kCUDA));
const T* pad_value_data = pad_value->template data<T>();
lite::cuda::math::SequencePadding(pad_data,
seq_data,
pad_value_data,
pad_value->numel() == 1,
seq_offsets_.data<size_t>(),
seq_num,
padded_length,
step_width,
&stream);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
using SeqPadFp32 =
paddle::lite::kernels::cuda::SequencePadCompute<float, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(sequence_pad, kCUDA, kFloat, kNCHW, SeqPadFp32, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("PadValue", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Length", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T, PrecisionType Ptype>
class SequencePadCompute : public KernelLite<TARGET(kCUDA), Ptype> {
public:
using param_t = operators::SequencePadParam;
void Run() override;
virtual ~SequencePadCompute() = default;
private:
lite::Tensor seq_offsets_;
std::vector<int64_t> seq_len_;
std::vector<size_t> seq_offsets_vec_;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/sequence_pad_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <random>
#include <utility>
#include <vector>
#include "lite/api/test_helper.h"
#include "lite/backends/cuda/cuda_utils.h"
// #include "lite/utils/float16.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class SequencePadTest : public ::testing::Test {
protected:
SequencePadTest()
: batch(5),
features(2),
padded_length(3),
x_lod({{0, 2, 5}}),
x_shape({batch, features}),
pad_value_shape({features}),
out_shape({static_cast<int64_t>(x_lod[0].size() - 1),
padded_length,
features}) {
X_ref.Resize(lite::DDim(x_shape));
X_ref.set_lod(x_lod);
X_gpu.Resize(X_ref.dims());
PadValue_ref.Resize(lite::DDim(pad_value_shape));
PadValue_gpu.Resize(PadValue_ref.dims());
Length_ref.Resize(lite::DDim({static_cast<int64_t>(x_lod[0].size() - 1)}));
Length_gpu.Resize(Length_ref.dims());
auto x_ref_data = X_ref.mutable_data<float>();
auto pad_value_ref_data = PadValue_ref.mutable_data<float>();
// prepare input
for (int64_t i = 0; i < X_ref.numel(); i++) {
x_ref_data[i] = static_cast<float>(i);
}
for (int64_t i = 0; i < PadValue_ref.numel(); i++) {
pad_value_ref_data[i] = static_cast<float>(i);
}
Out_ref.Resize(lite::DDim(out_shape));
Out_gpu.Resize(Out_ref.dims());
Out_cpu.Resize(Out_ref.dims());
cpu_base(&X_ref, &PadValue_ref, &Out_ref, &Length_ref);
device_init();
}
void device_init() {
ctx.reset(new KernelContext);
cudaStreamCreate(&stream);
param.X = &X_gpu;
param.PadValue = &PadValue_gpu;
param.Length = &Length_gpu;
param.Out = &Out_gpu;
param.padded_length = padded_length;
}
void float_data_init() {
X_gpu.Assign<float, lite::DDim, TARGET(kCUDA)>(X_ref.data<float>(),
X_gpu.dims());
X_gpu.set_lod(X_ref.lod());
PadValue_gpu.Assign<float, lite::DDim, TARGET(kCUDA)>(
PadValue_ref.data<float>(), PadValue_gpu.dims());
}
void half_data_init() {}
void cpu_base(const lite::Tensor* X,
const lite::Tensor* PadValue,
lite::Tensor* Out,
lite::Tensor* Length) {
auto* length_data = Length->mutable_data<int64_t>();
auto* out_data = Out->mutable_data<float>();
length_data[0] = 2;
length_data[1] = 3;
for (size_t i = 0; i < 4; ++i) {
out_data[i] = i;
}
out_data[4] = 0;
out_data[5] = 1;
for (size_t i = 4; i < 10; ++i) {
out_data[2 + i] = i;
}
}
int batch, features, padded_length;
LoD x_lod;
std::vector<int64_t> x_shape, pad_value_shape, out_shape;
lite::Tensor X_ref, PadValue_ref, Out_ref, Length_ref;
lite::Tensor X_gpu, PadValue_gpu, Out_gpu, Length_gpu;
lite::Tensor Out_cpu, Length_cpu;
operators::SequencePadParam param;
std::unique_ptr<KernelContext> ctx;
cudaStream_t stream;
};
TEST_F(SequencePadTest, fp32) {
float_data_init();
auto& context = ctx->As<CUDAContext>();
context.SetExecStream(stream);
SequencePadCompute<float, PRECISION(kFloat)> kernel;
kernel.SetParam(param);
kernel.SetContext(std::move(ctx));
for (int i = 0; i < FLAGS_warmup; ++i) {
kernel.Launch();
cudaDeviceSynchronize();
}
auto start = GetCurrentUS();
kernel.PrepareForRun();
for (int i = 0; i < FLAGS_repeats; ++i) {
kernel.Run();
}
cudaDeviceSynchronize();
auto duration = (GetCurrentUS() - start) / 1000.0;
LOG(INFO) << "fp32, warmup: " << FLAGS_warmup
<< ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average.";
CopySync<TARGET(kCUDA)>(Out_cpu.mutable_data<float>(),
Out_gpu.data<float>(),
sizeof(float) * Out_gpu.numel(),
IoDirection::DtoH);
CopySync<TARGET(kCUDA)>(Length_cpu.mutable_data<int64_t>(),
Length_gpu.data<int64_t>(),
sizeof(int64_t) * Length_gpu.numel(),
IoDirection::DtoH);
for (int i = 0; i < Out_gpu.numel(); ++i) {
EXPECT_NEAR(Out_cpu.data<float>()[i], Out_ref.data<float>()[i], 1e-5);
}
for (int i = 0; i < Length_gpu.numel(); ++i) {
EXPECT_NEAR(
Length_cpu.data<int64_t>()[i], Length_ref.data<int64_t>()[i], 1e-5);
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "lite/backends/cuda/math/sequence_padding.h"
#include "lite/core/op_registry.h"
#include "lite/core/target_wrapper.h"
#include "lite/kernels/cuda/sequence_unpad_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T, PrecisionType Ptype>
void SequenceUnpadCompute<T, Ptype>::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
const auto* pad_tensor = param.X;
const auto* len_t = param.Length;
auto* seq_tensor = param.Out;
int padded_length = pad_tensor->dims()[1];
int seq_num = seq_tensor->lod()[0].size() - 1;
int max_seq_len = 0;
int step_width = seq_tensor->numel() / seq_tensor->dims()[0];
seq_offsets_vec_.resize(seq_tensor->lod()[0].size());
for (size_t i = 0; i < seq_num; ++i) {
max_seq_len = std::max(max_seq_len,
static_cast<int>(seq_tensor->lod()[0][i + 1] -
seq_tensor->lod()[0][i]));
seq_offsets_vec_[i] = seq_tensor->lod()[0][i];
}
seq_offsets_vec_[seq_num] = seq_tensor->lod()[0][seq_num];
seq_offsets_.Resize({static_cast<int64_t>(seq_tensor->lod()[0].size())});
TargetWrapperCuda::MemcpyAsync(
seq_offsets_.mutable_data<size_t>(TARGET(kCUDA)),
seq_offsets_vec_.data(),
sizeof(size_t) * seq_offsets_vec_.size(),
IoDirection::HtoD,
stream);
const T* pad_data = pad_tensor->template data<T>();
T* seq_data = seq_tensor->template mutable_data<T>(TARGET(kCUDA));
lite::cuda::math::SequenceUnpadding(seq_data,
pad_data,
seq_offsets_.data<size_t>(),
seq_num,
padded_length,
step_width,
&stream);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
using SeqUnadFp32 =
paddle::lite::kernels::cuda::SequenceUnpadCompute<float, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(sequence_unpad, kCUDA, kFloat, kNCHW, SeqUnadFp32, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Length", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T, PrecisionType Ptype>
class SequenceUnpadCompute : public KernelLite<TARGET(kCUDA), Ptype> {
public:
using param_t = operators::SequenceUnpadParam;
void Run() override;
virtual ~SequenceUnpadCompute() = default;
private:
lite::Tensor seq_offsets_;
std::vector<size_t> seq_offsets_vec_;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/sequence_unpad_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <random>
#include <utility>
#include <vector>
#include "lite/api/test_helper.h"
#include "lite/backends/cuda/cuda_utils.h"
// #include "lite/utils/float16.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class SequenceUnpadTest : public ::testing::Test {
protected:
SequenceUnpadTest()
: batch(5),
features(2),
padded_length(3),
out_lod({{0, 2, 5}}),
x_shape({static_cast<int64_t>(out_lod[0].size() - 1),
padded_length,
features}),
out_shape({batch, features}) {
X_ref.Resize(lite::DDim(x_shape));
X_gpu.Resize(X_ref.dims());
Length_ref.Resize(
lite::DDim({static_cast<int64_t>(out_lod[0].size() - 1)}));
Length_gpu.Resize(Length_ref.dims());
auto* x_ref_data = X_ref.mutable_data<float>();
auto* length_ref_data = Length_ref.mutable_data<int64_t>();
// prepare input
for (int64_t i = 0; i < X_ref.numel(); i++) {
x_ref_data[i] = static_cast<float>(i);
}
for (size_t i = 0; i < out_lod[0].size() - 1; ++i) {
length_ref_data[i] = out_lod[0][i + 1] - out_lod[0][i];
}
Out_ref.Resize(lite::DDim(out_shape));
Out_ref.set_lod(out_lod);
Out_gpu.Resize(Out_ref.dims());
Out_gpu.set_lod(Out_ref.lod());
Out_cpu.Resize(Out_ref.dims());
Out_cpu.set_lod(Out_ref.lod());
cpu_base(&X_ref, &Length_ref, &Out_ref);
device_init();
}
void device_init() {
ctx.reset(new KernelContext);
cudaStreamCreate(&stream);
param.X = &X_gpu;
param.Length = &Length_gpu;
param.Out = &Out_gpu;
}
void float_data_init() {
X_gpu.Assign<float, lite::DDim, TARGET(kCUDA)>(X_ref.data<float>(),
X_gpu.dims());
Length_gpu.Assign<int64_t, lite::DDim, TARGET(kCUDA)>(
Length_ref.data<int64_t>(), Length_gpu.dims());
}
void half_data_init() {}
void cpu_base(const lite::Tensor* X,
const lite::Tensor* Length,
lite::Tensor* Out) {
auto* out_data = Out->mutable_data<float>();
for (size_t i = 0; i < 4; ++i) {
out_data[i] = i;
}
for (size_t i = 6; i < 12; ++i) {
out_data[i - 2] = i;
}
}
int batch, features, padded_length;
LoD out_lod;
std::vector<int64_t> x_shape, out_shape;
lite::Tensor X_ref, Out_ref, Length_ref;
lite::Tensor X_gpu, Out_gpu, Length_gpu;
lite::Tensor Out_cpu, Length_cpu;
operators::SequencePadParam param;
std::unique_ptr<KernelContext> ctx;
cudaStream_t stream;
};
TEST_F(SequenceUnpadTest, fp32) {
float_data_init();
auto& context = ctx->As<CUDAContext>();
context.SetExecStream(stream);
SequenceUnpadCompute<float, PRECISION(kFloat)> kernel;
kernel.SetParam(param);
kernel.SetContext(std::move(ctx));
for (int i = 0; i < FLAGS_warmup; ++i) {
kernel.Launch();
cudaDeviceSynchronize();
}
auto start = GetCurrentUS();
kernel.PrepareForRun();
for (int i = 0; i < FLAGS_repeats; ++i) {
kernel.Run();
}
cudaDeviceSynchronize();
auto duration = (GetCurrentUS() - start) / 1000.0;
LOG(INFO) << "fp32, warmup: " << FLAGS_warmup
<< ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average.";
CopySync<TARGET(kCUDA)>(Out_cpu.mutable_data<float>(),
Out_gpu.data<float>(),
sizeof(float) * Out_gpu.numel(),
IoDirection::DtoH);
for (int i = 0; i < Out_gpu.numel(); ++i) {
EXPECT_NEAR(Out_cpu.data<float>()[i], Out_ref.data<float>()[i], 1e-5);
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -77,6 +77,7 @@ add_operator(reduce_max_op_lite extra SRCS reduce_max_op.cc DEPS ${op_DEPS})
add_operator(shape_op_lite extra SRCS shape_op.cc DEPS ${op_DEPS})
add_operator(sequence_expand_op_lite extra SRCS sequence_expand_op.cc DEPS ${op_DEPS})
add_operator(sequence_unpad_op_lite extra SRCS sequence_unpad_op.cc DEPS ${op_DEPS})
add_operator(sequence_pad_op_lite extra SRCS sequence_pad_op.cc DEPS ${op_DEPS})
add_operator(im2sequence_op extra SRCS im2sequence_op.cc DEPS ${op_DEPS})
add_operator(gather_op extra SRCS gather_op.cc DEPS ${op_DEPS})
add_operator(anchor_generator_op extra SRCS anchor_generator_op.cc DEPS ${op_DEPS})
......
......@@ -1031,6 +1031,14 @@ struct SequenceExpandParam : ParamBase {
int ref_level{-1};
};
struct SequencePadParam : ParamBase {
const lite::Tensor* X{};
const lite::Tensor* PadValue{};
lite::Tensor* Out{};
lite::Tensor* Length{};
int padded_length{-1};
};
struct SequenceUnpadParam : ParamBase {
const lite::Tensor* X{};
const lite::Tensor* Length{};
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/sequence_pad_op.h"
#include <algorithm>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool SequencePadOp::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.PadValue);
CHECK_OR_FALSE(param_.Out);
CHECK_OR_FALSE(param_.Length);
return true;
}
bool SequencePadOp::InferShapeImpl() const {
auto x_dims = param_.X->dims();
CHECK_GE(x_dims.size(), 2) << "The rank of SequencePad OP Input(x) can't be "
"less than 2. But the rank we received is "
<< x_dims.size();
auto time_step_dims = x_dims.Slice(1, x_dims.size());
auto pad_value_dims = param_.PadValue->dims();
CHECK_EQ((pad_value_dims == DDim({1})) || (pad_value_dims == time_step_dims),
true)
<< "The SequencePad OP Input(PadValue) must be a scalar or a tensor "
"whiose shape equals to time steps in sequences";
auto x_lod = param_.X->lod();
CHECK_EQ(x_lod.empty(), false)
<< "The SequencePad OP Input(X) must hold lod info.";
const auto &x_lod_0 = x_lod[0];
CHECK_GE(x_lod_0.size(), 2)
<< "The size of SequencePadOp Input(X)'s lod info can't be less than 2. "
"But the size we received is "
<< x_lod_0.size();
CHECK_EQ(x_dims[0], static_cast<int64_t>(x_lod_0.back()))
<< "The SequencePadOp Input(X)'s lod info mismatches the actual tensor "
"shape. The 1st dimension of Input(X)'s lod info is "
<< x_dims[0] << ", the 1st dimension of actual tensor shape is "
<< static_cast<int64_t>(x_lod_0.back());
int seq_num = x_lod_0.size() - 1;
int max_seq_len = 0;
for (int i = 0; i < seq_num; ++i) {
max_seq_len =
std::max(max_seq_len, static_cast<int>(x_lod_0[i + 1] - x_lod_0[i]));
}
if (param_.padded_length == -1) {
param_.padded_length = max_seq_len;
}
CHECK_GE(param_.padded_length, max_seq_len)
<< "The SequencePadOp Attr(padded_length) should be greater than or "
"equal to the length of the longest original sequence. But the "
"padded_length we received is "
<< param_.padded_length
<< ", the length of the longest original sequence is " << max_seq_len;
int out_dim_0 = seq_num;
std::vector<int64_t> out_dims_vec{out_dim_0, param_.padded_length};
std::vector<int64_t> len_dims_vec{out_dim_0};
auto time_step_dims_vec = time_step_dims.Vectorize();
out_dims_vec.insert(
out_dims_vec.end(), time_step_dims_vec.begin(), time_step_dims_vec.end());
param_.Out->Resize(out_dims_vec);
param_.Length->Resize(len_dims_vec);
return true;
}
bool SequencePadOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.X = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
param_.PadValue = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("PadValue").front())->Get<lite::Tensor>());
param_.Length = scope->FindVar(opdesc.Input("Length").front())
->GetMutable<lite::Tensor>();
param_.Out =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
param_.padded_length = opdesc.GetAttr<int>("padded_length");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(sequence_pad, paddle::lite::operators::SequencePadOp);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
namespace paddle {
namespace lite {
namespace operators {
class SequencePadOp : public OpLite {
public:
SequencePadOp() {}
explicit SequencePadOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "sequence_pad"; }
private:
mutable SequencePadParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册