diff --git a/lite/backends/cuda/math/CMakeLists.txt b/lite/backends/cuda/math/CMakeLists.txt index 14e5ae38407be7e60e09bd2884a7cfa69ec8744c..495b273a301c004c96bd03b4cdcd2ca418f26f53 100644 --- a/lite/backends/cuda/math/CMakeLists.txt +++ b/lite/backends/cuda/math/CMakeLists.txt @@ -20,6 +20,7 @@ nv_library(cuda_batched_gemm SRCS batched_gemm.cc DEPS ${cuda_static_deps}) nv_library(cuda_strided_gemm SRCS strided_gemm.cc DEPS ${cuda_static_deps}) nv_library(cuda_sequence_padding SRCS sequence_padding.cu DEPS ${cuda_static_deps}) nv_library(cuda_bias SRCS bias.cu DEPS ${cuda_static_deps}) +nv_library(cuda_sequence_helper SRCS sequence_helper.cu DEPS ${cuda_static_deps}) set ( math_cuda @@ -39,6 +40,7 @@ set ( cuda_sequence_padding cuda_bias cudnn_helper + cuda_sequence_helper ) set(math_cuda "${math_cuda}" CACHE GLOBAL "math cuda") diff --git a/lite/backends/cuda/math/cudnn_conv.cc b/lite/backends/cuda/math/cudnn_conv.cc index 786ca33a18b0a85e3cde5e8a07dc77656abc65a1..1c99001a4933d9683e2e78b6d709abdc84cd0a42 100644 --- a/lite/backends/cuda/math/cudnn_conv.cc +++ b/lite/backends/cuda/math/cudnn_conv.cc @@ -55,31 +55,32 @@ bool CudnnConv2D::create(const operators::ConvParam& param, CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->input_desc_, CUDNN_TENSOR_NCHW, - GetCudnnDataType(), + cudnn::cudnnTypeWrapper::type, batch, ic, ih, iw)); CUDNN_CHECK(cudnnSetFilter4dDescriptor(this->filter_desc_, - GetCudnnDataType(), + cudnn::cudnnTypeWrapper::type, CUDNN_TENSOR_NCHW, oc, ic / param.groups, kh, kw)); - CUDNN_CHECK(cudnnSetConvolution2dDescriptor(this->conv_desc_, - ph, - pw, - sh, - sw, - dh, - dw, - CUDNN_CROSS_CORRELATION, - GetCudnnDataType())); + CUDNN_CHECK( + cudnnSetConvolution2dDescriptor(this->conv_desc_, + ph, + pw, + sh, + sw, + dh, + dw, + CUDNN_CROSS_CORRELATION, + cudnn::cudnnTypeWrapper::type)); CUDNN_CHECK(cudnnSetConvolutionGroupCount(this->conv_desc_, param.groups)); CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->output_desc_, CUDNN_TENSOR_NCHW, - GetCudnnDataType(), + cudnn::cudnnTypeWrapper::type, batch, oc, oh, @@ -179,7 +180,7 @@ bool CudnnConv2D::create(const operators::ConvParam& param, int dim_bias[] = {1, oc, 1, 1}; int stride_bias[] = {oc, 1, 1, 1}; cudnnSetTensorNdDescriptor(this->bias_desc_, - GetCudnnDataType(), + cudnn::cudnnTypeWrapper::type, 4, dim_bias, stride_bias); diff --git a/lite/backends/cuda/math/cudnn_helper.cc b/lite/backends/cuda/math/cudnn_helper.cc index 92cb32096161e019433047ba76ee499adeda19d1..86ef61945cf2c59ee94d8e476e74f5a0455f00d5 100644 --- a/lite/backends/cuda/math/cudnn_helper.cc +++ b/lite/backends/cuda/math/cudnn_helper.cc @@ -21,17 +21,7 @@ namespace paddle { namespace lite { namespace cuda { namespace math { - -template <> -cudnnDataType_t GetCudnnDataType() { - return CUDNN_DATA_FLOAT; -} - -template <> -cudnnDataType_t GetCudnnDataType() { - return CUDNN_DATA_HALF; -} - +namespace cudnn {} // namespace cudnn } // namespace math } // namespace cuda } // namespace lite diff --git a/lite/backends/cuda/math/cudnn_helper.h b/lite/backends/cuda/math/cudnn_helper.h index 972d841d972b7cf65245e593392a5d34e8b1d4f7..fd6722a113fe09b2bcff6d5186522bcdcc11012d 100644 --- a/lite/backends/cuda/math/cudnn_helper.h +++ b/lite/backends/cuda/math/cudnn_helper.h @@ -25,10 +25,97 @@ namespace paddle { namespace lite { namespace cuda { namespace math { +namespace cudnn { -template -cudnnDataType_t GetCudnnDataType(); +template +class cudnnTypeWrapper; +template <> +class cudnnTypeWrapper { + public: + static const cudnnDataType_t type = CUDNN_DATA_FLOAT; + typedef const float ScalingParamType; + static ScalingParamType* kOne() { + static ScalingParamType v = 1.0f; + return &v; + } + static ScalingParamType* kZero() { + static ScalingParamType v = 0.0f; + return &v; + } +}; + +template <> +class cudnnTypeWrapper { + public: + static const cudnnDataType_t type = CUDNN_DATA_HALF; + typedef const half ScalingParamType; + static ScalingParamType* kOne() { + static ScalingParamType v = __float2half(1.0f); + return &v; + } + static ScalingParamType* kZero() { + static ScalingParamType v = __float2half(0.0f); + return &v; + } +}; + +struct ParamsRegion { + ParamsRegion() : offset_(nullptr), size_(0) {} + ParamsRegion(void* offset, size_t size) : offset_(offset), size_(size) {} + ~ParamsRegion() {} + + ParamsRegion& operator=(const ParamsRegion& right) { + offset_ = right.offset_; + size_ = right.size_; + return *this; + } + + bool operator==(const ParamsRegion& right) { + bool comp_eq = true; + comp_eq = comp_eq && (offset_ == right.offset_); + comp_eq = comp_eq && (size_ = right.size_); + return comp_eq; + } + + void* offset_; + size_t size_; +}; + +template +class TensorDescriptors { + public: + TensorDescriptors(size_t n, + const std::vector>& dim, + const std::vector>& stride) { + descs_.resize(n); + CHECK_EQ(dim.size(), stride.size()) + << "dim size should be equal to stride size"; + for (size_t i = 0; i < n; ++i) { + CUDNN_CHECK(cudnnCreateTensorDescriptor(&descs_[i])); + CUDNN_CHECK(cudnnSetTensorNdDescriptor(descs_[i], + cudnnTypeWrapper::type, + dim[i].size(), + dim[i].data(), + stride[i].data())); + } + } + + ~TensorDescriptors() { + for (auto desc : descs_) { + CUDNN_CHECK(cudnnDestroyTensorDescriptor(desc)); + } + } + + const cudnnTensorDescriptor_t* descs() const { return descs_.data(); } + + int size() const { return descs_.size(); } + + private: + std::vector descs_; +}; + +} // namespace cudnn } // namespace math } // namespace cuda } // namespace lite diff --git a/lite/backends/cuda/math/cudnn_softmax.cc b/lite/backends/cuda/math/cudnn_softmax.cc index 5aafc519ac7f2c123c136de56569a79bf1ace1bd..cb8422fa014c12efdc76ad0878f04fe845c2c235 100644 --- a/lite/backends/cuda/math/cudnn_softmax.cc +++ b/lite/backends/cuda/math/cudnn_softmax.cc @@ -54,7 +54,7 @@ bool CudnnSoftmax::Create(const operators::SoftmaxParam& param, const int stride_c = H * stride_h; const int stride_n = C * stride_c; CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(bottom_desc_, - GetCudnnDataType(), + cudnn::cudnnTypeWrapper::type, N, C, H, @@ -64,7 +64,7 @@ bool CudnnSoftmax::Create(const operators::SoftmaxParam& param, stride_h, stride_w)); CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(top_desc_, - GetCudnnDataType(), + cudnn::cudnnTypeWrapper::type, N, C, H, diff --git a/lite/backends/cuda/math/sequence2batch.cu b/lite/backends/cuda/math/sequence2batch.cu index 9a93362b3bb163b889049d07186634987ed63940..1d7e9d756e78a2060c529792821c314b1a725904 100644 --- a/lite/backends/cuda/math/sequence2batch.cu +++ b/lite/backends/cuda/math/sequence2batch.cu @@ -30,17 +30,12 @@ __global__ void CopyMatrixRowsKernel(const T* src, int height, int width, bool is_src_index) { - int idx = threadIdx.x; - int idy = threadIdx.y; - int row_id = blockDim.y * blockIdx.x + idy; - if (row_id < height) { - int src_idx = is_src_index ? index[row_id] : row_id; - int dst_idx = is_src_index ? row_id : index[row_id]; - const T* src_data = src + src_idx * width; - T* dst_data = dst + dst_idx * width; - for (int i = idx; i < width; i += blockDim.x) { - dst_data[i] = src_data[i]; - } + CUDA_KERNEL_LOOP(tid, height * width) { + int row = tid / width; + int idx = tid % width; + int src_row = is_src_index ? index[row] : row; + int dst_row = is_src_index ? row : index[row]; + dst[dst_row * width + idx] = src[src_row * width + idx]; } } @@ -69,9 +64,8 @@ void CopyMatrixRowsFunctor::operator()( sizeof(uint64_t) * index_lod.size(), IoDirection::HtoD, stream); - dim3 threads(128, 8); - dim3 grids((height + threads.y - 1) / threads.y); - CopyMatrixRowsKernel<<>>( + CopyMatrixRowsKernel< + T><<>>( src_data, dst_data, index_tensor_data, height, width, is_src_index); CUDA_POST_KERNEL_CHECK; } diff --git a/lite/backends/cuda/math/sequence_helper.cu b/lite/backends/cuda/math/sequence_helper.cu new file mode 100644 index 0000000000000000000000000000000000000000..1818b9134b2dba595ca1e8a7bffc1077654b932e --- /dev/null +++ b/lite/backends/cuda/math/sequence_helper.cu @@ -0,0 +1,215 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include "lite/backends/cuda/cuda_utils.h" +#include "lite/backends/cuda/math/sequence_helper.h" +#include "lite/backends/cuda/math/utils.h" + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +template +__global__ void Map2Out( + Dtype* output, const Dtype* input, const int* map, int count, int lastdim) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < count) { + int seq = tid / lastdim; + output[map[seq] * lastdim + tid % lastdim] = input[tid]; + } +} + +template +__global__ void Map2In( + Dtype* output, const Dtype* input, const int* map, int count, int lastdim) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < count) { + int seq = tid / lastdim; + output[tid] = input[map[seq] * lastdim + tid % lastdim]; + } +} + +template +void Map2OutFunc(const Dtype* input, + Dtype* output, + int word_size, + int seq_sum, + cudaStream_t stream, + int* dev_map_vec) { + int count = seq_sum * word_size; + int block_dim = count; + int grid_dim = 1; + + if (count > 1024) { + block_dim = 256; + grid_dim = (count + block_dim - 1) / block_dim; + } + + Map2Out<<>>( + output, input, dev_map_vec, count, word_size); +} + +template +void Map2InFunc(const Dtype* input, + Dtype* output, + int hidden_size, + int seq_sum, + cudaStream_t stream, + int* dev_map_vec) { + int count = seq_sum * hidden_size; + int block_dim = count; + int grid_dim = 1; + if (count > 1024) { + block_dim = 256; + grid_dim = (count + block_dim - 1) / block_dim; + } + + Map2In<<>>( + output, input, dev_map_vec, count, hidden_size); +} + +template +void SeqSortedseqTranseUtil::Seq2SortedSeq(const Dtype* input, + Dtype* output, + int word_size, + cudaStream_t stream) { + int seq_sum = map_vec_.size(); + Map2OutFunc(input, output, word_size, seq_sum, stream, dev_map_vec_); +} + +template +void SeqSortedseqTranseUtil::SortedSeq2Seq(const Dtype* input, + Dtype* output, + int hidden_size, + cudaStream_t stream) { + int seq_sum = map_vec_.size(); + Map2InFunc(input, output, hidden_size, seq_sum, stream, dev_map_vec_); +} + +bool SeqSortedseqTranseUtil::GetSortedMap(const std::vector& offset_vec, + cudaStream_t stream_id) { + int batch_size = offset_vec.size() - 1; + int word_sum = offset_vec[offset_vec.size() - 1]; + std::vector length_vec(batch_size); + length_index_.resize(batch_size); + int emit_length = 0; + + if (batch_size == 1) { + emit_length = offset_vec[1] - offset_vec[0]; + emit_offset_vec_.resize(emit_length + 1); + + for (int i = 0; i <= emit_length; ++i) { + emit_offset_vec_[i] = i; + } + + return false; + } + + int max_len = 0; + + for (int i = 0; i < offset_vec.size() - 1; ++i) { + int len = offset_vec[i + 1] - offset_vec[i]; + max_len = max_len > len ? max_len : len; + length_vec[i] = len; + length_index_[i] = i; + } + + emit_length = max_len; + + if (max_len == 1) { + emit_offset_vec_.resize(2); + emit_offset_vec_[0] = 0; + emit_offset_vec_[1] = emit_length * batch_size; + return false; + } + + std::stable_sort(length_index_.begin(), + length_index_.end(), + [&length_vec](int i1, int i2) { + return length_vec[i1] > length_vec[i2]; + }); + + emit_offset_vec_.resize(max_len + 1); + map_vec_.resize(word_sum); + + if (word_sum > dev_map_vec_length_) { + if (dev_map_vec_ != nullptr) { + TargetWrapperCuda::Free(static_cast(dev_map_vec_)); + } + + dev_map_vec_ = + static_cast(TargetWrapperCuda::Malloc(sizeof(int) * word_sum)); + dev_map_vec_length_ = word_sum; + } + + int target_word_id = 0; + std::vector length_vec_cnt = length_vec; + int last_batch_size = batch_size; + for (int word_id_in_seq = 0; word_id_in_seq < max_len; word_id_in_seq++) { + emit_offset_vec_[word_id_in_seq] = target_word_id; + + for (int batch_id = 0; batch_id < last_batch_size; batch_id++) { + int old_batch_id = length_index_[batch_id]; + + if (length_vec_cnt[old_batch_id] > 0) { + int inner_word_id_in_seq = word_id_in_seq; + + if (is_reverse_) { + inner_word_id_in_seq = length_vec[old_batch_id] - 1 - word_id_in_seq; + } + + int old_word_id = offset_vec[old_batch_id] + inner_word_id_in_seq; + map_vec_[old_word_id] = target_word_id; + length_vec_cnt[old_batch_id]--; + target_word_id++; + } else { + last_batch_size--; + break; + } + } + } + + TargetWrapperCuda::MemcpyAsync(dev_map_vec_, + map_vec_.data(), + sizeof(int) * word_sum, + IoDirection::HtoD, + stream_id); + emit_offset_vec_[max_len] = word_sum; + emit_length_ = emit_length; + return true; +} + +template void SeqSortedseqTranseUtil::Seq2SortedSeq(const float* input, + float* output, + int word_size, + cudaStream_t stream); +template void SeqSortedseqTranseUtil::SortedSeq2Seq(const float* input, + float* output, + int hidden_size, + cudaStream_t stream); +template void SeqSortedseqTranseUtil::Seq2SortedSeq(const half* input, + half* output, + int word_size, + cudaStream_t stream); +template void SeqSortedseqTranseUtil::SortedSeq2Seq(const half* input, + half* output, + int hidden_size, + cudaStream_t stream); + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/math/sequence_helper.h b/lite/backends/cuda/math/sequence_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..ea02279c762671844a68226df9c14905ce766563 --- /dev/null +++ b/lite/backends/cuda/math/sequence_helper.h @@ -0,0 +1,77 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include + +#include "lite/backends/cuda/target_wrapper.h" + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +class SeqSortedseqTranseUtil { + public: + explicit SeqSortedseqTranseUtil(bool is_reverse = false, bool is_bi = false) + : is_reverse_(is_reverse), + is_bi_(is_bi), + dev_map_vec_(nullptr), + dev_map_vec_length_(0) {} + + ~SeqSortedseqTranseUtil() { + if (dev_map_vec_ != nullptr) { + TargetWrapperCuda::Free(static_cast(dev_map_vec_)); + } + } + + std::vector& GetLengthIndex() { return length_index_; } + std::vector& GetEmitOffsetVec() { return emit_offset_vec_; } + std::vector& GetMapVec() { return map_vec_; } + int* GetDevMapVec() { return dev_map_vec_; } + int GetEmitLength() { return emit_length_; } + + template + void Seq2SortedSeq(const Dtype* input, + Dtype* output, + int word_size, + cudaStream_t stream); + + template + void SortedSeq2Seq(const Dtype* input, + Dtype* output, + int hidden_size, + cudaStream_t stream); + + bool GetSortedMap(const std::vector& offset_vec, cudaStream_t stream_id); + + private: + std::vector length_index_; + std::vector emit_offset_vec_; + std::vector map_vec_; + int emit_length_; + + bool is_reverse_; + bool is_bi_; + int* dev_map_vec_; + int dev_map_vec_length_; +}; + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/core/CMakeLists.txt b/lite/core/CMakeLists.txt index af2bfbe86aaa1b3f145838015a6d6a62090cb3b1..d51a93971ff197b6073cc80774322ca151377659 100644 --- a/lite/core/CMakeLists.txt +++ b/lite/core/CMakeLists.txt @@ -133,7 +133,7 @@ lite_cc_library(type_system SRCS type_system.cc DEPS tensor target_wrapper) lite_cc_library(program SRCS program.cc DEPS op kernel model_parser ${ops} ${cpp_wrapper} PROFILE_DEPS lite_profiler - CUDA_DEPS nvtx_wrapper) + CUDA_DEPS nvtx_wrapper cuda_type_trans) if (NOT LITE_ON_TINY_PUBLISH) lite_cc_library(optimizer SRCS optimizer.cc DEPS mir_pass_manager model_parser program) diff --git a/lite/kernels/cuda/gru_compute.cu b/lite/kernels/cuda/gru_compute.cu index ddca95048b303cce55cc3435b15f945a84fc8c0c..84f59496e5fce7771f08cbc6d04d71cdc46055bc 100644 --- a/lite/kernels/cuda/gru_compute.cu +++ b/lite/kernels/cuda/gru_compute.cu @@ -14,6 +14,7 @@ #include "lite/kernels/cuda/gru_compute.h" #include +#include #include "lite/backends/cuda/cuda_utils.h" #include "lite/backends/cuda/math/bias.h" @@ -273,6 +274,8 @@ void GRUCompute::Run() { auto& param = this->template Param(); auto* input = param.input; + T* x_data = + const_cast(input)->template mutable_data(TARGET(kCUDA)); lite::Tensor* h0{nullptr}; if (param.h0) { h0 = const_cast(param.h0); @@ -289,7 +292,7 @@ void GRUCompute::Run() { lite::Tensor* hidden = param.hidden; T* batch_reset_hidden_prev_data = batch_reset_hidden_prev->template mutable_data(TARGET(kCUDA)); - hidden->template mutable_data(TARGET(kCUDA)); + T* out_data = hidden->template mutable_data(TARGET(kCUDA)); T* batch_gate_data = batch_gate->template mutable_data(TARGET(kCUDA)); T* batch_hidden_data = batch_hidden->template mutable_data(TARGET(kCUDA)); bool is_reverse = param.is_reverse; @@ -300,14 +303,28 @@ void GRUCompute::Run() { auto hidden_dims = hidden->dims(); int frame_size = hidden_dims[1]; - lite::cuda::math::LoDTensor2BatchFunctor batch_func; - batch_func(*input, batch_gate, is_reverse, stream); + LoD offset_vec_vec = input->lod(); + std::vector offset(offset_vec_vec[offset_vec_vec.size() - 1].size()); + for (size_t i = 0; i < offset_vec_vec[offset_vec_vec.size() - 1].size(); + ++i) { + offset[i] = static_cast(offset_vec_vec[offset_vec_vec.size() - 1][i]); + } + bool need_process = seq_utils_.GetSortedMap(offset, stream); + int emit_length = seq_utils_.GetEmitOffsetVec().size() - 1; + auto emit_offset_vec = seq_utils_.GetEmitOffsetVec(); + if (need_process) { + seq_utils_.Seq2SortedSeq( + input->template data(), batch_gate_data, 3 * frame_size, stream); + x_data = batch_gate_data; + out_data = batch_hidden_data; + } if (bias) { + // TODO(wilber): validate when bias is not nullptr lite::cuda::math::RowwiseAdd add_bias; - add_bias(batch_gate_data, + add_bias(x_data, bias->template data(), - batch_gate_data, + x_data, frame_size, batch_gate->numel(), stream); @@ -320,6 +337,7 @@ void GRUCompute::Run() { // Since the batch computing for GRU reorders the input sequences // according to their length. The initialized cell state also needs // to reorder. + // TODO(wilber): validate when h0 is not nullptr ordered_h0_.Resize(h0->dims()); lite::cuda::math::CopyMatrixRowsFunctor row_shuffle; row_shuffle(*h0, &ordered_h0_, batch_gate->lod()[2], true, stream); @@ -327,15 +345,13 @@ void GRUCompute::Run() { } else { gru_value.prev_out_value = nullptr; } - auto batch_starts = batch_gate->lod()[0]; - size_t num_batch = batch_starts.size() - 1; - for (size_t n = 0; n < num_batch; ++n) { - int bstart = static_cast(batch_starts[n]); - int bend = static_cast(batch_starts[n + 1]); + for (size_t n = 0; n < emit_length; ++n) { + int bstart = emit_offset_vec[n]; + int bend = emit_offset_vec[n + 1]; int cur_batch_size = bend - bstart; - gru_value.output_value = batch_hidden_data + bstart * frame_size; - gru_value.gate_value = batch_gate_data + bstart * frame_size * 3; + gru_value.output_value = out_data + bstart * frame_size; + gru_value.gate_value = x_data + bstart * frame_size * 3; gru_value.reset_output_value = batch_reset_hidden_prev_data + bstart * frame_size; @@ -349,10 +365,13 @@ void GRUCompute::Run() { &context); gru_value.prev_out_value = gru_value.output_value; } - - lite::cuda::math::Batch2LoDTensorFunctor to_seq; - batch_hidden->set_lod(batch_gate->lod()); - to_seq(*batch_hidden, hidden, stream); + if (need_process) { + seq_utils_.SortedSeq2Seq(batch_hidden_data, + hidden->mutable_data(TARGET(kCUDA)), + frame_size, + stream); + } + hidden->set_lod(input->lod()); } } // namespace cuda diff --git a/lite/kernels/cuda/gru_compute.h b/lite/kernels/cuda/gru_compute.h index 070deca2c54b919d1afeb856633d94fe5919eabd..f710404378a213a5d9b5c18769c6be05f735404f 100644 --- a/lite/kernels/cuda/gru_compute.h +++ b/lite/kernels/cuda/gru_compute.h @@ -16,6 +16,7 @@ #include #include "lite/backends/cuda/math/gemm.h" +#include "lite/backends/cuda/math/sequence_helper.h" #include "lite/core/kernel.h" #include "lite/operators/op_params.h" @@ -38,6 +39,7 @@ class GRUCompute : public KernelLite { private: std::unique_ptr> gemm_impl_{nullptr}; lite::Tensor ordered_h0_; + lite::cuda::math::SeqSortedseqTranseUtil seq_utils_; }; } // namespace cuda diff --git a/lite/operators/softmax_op.cc b/lite/operators/softmax_op.cc index 4e4aa9c9812978eb2a7cf49bc1067e28131c8663..8a45df119ff36fee044a063c6adfee6136b750be 100644 --- a/lite/operators/softmax_op.cc +++ b/lite/operators/softmax_op.cc @@ -55,6 +55,8 @@ bool SoftmaxOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { if (opdesc.HasAttr("use_cudnn")) { param_.use_cudnn = opdesc.GetAttr("use_cudnn"); } + // TODO(wilber): use cudnn default when compile with cuda. + param_.use_cudnn = true; return true; }