From da043b1499780f5ec3305c76776518332e0edef2 Mon Sep 17 00:00:00 2001 From: Wilber Date: Wed, 4 Dec 2019 14:08:18 +0800 Subject: [PATCH] update cuda kernels to run content-dnn models test=develop (#2554) update cuda kernels to run content-dnn model --- lite/api/paddle_api.cc | 1 + lite/core/op_registry.cc | 2 + .../cuda/attention_padding_mask_compute.cu | 42 +++++++++++++- lite/kernels/cuda/feed_compute.cc | 45 ++++++++++++--- lite/kernels/cuda/feed_compute.h | 3 +- .../cuda/match_matrix_tensor_compute.cu | 24 ++++++++ .../cuda/search_aligned_mat_mul_compute.cc | 3 + lite/kernels/cuda/search_fc_compute.cu | 6 -- .../cuda/search_group_padding_compute.cu | 17 ++++-- .../cuda/search_seq_depadding_compute.cu | 9 ++- .../cuda/sequence_arithmetic_compute.cu | 3 +- lite/kernels/cuda/sequence_concat_compute.cu | 20 +++++++ lite/kernels/cuda/sequence_pool_compute.cu | 1 + lite/kernels/cuda/sequence_reverse_compute.cu | 47 +++++++++------- lite/kernels/cuda/sequence_reverse_compute.h | 4 +- .../cuda/sequence_reverse_compute_test.cc | 2 +- .../cuda/sequence_topk_avg_pooling_compute.cu | 55 ++++++++++++------- lite/kernels/cuda/softmax_compute.cu | 8 ++- .../operators/sequence_topk_avg_pooling_op.cc | 3 +- 19 files changed, 216 insertions(+), 79 deletions(-) diff --git a/lite/api/paddle_api.cc b/lite/api/paddle_api.cc index 131041012a..aabb535292 100644 --- a/lite/api/paddle_api.cc +++ b/lite/api/paddle_api.cc @@ -121,6 +121,7 @@ template void Tensor::CopyFromCpu(const int *); template void Tensor::CopyFromCpu(const float *); template void Tensor::CopyFromCpu(const int8_t *); template void Tensor::CopyFromCpu(const int *); +template void Tensor::CopyFromCpu(const int64_t *); template void Tensor::CopyFromCpu(const float *); template void Tensor::CopyFromCpu(const int8_t *); diff --git a/lite/core/op_registry.cc b/lite/core/op_registry.cc index 3b8b350ad8..c23d3157e0 100644 --- a/lite/core/op_registry.cc +++ b/lite/core/op_registry.cc @@ -115,6 +115,8 @@ KernelRegistry::KernelRegistry() INIT_FOR(kCUDA, kAny, kNCHW); INIT_FOR(kCUDA, kAny, kAny); INIT_FOR(kCUDA, kInt8, kNHWC); + INIT_FOR(kCUDA, kInt64, kNCHW); + INIT_FOR(kCUDA, kInt64, kNHWC); INIT_FOR(kHost, kFloat, kNCHW); INIT_FOR(kHost, kAny, kNCHW); diff --git a/lite/kernels/cuda/attention_padding_mask_compute.cu b/lite/kernels/cuda/attention_padding_mask_compute.cu index 8627903b23..fac73b1adc 100644 --- a/lite/kernels/cuda/attention_padding_mask_compute.cu +++ b/lite/kernels/cuda/attention_padding_mask_compute.cu @@ -40,6 +40,7 @@ __global__ void ker_attention_padding_mask(T* out_data, const int attn_seq_len, const int src_seq_num, const int src_seq_len, + const T* pad_begin_data, const T mask, const int count) { CUDA_KERNEL_LOOP(tid, count) { @@ -49,7 +50,12 @@ __global__ void ker_attention_padding_mask(T* out_data, int attn_word_id = tmp_tid % attn_seq_len; int src_seq_id = attn_seq_id % src_seq_num; int cur_len = src_offset[src_seq_id + 1] - src_offset[src_seq_id]; - if (src_word_id >= cur_len) { + + int k = static_cast(pad_begin_data[src_seq_id]); + if (k < cur_len && + tid >= src_seq_len * (attn_seq_len * attn_seq_id + attn_word_id) + k && + tid < src_seq_len * (attn_seq_len * attn_seq_id + attn_word_id) + + cur_len) { out_data[tid] = mask; } else { out_data[tid] = attn_data[tid]; @@ -79,6 +85,35 @@ void AttentionPaddingMaskCompute::Run() { auto attn_data = attn->data(); auto out_data = out->mutable_data(TARGET(kCUDA)); + std::vector src_cpu(src->numel(), 0); + TargetWrapperCuda::MemcpyAsync(src_cpu.data(), + src->data(), + sizeof(float) * src->numel(), + IoDirection::DtoH, + stream); + cudaStreamSynchronize(stream); + + std::vector pad_begin(src_seq_num, 0); + auto src_len = static_cast(src->lod()[0][1]); + int _pad_id = param.pad_id; + for (int i = 0; i < src_seq_num; ++i) { + const auto* src_data = src_cpu.data() + src_len * i; + int index = src_len - 1; + for (; index >= 0 && _pad_id == static_cast(src_data[index]); + --index) { + } + pad_begin[i] = static_cast(index + 1); + } + + param.pad_begin->Resize({static_cast(src_seq_num)}); + auto pad_begin_cuda_data = + param.pad_begin->mutable_data(TARGET(kCUDA)); + TargetWrapperCuda::MemcpyAsync(pad_begin_cuda_data, + pad_begin.data(), + sizeof(float) * src_seq_num, + IoDirection::HtoD, + stream); + std::vector src_offset_cpu(src_offset.size(), 0); for (int i = 0; i < src_offset.size(); i++) { src_offset_cpu[i] = src_offset[i]; @@ -101,11 +136,12 @@ void AttentionPaddingMaskCompute::Run() { attn_seq_len, src_seq_num, src_seq_len, + pad_begin_cuda_data, param.mask, count); cudaError_t error = cudaGetLastError(); - if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); + if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error); } } // namespace cuda @@ -113,7 +149,7 @@ void AttentionPaddingMaskCompute::Run() { } // namespace lite } // namespace paddle -REGISTER_LITE_KERNEL(attention_padding_mask, +REGISTER_LITE_KERNEL(search_attention_padding_mask, kCUDA, kFloat, kNCHW, diff --git a/lite/kernels/cuda/feed_compute.cc b/lite/kernels/cuda/feed_compute.cc index cffa8a573d..e54c5b9b03 100644 --- a/lite/kernels/cuda/feed_compute.cc +++ b/lite/kernels/cuda/feed_compute.cc @@ -20,21 +20,22 @@ namespace lite { namespace kernels { namespace cuda { -void FeedCompute::Run() { - auto& param = this->Param(); +template +void FeedCompute::Run() { + auto& param = this->template Param(); auto& ctx = this->ctx_->template As(); auto stream = ctx.exec_stream(); VLOG(4) << "feed_list.size: " << param.feed_list->size(); const lite::Tensor& feed_item = (*param.feed_list)[param.col]; int num = static_cast(feed_item.numel()); - auto input = feed_item.data(); + auto input = feed_item.data(); param.out->Resize(feed_item.dims()); - auto output = param.out->mutable_data(TARGET(kCUDA)); + auto output = param.out->template mutable_data(TARGET(kCUDA)); VLOG(4) << "col: " << param.col << " num:" << num; TargetW::MemcpyAsync( - output, input, num * sizeof(float), IoDirection::HtoD, stream); + output, input, num * sizeof(T), IoDirection::HtoD, stream); } } // namespace cuda @@ -42,8 +43,13 @@ void FeedCompute::Run() { } // namespace lite } // namespace paddle -REGISTER_LITE_KERNEL( - feed, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::FeedCompute, nchw) +typedef paddle::lite::kernels::cuda::FeedCompute + FeedFp32; + +typedef paddle::lite::kernels::cuda::FeedCompute + FeedInt64; + +REGISTER_LITE_KERNEL(feed, kCUDA, kFloat, kNCHW, FeedFp32, nchw) .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat), @@ -54,8 +60,7 @@ REGISTER_LITE_KERNEL( DATALAYOUT(kNCHW))}) .Finalize(); -REGISTER_LITE_KERNEL( - feed, kCUDA, kFloat, kNHWC, paddle::lite::kernels::cuda::FeedCompute, nhwc) +REGISTER_LITE_KERNEL(feed, kCUDA, kFloat, kNHWC, FeedFp32, nhwc) .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat), @@ -65,3 +70,25 @@ REGISTER_LITE_KERNEL( PRECISION(kFloat), DATALAYOUT(kNHWC))}) .Finalize(); + +REGISTER_LITE_KERNEL(feed, kCUDA, kInt64, kNCHW, FeedInt64, nchw) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kInt64), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kInt64), + DATALAYOUT(kNCHW))}) + .Finalize(); + +REGISTER_LITE_KERNEL(feed, kCUDA, kInt64, kNHWC, FeedInt64, nhwc) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kInt64), + DATALAYOUT(kNHWC))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kInt64), + DATALAYOUT(kNHWC))}) + .Finalize(); diff --git a/lite/kernels/cuda/feed_compute.h b/lite/kernels/cuda/feed_compute.h index 0510404b2b..9c42dcc1ca 100644 --- a/lite/kernels/cuda/feed_compute.h +++ b/lite/kernels/cuda/feed_compute.h @@ -20,7 +20,8 @@ namespace lite { namespace kernels { namespace cuda { -class FeedCompute : public KernelLite { +template +class FeedCompute : public KernelLite { public: using param_t = operators::FeedParam; using TargetW = TargetWrapper; diff --git a/lite/kernels/cuda/match_matrix_tensor_compute.cu b/lite/kernels/cuda/match_matrix_tensor_compute.cu index 751bcb03ca..f89b9c9578 100644 --- a/lite/kernels/cuda/match_matrix_tensor_compute.cu +++ b/lite/kernels/cuda/match_matrix_tensor_compute.cu @@ -82,8 +82,32 @@ void MatchMatrixTensorCompute::Run() { gemm_impl_->run(1.0f, 0.0f, l_t_data, r_data, top_data, &context); } } + + int batch_size = x->lod()[0].size() - 1; + int lod_lv1_size = batch_size * dim_t; + int lod_lv2_size = x->lod()[0].back() * dim_t; + std::vector out_lod0(batch_size + 1, 0); + std::vector out_lod1(lod_lv1_size + 1, 0); + std::vector out_lod2(lod_lv2_size + 1, 0); + for (int i = 0; i < batch_size; i++) { + out_lod0[i + 1] = out_lod0[i] + dim_t; + int len_l = offset_l[i + 1] - offset_l[i]; + + for (int j = 0; j < dim_t; j++) { + out_lod1[i * dim_t + j + 1] = out_lod1[i * dim_t + j] + len_l; + int len_r = offset_r[i + 1] - offset_r[i]; + + for (int k = 0; k < len_l; k++) { + out_lod2[offset_l[i] * dim_t + j * len_l + k + 1] = + out_lod2[offset_l[i] * dim_t + j * len_l + k] + len_r; + } + } + } + LoD out_lod; out_lod.push_back(top_offset); + out_lod.push_back(offset_l); + out_lod.push_back(offset_r); out->set_lod(out_lod); } diff --git a/lite/kernels/cuda/search_aligned_mat_mul_compute.cc b/lite/kernels/cuda/search_aligned_mat_mul_compute.cc index 525765de28..ddefb608dd 100644 --- a/lite/kernels/cuda/search_aligned_mat_mul_compute.cc +++ b/lite/kernels/cuda/search_aligned_mat_mul_compute.cc @@ -32,4 +32,7 @@ REGISTER_LITE_KERNEL(search_aligned_mat_mul, .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("_a_addr", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("_b_addr", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("_c_addr", {LiteType::GetTensorTy(TARGET(kCUDA))}) .Finalize(); diff --git a/lite/kernels/cuda/search_fc_compute.cu b/lite/kernels/cuda/search_fc_compute.cu index b634bc933d..591e2474a4 100644 --- a/lite/kernels/cuda/search_fc_compute.cu +++ b/lite/kernels/cuda/search_fc_compute.cu @@ -36,7 +36,6 @@ void anakin_NV_gemv(cublasHandle_t handle, const float* x, const float beta, float* y) { - LOG(INFO) << "1"; cublasOperation_t cuTransA = (TransA == false) ? CUBLAS_OP_T : CUBLAS_OP_N; CUBLAS_CHECK( cublasSgemv(handle, cuTransA, N, M, &alpha, A, N, x, 1, &beta, y, 1)); @@ -66,17 +65,13 @@ void anakin_NV_gemm(cublasHandle_t handle, const float* B, const float beta, float* C) { - LOG(INFO) << "1"; // Note that cublas follows fortran order. int lda = (!TransA /* == CblasNoTrans*/) ? K : M; int ldb = (!TransB /* == CblasNoTrans*/) ? N : K; - LOG(INFO) << "1"; cublasOperation_t cuTransA = (!TransA /* == CblasNoTrans*/) ? CUBLAS_OP_N : CUBLAS_OP_T; - LOG(INFO) << "1"; cublasOperation_t cuTransB = (!TransB /* == CblasNoTrans*/) ? CUBLAS_OP_N : CUBLAS_OP_T; - LOG(INFO) << "1"; CUBLAS_CHECK(cublasSgemm(handle, cuTransB, cuTransA, @@ -91,7 +86,6 @@ void anakin_NV_gemm(cublasHandle_t handle, &beta, C, N)); - LOG(INFO) << "1"; } template <> diff --git a/lite/kernels/cuda/search_group_padding_compute.cu b/lite/kernels/cuda/search_group_padding_compute.cu index f395aad018..697e53dbb6 100644 --- a/lite/kernels/cuda/search_group_padding_compute.cu +++ b/lite/kernels/cuda/search_group_padding_compute.cu @@ -46,7 +46,9 @@ __global__ void ker_search_group_padding(Dtype* out_emb_padding_data, in_data[(offset[seq_id] + word_id_in_seq) * emb_size + emb_id]; } else { out_emb_padding_data[tid] = 0.f; - out_padding_data[word_id] = pad_id; + if (emb_id == 0) { + out_padding_data[word_id] = pad_id; + } } } } @@ -61,12 +63,7 @@ void SearchGroupPaddingCompute::Run() { Tensor* out_new = param.out_new; Tensor* out_padding = param.out_padding; const float pad_id = static_cast(param.pad_id); - const float* in_data = x->data(); - float* out_emb_padding_data = - out_emb_padding->mutable_data(TARGET(kCUDA)); - float* out_new_data = out_new->mutable_data(TARGET(kCUDA)); - float* out_padding_data = out_padding->mutable_data(TARGET(kCUDA)); const auto& in_seq_offset = x->lod()[0]; int batch = in_seq_offset.size() - 1; int max_seq = 0; @@ -85,16 +82,20 @@ void SearchGroupPaddingCompute::Run() { out_emb_padding_lod.push_back(new_offset); out_emb_padding->set_lod(out_emb_padding_lod); out_emb_padding->Resize({batch * max_seq, x_dims[1]}); + float* out_emb_padding_data = + out_emb_padding->mutable_data(TARGET(kCUDA)); LoD out_new_lod; out_new_lod.push_back(in_seq_offset); out_new->set_lod(out_new_lod); out_new->Resize({x_dims[0], 1}); + float* out_new_data = out_new->mutable_data(TARGET(kCUDA)); LoD out_padding_lod; out_padding_lod.push_back(new_offset); out_padding->set_lod(out_padding_lod); out_padding->Resize({batch * max_seq, 1}); + float* out_padding_data = out_padding->mutable_data(TARGET(kCUDA)); const int count = out_emb_padding->numel(); const auto& out_emb_padding_seq_offset = out_emb_padding->lod()[0]; @@ -112,6 +113,10 @@ void SearchGroupPaddingCompute::Run() { TargetWrapperCuda::MemsetSync( out_new_data, 0, out_new->dims()[0] * out_new->dims()[1] * sizeof(float)); + TargetWrapperCuda::MemsetSync( + out_padding_data, + 0, + out_padding->dims()[0] * out_padding->dims()[1] * sizeof(float)); ker_search_group_padding< float><<>>( diff --git a/lite/kernels/cuda/search_seq_depadding_compute.cu b/lite/kernels/cuda/search_seq_depadding_compute.cu index 179041596b..ecadceab58 100644 --- a/lite/kernels/cuda/search_seq_depadding_compute.cu +++ b/lite/kernels/cuda/search_seq_depadding_compute.cu @@ -50,6 +50,7 @@ void SearchSeqDepaddingCompute::Run() { auto* out = param.out; auto* in_data = pad->data(); + out->Resize({src->dims()[0], pad->dims()[1]}); auto* out_data = out->mutable_data(TARGET(kCUDA)); const int count = out->numel(); @@ -59,6 +60,9 @@ void SearchSeqDepaddingCompute::Run() { int seq_num = pad_seq_offset.size() - 1; int emb_size = pad->dims()[1]; + LoD out_lod; + out_lod.push_back(src_seq_offset); + out->set_lod(out_lod); std::vector seq_id_map; for (int i = 0; i < seq_num; i++) { int cur_len = src_seq_offset[i + 1] - src_seq_offset[i]; @@ -77,11 +81,12 @@ void SearchSeqDepaddingCompute::Run() { cuda_stream); int threads = 512; - ker_sequence_depadding_fwd<<>>( + int blocks = (count + threads - 1) / threads; + ker_sequence_depadding_fwd<<>>( out_data, in_data, seq_id_map_data, seq_num, max_len, emb_size, count); cudaError_t error = cudaGetLastError(); - if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); + if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error); } } // namespace cuda diff --git a/lite/kernels/cuda/sequence_arithmetic_compute.cu b/lite/kernels/cuda/sequence_arithmetic_compute.cu index 5ca12267f9..7593632a14 100644 --- a/lite/kernels/cuda/sequence_arithmetic_compute.cu +++ b/lite/kernels/cuda/sequence_arithmetic_compute.cu @@ -120,7 +120,7 @@ void SequenceArithmeticCompute::Run() { auto x_data = param.X->data(); auto x_lod = param.X->lod()[0]; - auto y_data = param.X->data(); + auto y_data = param.Y->data(); auto y_lod = param.Y->lod()[0]; auto out_data = param.Out->mutable_data(TARGET(kCUDA)); @@ -174,7 +174,6 @@ void SequenceArithmeticCompute::Run() { int seq_num = x_lod.size() - 1; int count = param.X->numel(); int inner_size = param.X->dims()[1]; - switch (param.op_type) { case 1: // sum ker_arithmetic_sum< diff --git a/lite/kernels/cuda/sequence_concat_compute.cu b/lite/kernels/cuda/sequence_concat_compute.cu index 3488c829ce..d4390046b0 100644 --- a/lite/kernels/cuda/sequence_concat_compute.cu +++ b/lite/kernels/cuda/sequence_concat_compute.cu @@ -24,6 +24,24 @@ namespace cuda { const int CUDA_NUM_THREADS = 512; +template +inline LoD ConcatLoD(const std::vector& xs) { + std::vector result; + result.resize(xs[0]->lod()[0].size()); + + for (size_t i = 1; i < result.size(); ++i) { + size_t sum = 0; + for (size_t j = 0; j < xs.size(); ++j) { + auto& x_lod = xs[j]->lod()[0]; + sum += x_lod[i]; + } + result[i] = sum; + } + LoD lod; + lod.emplace_back(result); + return lod; +} + template __global__ void ker_sequence_concat(Dtype* out_data, const uint64_t* in_locate_data, @@ -96,6 +114,8 @@ void SequenceConcatCompute::Run() { IoDirection::HtoD, stream); + param.Out->set_lod(ConcatLoD(param.X)); + int count = param.X[0]->numel(); for (int i = 1; i < param.X.size(); ++i) { count += param.X[i]->numel(); diff --git a/lite/kernels/cuda/sequence_pool_compute.cu b/lite/kernels/cuda/sequence_pool_compute.cu index 962804d03c..97876ec32f 100644 --- a/lite/kernels/cuda/sequence_pool_compute.cu +++ b/lite/kernels/cuda/sequence_pool_compute.cu @@ -254,4 +254,5 @@ REGISTER_LITE_KERNEL(sequence_pool, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("MaxIndex", {LiteType::GetTensorTy(TARGET(kCUDA))}) .Finalize(); diff --git a/lite/kernels/cuda/sequence_reverse_compute.cu b/lite/kernels/cuda/sequence_reverse_compute.cu index ee2550cd96..68447fcebb 100644 --- a/lite/kernels/cuda/sequence_reverse_compute.cu +++ b/lite/kernels/cuda/sequence_reverse_compute.cu @@ -42,11 +42,9 @@ __host__ __device__ inline size_t UpperBound(const T* x, return static_cast(first - x); } -__global__ void SequenceReverseKernelGridIsOne(const float* x, - float* y, - const int64_t* lod, - size_t lod_count, - int64_t row_numel) { +template +__global__ void SequenceReverseKernelGridIsOne( + const T* x, T* y, const int64_t* lod, size_t lod_count, int64_t row_numel) { int64_t idx = static_cast(threadIdx.x); auto row_idx_x = idx / row_numel; auto lod_idx = UpperBound(lod, lod_count, row_idx_x); @@ -55,8 +53,9 @@ __global__ void SequenceReverseKernelGridIsOne(const float* x, y[idx_y] = x[idx]; } -__global__ void SequenceReverseKernel(const float* x, - float* y, +template +__global__ void SequenceReverseKernel(const T* x, + T* y, const int64_t* lod, size_t lod_count, int64_t row_numel, @@ -71,19 +70,20 @@ __global__ void SequenceReverseKernel(const float* x, } } -void SequenceReverseCompute::Run() { - auto& param = this->Param(); +template +void SequenceReverseCompute::Run() { + auto& param = this->template Param(); auto& ctx = this->ctx_->template As(); auto stream = ctx.exec_stream(); - size_t limit = static_cast(param.X->numel()); int64_t row_numel = static_cast(limit / param.X->dims()[0]); - const auto* x_data = param.X->data(); - auto y_data = param.Out->mutable_data(TARGET(kCUDA)); + const auto* x_data = param.X->template data(); + auto y_data = param.Out->template mutable_data(TARGET(kCUDA)); CHECK_NE(x_data, y_data) << "SequenceReverse Op does not support in-place operation"; const auto lod = param.X->lod()[param.X->lod().size() - 1]; const size_t lod_count = lod.size(); + param.Out->set_lod(param.X->lod()); lod_cuda.Resize({static_cast(lod.size())}); int64_t* lod_data = lod_cuda.mutable_data(TARGET(kCUDA)); @@ -92,11 +92,9 @@ void SequenceReverseCompute::Run() { sizeof(int64_t) * lod.size(), IoDirection::HtoD, stream); - constexpr int num_threads = 1024; int block_size = limit <= num_threads ? limit : num_threads; int grid_size = (limit + num_threads - 1) / num_threads; - if (grid_size == 1) { SequenceReverseKernelGridIsOne<<<1, block_size, 0, stream>>>( x_data, y_data, lod_data, lod_count, row_numel); @@ -104,7 +102,6 @@ void SequenceReverseCompute::Run() { SequenceReverseKernel<<>>( x_data, y_data, lod_data, lod_count, row_numel, limit); } - cudaError_t error = cudaGetLastError(); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); } @@ -114,12 +111,20 @@ void SequenceReverseCompute::Run() { } // namespace lite } // namespace paddle -REGISTER_LITE_KERNEL(sequence_reverse, - kCUDA, - kFloat, - kNCHW, - paddle::lite::kernels::cuda::SequenceReverseCompute, - def) +typedef paddle::lite::kernels::cuda::SequenceReverseCompute + ReverseFp32; + +typedef paddle::lite::kernels::cuda::SequenceReverseCompute + ReverseInt64; + +REGISTER_LITE_KERNEL(sequence_reverse, kCUDA, kFloat, kNCHW, ReverseFp32, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindOutput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))}) .Finalize(); + +REGISTER_LITE_KERNEL(sequence_reverse, kCUDA, kInt64, kNCHW, ReverseInt64, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt64))}) + .BindOutput("Y", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt64))}) + .Finalize(); diff --git a/lite/kernels/cuda/sequence_reverse_compute.h b/lite/kernels/cuda/sequence_reverse_compute.h index ba85f08563..6b6199e020 100644 --- a/lite/kernels/cuda/sequence_reverse_compute.h +++ b/lite/kernels/cuda/sequence_reverse_compute.h @@ -20,8 +20,8 @@ namespace lite { namespace kernels { namespace cuda { -class SequenceReverseCompute - : public KernelLite { +template +class SequenceReverseCompute : public KernelLite { public: using param_t = operators::SequenceReverseParam; diff --git a/lite/kernels/cuda/sequence_reverse_compute_test.cc b/lite/kernels/cuda/sequence_reverse_compute_test.cc index 3659f0d12c..3317b52303 100644 --- a/lite/kernels/cuda/sequence_reverse_compute_test.cc +++ b/lite/kernels/cuda/sequence_reverse_compute_test.cc @@ -40,7 +40,7 @@ static void sequence_reverse_ref(const lite::Tensor* x, lite::Tensor* y) { } TEST(sequence_reverse_cuda, normal) { - SequenceReverseCompute seq_kernel; + SequenceReverseCompute seq_kernel; std::unique_ptr ctx(new KernelContext); auto& context = ctx->As(); diff --git a/lite/kernels/cuda/sequence_topk_avg_pooling_compute.cu b/lite/kernels/cuda/sequence_topk_avg_pooling_compute.cu index 7f4500c158..8ea3edb30d 100644 --- a/lite/kernels/cuda/sequence_topk_avg_pooling_compute.cu +++ b/lite/kernels/cuda/sequence_topk_avg_pooling_compute.cu @@ -26,8 +26,6 @@ __global__ void topk_avg_pooling_kernel_by_row_improve( const Dtype *input, const int *gpu_input_offset_l, const int *gpu_input_offset_r, - const int row_max, - const int col_max, const int topk_size, const int *topks, const int feat_map_num) { @@ -35,17 +33,20 @@ __global__ void topk_avg_pooling_kernel_by_row_improve( gpu_input_offset_l[blockIdx.x + 1] - gpu_input_offset_l[blockIdx.x]; // 8 int col = gpu_input_offset_r[blockIdx.x + 1] - gpu_input_offset_r[blockIdx.x]; // 30 - int max_k = topks[topk_size - 1]; max_k = max_k < col ? max_k : col; extern __shared__ Dtype smem[]; // H*W - const Dtype *fm_row_in_data = input + - blockIdx.x * row_max * feat_map_num * col_max + - blockIdx.y * row_max * col_max; + const Dtype *fm_row_in_data = input; + for (int i = 0; i < blockIdx.x; ++i) { + int tmp_row = gpu_input_offset_l[i + 1] - gpu_input_offset_l[i]; + int tmp_col = gpu_input_offset_r[i + 1] - gpu_input_offset_r[i]; + fm_row_in_data += tmp_row * feat_map_num * tmp_col; + } + fm_row_in_data += blockIdx.y * row * col; - for (int i = threadIdx.x; i < row * col_max; i += blockDim.x) { + for (int i = threadIdx.x; i < row * col; i += blockDim.x) { smem[i] = fm_row_in_data[i]; } __syncthreads(); @@ -56,7 +57,7 @@ __global__ void topk_avg_pooling_kernel_by_row_improve( (gpu_input_offset_l[blockIdx.x] + idx) * feat_map_num * topk_size + blockIdx.y * topk_size; - Dtype *smem_start_col = smem + idx * col_max; + Dtype *smem_start_col = smem + idx * col; int counter = max_k; // topk_size; Dtype last_max_val = -20000.0; @@ -75,7 +76,7 @@ __global__ void topk_avg_pooling_kernel_by_row_improve( if (max_val < -9999.0) { // == -10000.0 max_val = last_max_val; } - smem_start_col[max_pos] = 10000000.0; + smem_start_col[max_pos] = -10000000.0; int i = max_k - counter; for (int c = 0; c < topk_size; c++) { if (i <= topks[c] - 1) { @@ -97,7 +98,6 @@ void SequenceTopkAvgPoolingCompute::Run() { auto ¶m = this->Param(); auto &ctx = this->ctx_->template As(); auto cuda_stream = ctx.exec_stream(); - int topk_num = param.topks.size(); lite::DDim top_ks_shape(std::vector{topk_num, 1, 1, 1}); _top_ks.Resize(top_ks_shape); @@ -107,12 +107,16 @@ void SequenceTopkAvgPoolingCompute::Run() { cudaMemcpyHostToDevice, cuda_stream); - int width_offset_len = param.X->lod()[0].size(); + int width_offset_len = param.COLUMN->lod()[0].size(); lite::DDim width_offset_shape( std::vector{width_offset_len, 1, 1, 1}); _width_offset.Resize(width_offset_shape); + std::vector width_lod_0(width_offset_len, 0); + for (size_t i = 0; i < param.COLUMN->lod()[0].size(); ++i) { + width_lod_0[i] = static_cast(param.COLUMN->lod()[0][i]); + } cudaMemcpyAsync(_width_offset.mutable_data(TARGET(kCUDA)), - &(param.X->lod()[0][0]), + &width_lod_0[0], sizeof(int) * width_offset_len, cudaMemcpyHostToDevice, cuda_stream); @@ -121,8 +125,12 @@ void SequenceTopkAvgPoolingCompute::Run() { lite::DDim height_offset_shape( std::vector{height_offset_len, 1, 1, 1}); _height_offset.Resize(height_offset_shape); + std::vector height_lod_0(height_offset_len, 0); + for (size_t i = 0; i < param.ROW->lod()[0].size(); ++i) { + height_lod_0[i] = static_cast(param.ROW->lod()[0][i]); + } cudaMemcpyAsync(_height_offset.mutable_data(TARGET(kCUDA)), - &(param.ROW->lod()[0][0]), + &height_lod_0[0], sizeof(int) * height_offset_len, cudaMemcpyHostToDevice, cuda_stream); @@ -136,16 +144,20 @@ void SequenceTopkAvgPoolingCompute::Run() { sizeof(T) * out_tensor->numel(), cuda_stream); - auto x_dims = x_tensor->dims(); - int num = x_dims[0]; - int channel = x_dims[1]; - int height = x_dims[2]; - int width = x_dims[3]; + int num = param.ROW->lod()[0].size() - 1; + int channel = param.channel_num; const int *height_offset = _height_offset.data(); const int *width_offset = _width_offset.data(); - int feat_map_size = height * width; + int feat_map_size = 0; + for (size_t i = 0; i < height_lod_0.size() - 1; ++i) { + int height = height_lod_0[i + 1] - height_lod_0[i]; + int width = width_lod_0[i + 1] - width_lod_0[i]; + if (height * width > feat_map_size) { + feat_map_size = height * width; + } + } dim3 blocks(num, channel); dim3 threads(32, 1); topk_avg_pooling_kernel_by_row_improve< @@ -154,11 +166,12 @@ void SequenceTopkAvgPoolingCompute::Run() { in_data, height_offset, width_offset, - height, - width, param.topks.size(), _top_ks.data(), param.channel_num); + + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error); } } // namespace cuda diff --git a/lite/kernels/cuda/softmax_compute.cu b/lite/kernels/cuda/softmax_compute.cu index 14ed391f7f..6293f7295e 100644 --- a/lite/kernels/cuda/softmax_compute.cu +++ b/lite/kernels/cuda/softmax_compute.cu @@ -173,9 +173,10 @@ void SoftmaxCompute::Run() { cudaGetDeviceProperties(&deviceProp, device_id); size_t sharedmem_size = deviceProp.sharedMemPerBlock; int max_dimsize = sharedmem_size / sizeof(float) / threads; - auto input_data = param.x->data(); auto output_data = param.output->mutable_data(TARGET(kCUDA)); + TargetWrapperCuda::MemsetSync( + output_data, 0, param.output->numel() * sizeof(float)); if (axis_size <= max_dimsize) { int use_sharemem_size = axis_size * threads * sizeof(float); sharemem_softmax_kernel<<>>( @@ -194,7 +195,7 @@ void SoftmaxCompute::Run() { auto max_data = tmax_data.mutable_data(TARGET(kCUDA)); auto sum_data = tsum_data.mutable_data(TARGET(kCUDA)); //! firstly, get maximum data - float min_data = std::numeric_limits::min(); + float min_data = std::numeric_limits::lowest(); softmax_max_kernel<<>>(total_threads, input_data, max_data, @@ -217,7 +218,7 @@ void SoftmaxCompute::Run() { total_threads, output_data, sum_data, inner_num, outer_num, axis_size); } cudaError_t error = cudaGetLastError(); - if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); + if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error); } } // namespace cuda @@ -258,4 +259,5 @@ REGISTER_LITE_KERNEL(search_seq_softmax, {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW))}) + .BindOutput("Out_log", {LiteType::GetTensorTy(TARGET(kCUDA))}) .Finalize(); diff --git a/lite/operators/sequence_topk_avg_pooling_op.cc b/lite/operators/sequence_topk_avg_pooling_op.cc index 637b3d65a0..6f5cbeeeee 100644 --- a/lite/operators/sequence_topk_avg_pooling_op.cc +++ b/lite/operators/sequence_topk_avg_pooling_op.cc @@ -54,8 +54,7 @@ bool SequenceTopkAvgPoolingOpLite::InferShape() const { vec_out_shape.push_back(channel_num * num_k); param_.Out->Resize(lite::DDim(vec_out_shape)); - auto out_lod = param_.Out->mutable_lod(); - *out_lod = param_.X->lod(); + param_.Out->set_lod(param_.ROW->lod()); return true; } -- GitLab