未验证 提交 ab7b2855 编写于 作者: W Wilber 提交者: GitHub

[CUDA] [Kernel] Optimize gru. (#4062)

上级 be13a60a
......@@ -20,6 +20,7 @@ nv_library(cuda_batched_gemm SRCS batched_gemm.cc DEPS ${cuda_static_deps})
nv_library(cuda_strided_gemm SRCS strided_gemm.cc DEPS ${cuda_static_deps})
nv_library(cuda_sequence_padding SRCS sequence_padding.cu DEPS ${cuda_static_deps})
nv_library(cuda_bias SRCS bias.cu DEPS ${cuda_static_deps})
nv_library(cuda_sequence_helper SRCS sequence_helper.cu DEPS ${cuda_static_deps})
set (
math_cuda
......@@ -39,6 +40,7 @@ set (
cuda_sequence_padding
cuda_bias
cudnn_helper
cuda_sequence_helper
)
set(math_cuda "${math_cuda}" CACHE GLOBAL "math cuda")
......@@ -55,31 +55,32 @@ bool CudnnConv2D<T, Ptype_out>::create(const operators::ConvParam& param,
CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->input_desc_,
CUDNN_TENSOR_NCHW,
GetCudnnDataType<Ptype_out>(),
cudnn::cudnnTypeWrapper<T>::type,
batch,
ic,
ih,
iw));
CUDNN_CHECK(cudnnSetFilter4dDescriptor(this->filter_desc_,
GetCudnnDataType<Ptype_out>(),
cudnn::cudnnTypeWrapper<T>::type,
CUDNN_TENSOR_NCHW,
oc,
ic / param.groups,
kh,
kw));
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(this->conv_desc_,
ph,
pw,
sh,
sw,
dh,
dw,
CUDNN_CROSS_CORRELATION,
GetCudnnDataType<Ptype_out>()));
CUDNN_CHECK(
cudnnSetConvolution2dDescriptor(this->conv_desc_,
ph,
pw,
sh,
sw,
dh,
dw,
CUDNN_CROSS_CORRELATION,
cudnn::cudnnTypeWrapper<T>::type));
CUDNN_CHECK(cudnnSetConvolutionGroupCount(this->conv_desc_, param.groups));
CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->output_desc_,
CUDNN_TENSOR_NCHW,
GetCudnnDataType<Ptype_out>(),
cudnn::cudnnTypeWrapper<T>::type,
batch,
oc,
oh,
......@@ -179,7 +180,7 @@ bool CudnnConv2D<T, Ptype_out>::create(const operators::ConvParam& param,
int dim_bias[] = {1, oc, 1, 1};
int stride_bias[] = {oc, 1, 1, 1};
cudnnSetTensorNdDescriptor(this->bias_desc_,
GetCudnnDataType<Ptype_out>(),
cudnn::cudnnTypeWrapper<T>::type,
4,
dim_bias,
stride_bias);
......
......@@ -21,17 +21,7 @@ namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <>
cudnnDataType_t GetCudnnDataType<PRECISION(kFloat)>() {
return CUDNN_DATA_FLOAT;
}
template <>
cudnnDataType_t GetCudnnDataType<PRECISION(kFP16)>() {
return CUDNN_DATA_HALF;
}
namespace cudnn {} // namespace cudnn
} // namespace math
} // namespace cuda
} // namespace lite
......
......@@ -25,10 +25,97 @@ namespace paddle {
namespace lite {
namespace cuda {
namespace math {
namespace cudnn {
template <lite_api::PrecisionType PType>
cudnnDataType_t GetCudnnDataType();
template <typename T>
class cudnnTypeWrapper;
template <>
class cudnnTypeWrapper<float> {
public:
static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
typedef const float ScalingParamType;
static ScalingParamType* kOne() {
static ScalingParamType v = 1.0f;
return &v;
}
static ScalingParamType* kZero() {
static ScalingParamType v = 0.0f;
return &v;
}
};
template <>
class cudnnTypeWrapper<half> {
public:
static const cudnnDataType_t type = CUDNN_DATA_HALF;
typedef const half ScalingParamType;
static ScalingParamType* kOne() {
static ScalingParamType v = __float2half(1.0f);
return &v;
}
static ScalingParamType* kZero() {
static ScalingParamType v = __float2half(0.0f);
return &v;
}
};
struct ParamsRegion {
ParamsRegion() : offset_(nullptr), size_(0) {}
ParamsRegion(void* offset, size_t size) : offset_(offset), size_(size) {}
~ParamsRegion() {}
ParamsRegion& operator=(const ParamsRegion& right) {
offset_ = right.offset_;
size_ = right.size_;
return *this;
}
bool operator==(const ParamsRegion& right) {
bool comp_eq = true;
comp_eq = comp_eq && (offset_ == right.offset_);
comp_eq = comp_eq && (size_ = right.size_);
return comp_eq;
}
void* offset_;
size_t size_;
};
template <typename T>
class TensorDescriptors {
public:
TensorDescriptors(size_t n,
const std::vector<std::vector<int>>& dim,
const std::vector<std::vector<int>>& stride) {
descs_.resize(n);
CHECK_EQ(dim.size(), stride.size())
<< "dim size should be equal to stride size";
for (size_t i = 0; i < n; ++i) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&descs_[i]));
CUDNN_CHECK(cudnnSetTensorNdDescriptor(descs_[i],
cudnnTypeWrapper<T>::type,
dim[i].size(),
dim[i].data(),
stride[i].data()));
}
}
~TensorDescriptors() {
for (auto desc : descs_) {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(desc));
}
}
const cudnnTensorDescriptor_t* descs() const { return descs_.data(); }
int size() const { return descs_.size(); }
private:
std::vector<cudnnTensorDescriptor_t> descs_;
};
} // namespace cudnn
} // namespace math
} // namespace cuda
} // namespace lite
......
......@@ -54,7 +54,7 @@ bool CudnnSoftmax<T, Ptype>::Create(const operators::SoftmaxParam& param,
const int stride_c = H * stride_h;
const int stride_n = C * stride_c;
CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(bottom_desc_,
GetCudnnDataType<Ptype>(),
cudnn::cudnnTypeWrapper<T>::type,
N,
C,
H,
......@@ -64,7 +64,7 @@ bool CudnnSoftmax<T, Ptype>::Create(const operators::SoftmaxParam& param,
stride_h,
stride_w));
CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(top_desc_,
GetCudnnDataType<Ptype>(),
cudnn::cudnnTypeWrapper<T>::type,
N,
C,
H,
......
......@@ -30,17 +30,12 @@ __global__ void CopyMatrixRowsKernel(const T* src,
int height,
int width,
bool is_src_index) {
int idx = threadIdx.x;
int idy = threadIdx.y;
int row_id = blockDim.y * blockIdx.x + idy;
if (row_id < height) {
int src_idx = is_src_index ? index[row_id] : row_id;
int dst_idx = is_src_index ? row_id : index[row_id];
const T* src_data = src + src_idx * width;
T* dst_data = dst + dst_idx * width;
for (int i = idx; i < width; i += blockDim.x) {
dst_data[i] = src_data[i];
}
CUDA_KERNEL_LOOP(tid, height * width) {
int row = tid / width;
int idx = tid % width;
int src_row = is_src_index ? index[row] : row;
int dst_row = is_src_index ? row : index[row];
dst[dst_row * width + idx] = src[src_row * width + idx];
}
}
......@@ -69,9 +64,8 @@ void CopyMatrixRowsFunctor<T>::operator()(
sizeof(uint64_t) * index_lod.size(),
IoDirection::HtoD,
stream);
dim3 threads(128, 8);
dim3 grids((height + threads.y - 1) / threads.y);
CopyMatrixRowsKernel<T><<<grids, threads, 0, stream>>>(
CopyMatrixRowsKernel<
T><<<CUDA_GET_BLOCKS(height * width), CUDA_NUM_THREADS, 0, stream>>>(
src_data, dst_data, index_tensor_data, height, width, is_src_index);
CUDA_POST_KERNEL_CHECK;
}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/backends/cuda/math/sequence_helper.h"
#include "lite/backends/cuda/math/utils.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename Dtype>
__global__ void Map2Out(
Dtype* output, const Dtype* input, const int* map, int count, int lastdim) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < count) {
int seq = tid / lastdim;
output[map[seq] * lastdim + tid % lastdim] = input[tid];
}
}
template <typename Dtype>
__global__ void Map2In(
Dtype* output, const Dtype* input, const int* map, int count, int lastdim) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < count) {
int seq = tid / lastdim;
output[tid] = input[map[seq] * lastdim + tid % lastdim];
}
}
template <typename Dtype>
void Map2OutFunc(const Dtype* input,
Dtype* output,
int word_size,
int seq_sum,
cudaStream_t stream,
int* dev_map_vec) {
int count = seq_sum * word_size;
int block_dim = count;
int grid_dim = 1;
if (count > 1024) {
block_dim = 256;
grid_dim = (count + block_dim - 1) / block_dim;
}
Map2Out<<<grid_dim, block_dim, 0, stream>>>(
output, input, dev_map_vec, count, word_size);
}
template <typename Dtype>
void Map2InFunc(const Dtype* input,
Dtype* output,
int hidden_size,
int seq_sum,
cudaStream_t stream,
int* dev_map_vec) {
int count = seq_sum * hidden_size;
int block_dim = count;
int grid_dim = 1;
if (count > 1024) {
block_dim = 256;
grid_dim = (count + block_dim - 1) / block_dim;
}
Map2In<<<grid_dim, block_dim, 0, stream>>>(
output, input, dev_map_vec, count, hidden_size);
}
template <typename Dtype>
void SeqSortedseqTranseUtil::Seq2SortedSeq(const Dtype* input,
Dtype* output,
int word_size,
cudaStream_t stream) {
int seq_sum = map_vec_.size();
Map2OutFunc(input, output, word_size, seq_sum, stream, dev_map_vec_);
}
template <typename Dtype>
void SeqSortedseqTranseUtil::SortedSeq2Seq(const Dtype* input,
Dtype* output,
int hidden_size,
cudaStream_t stream) {
int seq_sum = map_vec_.size();
Map2InFunc(input, output, hidden_size, seq_sum, stream, dev_map_vec_);
}
bool SeqSortedseqTranseUtil::GetSortedMap(const std::vector<int>& offset_vec,
cudaStream_t stream_id) {
int batch_size = offset_vec.size() - 1;
int word_sum = offset_vec[offset_vec.size() - 1];
std::vector<int> length_vec(batch_size);
length_index_.resize(batch_size);
int emit_length = 0;
if (batch_size == 1) {
emit_length = offset_vec[1] - offset_vec[0];
emit_offset_vec_.resize(emit_length + 1);
for (int i = 0; i <= emit_length; ++i) {
emit_offset_vec_[i] = i;
}
return false;
}
int max_len = 0;
for (int i = 0; i < offset_vec.size() - 1; ++i) {
int len = offset_vec[i + 1] - offset_vec[i];
max_len = max_len > len ? max_len : len;
length_vec[i] = len;
length_index_[i] = i;
}
emit_length = max_len;
if (max_len == 1) {
emit_offset_vec_.resize(2);
emit_offset_vec_[0] = 0;
emit_offset_vec_[1] = emit_length * batch_size;
return false;
}
std::stable_sort(length_index_.begin(),
length_index_.end(),
[&length_vec](int i1, int i2) {
return length_vec[i1] > length_vec[i2];
});
emit_offset_vec_.resize(max_len + 1);
map_vec_.resize(word_sum);
if (word_sum > dev_map_vec_length_) {
if (dev_map_vec_ != nullptr) {
TargetWrapperCuda::Free(static_cast<void*>(dev_map_vec_));
}
dev_map_vec_ =
static_cast<int*>(TargetWrapperCuda::Malloc(sizeof(int) * word_sum));
dev_map_vec_length_ = word_sum;
}
int target_word_id = 0;
std::vector<int> length_vec_cnt = length_vec;
int last_batch_size = batch_size;
for (int word_id_in_seq = 0; word_id_in_seq < max_len; word_id_in_seq++) {
emit_offset_vec_[word_id_in_seq] = target_word_id;
for (int batch_id = 0; batch_id < last_batch_size; batch_id++) {
int old_batch_id = length_index_[batch_id];
if (length_vec_cnt[old_batch_id] > 0) {
int inner_word_id_in_seq = word_id_in_seq;
if (is_reverse_) {
inner_word_id_in_seq = length_vec[old_batch_id] - 1 - word_id_in_seq;
}
int old_word_id = offset_vec[old_batch_id] + inner_word_id_in_seq;
map_vec_[old_word_id] = target_word_id;
length_vec_cnt[old_batch_id]--;
target_word_id++;
} else {
last_batch_size--;
break;
}
}
}
TargetWrapperCuda::MemcpyAsync(dev_map_vec_,
map_vec_.data(),
sizeof(int) * word_sum,
IoDirection::HtoD,
stream_id);
emit_offset_vec_[max_len] = word_sum;
emit_length_ = emit_length;
return true;
}
template void SeqSortedseqTranseUtil::Seq2SortedSeq(const float* input,
float* output,
int word_size,
cudaStream_t stream);
template void SeqSortedseqTranseUtil::SortedSeq2Seq(const float* input,
float* output,
int hidden_size,
cudaStream_t stream);
template void SeqSortedseqTranseUtil::Seq2SortedSeq(const half* input,
half* output,
int word_size,
cudaStream_t stream);
template void SeqSortedseqTranseUtil::SortedSeq2Seq(const half* input,
half* output,
int hidden_size,
cudaStream_t stream);
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#include "lite/backends/cuda/target_wrapper.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
class SeqSortedseqTranseUtil {
public:
explicit SeqSortedseqTranseUtil(bool is_reverse = false, bool is_bi = false)
: is_reverse_(is_reverse),
is_bi_(is_bi),
dev_map_vec_(nullptr),
dev_map_vec_length_(0) {}
~SeqSortedseqTranseUtil() {
if (dev_map_vec_ != nullptr) {
TargetWrapperCuda::Free(static_cast<void*>(dev_map_vec_));
}
}
std::vector<int>& GetLengthIndex() { return length_index_; }
std::vector<int>& GetEmitOffsetVec() { return emit_offset_vec_; }
std::vector<int>& GetMapVec() { return map_vec_; }
int* GetDevMapVec() { return dev_map_vec_; }
int GetEmitLength() { return emit_length_; }
template <typename Dtype>
void Seq2SortedSeq(const Dtype* input,
Dtype* output,
int word_size,
cudaStream_t stream);
template <typename Dtype>
void SortedSeq2Seq(const Dtype* input,
Dtype* output,
int hidden_size,
cudaStream_t stream);
bool GetSortedMap(const std::vector<int>& offset_vec, cudaStream_t stream_id);
private:
std::vector<int> length_index_;
std::vector<int> emit_offset_vec_;
std::vector<int> map_vec_;
int emit_length_;
bool is_reverse_;
bool is_bi_;
int* dev_map_vec_;
int dev_map_vec_length_;
};
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
......@@ -133,7 +133,7 @@ lite_cc_library(type_system SRCS type_system.cc DEPS tensor target_wrapper)
lite_cc_library(program SRCS program.cc
DEPS op kernel model_parser ${ops} ${cpp_wrapper}
PROFILE_DEPS lite_profiler
CUDA_DEPS nvtx_wrapper)
CUDA_DEPS nvtx_wrapper cuda_type_trans)
if (NOT LITE_ON_TINY_PUBLISH)
lite_cc_library(optimizer SRCS optimizer.cc DEPS mir_pass_manager model_parser program)
......
......@@ -14,6 +14,7 @@
#include "lite/kernels/cuda/gru_compute.h"
#include <string>
#include <vector>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/backends/cuda/math/bias.h"
......@@ -273,6 +274,8 @@ void GRUCompute<T, PType>::Run() {
auto& param = this->template Param<param_t>();
auto* input = param.input;
T* x_data =
const_cast<lite::Tensor*>(input)->template mutable_data<T>(TARGET(kCUDA));
lite::Tensor* h0{nullptr};
if (param.h0) {
h0 = const_cast<lite::Tensor*>(param.h0);
......@@ -289,7 +292,7 @@ void GRUCompute<T, PType>::Run() {
lite::Tensor* hidden = param.hidden;
T* batch_reset_hidden_prev_data =
batch_reset_hidden_prev->template mutable_data<T>(TARGET(kCUDA));
hidden->template mutable_data<T>(TARGET(kCUDA));
T* out_data = hidden->template mutable_data<T>(TARGET(kCUDA));
T* batch_gate_data = batch_gate->template mutable_data<T>(TARGET(kCUDA));
T* batch_hidden_data = batch_hidden->template mutable_data<T>(TARGET(kCUDA));
bool is_reverse = param.is_reverse;
......@@ -300,14 +303,28 @@ void GRUCompute<T, PType>::Run() {
auto hidden_dims = hidden->dims();
int frame_size = hidden_dims[1];
lite::cuda::math::LoDTensor2BatchFunctor<T> batch_func;
batch_func(*input, batch_gate, is_reverse, stream);
LoD offset_vec_vec = input->lod();
std::vector<int> offset(offset_vec_vec[offset_vec_vec.size() - 1].size());
for (size_t i = 0; i < offset_vec_vec[offset_vec_vec.size() - 1].size();
++i) {
offset[i] = static_cast<int>(offset_vec_vec[offset_vec_vec.size() - 1][i]);
}
bool need_process = seq_utils_.GetSortedMap(offset, stream);
int emit_length = seq_utils_.GetEmitOffsetVec().size() - 1;
auto emit_offset_vec = seq_utils_.GetEmitOffsetVec();
if (need_process) {
seq_utils_.Seq2SortedSeq(
input->template data<T>(), batch_gate_data, 3 * frame_size, stream);
x_data = batch_gate_data;
out_data = batch_hidden_data;
}
if (bias) {
// TODO(wilber): validate when bias is not nullptr
lite::cuda::math::RowwiseAdd<T> add_bias;
add_bias(batch_gate_data,
add_bias(x_data,
bias->template data<T>(),
batch_gate_data,
x_data,
frame_size,
batch_gate->numel(),
stream);
......@@ -320,6 +337,7 @@ void GRUCompute<T, PType>::Run() {
// Since the batch computing for GRU reorders the input sequences
// according to their length. The initialized cell state also needs
// to reorder.
// TODO(wilber): validate when h0 is not nullptr
ordered_h0_.Resize(h0->dims());
lite::cuda::math::CopyMatrixRowsFunctor<T> row_shuffle;
row_shuffle(*h0, &ordered_h0_, batch_gate->lod()[2], true, stream);
......@@ -327,15 +345,13 @@ void GRUCompute<T, PType>::Run() {
} else {
gru_value.prev_out_value = nullptr;
}
auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1;
for (size_t n = 0; n < num_batch; ++n) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
for (size_t n = 0; n < emit_length; ++n) {
int bstart = emit_offset_vec[n];
int bend = emit_offset_vec[n + 1];
int cur_batch_size = bend - bstart;
gru_value.output_value = batch_hidden_data + bstart * frame_size;
gru_value.gate_value = batch_gate_data + bstart * frame_size * 3;
gru_value.output_value = out_data + bstart * frame_size;
gru_value.gate_value = x_data + bstart * frame_size * 3;
gru_value.reset_output_value =
batch_reset_hidden_prev_data + bstart * frame_size;
......@@ -349,10 +365,13 @@ void GRUCompute<T, PType>::Run() {
&context);
gru_value.prev_out_value = gru_value.output_value;
}
lite::cuda::math::Batch2LoDTensorFunctor<T> to_seq;
batch_hidden->set_lod(batch_gate->lod());
to_seq(*batch_hidden, hidden, stream);
if (need_process) {
seq_utils_.SortedSeq2Seq(batch_hidden_data,
hidden->mutable_data<T>(TARGET(kCUDA)),
frame_size,
stream);
}
hidden->set_lod(input->lod());
}
} // namespace cuda
......
......@@ -16,6 +16,7 @@
#include <memory>
#include "lite/backends/cuda/math/gemm.h"
#include "lite/backends/cuda/math/sequence_helper.h"
#include "lite/core/kernel.h"
#include "lite/operators/op_params.h"
......@@ -38,6 +39,7 @@ class GRUCompute : public KernelLite<TARGET(kCUDA), PType> {
private:
std::unique_ptr<lite::cuda::math::Gemm<T, T>> gemm_impl_{nullptr};
lite::Tensor ordered_h0_;
lite::cuda::math::SeqSortedseqTranseUtil seq_utils_;
};
} // namespace cuda
......
......@@ -55,6 +55,8 @@ bool SoftmaxOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
if (opdesc.HasAttr("use_cudnn")) {
param_.use_cudnn = opdesc.GetAttr<bool>("use_cudnn");
}
// TODO(wilber): use cudnn default when compile with cuda.
param_.use_cudnn = true;
return true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册