未验证 提交 dad43f81 编写于 作者: W Wilber 提交者: GitHub

optimize search_grnn test=develop (#2608)

optimize search_grnn
上级 d5434aa2
......@@ -69,44 +69,16 @@ void BatchTranspose2DCUDAImpl(const int N,
const int W,
const T* input,
T* out,
CUDAContext* ctx) {
cudaStream_t* stream) {
const int dh = (H + kTileDim - 1) / kTileDim;
const int dw = (W + kTileDim - 1) / kTileDim;
BatchTranspose2DCUDAKernel<
T><<<N * dh * dw, dim3(kTileDim, kBlockRows), 0, ctx->exec_stream()>>>(
T><<<N * dh * dw, dim3(kTileDim, kBlockRows), 0, *stream>>>(
N, H, W, dh, dw, input, out);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
#define TYPE_SPECIALIZED_CUDA_NCHW2NHWC(T) \
template <> \
void NCHW2NHWC<T>(const int N, \
const int C, \
const int HxW, \
const T* X, \
T* Y, \
CUDAContext* ctx) { \
BatchTranspose2DCUDAImpl<T>(N, C, HxW, X, Y, ctx); \
}
TYPE_SPECIALIZED_CUDA_NCHW2NHWC(float)
TYPE_SPECIALIZED_CUDA_NCHW2NHWC(int8_t)
#undef TYPE_SPECIALIZED_CUDA_NCHW2NHWC
#define TYPE_SPECIALIZED_CUDA_NHWC2NCHW(T) \
template <> \
void NHWC2NCHW<T>(const int N, \
const int C, \
const int HxW, \
const T* X, \
T* Y, \
CUDAContext* ctx) { \
BatchTranspose2DCUDAImpl<T>(N, HxW, C, X, Y, ctx); \
}
TYPE_SPECIALIZED_CUDA_NHWC2NCHW(float)
TYPE_SPECIALIZED_CUDA_NHWC2NCHW(int8_t)
#undef TYPE_SPECIALIZED_CUDA_NHWC2NCHW
template <typename T>
__global__ void TransposeCUDAKernel(const int size,
const int ndim,
......@@ -136,7 +108,9 @@ void TransposeCUDAImpl(const std::vector<int64_t>& X_dims,
const std::vector<int>& axes,
const T* X,
T* Y,
CUDAContext* ctx) {
lite::Tensor* Y_dims_,
lite::Tensor* strides_,
cudaStream_t* stream) {
CHECK_EQ(X_dims.size(), axes.size()) << "dimension size should be equal";
int ndim = X_dims.size();
std::vector<int> strides(ndim, 0);
......@@ -156,37 +130,68 @@ void TransposeCUDAImpl(const std::vector<int64_t>& X_dims,
size *= X_dims[i];
}
lite::Tensor Y_dims_, strides_;
Y_dims_.Resize(std::vector<int64_t>({ndim}));
int* d_y_dims = Y_dims_.mutable_data<int>(TARGET(kCUDA));
CopySync<TARGET(kCUDA)>(
d_y_dims, Y_dims.data(), sizeof(int) * Y_dims.size(), IoDirection::HtoD);
Y_dims_->Resize(std::vector<int64_t>({ndim}));
int* d_y_dims = Y_dims_->mutable_data<int>(TARGET(kCUDA));
TargetWrapperCuda::MemcpyAsync(d_y_dims,
Y_dims.data(),
sizeof(int) * Y_dims.size(),
IoDirection::HtoD,
*stream);
strides_.Resize(std::vector<int64_t>({ndim}));
int* d_strides = strides_.mutable_data<int>(TARGET(kCUDA));
CopySync<TARGET(kCUDA)>(d_strides,
strides.data(),
sizeof(int) * strides.size(),
IoDirection::HtoD);
strides_->Resize(std::vector<int64_t>({ndim}));
int* d_strides = strides_->mutable_data<int>(TARGET(kCUDA));
TargetWrapperCuda::MemcpyAsync(d_strides,
strides.data(),
sizeof(int) * strides.size(),
IoDirection::HtoD,
*stream);
const int M = (size + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
TransposeCUDAKernel<<<M, CUDA_NUM_THREADS, 0, ctx->exec_stream()>>>(
TransposeCUDAKernel<<<M, CUDA_NUM_THREADS, 0, *stream>>>(
size, ndim, d_strides, d_y_dims, X, Y);
auto e = cudaGetLastError();
CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e);
}
#define TYPE_SPECIALIZED_CUDA_TRANSPOSE(T) \
template <> \
void Transpose<T>(const std::vector<int64_t>& X_dims, \
const std::vector<int>& axes, \
const T* X, \
T* Y, \
CUDAContext* ctx) { \
TransposeCUDAImpl<T>(X_dims, axes, X, Y, ctx); \
}
TYPE_SPECIALIZED_CUDA_TRANSPOSE(float)
#undef TYPE_SPECIALIZED_CUDA_TRANSPOSEF
template <typename T>
void Transpose<T>::NCHW2NHWC(
int N, int C, int HxW, const T* X, T* Y, cudaStream_t* stream) {
BatchTranspose2DCUDAImpl<T>(N, C, HxW, X, Y, stream);
}
template <typename T>
void Transpose<T>::NHWC2NCHW(
int N, int C, int HxW, const T* X, T* Y, cudaStream_t* stream) {
BatchTranspose2DCUDAImpl<T>(N, HxW, C, X, Y, stream);
}
template <typename T>
void Transpose<T>::transpose(T* dst,
const T* src,
const std::vector<int64_t>& src_dims,
const std::vector<int>& axes,
cudaStream_t* stream) {
TransposeCUDAImpl<T>(src_dims, axes, src, dst, &Y_dims_, &strides_, stream);
}
// template <typename T>
// void Transpose<T>::transpose(T* dst,
// const T* src,
// const std::vector<int>& src_dims,
// const std::vector<int>& axes,
// cudaStream_t* stream) {
// std::vector<int64_t> _src_dims(src_dims.size(), 0);
// std::transform(
// src_dims.begin(),
// src_dims.end(),
// _src_dims.begin(),
// [](int data) -> int64_t { return static_cast<int64_t>(data); });
// TransposeCUDAImpl<T>(_src_dims, axes, src, dst, &Y_dims_, &strides_,
// stream);
//}
template class Transpose<int8_t>;
template class Transpose<float>;
} // namespace math
} // namespace cuda
......
......@@ -26,17 +26,27 @@ namespace cuda {
namespace math {
template <typename T>
void NCHW2NHWC(int N, int C, int HxW, const T* X, T* Y, CUDAContext* context);
class Transpose {
public:
void NCHW2NHWC(int N, int C, int HxW, const T* X, T* Y, cudaStream_t* stream);
template <typename T>
void NHWC2NCHW(int N, int C, int HxW, const T* X, T* Y, CUDAContext* context);
void NHWC2NCHW(int N, int C, int HxW, const T* X, T* Y, cudaStream_t* stream);
template <typename T>
void Transpose(const std::vector<int64_t>& X_dims,
const std::vector<int>& axes,
const T* X,
T* Y,
CUDAContext* ctx);
void transpose(T* dst,
const T* src,
const std::vector<int64_t>& src_dims,
const std::vector<int>& axes,
cudaStream_t* stream);
// void transpose(T* dst,
// const T* src,
// const std::vector<int>& src_dims,
// const std::vector<int>& axes,
// cudaStream_t* stream);
private:
lite::Tensor Y_dims_, strides_; // for transpose.
};
} // namespace math
} // namespace cuda
......
......@@ -26,7 +26,7 @@ add_kernel(pool_compute_cuda CUDA basic SRCS pool_compute.cu DEPS
${lite_kernel_deps} cudnn_pool)
add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.cu DEPS ${lite_kernel_deps})
add_kernel(search_seq_depadding_compute_cuda CUDA extra SRCS search_seq_depadding_compute.cu DEPS ${lite_kernel_deps})
add_kernel(search_grnn_compute_cuda CUDA extra SRCS search_grnn_compute.cu DEPS ${lite_kernel_deps} cuda_gemm)
add_kernel(search_grnn_compute_cuda CUDA extra SRCS search_grnn_compute.cu DEPS ${lite_kernel_deps} cuda_gemm ${math_cuda})
add_kernel(sequence_reverse_compute_cuda CUDA basic SRCS sequence_reverse_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_concat_compute_cuda CUDA basic SRCS sequence_concat_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_arithmetic_compute_cuda CUDA basic SRCS sequence_arithmetic_compute.cu DEPS ${lite_kernel_deps})
......
......@@ -29,7 +29,7 @@ inline DDim trim_singular_dims(const DDim& dims) {
}
std::vector<int64_t> trim_dims;
trim_dims.resize(actual_dims_size);
for (int i = 0; i < actual_dims_size; ++i) {
for (size_t i = 0; i < actual_dims_size; ++i) {
trim_dims[i] = dims[i];
}
if (trim_dims.size() == 0) {
......@@ -41,6 +41,7 @@ inline DDim trim_singular_dims(const DDim& dims) {
#define NCHWTONHWC(type) \
auto& param = this->template Param<param_t>(); \
auto& ctx = this->ctx_->template As<CUDAContext>(); \
auto stream = ctx.exec_stream(); \
auto input = param.x->template data<type>(); \
auto input_dim = param.x->dims(); \
DDim input_trim_dim = trim_singular_dims(input_dim); \
......@@ -56,11 +57,12 @@ inline DDim trim_singular_dims(const DDim& dims) {
int w = input_dim[3]; \
param.y->Resize({n, h, w, c}); \
auto output = param.y->template mutable_data<type>(TARGET(kCUDA)); \
lite::cuda::math::NCHW2NHWC<type>(n, c, h * w, input, output, &ctx);
trans.NCHW2NHWC(n, c, h* w, input, output, &stream);
#define NHWCTONCHW(type) \
auto& param = this->template Param<param_t>(); \
auto& ctx = this->ctx_->template As<CUDAContext>(); \
auto stream = ctx.exec_stream(); \
auto input = param.x->template data<type>(); \
auto input_dim = param.x->dims(); \
DDim input_trim_dim = trim_singular_dims(input_dim); \
......@@ -76,7 +78,7 @@ inline DDim trim_singular_dims(const DDim& dims) {
int c = input_dim[3]; \
param.y->Resize({n, c, h, w}); \
auto output = param.y->template mutable_data<type>(TARGET(kCUDA)); \
lite::cuda::math::NHWC2NCHW<type>(n, c, h * w, input, output, &ctx);
trans.NHWC2NCHW(n, c, h* w, input, output, &stream);
void NCHWToNHWCCompute::Run() { NCHWTONHWC(float) }
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include "lite/backends/cuda/math/transpose.h"
#include "lite/core/kernel.h"
namespace paddle {
......@@ -25,6 +26,9 @@ class NCHWToNHWCCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
using param_t = operators::LayoutParam;
void Run() override;
virtual ~NCHWToNHWCCompute() = default;
private:
lite::cuda::math::Transpose<float> trans;
};
class NCHWToNHWCComputeInt8
......@@ -33,6 +37,9 @@ class NCHWToNHWCComputeInt8
using param_t = operators::LayoutParam;
void Run() override;
virtual ~NCHWToNHWCComputeInt8() = default;
private:
lite::cuda::math::Transpose<int8_t> trans;
};
class NHWCToNCHWCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
......@@ -40,6 +47,9 @@ class NHWCToNCHWCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
using param_t = operators::LayoutParam;
void Run() override;
virtual ~NHWCToNCHWCompute() = default;
private:
lite::cuda::math::Transpose<float> trans;
};
class NHWCToNCHWComputeInt8
......@@ -48,6 +58,9 @@ class NHWCToNCHWComputeInt8
using param_t = operators::LayoutParam;
void Run() override;
virtual ~NHWCToNCHWComputeInt8() = default;
private:
lite::cuda::math::Transpose<int8_t> trans;
};
} // namespace cuda
......
......@@ -14,6 +14,7 @@
#pragma once
#include <memory>
#include <vector>
#include "lite/backends/cuda/blas.h"
#include "lite/backends/cuda/math/gemm.h"
#include "lite/core/kernel.h"
......@@ -23,6 +24,53 @@ namespace lite {
namespace kernels {
namespace cuda {
class SeqSortedseqTranseUtil {
public:
explicit SeqSortedseqTranseUtil(bool is_reverse = false, bool is_bi = false)
: _is_reverse(is_reverse),
_is_bi(is_bi),
_dev_map_vec(nullptr),
_dev_map_vec_length(0) {}
~SeqSortedseqTranseUtil() {
if (_dev_map_vec != nullptr) {
TargetWrapperCuda::Free(static_cast<void*>(_dev_map_vec));
}
}
std::vector<int>& get_length_index() { return _length_index; }
std::vector<int>& get_emit_offset_vec() { return _emit_offset_vec; }
std::vector<int>& get_map_vec() { return _map_vec; }
int* get_dev_map_vec() { return _dev_map_vec; }
int get_emit_length() { return _emit_length; }
template <typename Dtype>
void seq_2_sorted_seq(const Dtype* input,
Dtype* output,
int word_size,
cudaStream_t stream);
template <typename Dtype>
void sorted_seq_2_seq(const Dtype* input,
Dtype* output,
int hidden_size,
cudaStream_t stream);
bool get_sorted_map(const std::vector<int>& offset_vec,
cudaStream_t stream_id);
private:
std::vector<int> _length_index;
std::vector<int> _emit_offset_vec;
std::vector<int> _map_vec;
int _emit_length;
bool _is_reverse;
bool _is_bi;
int* _dev_map_vec;
int _dev_map_vec_length;
};
class SearchGrnnCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW)> {
public:
......@@ -34,10 +82,26 @@ class SearchGrnnCompute
virtual ~SearchGrnnCompute() = default;
private:
std::shared_ptr<Tensor> idx_sorted_by_width_cpu;
// Weights preprocess:
// wi need to be transpose, the axes should be (2, 0, 1)
// wh0 should transpose, {wh1 wh2} need be transpose, the axes should be {2,
// 0, 1}
void WeightsPreprocess();
private:
std::unique_ptr<lite::cuda::math::Gemm<float, float>> gemm_impl_;
void PrepareLayout(const Tensor* input);
void CopyBack(float* from, float* to, int step);
lite::Tensor _temp_tensor_in;
lite::Tensor _temp_tensor_out;
lite::Tensor _temp_wx;
lite::Tensor _temp_wh;
lite::Tensor _temp_zero;
lite::Tensor _temp_weights_h2h;
lite::Tensor _wi;
lite::Tensor _wh;
SeqSortedseqTranseUtil _seq_util;
};
} // namespace cuda
......
......@@ -25,6 +25,7 @@ namespace cuda {
void TransposeCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
const lite::Tensor* X = param.x;
lite::Tensor* Out = param.output;
......@@ -39,8 +40,7 @@ void TransposeCompute::Run() {
// NCHW -> NHWC
if (axes.size() == 4 && axes[0] == 0 && axes[1] == 2 && axes[2] == 3 &&
axes[3] == 1) {
lite::cuda::math::NCHW2NHWC(
dims[0], dims[1], dims[2] * dims[3], in, out, &ctx);
trans.NCHW2NHWC(dims[0], dims[1], dims[2] * dims[3], in, out, &stream);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
return;
......@@ -49,14 +49,13 @@ void TransposeCompute::Run() {
// NHWC -> NCHW
if (axes.size() == 4 && axes[0] == 0 && axes[1] == 3 && axes[2] == 1 &&
axes[3] == 2) {
lite::cuda::math::NHWC2NCHW(
dims[0], dims[3], dims[1] * dims[2], in, out, &ctx);
trans.NHWC2NCHW(dims[0], dims[3], dims[1] * dims[2], in, out, &stream);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
return;
}
lite::cuda::math::Transpose(dims, axes, in, out, &ctx);
trans.transpose(out, in, dims, axes, &stream);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
......
......@@ -29,7 +29,7 @@ class TransposeCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
virtual ~TransposeCompute() = default;
private:
lite::Tensor axes_, dims_;
lite::cuda::math::Transpose<float> trans;
};
} // namespace cuda
......
......@@ -238,7 +238,7 @@ TEST(transpose, normal) {
lite::Tensor x, x_cpu, x_ref;
lite::Tensor out, out_cpu, out_ref;
int C = 6, H = 7, W = 8;
int C = 3, H = 128, W = 128;
std::vector<int> axes({2, 0, 1});
x.Resize({C, H, W});
out.Resize({W, C, H});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册