diff --git a/lite/backends/cuda/math/transpose.cu b/lite/backends/cuda/math/transpose.cu index cebcece812dc584d0921edea2fef8f129e430b56..c50840fe269657965db8c58b171fce6819009775 100644 --- a/lite/backends/cuda/math/transpose.cu +++ b/lite/backends/cuda/math/transpose.cu @@ -69,44 +69,16 @@ void BatchTranspose2DCUDAImpl(const int N, const int W, const T* input, T* out, - CUDAContext* ctx) { + cudaStream_t* stream) { const int dh = (H + kTileDim - 1) / kTileDim; const int dw = (W + kTileDim - 1) / kTileDim; BatchTranspose2DCUDAKernel< - T><<exec_stream()>>>( + T><<>>( N, H, W, dh, dw, input, out); cudaError_t error = cudaGetLastError(); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); } -#define TYPE_SPECIALIZED_CUDA_NCHW2NHWC(T) \ - template <> \ - void NCHW2NHWC(const int N, \ - const int C, \ - const int HxW, \ - const T* X, \ - T* Y, \ - CUDAContext* ctx) { \ - BatchTranspose2DCUDAImpl(N, C, HxW, X, Y, ctx); \ - } -TYPE_SPECIALIZED_CUDA_NCHW2NHWC(float) -TYPE_SPECIALIZED_CUDA_NCHW2NHWC(int8_t) -#undef TYPE_SPECIALIZED_CUDA_NCHW2NHWC - -#define TYPE_SPECIALIZED_CUDA_NHWC2NCHW(T) \ - template <> \ - void NHWC2NCHW(const int N, \ - const int C, \ - const int HxW, \ - const T* X, \ - T* Y, \ - CUDAContext* ctx) { \ - BatchTranspose2DCUDAImpl(N, HxW, C, X, Y, ctx); \ - } -TYPE_SPECIALIZED_CUDA_NHWC2NCHW(float) -TYPE_SPECIALIZED_CUDA_NHWC2NCHW(int8_t) -#undef TYPE_SPECIALIZED_CUDA_NHWC2NCHW - template __global__ void TransposeCUDAKernel(const int size, const int ndim, @@ -136,7 +108,9 @@ void TransposeCUDAImpl(const std::vector& X_dims, const std::vector& axes, const T* X, T* Y, - CUDAContext* ctx) { + lite::Tensor* Y_dims_, + lite::Tensor* strides_, + cudaStream_t* stream) { CHECK_EQ(X_dims.size(), axes.size()) << "dimension size should be equal"; int ndim = X_dims.size(); std::vector strides(ndim, 0); @@ -156,37 +130,68 @@ void TransposeCUDAImpl(const std::vector& X_dims, size *= X_dims[i]; } - lite::Tensor Y_dims_, strides_; - Y_dims_.Resize(std::vector({ndim})); - int* d_y_dims = Y_dims_.mutable_data(TARGET(kCUDA)); - CopySync( - d_y_dims, Y_dims.data(), sizeof(int) * Y_dims.size(), IoDirection::HtoD); + Y_dims_->Resize(std::vector({ndim})); + int* d_y_dims = Y_dims_->mutable_data(TARGET(kCUDA)); + TargetWrapperCuda::MemcpyAsync(d_y_dims, + Y_dims.data(), + sizeof(int) * Y_dims.size(), + IoDirection::HtoD, + *stream); - strides_.Resize(std::vector({ndim})); - int* d_strides = strides_.mutable_data(TARGET(kCUDA)); - CopySync(d_strides, - strides.data(), - sizeof(int) * strides.size(), - IoDirection::HtoD); + strides_->Resize(std::vector({ndim})); + int* d_strides = strides_->mutable_data(TARGET(kCUDA)); + TargetWrapperCuda::MemcpyAsync(d_strides, + strides.data(), + sizeof(int) * strides.size(), + IoDirection::HtoD, + *stream); const int M = (size + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; - TransposeCUDAKernel<<exec_stream()>>>( + TransposeCUDAKernel<<>>( size, ndim, d_strides, d_y_dims, X, Y); auto e = cudaGetLastError(); CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e); } -#define TYPE_SPECIALIZED_CUDA_TRANSPOSE(T) \ - template <> \ - void Transpose(const std::vector& X_dims, \ - const std::vector& axes, \ - const T* X, \ - T* Y, \ - CUDAContext* ctx) { \ - TransposeCUDAImpl(X_dims, axes, X, Y, ctx); \ - } -TYPE_SPECIALIZED_CUDA_TRANSPOSE(float) -#undef TYPE_SPECIALIZED_CUDA_TRANSPOSEF +template +void Transpose::NCHW2NHWC( + int N, int C, int HxW, const T* X, T* Y, cudaStream_t* stream) { + BatchTranspose2DCUDAImpl(N, C, HxW, X, Y, stream); +} + +template +void Transpose::NHWC2NCHW( + int N, int C, int HxW, const T* X, T* Y, cudaStream_t* stream) { + BatchTranspose2DCUDAImpl(N, HxW, C, X, Y, stream); +} + +template +void Transpose::transpose(T* dst, + const T* src, + const std::vector& src_dims, + const std::vector& axes, + cudaStream_t* stream) { + TransposeCUDAImpl(src_dims, axes, src, dst, &Y_dims_, &strides_, stream); +} + +// template +// void Transpose::transpose(T* dst, +// const T* src, +// const std::vector& src_dims, +// const std::vector& axes, +// cudaStream_t* stream) { +// std::vector _src_dims(src_dims.size(), 0); +// std::transform( +// src_dims.begin(), +// src_dims.end(), +// _src_dims.begin(), +// [](int data) -> int64_t { return static_cast(data); }); +// TransposeCUDAImpl(_src_dims, axes, src, dst, &Y_dims_, &strides_, +// stream); +//} + +template class Transpose; +template class Transpose; } // namespace math } // namespace cuda diff --git a/lite/backends/cuda/math/transpose.h b/lite/backends/cuda/math/transpose.h index ba2464547b587f44cd9b0ce287a0d40d37d46411..ed52ba3b5590ab631c3c57a0472e16cb0ed51a91 100644 --- a/lite/backends/cuda/math/transpose.h +++ b/lite/backends/cuda/math/transpose.h @@ -26,17 +26,27 @@ namespace cuda { namespace math { template -void NCHW2NHWC(int N, int C, int HxW, const T* X, T* Y, CUDAContext* context); +class Transpose { + public: + void NCHW2NHWC(int N, int C, int HxW, const T* X, T* Y, cudaStream_t* stream); -template -void NHWC2NCHW(int N, int C, int HxW, const T* X, T* Y, CUDAContext* context); + void NHWC2NCHW(int N, int C, int HxW, const T* X, T* Y, cudaStream_t* stream); -template -void Transpose(const std::vector& X_dims, - const std::vector& axes, - const T* X, - T* Y, - CUDAContext* ctx); + void transpose(T* dst, + const T* src, + const std::vector& src_dims, + const std::vector& axes, + cudaStream_t* stream); + + // void transpose(T* dst, + // const T* src, + // const std::vector& src_dims, + // const std::vector& axes, + // cudaStream_t* stream); + + private: + lite::Tensor Y_dims_, strides_; // for transpose. +}; } // namespace math } // namespace cuda diff --git a/lite/kernels/cuda/CMakeLists.txt b/lite/kernels/cuda/CMakeLists.txt index 4bf1cbf5210214befb3620f8b7d70923f41f98f2..313bb0da1cb2dad2e6867d266bc7acb1ac52a183 100644 --- a/lite/kernels/cuda/CMakeLists.txt +++ b/lite/kernels/cuda/CMakeLists.txt @@ -26,7 +26,7 @@ add_kernel(pool_compute_cuda CUDA basic SRCS pool_compute.cu DEPS ${lite_kernel_deps} cudnn_pool) add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.cu DEPS ${lite_kernel_deps}) add_kernel(search_seq_depadding_compute_cuda CUDA extra SRCS search_seq_depadding_compute.cu DEPS ${lite_kernel_deps}) -add_kernel(search_grnn_compute_cuda CUDA extra SRCS search_grnn_compute.cu DEPS ${lite_kernel_deps} cuda_gemm) +add_kernel(search_grnn_compute_cuda CUDA extra SRCS search_grnn_compute.cu DEPS ${lite_kernel_deps} cuda_gemm ${math_cuda}) add_kernel(sequence_reverse_compute_cuda CUDA basic SRCS sequence_reverse_compute.cu DEPS ${lite_kernel_deps}) add_kernel(sequence_concat_compute_cuda CUDA basic SRCS sequence_concat_compute.cu DEPS ${lite_kernel_deps}) add_kernel(sequence_arithmetic_compute_cuda CUDA basic SRCS sequence_arithmetic_compute.cu DEPS ${lite_kernel_deps}) diff --git a/lite/kernels/cuda/layout_compute.cc b/lite/kernels/cuda/layout_compute.cc index 6b56d9e1de28cbec57b4b45aff1d1b237b1784b9..17462a5015142540e7b1d5cb9eb1e74acd9621b5 100644 --- a/lite/kernels/cuda/layout_compute.cc +++ b/lite/kernels/cuda/layout_compute.cc @@ -29,7 +29,7 @@ inline DDim trim_singular_dims(const DDim& dims) { } std::vector trim_dims; trim_dims.resize(actual_dims_size); - for (int i = 0; i < actual_dims_size; ++i) { + for (size_t i = 0; i < actual_dims_size; ++i) { trim_dims[i] = dims[i]; } if (trim_dims.size() == 0) { @@ -41,6 +41,7 @@ inline DDim trim_singular_dims(const DDim& dims) { #define NCHWTONHWC(type) \ auto& param = this->template Param(); \ auto& ctx = this->ctx_->template As(); \ + auto stream = ctx.exec_stream(); \ auto input = param.x->template data(); \ auto input_dim = param.x->dims(); \ DDim input_trim_dim = trim_singular_dims(input_dim); \ @@ -56,11 +57,12 @@ inline DDim trim_singular_dims(const DDim& dims) { int w = input_dim[3]; \ param.y->Resize({n, h, w, c}); \ auto output = param.y->template mutable_data(TARGET(kCUDA)); \ - lite::cuda::math::NCHW2NHWC(n, c, h * w, input, output, &ctx); + trans.NCHW2NHWC(n, c, h* w, input, output, &stream); #define NHWCTONCHW(type) \ auto& param = this->template Param(); \ auto& ctx = this->ctx_->template As(); \ + auto stream = ctx.exec_stream(); \ auto input = param.x->template data(); \ auto input_dim = param.x->dims(); \ DDim input_trim_dim = trim_singular_dims(input_dim); \ @@ -76,7 +78,7 @@ inline DDim trim_singular_dims(const DDim& dims) { int c = input_dim[3]; \ param.y->Resize({n, c, h, w}); \ auto output = param.y->template mutable_data(TARGET(kCUDA)); \ - lite::cuda::math::NHWC2NCHW(n, c, h * w, input, output, &ctx); + trans.NHWC2NCHW(n, c, h* w, input, output, &stream); void NCHWToNHWCCompute::Run() { NCHWTONHWC(float) } diff --git a/lite/kernels/cuda/layout_compute.h b/lite/kernels/cuda/layout_compute.h index 10a0961212dde34a35dcc43b07bc0207ed2c93a3..634f73038e5a9a7a215af89278e786055426b8c0 100644 --- a/lite/kernels/cuda/layout_compute.h +++ b/lite/kernels/cuda/layout_compute.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include "lite/backends/cuda/math/transpose.h" #include "lite/core/kernel.h" namespace paddle { @@ -25,6 +26,9 @@ class NCHWToNHWCCompute : public KernelLite { using param_t = operators::LayoutParam; void Run() override; virtual ~NCHWToNHWCCompute() = default; + + private: + lite::cuda::math::Transpose trans; }; class NCHWToNHWCComputeInt8 @@ -33,6 +37,9 @@ class NCHWToNHWCComputeInt8 using param_t = operators::LayoutParam; void Run() override; virtual ~NCHWToNHWCComputeInt8() = default; + + private: + lite::cuda::math::Transpose trans; }; class NHWCToNCHWCompute : public KernelLite { @@ -40,6 +47,9 @@ class NHWCToNCHWCompute : public KernelLite { using param_t = operators::LayoutParam; void Run() override; virtual ~NHWCToNCHWCompute() = default; + + private: + lite::cuda::math::Transpose trans; }; class NHWCToNCHWComputeInt8 @@ -48,6 +58,9 @@ class NHWCToNCHWComputeInt8 using param_t = operators::LayoutParam; void Run() override; virtual ~NHWCToNCHWComputeInt8() = default; + + private: + lite::cuda::math::Transpose trans; }; } // namespace cuda diff --git a/lite/kernels/cuda/search_grnn_compute.cu b/lite/kernels/cuda/search_grnn_compute.cu index 468b66e5680c7d0e5879def9a888e10faa0bca32..2c1cb94a14d911d282d8e365ca0b818e7992461d 100644 --- a/lite/kernels/cuda/search_grnn_compute.cu +++ b/lite/kernels/cuda/search_grnn_compute.cu @@ -12,6 +12,7 @@ limitations under the License. */ #pragma once #include #include +#include "lite/backends/cuda/math/transpose.h" #include "lite/core/op_registry.h" #include "lite/kernels/cuda/search_grnn_compute.h" @@ -19,294 +20,469 @@ namespace paddle { namespace lite { namespace kernels { namespace cuda { + using Tensor = lite::Tensor; -template -T sigmoid(T z) { - return 1 / (1 + std::exp(-z)); +template +__global__ void trans_map2out( + Dtype* output, const Dtype* input, const int* map, int count, int lastdim) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < count) { + int seq = tid / lastdim; + output[map[seq] * lastdim + tid % lastdim] = input[tid]; + } } -template -__global__ void PreComputeKernel( - const int num, const T* w_x_e, const T* wz_x_e, T* tilde, T* z, T* hidden) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < num) { - tilde[index] = std::tanh(w_x_e[index]); - z[index] = 1 / (1 + std::exp(-wz_x_e[index])); - hidden[index] = (1. - z[index]) * tilde[index]; +template +__global__ void trans_map2in( + Dtype* output, const Dtype* input, const int* map, int count, int lastdim) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < count) { + int seq = tid / lastdim; + output[tid] = input[map[seq] * lastdim + tid % lastdim]; } } -template -__global__ void PostComputeKernel(const int start, - const int end, - const int cap_h, - const int w_tm1, - const T* wr_x_e, - const T* ur_x_h, - const T* wz_x_e, - const T* uz_x_h, - const T* w_x_e, - const T* u_x_h, - T* r, - T* z, - T* tilde, - T* hidden) { - int j = start + blockIdx.x * blockDim.x + threadIdx.x; - if (j < end) { - r[j] = 1 / (1 + std::exp(-(wr_x_e[j] + ur_x_h[j]))); - z[j] = 1 / (1 + std::exp(-(wz_x_e[j] + uz_x_h[j]))); - tilde[j] = std::tanh(w_x_e[j] + r[j] * u_x_h[j]); - hidden[j] = z[j] * hidden[j - cap_h * w_tm1] + (1.0 - z[j]) * tilde[j]; +template +void trans_map2out_cfunc(const Dtype* input, + Dtype* output, + int word_size, + int seq_sum, + cudaStream_t stream, + int* dev_map_vec) { + int count = seq_sum * word_size; + int block_dim = count; + int grid_dim = 1; + + if (count > 1024) { + block_dim = 256; + grid_dim = (count + block_dim - 1) / block_dim; } + + trans_map2out<<>>( + output, input, dev_map_vec, count, word_size); } -void SearchGrnnCompute::PrepareForRun() { - gemm_impl_.reset(new lite::cuda::math::Gemm); +template +void trans_map2in_cfunc(const Dtype* input, + Dtype* output, + int hidden_size, + int seq_sum, + cudaStream_t stream, + int* dev_map_vec) { + int count = seq_sum * hidden_size; + int block_dim = count; + int grid_dim = 1; + if (count > 1024) { + block_dim = 256; + grid_dim = (count + block_dim - 1) / block_dim; + } + + trans_map2in<<>>( + output, input, dev_map_vec, count, hidden_size); } -void SearchGrnnCompute::PrepareLayout(const Tensor* input_blob) { - auto& param = this->Param(); - auto& context = this->ctx_->template As(); - auto cuda_stream = context.exec_stream(); +template +void SeqSortedseqTranseUtil::seq_2_sorted_seq(const Dtype* input, + Dtype* output, + int word_size, + cudaStream_t stream) { + int seq_sum = _map_vec.size(); + trans_map2out_cfunc(input, output, word_size, seq_sum, stream, _dev_map_vec); +} + +template +void SeqSortedseqTranseUtil::sorted_seq_2_seq(const Dtype* input, + Dtype* output, + int hidden_size, + cudaStream_t stream) { + int seq_sum = _map_vec.size(); + trans_map2in_cfunc(input, output, hidden_size, seq_sum, stream, _dev_map_vec); +} + +bool SeqSortedseqTranseUtil::get_sorted_map(const std::vector& offset_vec, + cudaStream_t stream_id) { + int batch_size = offset_vec.size() - 1; + int word_sum = offset_vec[offset_vec.size() - 1]; + std::vector length_vec(batch_size); + _length_index.resize(batch_size); + int emit_length = 0; + + if (batch_size == 1) { + emit_length = offset_vec[1] - offset_vec[0]; + _emit_offset_vec.resize(emit_length + 1); + + for (int i = 0; i <= emit_length; ++i) { + _emit_offset_vec[i] = i; + } - auto* _input = input_blob; - int dim0 = _input->dims()[0]; - int dim1 = 1; - if (_input->dims().size() > 1) { - dim1 = _input->dims()[1]; + return false; } - int batch = _input->lod()[0].size() - 1; - auto& offset = _input->lod()[0]; - - idx_sorted_by_width_cpu = std::make_shared(); - idx_sorted_by_width_cpu->Resize({batch}); - int* idx_sorted_by_width_cpu_data = - idx_sorted_by_width_cpu->mutable_data(); - - Tensor _width; - _width.Resize({batch}); - int* width_data = _width.mutable_data(); - // sort sequence by width (descending) and find the largest width in the - // batch - for (int i = 0; i < batch; i++) { - width_data[i] = offset[i + 1] - offset[i]; - idx_sorted_by_width_cpu_data[i] = i; + + int max_len = 0; + + for (int i = 0; i < offset_vec.size() - 1; ++i) { + int len = offset_vec[i + 1] - offset_vec[i]; + max_len = max_len > len ? max_len : len; + length_vec[i] = len; + _length_index[i] = i; + } + + emit_length = max_len; + + if (max_len == 1) { + _emit_offset_vec.resize(2); + _emit_offset_vec[0] = 0; + _emit_offset_vec[1] = emit_length * batch_size; + return false; } - std::sort(idx_sorted_by_width_cpu_data, - idx_sorted_by_width_cpu_data + batch, - [&_width](int a, int b) { - return _width.data()[a] > _width.data()[b]; + + std::sort(_length_index.begin(), + _length_index.end(), + [&length_vec](int i1, int i2) { + return length_vec[i1] > length_vec[i2]; }); - int max_width = width_data[idx_sorted_by_width_cpu_data[0]]; - - // start of reorganizing the input - std::vector new_offset; - new_offset.resize(max_width + 1); - new_offset[0] = 0; - int j = batch - 1; - int last_width = 0; - int sub_row = 0; - int sub_col = 0; - - for (int i = 1; i <= max_width;) { - for (int k = j; k >= 0; --k) { - if (width_data[idx_sorted_by_width_cpu_data[k]] > last_width) { - sub_row = width_data[idx_sorted_by_width_cpu_data[k]] - last_width; - sub_col = k + 1; - for (int s = 0; s < sub_row; s++) { - new_offset[i] = new_offset[i - 1] + sub_col; - i++; + + _emit_offset_vec.resize(max_len + 1); + _map_vec.resize(word_sum); + + if (word_sum > _dev_map_vec_length) { + if (_dev_map_vec != nullptr) { + TargetWrapperCuda::Free(static_cast(_dev_map_vec)); + } + + _dev_map_vec = + static_cast(TargetWrapperCuda::Malloc(sizeof(int) * word_sum)); + _dev_map_vec_length = word_sum; + } + + int target_word_id = 0; + std::vector length_vec_cnt = length_vec; + int last_batch_size = batch_size; + for (int word_id_in_seq = 0; word_id_in_seq < max_len; word_id_in_seq++) { + _emit_offset_vec[word_id_in_seq] = target_word_id; + + for (int batch_id = 0; batch_id < last_batch_size; batch_id++) { + int old_batch_id = _length_index[batch_id]; + + if (length_vec_cnt[old_batch_id] > 0) { + int inner_word_id_in_seq = word_id_in_seq; + + if (_is_reverse) { + inner_word_id_in_seq = length_vec[old_batch_id] - 1 - word_id_in_seq; } - // move on - last_width = width_data[idx_sorted_by_width_cpu_data[k]]; - j = k - 1; + + int old_word_id = offset_vec[old_batch_id] + inner_word_id_in_seq; + _map_vec[old_word_id] = target_word_id; + length_vec_cnt[old_batch_id]--; + target_word_id++; + } else { + last_batch_size--; break; } } } - // copying to the reorganized buffer - auto* _layout_input = new Tensor(); - auto* _layout_input_gpu = param.layout_input; - if (_input->dims().size() == 1) { - // _layout_input.reshape_batch_sequence({dim0}, new_offset); - LOG(FATAL) << "_input->dims().size() = 1, error."; - } else { - // _layout_input.reshape_batch_sequence({dim0, dim1}, new_offset); - LoD new_lod; - new_lod.push_back(new_offset); - _layout_input->set_lod(new_lod); - _layout_input->Resize({dim0, dim1}); - _layout_input_gpu->set_lod(new_lod); - _layout_input_gpu->Resize({dim0, dim1}); - } + TargetWrapperCuda::MemcpyAsync(_dev_map_vec, + _map_vec.data(), + sizeof(int) * word_sum, + IoDirection::HtoD, + stream_id); + _emit_offset_vec[max_len] = word_sum; + _emit_length = emit_length; + return true; +} - auto* new_emb = _layout_input->mutable_data(); - auto* input_cpu = new Tensor(); - input_cpu->Resize(_input->dims()); - auto* input_cpu_data = input_cpu->mutable_data(); - TargetW::MemcpyAsync(input_cpu_data, - _input->data(), - _input->numel() * sizeof(float), - IoDirection::DtoH, - cuda_stream); - for (int i = 0; i < max_width; i++) { - int w = new_offset[i + 1] - new_offset[i]; - auto* emb_start = new_emb + dim1 * new_offset[i]; - for (int j = 0; j < w; ++j) { - memcpy(emb_start + dim1 * j, - input_cpu_data + dim1 * offset[idx_sorted_by_width_cpu_data[j]] + - dim1 * i, - dim1 * sizeof(float)); - } +template +__global__ void transpose_2d(Dtype* output, const Dtype* input, int m, int n) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < m * n) { + int i = tid / n; + int j = tid % m; + output[tid] = input[j * n + i]; } +} + +void SearchGrnnCompute::WeightsPreprocess() { + auto& param = this->Param(); + auto& context = this->ctx_->template As(); + auto stream = context.exec_stream(); - auto* _layout_input_gpu_data = - _layout_input_gpu->mutable_data(TARGET(kCUDA)); - TargetW::MemcpyAsync(_layout_input_gpu_data, - new_emb, - _layout_input->numel() * sizeof(float), - IoDirection::HtoD, - cuda_stream); - delete _layout_input; - delete input_cpu; + DDim idims = param.wi->dims(); + DDim hdims = param.wh->dims(); + _wi.Resize({idims[2], idims[0], idims[1]}); + _wh.Resize({hdims[2], hdims[0], hdims[1]}); + lite::cuda::math::Transpose trans; + trans.transpose(_wi.mutable_data(TARGET(kCUDA)), + param.wi->data(), + idims.Vectorize(), + {2, 0, 1}, + &stream); + trans.transpose(_wh.mutable_data(TARGET(kCUDA)) + hdims[1] * hdims[2], + param.wh->data() + hdims[1] * hdims[2], + {hdims[0] - 1, hdims[1], hdims[2]}, + {2, 0, 1}, + &stream); + trans.transpose(_wh.mutable_data(TARGET(kCUDA)), + param.wh->data(), + {hdims[1], hdims[2]}, + {1, 0}, + &stream); + + // int thread_num = 512; + // int block_num = (hdims[1] * hdims[2] + thread_num - 1) / thread_num; + // transpose_2d<<>>( + // _wh.mutable_data(TARGET(kCUDA)), + // param.wh->data(), + // hdims[1], + // hdims[2]); } -void SearchGrnnCompute::CopyBack(float* from, float* to, int step) { +void SearchGrnnCompute::PrepareForRun() { auto& param = this->Param(); auto& context = this->ctx_->template As(); auto stream = context.exec_stream(); - auto* _input = param.x; - auto* _layout_input = param.layout_input; - - const auto& offset = _input->lod()[0]; - const auto& new_offset = _layout_input->lod()[0]; - const auto* idx_sorted_by_width_cpu_data = - idx_sorted_by_width_cpu->data(); - for (size_t i = 0; i < _layout_input->lod()[0].size() - 1; ++i) { - int w = new_offset[i + 1] - new_offset[i]; - for (int j = 0; j < w; j++) { - TargetW::MemcpyAsync( - to + step * (offset[idx_sorted_by_width_cpu_data[j]] + i), - from + (new_offset[i] + j) * step, - step * sizeof(float), - IoDirection::DtoD, - stream); + gemm_impl_.reset(new lite::cuda::math::Gemm); + _seq_util = SeqSortedseqTranseUtil(); + + WeightsPreprocess(); + + int hidden_size = param.num_hidden; + int word_size = param.num_input; + int weights_h2h_size = hidden_size * hidden_size * 3; + int weights_i2h_size = hidden_size * word_size * 3; + + lite::Tensor temp_weights_h2h_ori; + lite::Tensor temp_weights_h2h_swarp; + temp_weights_h2h_ori.Resize({weights_h2h_size}); + temp_weights_h2h_swarp.Resize({weights_h2h_size}); + + TargetWrapperCuda::MemcpyAsync(temp_weights_h2h_ori.mutable_data(), + _wh.data(), + sizeof(float) * weights_h2h_size, + IoDirection::DtoH, + stream); + cudaStreamSynchronize(stream); + + float* temp_tensor_ptr = temp_weights_h2h_swarp.mutable_data(); + memcpy(temp_tensor_ptr, + temp_weights_h2h_ori.data(), + sizeof(float) * hidden_size * hidden_size); + + float* rz_temp_tensor_ptr = temp_tensor_ptr + hidden_size * hidden_size; + const float* rz_weights_tensor_ptr = + temp_weights_h2h_ori.data() + hidden_size * hidden_size; + for (int row = 0; row < hidden_size; row++) { + for (int block = 0; block < 2; block++) { + int block_offset = block * hidden_size; + for (int cow = 0; cow < hidden_size; cow++) { + rz_temp_tensor_ptr[block * hidden_size * hidden_size + + row * hidden_size + cow] = + rz_weights_tensor_ptr[row * (2 * hidden_size) + cow + block_offset]; + } + } + } + + float* orz_temp_tensor_ptr = temp_tensor_ptr; + float* orz_weights_tensor_ptr = temp_weights_h2h_ori.mutable_data(); + for (int row = 0; row < hidden_size; row++) { + for (int block = 0; block < 3; block++) { + int block_offset = block * hidden_size; + for (int cow = 0; cow < hidden_size; cow++) { + orz_weights_tensor_ptr[row * (3 * hidden_size) + cow + block_offset] = + orz_temp_tensor_ptr[block * hidden_size * hidden_size + + row * hidden_size + cow]; + } } } + + _temp_weights_h2h.Resize({weights_h2h_size}); + TargetWrapperCuda::MemcpyAsync( + _temp_weights_h2h.mutable_data(TARGET(kCUDA)), + temp_weights_h2h_ori.data(), + sizeof(float) * weights_h2h_size, + IoDirection::HtoD, + stream); + cudaStreamSynchronize(stream); +} + +template +static inline __device__ Dtype Sigmoid(const Dtype a) { + return static_cast(1.0) / (static_cast(1.0) + expf(-a)); +} + +template +static inline __device__ Dtype Tanh(const Dtype a) { + Dtype tmp = static_cast(-2.0) * a; + return (static_cast(2.0) / (static_cast(1.0) + expf(tmp))) - + static_cast(1.0); +} + +template +__global__ void cal_cudnn_kernel(const Dtype* w_x_r, + const Dtype* w_x_z, + const Dtype* w_x_o, + const Dtype* w_h_r, + const Dtype* w_h_z, + const Dtype* w_h_o, + int hidden_size, + int batch_size, + Dtype* output, + const Dtype* hidden_pre) { + const int thread_id = blockIdx.x * blockDim.x + threadIdx.x; + const int batch_id = thread_id / hidden_size; + const int index = thread_id % hidden_size; + if (index < hidden_size && batch_id < batch_size) { + int w_base_index = batch_id * hidden_size * 3 + index; + int h_base_index = batch_id * hidden_size + index; + Dtype hidden_pre_value = hidden_pre[h_base_index]; + Dtype r = Sigmoid(w_x_r[w_base_index] + w_h_r[w_base_index]); + Dtype z = Sigmoid(w_x_z[w_base_index] + w_h_z[w_base_index]); + Dtype _h = Tanh(w_x_o[w_base_index] + w_h_o[w_base_index] * r); + + output[h_base_index] = + (static_cast(1.0) - z) * _h + z * hidden_pre_value; + } } void SearchGrnnCompute::Run() { - CHECK(ctx_) << "running context should be set first"; auto& param = this->Param(); auto& context = this->ctx_->template As(); auto stream = context.exec_stream(); - auto* bottom = param.x; - auto* wi = param.wi; - auto* wh = param.wh; - auto* top = param.out; - auto* _buffer = param.tmp_buffer; - int _cap_h = param.num_hidden; - int _cap_e = param.num_input; - - int _cap_l = bottom->dims()[0]; - int batch = bottom->lod()[0].size() - 1; - - const auto& offset = bottom->lod()[0]; - LoD top_lod; - top_lod.push_back(offset); - top->set_lod(top_lod); - std::vector top_dims_vec{_cap_l, _cap_h}; - top->Resize(top_dims_vec); - auto* top_hidden = top->mutable_data(TARGET(kCUDA)); - - const auto* dense_e2h = wi->data(); - const auto* dense_h2h = wh->data(); - - const auto* e2h = dense_e2h; - const auto* e2hr = dense_e2h + 1 * _cap_e * _cap_h; - const auto* e2hz = dense_e2h + 2 * _cap_e * _cap_h; - const auto* h2h = dense_h2h; - const auto* h2hr = dense_h2h + 1 * _cap_h * _cap_h; - const auto* h2hz = dense_h2h + 2 * _cap_h * _cap_h; - - PrepareLayout(bottom); - - auto* _layout_input = param.layout_input; - auto* new_emb = _layout_input->data(); - const auto& new_offset = _layout_input->lod()[0]; - int max_width = _layout_input->lod()[0].size() - 1; - - // this buffer is used for book keeping info which will be used in bp - // buffer also needed in bp, so make it larger - _buffer->Resize({20, _cap_l, _cap_h}); - auto* buffer_data = _buffer->mutable_data(TARGET(kCUDA)); - auto* w_x_e = buffer_data + 0 * _cap_l * _cap_h; - auto* wr_x_e = buffer_data + 1 * _cap_l * _cap_h; - auto* wz_x_e = buffer_data + 2 * _cap_l * _cap_h; - auto* u_x_h = buffer_data + 3 * _cap_l * _cap_h; - auto* ur_x_h = buffer_data + 4 * _cap_l * _cap_h; - auto* uz_x_h = buffer_data + 5 * _cap_l * _cap_h; - auto* r = buffer_data + 6 * _cap_l * _cap_h; - auto* z = buffer_data + 7 * _cap_l * _cap_h; - auto* tilde = buffer_data + 8 * _cap_l * _cap_h; - // the internal hidden - auto* hidden = buffer_data + 19 * _cap_l * _cap_h; - - gemm_impl_->init(false, true, _cap_l, _cap_h, _cap_e, &context); - gemm_impl_->run(1.0f, 0.0f, new_emb, e2h, w_x_e, &context); - gemm_impl_->init(false, true, _cap_l, _cap_h, _cap_e, &context); - gemm_impl_->run(1.0f, 0.0f, new_emb, e2hr, wr_x_e, &context); - gemm_impl_->init(false, true, _cap_l, _cap_h, _cap_e, &context); - gemm_impl_->run(1.0f, 0.0f, new_emb, e2hz, wz_x_e, &context); - - // precompute hidden0 - int num = batch * _cap_h; - int threads = 512; - int blocks = (num + threads - 1) / threads; - PreComputeKernel<<>>( - num, w_x_e, wz_x_e, tilde, z, hidden); - - // recurrence - for (int i = 1; i < max_width; i++) { - int w_tm1 = new_offset[i] - new_offset[i - 1]; - int w = new_offset[i + 1] - new_offset[i]; - - // precompute hidden i-1 to hidden i - auto* htm1 = hidden + new_offset[i - 1] * _cap_h; - - gemm_impl_->init(false, true, w, _cap_h, _cap_h, &context); - gemm_impl_->run( - 1.0f, 0.0f, htm1, h2h, u_x_h + new_offset[i] * _cap_h, &context); - gemm_impl_->init(false, true, w, _cap_h, _cap_h, &context); - gemm_impl_->run( - 1.0f, 0.0f, htm1, h2hr, ur_x_h + new_offset[i] * _cap_h, &context); - gemm_impl_->init(false, true, w, _cap_h, _cap_h, &context); - gemm_impl_->run( - 1.0f, 0.0f, htm1, h2hz, uz_x_h + new_offset[i] * _cap_h, &context); - - // compute the gate and hidden - int start = new_offset[i] * _cap_h; - int end = (new_offset[i] + w) * _cap_h; - PostComputeKernel<<>>(start, - end, - _cap_h, - w_tm1, - wr_x_e, - ur_x_h, - wz_x_e, - uz_x_h, - w_x_e, - u_x_h, - r, - z, - tilde, - hidden); + auto* x = param.x; + LoD offset_vec_vec = x->lod(); + std::vector offset(offset_vec_vec[offset_vec_vec.size() - 1].size()); + for (size_t i = 0; i < offset_vec_vec[offset_vec_vec.size() - 1].size(); + ++i) { + offset[i] = static_cast(offset_vec_vec[offset_vec_vec.size() - 1][i]); + } + const float* x_data = x->data(); + auto* dout = param.out; + std::vector out_dims_vec{x->dims()[0], param.num_hidden}; + dout->Resize(out_dims_vec); + float* dout_data = dout->mutable_data(TARGET(kCUDA)); + auto* wi = &_wi; + auto* wh = &_wh; + + const float* weights_i2h = wi->data(); + const float* weights_h2h = wh->data(); + + int batch_size = offset.size() - 1; + int seq_sum = x->dims()[0]; + bool is_batched = offset.size() > 2; + int hidden_size = param.num_hidden; + int word_size = param.num_input; + int o_offset = 0; + int r_offset = 1; + int z_offset = 2; + + is_batched = _seq_util.get_sorted_map(offset, stream); + std::vector emit_offset_vec = _seq_util.get_emit_offset_vec(); + int emit_length = emit_offset_vec.size() - 1; + + if (is_batched) { + std::vector seq_shape{1, 1, seq_sum, word_size}; + _temp_tensor_in.Resize(seq_shape); + std::vector seq_out_shape{1, 1, seq_sum, hidden_size}; + _temp_tensor_out.Resize(seq_out_shape); + _seq_util.seq_2_sorted_seq( + x_data, + _temp_tensor_in.mutable_data(TARGET(kCUDA)), + word_size, + stream); + x_data = _temp_tensor_in.data(); + dout_data = _temp_tensor_out.mutable_data(TARGET(kCUDA)); + } + + std::vector shape_wx({seq_sum, 1, 3, hidden_size}); + _temp_wx.Resize(shape_wx); + + std::vector shape_wh({1, batch_size, 3, hidden_size}); + _temp_wh.Resize(shape_wh); + + gemm_impl_->init(false, false, seq_sum, 3 * hidden_size, word_size, &context); + gemm_impl_->run(1.0f, + 0.0f, + x_data, + weights_i2h, + _temp_wx.mutable_data(TARGET(kCUDA)), + &context); + + std::vector shape_zero({batch_size * hidden_size}); + _temp_zero.Resize(shape_zero); + + TargetWrapperCuda::MemsetAsync(_temp_zero.mutable_data(TARGET(kCUDA)), + 0, + sizeof(float) * batch_size * hidden_size, + stream); + + const float* h = _temp_zero.data(); + for (int word_id = 0; word_id < emit_length; word_id++) { + int real_word_id = word_id; + int last_word_id = word_id - 1; + int emit_word_id_start = emit_offset_vec[real_word_id]; + int emit_word_id_end = emit_offset_vec[real_word_id + 1]; + int emit_word_length = emit_word_id_end - emit_word_id_start; + + const float* hidden_in; + float* hidden_out = dout_data + emit_offset_vec[real_word_id] * hidden_size; + + if (word_id == 0) { + hidden_in = h; + } else { + hidden_in = dout_data + emit_offset_vec[last_word_id] * hidden_size; + } + + float* w_x_r = _temp_wx.mutable_data(TARGET(kCUDA)) + + r_offset * hidden_size + + emit_word_id_start * hidden_size * 3; + float* w_x_z = _temp_wx.mutable_data(TARGET(kCUDA)) + + z_offset * hidden_size + + emit_word_id_start * hidden_size * 3; + float* w_x_o = _temp_wx.mutable_data(TARGET(kCUDA)) + + o_offset * hidden_size + + emit_word_id_start * hidden_size * 3; + + float* w_h_r = + _temp_wh.mutable_data(TARGET(kCUDA)) + r_offset * hidden_size; + float* w_h_z = + _temp_wh.mutable_data(TARGET(kCUDA)) + z_offset * hidden_size; + float* w_h_o = + _temp_wh.mutable_data(TARGET(kCUDA)) + o_offset * hidden_size; + gemm_impl_->init( + false, false, emit_word_length, 3 * hidden_size, hidden_size, &context); + gemm_impl_->run(1.0f, + 0.0f, + hidden_in, + _temp_weights_h2h.data(), + _temp_wh.mutable_data(TARGET(kCUDA)), + &context); + + const float* w_o = weights_h2h; + const int block_dim = 512; + const int grid_dim = + (emit_word_length * hidden_size + block_dim - 1) / block_dim; + cal_cudnn_kernel<<>>(w_x_r, + w_x_z, + w_x_o, + w_h_r, + w_h_z, + w_h_o, + hidden_size, + emit_word_length, + hidden_out, + hidden_in); + } + + if (is_batched) { + _seq_util.sorted_seq_2_seq(_temp_tensor_out.data(), + dout->mutable_data(TARGET(kCUDA)), + hidden_size, + stream); } - CopyBack(hidden, top_hidden, _cap_h); + dout->set_lod(x->lod()); } } // namespace cuda diff --git a/lite/kernels/cuda/search_grnn_compute.h b/lite/kernels/cuda/search_grnn_compute.h index 73d84635d06f578f68bd844fe275d99595e70fc8..42def3eb36d954a1f87d613a424085916a08b625 100644 --- a/lite/kernels/cuda/search_grnn_compute.h +++ b/lite/kernels/cuda/search_grnn_compute.h @@ -14,6 +14,7 @@ #pragma once #include +#include #include "lite/backends/cuda/blas.h" #include "lite/backends/cuda/math/gemm.h" #include "lite/core/kernel.h" @@ -23,6 +24,53 @@ namespace lite { namespace kernels { namespace cuda { +class SeqSortedseqTranseUtil { + public: + explicit SeqSortedseqTranseUtil(bool is_reverse = false, bool is_bi = false) + : _is_reverse(is_reverse), + _is_bi(is_bi), + _dev_map_vec(nullptr), + _dev_map_vec_length(0) {} + + ~SeqSortedseqTranseUtil() { + if (_dev_map_vec != nullptr) { + TargetWrapperCuda::Free(static_cast(_dev_map_vec)); + } + } + + std::vector& get_length_index() { return _length_index; } + std::vector& get_emit_offset_vec() { return _emit_offset_vec; } + std::vector& get_map_vec() { return _map_vec; } + int* get_dev_map_vec() { return _dev_map_vec; } + int get_emit_length() { return _emit_length; } + + template + void seq_2_sorted_seq(const Dtype* input, + Dtype* output, + int word_size, + cudaStream_t stream); + + template + void sorted_seq_2_seq(const Dtype* input, + Dtype* output, + int hidden_size, + cudaStream_t stream); + + bool get_sorted_map(const std::vector& offset_vec, + cudaStream_t stream_id); + + private: + std::vector _length_index; + std::vector _emit_offset_vec; + std::vector _map_vec; + int _emit_length; + + bool _is_reverse; + bool _is_bi; + int* _dev_map_vec; + int _dev_map_vec_length; +}; + class SearchGrnnCompute : public KernelLite { public: @@ -34,10 +82,26 @@ class SearchGrnnCompute virtual ~SearchGrnnCompute() = default; private: - std::shared_ptr idx_sorted_by_width_cpu; + // Weights preprocess: + // wi need to be transpose, the axes should be (2, 0, 1) + // wh0 should transpose, {wh1 wh2} need be transpose, the axes should be {2, + // 0, 1} + void WeightsPreprocess(); + + private: std::unique_ptr> gemm_impl_; - void PrepareLayout(const Tensor* input); - void CopyBack(float* from, float* to, int step); + + lite::Tensor _temp_tensor_in; + lite::Tensor _temp_tensor_out; + lite::Tensor _temp_wx; + lite::Tensor _temp_wh; + lite::Tensor _temp_zero; + lite::Tensor _temp_weights_h2h; + + lite::Tensor _wi; + lite::Tensor _wh; + + SeqSortedseqTranseUtil _seq_util; }; } // namespace cuda diff --git a/lite/kernels/cuda/transpose_compute.cu b/lite/kernels/cuda/transpose_compute.cu index 0050e5e0f6d67f4eacaadc675b98417b9436b006..c5693c674c573d7c9f59034dd3c0985c9d94a22f 100644 --- a/lite/kernels/cuda/transpose_compute.cu +++ b/lite/kernels/cuda/transpose_compute.cu @@ -25,6 +25,7 @@ namespace cuda { void TransposeCompute::Run() { auto& param = this->Param(); auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); const lite::Tensor* X = param.x; lite::Tensor* Out = param.output; @@ -39,8 +40,7 @@ void TransposeCompute::Run() { // NCHW -> NHWC if (axes.size() == 4 && axes[0] == 0 && axes[1] == 2 && axes[2] == 3 && axes[3] == 1) { - lite::cuda::math::NCHW2NHWC( - dims[0], dims[1], dims[2] * dims[3], in, out, &ctx); + trans.NCHW2NHWC(dims[0], dims[1], dims[2] * dims[3], in, out, &stream); cudaError_t error = cudaGetLastError(); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); return; @@ -49,14 +49,13 @@ void TransposeCompute::Run() { // NHWC -> NCHW if (axes.size() == 4 && axes[0] == 0 && axes[1] == 3 && axes[2] == 1 && axes[3] == 2) { - lite::cuda::math::NHWC2NCHW( - dims[0], dims[3], dims[1] * dims[2], in, out, &ctx); + trans.NHWC2NCHW(dims[0], dims[3], dims[1] * dims[2], in, out, &stream); cudaError_t error = cudaGetLastError(); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); return; } - lite::cuda::math::Transpose(dims, axes, in, out, &ctx); + trans.transpose(out, in, dims, axes, &stream); cudaError_t error = cudaGetLastError(); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); } diff --git a/lite/kernels/cuda/transpose_compute.h b/lite/kernels/cuda/transpose_compute.h index f85f43993d60cc9dbe5e665a3b2b0fffcbcbc7c9..273d072231fb0608deb9ed729bdf153395ee983f 100644 --- a/lite/kernels/cuda/transpose_compute.h +++ b/lite/kernels/cuda/transpose_compute.h @@ -29,7 +29,7 @@ class TransposeCompute : public KernelLite { virtual ~TransposeCompute() = default; private: - lite::Tensor axes_, dims_; + lite::cuda::math::Transpose trans; }; } // namespace cuda diff --git a/lite/kernels/cuda/transpose_compute_test.cc b/lite/kernels/cuda/transpose_compute_test.cc index 517f761b61268d2c664f74bdb338ffb79f8841f8..bf0d803a14a5f0e47c96128b953ae72a18798205 100644 --- a/lite/kernels/cuda/transpose_compute_test.cc +++ b/lite/kernels/cuda/transpose_compute_test.cc @@ -238,7 +238,7 @@ TEST(transpose, normal) { lite::Tensor x, x_cpu, x_ref; lite::Tensor out, out_cpu, out_ref; - int C = 6, H = 7, W = 8; + int C = 3, H = 128, W = 128; std::vector axes({2, 0, 1}); x.Resize({C, H, W}); out.Resize({W, C, H});