提交 b65a6dc9 编写于 作者: W Wilber 提交者: GitHub

optimize search_grnn test=develop (#2608)

optimize search_grnn
上级 1dbcd51d
...@@ -69,44 +69,16 @@ void BatchTranspose2DCUDAImpl(const int N, ...@@ -69,44 +69,16 @@ void BatchTranspose2DCUDAImpl(const int N,
const int W, const int W,
const T* input, const T* input,
T* out, T* out,
CUDAContext* ctx) { cudaStream_t* stream) {
const int dh = (H + kTileDim - 1) / kTileDim; const int dh = (H + kTileDim - 1) / kTileDim;
const int dw = (W + kTileDim - 1) / kTileDim; const int dw = (W + kTileDim - 1) / kTileDim;
BatchTranspose2DCUDAKernel< BatchTranspose2DCUDAKernel<
T><<<N * dh * dw, dim3(kTileDim, kBlockRows), 0, ctx->exec_stream()>>>( T><<<N * dh * dw, dim3(kTileDim, kBlockRows), 0, *stream>>>(
N, H, W, dh, dw, input, out); N, H, W, dh, dw, input, out);
cudaError_t error = cudaGetLastError(); cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
} }
#define TYPE_SPECIALIZED_CUDA_NCHW2NHWC(T) \
template <> \
void NCHW2NHWC<T>(const int N, \
const int C, \
const int HxW, \
const T* X, \
T* Y, \
CUDAContext* ctx) { \
BatchTranspose2DCUDAImpl<T>(N, C, HxW, X, Y, ctx); \
}
TYPE_SPECIALIZED_CUDA_NCHW2NHWC(float)
TYPE_SPECIALIZED_CUDA_NCHW2NHWC(int8_t)
#undef TYPE_SPECIALIZED_CUDA_NCHW2NHWC
#define TYPE_SPECIALIZED_CUDA_NHWC2NCHW(T) \
template <> \
void NHWC2NCHW<T>(const int N, \
const int C, \
const int HxW, \
const T* X, \
T* Y, \
CUDAContext* ctx) { \
BatchTranspose2DCUDAImpl<T>(N, HxW, C, X, Y, ctx); \
}
TYPE_SPECIALIZED_CUDA_NHWC2NCHW(float)
TYPE_SPECIALIZED_CUDA_NHWC2NCHW(int8_t)
#undef TYPE_SPECIALIZED_CUDA_NHWC2NCHW
template <typename T> template <typename T>
__global__ void TransposeCUDAKernel(const int size, __global__ void TransposeCUDAKernel(const int size,
const int ndim, const int ndim,
...@@ -136,7 +108,9 @@ void TransposeCUDAImpl(const std::vector<int64_t>& X_dims, ...@@ -136,7 +108,9 @@ void TransposeCUDAImpl(const std::vector<int64_t>& X_dims,
const std::vector<int>& axes, const std::vector<int>& axes,
const T* X, const T* X,
T* Y, T* Y,
CUDAContext* ctx) { lite::Tensor* Y_dims_,
lite::Tensor* strides_,
cudaStream_t* stream) {
CHECK_EQ(X_dims.size(), axes.size()) << "dimension size should be equal"; CHECK_EQ(X_dims.size(), axes.size()) << "dimension size should be equal";
int ndim = X_dims.size(); int ndim = X_dims.size();
std::vector<int> strides(ndim, 0); std::vector<int> strides(ndim, 0);
...@@ -156,37 +130,68 @@ void TransposeCUDAImpl(const std::vector<int64_t>& X_dims, ...@@ -156,37 +130,68 @@ void TransposeCUDAImpl(const std::vector<int64_t>& X_dims,
size *= X_dims[i]; size *= X_dims[i];
} }
lite::Tensor Y_dims_, strides_; Y_dims_->Resize(std::vector<int64_t>({ndim}));
Y_dims_.Resize(std::vector<int64_t>({ndim})); int* d_y_dims = Y_dims_->mutable_data<int>(TARGET(kCUDA));
int* d_y_dims = Y_dims_.mutable_data<int>(TARGET(kCUDA)); TargetWrapperCuda::MemcpyAsync(d_y_dims,
CopySync<TARGET(kCUDA)>( Y_dims.data(),
d_y_dims, Y_dims.data(), sizeof(int) * Y_dims.size(), IoDirection::HtoD); sizeof(int) * Y_dims.size(),
IoDirection::HtoD,
*stream);
strides_.Resize(std::vector<int64_t>({ndim})); strides_->Resize(std::vector<int64_t>({ndim}));
int* d_strides = strides_.mutable_data<int>(TARGET(kCUDA)); int* d_strides = strides_->mutable_data<int>(TARGET(kCUDA));
CopySync<TARGET(kCUDA)>(d_strides, TargetWrapperCuda::MemcpyAsync(d_strides,
strides.data(), strides.data(),
sizeof(int) * strides.size(), sizeof(int) * strides.size(),
IoDirection::HtoD); IoDirection::HtoD,
*stream);
const int M = (size + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; const int M = (size + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
TransposeCUDAKernel<<<M, CUDA_NUM_THREADS, 0, ctx->exec_stream()>>>( TransposeCUDAKernel<<<M, CUDA_NUM_THREADS, 0, *stream>>>(
size, ndim, d_strides, d_y_dims, X, Y); size, ndim, d_strides, d_y_dims, X, Y);
auto e = cudaGetLastError(); auto e = cudaGetLastError();
CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e); CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e);
} }
#define TYPE_SPECIALIZED_CUDA_TRANSPOSE(T) \ template <typename T>
template <> \ void Transpose<T>::NCHW2NHWC(
void Transpose<T>(const std::vector<int64_t>& X_dims, \ int N, int C, int HxW, const T* X, T* Y, cudaStream_t* stream) {
const std::vector<int>& axes, \ BatchTranspose2DCUDAImpl<T>(N, C, HxW, X, Y, stream);
const T* X, \ }
T* Y, \
CUDAContext* ctx) { \ template <typename T>
TransposeCUDAImpl<T>(X_dims, axes, X, Y, ctx); \ void Transpose<T>::NHWC2NCHW(
} int N, int C, int HxW, const T* X, T* Y, cudaStream_t* stream) {
TYPE_SPECIALIZED_CUDA_TRANSPOSE(float) BatchTranspose2DCUDAImpl<T>(N, HxW, C, X, Y, stream);
#undef TYPE_SPECIALIZED_CUDA_TRANSPOSEF }
template <typename T>
void Transpose<T>::transpose(T* dst,
const T* src,
const std::vector<int64_t>& src_dims,
const std::vector<int>& axes,
cudaStream_t* stream) {
TransposeCUDAImpl<T>(src_dims, axes, src, dst, &Y_dims_, &strides_, stream);
}
// template <typename T>
// void Transpose<T>::transpose(T* dst,
// const T* src,
// const std::vector<int>& src_dims,
// const std::vector<int>& axes,
// cudaStream_t* stream) {
// std::vector<int64_t> _src_dims(src_dims.size(), 0);
// std::transform(
// src_dims.begin(),
// src_dims.end(),
// _src_dims.begin(),
// [](int data) -> int64_t { return static_cast<int64_t>(data); });
// TransposeCUDAImpl<T>(_src_dims, axes, src, dst, &Y_dims_, &strides_,
// stream);
//}
template class Transpose<int8_t>;
template class Transpose<float>;
} // namespace math } // namespace math
} // namespace cuda } // namespace cuda
......
...@@ -26,17 +26,27 @@ namespace cuda { ...@@ -26,17 +26,27 @@ namespace cuda {
namespace math { namespace math {
template <typename T> template <typename T>
void NCHW2NHWC(int N, int C, int HxW, const T* X, T* Y, CUDAContext* context); class Transpose {
public:
void NCHW2NHWC(int N, int C, int HxW, const T* X, T* Y, cudaStream_t* stream);
template <typename T> void NHWC2NCHW(int N, int C, int HxW, const T* X, T* Y, cudaStream_t* stream);
void NHWC2NCHW(int N, int C, int HxW, const T* X, T* Y, CUDAContext* context);
template <typename T> void transpose(T* dst,
void Transpose(const std::vector<int64_t>& X_dims, const T* src,
const std::vector<int>& axes, const std::vector<int64_t>& src_dims,
const T* X, const std::vector<int>& axes,
T* Y, cudaStream_t* stream);
CUDAContext* ctx);
// void transpose(T* dst,
// const T* src,
// const std::vector<int>& src_dims,
// const std::vector<int>& axes,
// cudaStream_t* stream);
private:
lite::Tensor Y_dims_, strides_; // for transpose.
};
} // namespace math } // namespace math
} // namespace cuda } // namespace cuda
......
...@@ -26,7 +26,7 @@ add_kernel(pool_compute_cuda CUDA basic SRCS pool_compute.cu DEPS ...@@ -26,7 +26,7 @@ add_kernel(pool_compute_cuda CUDA basic SRCS pool_compute.cu DEPS
${lite_kernel_deps} cudnn_pool) ${lite_kernel_deps} cudnn_pool)
add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.cu DEPS ${lite_kernel_deps}) add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.cu DEPS ${lite_kernel_deps})
add_kernel(search_seq_depadding_compute_cuda CUDA extra SRCS search_seq_depadding_compute.cu DEPS ${lite_kernel_deps}) add_kernel(search_seq_depadding_compute_cuda CUDA extra SRCS search_seq_depadding_compute.cu DEPS ${lite_kernel_deps})
add_kernel(search_grnn_compute_cuda CUDA extra SRCS search_grnn_compute.cu DEPS ${lite_kernel_deps} cuda_gemm) add_kernel(search_grnn_compute_cuda CUDA extra SRCS search_grnn_compute.cu DEPS ${lite_kernel_deps} cuda_gemm ${math_cuda})
add_kernel(sequence_reverse_compute_cuda CUDA basic SRCS sequence_reverse_compute.cu DEPS ${lite_kernel_deps}) add_kernel(sequence_reverse_compute_cuda CUDA basic SRCS sequence_reverse_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_concat_compute_cuda CUDA basic SRCS sequence_concat_compute.cu DEPS ${lite_kernel_deps}) add_kernel(sequence_concat_compute_cuda CUDA basic SRCS sequence_concat_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_arithmetic_compute_cuda CUDA basic SRCS sequence_arithmetic_compute.cu DEPS ${lite_kernel_deps}) add_kernel(sequence_arithmetic_compute_cuda CUDA basic SRCS sequence_arithmetic_compute.cu DEPS ${lite_kernel_deps})
......
...@@ -29,7 +29,7 @@ inline DDim trim_singular_dims(const DDim& dims) { ...@@ -29,7 +29,7 @@ inline DDim trim_singular_dims(const DDim& dims) {
} }
std::vector<int64_t> trim_dims; std::vector<int64_t> trim_dims;
trim_dims.resize(actual_dims_size); trim_dims.resize(actual_dims_size);
for (int i = 0; i < actual_dims_size; ++i) { for (size_t i = 0; i < actual_dims_size; ++i) {
trim_dims[i] = dims[i]; trim_dims[i] = dims[i];
} }
if (trim_dims.size() == 0) { if (trim_dims.size() == 0) {
...@@ -41,6 +41,7 @@ inline DDim trim_singular_dims(const DDim& dims) { ...@@ -41,6 +41,7 @@ inline DDim trim_singular_dims(const DDim& dims) {
#define NCHWTONHWC(type) \ #define NCHWTONHWC(type) \
auto& param = this->template Param<param_t>(); \ auto& param = this->template Param<param_t>(); \
auto& ctx = this->ctx_->template As<CUDAContext>(); \ auto& ctx = this->ctx_->template As<CUDAContext>(); \
auto stream = ctx.exec_stream(); \
auto input = param.x->template data<type>(); \ auto input = param.x->template data<type>(); \
auto input_dim = param.x->dims(); \ auto input_dim = param.x->dims(); \
DDim input_trim_dim = trim_singular_dims(input_dim); \ DDim input_trim_dim = trim_singular_dims(input_dim); \
...@@ -56,11 +57,12 @@ inline DDim trim_singular_dims(const DDim& dims) { ...@@ -56,11 +57,12 @@ inline DDim trim_singular_dims(const DDim& dims) {
int w = input_dim[3]; \ int w = input_dim[3]; \
param.y->Resize({n, h, w, c}); \ param.y->Resize({n, h, w, c}); \
auto output = param.y->template mutable_data<type>(TARGET(kCUDA)); \ auto output = param.y->template mutable_data<type>(TARGET(kCUDA)); \
lite::cuda::math::NCHW2NHWC<type>(n, c, h * w, input, output, &ctx); trans.NCHW2NHWC(n, c, h* w, input, output, &stream);
#define NHWCTONCHW(type) \ #define NHWCTONCHW(type) \
auto& param = this->template Param<param_t>(); \ auto& param = this->template Param<param_t>(); \
auto& ctx = this->ctx_->template As<CUDAContext>(); \ auto& ctx = this->ctx_->template As<CUDAContext>(); \
auto stream = ctx.exec_stream(); \
auto input = param.x->template data<type>(); \ auto input = param.x->template data<type>(); \
auto input_dim = param.x->dims(); \ auto input_dim = param.x->dims(); \
DDim input_trim_dim = trim_singular_dims(input_dim); \ DDim input_trim_dim = trim_singular_dims(input_dim); \
...@@ -76,7 +78,7 @@ inline DDim trim_singular_dims(const DDim& dims) { ...@@ -76,7 +78,7 @@ inline DDim trim_singular_dims(const DDim& dims) {
int c = input_dim[3]; \ int c = input_dim[3]; \
param.y->Resize({n, c, h, w}); \ param.y->Resize({n, c, h, w}); \
auto output = param.y->template mutable_data<type>(TARGET(kCUDA)); \ auto output = param.y->template mutable_data<type>(TARGET(kCUDA)); \
lite::cuda::math::NHWC2NCHW<type>(n, c, h * w, input, output, &ctx); trans.NHWC2NCHW(n, c, h* w, input, output, &stream);
void NCHWToNHWCCompute::Run() { NCHWTONHWC(float) } void NCHWToNHWCCompute::Run() { NCHWTONHWC(float) }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "lite/backends/cuda/math/transpose.h"
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
namespace paddle { namespace paddle {
...@@ -25,6 +26,9 @@ class NCHWToNHWCCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> { ...@@ -25,6 +26,9 @@ class NCHWToNHWCCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
using param_t = operators::LayoutParam; using param_t = operators::LayoutParam;
void Run() override; void Run() override;
virtual ~NCHWToNHWCCompute() = default; virtual ~NCHWToNHWCCompute() = default;
private:
lite::cuda::math::Transpose<float> trans;
}; };
class NCHWToNHWCComputeInt8 class NCHWToNHWCComputeInt8
...@@ -33,6 +37,9 @@ class NCHWToNHWCComputeInt8 ...@@ -33,6 +37,9 @@ class NCHWToNHWCComputeInt8
using param_t = operators::LayoutParam; using param_t = operators::LayoutParam;
void Run() override; void Run() override;
virtual ~NCHWToNHWCComputeInt8() = default; virtual ~NCHWToNHWCComputeInt8() = default;
private:
lite::cuda::math::Transpose<int8_t> trans;
}; };
class NHWCToNCHWCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> { class NHWCToNCHWCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
...@@ -40,6 +47,9 @@ class NHWCToNCHWCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> { ...@@ -40,6 +47,9 @@ class NHWCToNCHWCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
using param_t = operators::LayoutParam; using param_t = operators::LayoutParam;
void Run() override; void Run() override;
virtual ~NHWCToNCHWCompute() = default; virtual ~NHWCToNCHWCompute() = default;
private:
lite::cuda::math::Transpose<float> trans;
}; };
class NHWCToNCHWComputeInt8 class NHWCToNCHWComputeInt8
...@@ -48,6 +58,9 @@ class NHWCToNCHWComputeInt8 ...@@ -48,6 +58,9 @@ class NHWCToNCHWComputeInt8
using param_t = operators::LayoutParam; using param_t = operators::LayoutParam;
void Run() override; void Run() override;
virtual ~NHWCToNCHWComputeInt8() = default; virtual ~NHWCToNCHWComputeInt8() = default;
private:
lite::cuda::math::Transpose<int8_t> trans;
}; };
} // namespace cuda } // namespace cuda
......
...@@ -12,6 +12,7 @@ limitations under the License. */ ...@@ -12,6 +12,7 @@ limitations under the License. */
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "lite/backends/cuda/math/transpose.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/cuda/search_grnn_compute.h" #include "lite/kernels/cuda/search_grnn_compute.h"
...@@ -19,294 +20,469 @@ namespace paddle { ...@@ -19,294 +20,469 @@ namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
namespace cuda { namespace cuda {
using Tensor = lite::Tensor; using Tensor = lite::Tensor;
template <typename T> template <typename Dtype>
T sigmoid(T z) { __global__ void trans_map2out(
return 1 / (1 + std::exp(-z)); Dtype* output, const Dtype* input, const int* map, int count, int lastdim) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < count) {
int seq = tid / lastdim;
output[map[seq] * lastdim + tid % lastdim] = input[tid];
}
} }
template <typename T> template <typename Dtype>
__global__ void PreComputeKernel( __global__ void trans_map2in(
const int num, const T* w_x_e, const T* wz_x_e, T* tilde, T* z, T* hidden) { Dtype* output, const Dtype* input, const int* map, int count, int lastdim) {
int index = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (index < num) { if (tid < count) {
tilde[index] = std::tanh(w_x_e[index]); int seq = tid / lastdim;
z[index] = 1 / (1 + std::exp(-wz_x_e[index])); output[tid] = input[map[seq] * lastdim + tid % lastdim];
hidden[index] = (1. - z[index]) * tilde[index];
} }
} }
template <typename T> template <typename Dtype>
__global__ void PostComputeKernel(const int start, void trans_map2out_cfunc(const Dtype* input,
const int end, Dtype* output,
const int cap_h, int word_size,
const int w_tm1, int seq_sum,
const T* wr_x_e, cudaStream_t stream,
const T* ur_x_h, int* dev_map_vec) {
const T* wz_x_e, int count = seq_sum * word_size;
const T* uz_x_h, int block_dim = count;
const T* w_x_e, int grid_dim = 1;
const T* u_x_h,
T* r, if (count > 1024) {
T* z, block_dim = 256;
T* tilde, grid_dim = (count + block_dim - 1) / block_dim;
T* hidden) {
int j = start + blockIdx.x * blockDim.x + threadIdx.x;
if (j < end) {
r[j] = 1 / (1 + std::exp(-(wr_x_e[j] + ur_x_h[j])));
z[j] = 1 / (1 + std::exp(-(wz_x_e[j] + uz_x_h[j])));
tilde[j] = std::tanh(w_x_e[j] + r[j] * u_x_h[j]);
hidden[j] = z[j] * hidden[j - cap_h * w_tm1] + (1.0 - z[j]) * tilde[j];
} }
trans_map2out<<<grid_dim, block_dim, 0, stream>>>(
output, input, dev_map_vec, count, word_size);
} }
void SearchGrnnCompute::PrepareForRun() { template <typename Dtype>
gemm_impl_.reset(new lite::cuda::math::Gemm<float, float>); void trans_map2in_cfunc(const Dtype* input,
Dtype* output,
int hidden_size,
int seq_sum,
cudaStream_t stream,
int* dev_map_vec) {
int count = seq_sum * hidden_size;
int block_dim = count;
int grid_dim = 1;
if (count > 1024) {
block_dim = 256;
grid_dim = (count + block_dim - 1) / block_dim;
}
trans_map2in<<<grid_dim, block_dim, 0, stream>>>(
output, input, dev_map_vec, count, hidden_size);
} }
void SearchGrnnCompute::PrepareLayout(const Tensor* input_blob) { template <typename Dtype>
auto& param = this->Param<param_t>(); void SeqSortedseqTranseUtil::seq_2_sorted_seq(const Dtype* input,
auto& context = this->ctx_->template As<CUDAContext>(); Dtype* output,
auto cuda_stream = context.exec_stream(); int word_size,
cudaStream_t stream) {
int seq_sum = _map_vec.size();
trans_map2out_cfunc(input, output, word_size, seq_sum, stream, _dev_map_vec);
}
template <typename Dtype>
void SeqSortedseqTranseUtil::sorted_seq_2_seq(const Dtype* input,
Dtype* output,
int hidden_size,
cudaStream_t stream) {
int seq_sum = _map_vec.size();
trans_map2in_cfunc(input, output, hidden_size, seq_sum, stream, _dev_map_vec);
}
bool SeqSortedseqTranseUtil::get_sorted_map(const std::vector<int>& offset_vec,
cudaStream_t stream_id) {
int batch_size = offset_vec.size() - 1;
int word_sum = offset_vec[offset_vec.size() - 1];
std::vector<int> length_vec(batch_size);
_length_index.resize(batch_size);
int emit_length = 0;
if (batch_size == 1) {
emit_length = offset_vec[1] - offset_vec[0];
_emit_offset_vec.resize(emit_length + 1);
for (int i = 0; i <= emit_length; ++i) {
_emit_offset_vec[i] = i;
}
auto* _input = input_blob; return false;
int dim0 = _input->dims()[0];
int dim1 = 1;
if (_input->dims().size() > 1) {
dim1 = _input->dims()[1];
} }
int batch = _input->lod()[0].size() - 1;
auto& offset = _input->lod()[0]; int max_len = 0;
idx_sorted_by_width_cpu = std::make_shared<Tensor>(); for (int i = 0; i < offset_vec.size() - 1; ++i) {
idx_sorted_by_width_cpu->Resize({batch}); int len = offset_vec[i + 1] - offset_vec[i];
int* idx_sorted_by_width_cpu_data = max_len = max_len > len ? max_len : len;
idx_sorted_by_width_cpu->mutable_data<int>(); length_vec[i] = len;
_length_index[i] = i;
Tensor _width; }
_width.Resize({batch});
int* width_data = _width.mutable_data<int>(); emit_length = max_len;
// sort sequence by width (descending) and find the largest width in the
// batch if (max_len == 1) {
for (int i = 0; i < batch; i++) { _emit_offset_vec.resize(2);
width_data[i] = offset[i + 1] - offset[i]; _emit_offset_vec[0] = 0;
idx_sorted_by_width_cpu_data[i] = i; _emit_offset_vec[1] = emit_length * batch_size;
return false;
} }
std::sort(idx_sorted_by_width_cpu_data,
idx_sorted_by_width_cpu_data + batch, std::sort(_length_index.begin(),
[&_width](int a, int b) { _length_index.end(),
return _width.data<int>()[a] > _width.data<int>()[b]; [&length_vec](int i1, int i2) {
return length_vec[i1] > length_vec[i2];
}); });
int max_width = width_data[idx_sorted_by_width_cpu_data[0]];
_emit_offset_vec.resize(max_len + 1);
// start of reorganizing the input _map_vec.resize(word_sum);
std::vector<size_t> new_offset;
new_offset.resize(max_width + 1); if (word_sum > _dev_map_vec_length) {
new_offset[0] = 0; if (_dev_map_vec != nullptr) {
int j = batch - 1; TargetWrapperCuda::Free(static_cast<void*>(_dev_map_vec));
int last_width = 0; }
int sub_row = 0;
int sub_col = 0; _dev_map_vec =
static_cast<int*>(TargetWrapperCuda::Malloc(sizeof(int) * word_sum));
for (int i = 1; i <= max_width;) { _dev_map_vec_length = word_sum;
for (int k = j; k >= 0; --k) { }
if (width_data[idx_sorted_by_width_cpu_data[k]] > last_width) {
sub_row = width_data[idx_sorted_by_width_cpu_data[k]] - last_width; int target_word_id = 0;
sub_col = k + 1; std::vector<int> length_vec_cnt = length_vec;
for (int s = 0; s < sub_row; s++) { int last_batch_size = batch_size;
new_offset[i] = new_offset[i - 1] + sub_col; for (int word_id_in_seq = 0; word_id_in_seq < max_len; word_id_in_seq++) {
i++; _emit_offset_vec[word_id_in_seq] = target_word_id;
for (int batch_id = 0; batch_id < last_batch_size; batch_id++) {
int old_batch_id = _length_index[batch_id];
if (length_vec_cnt[old_batch_id] > 0) {
int inner_word_id_in_seq = word_id_in_seq;
if (_is_reverse) {
inner_word_id_in_seq = length_vec[old_batch_id] - 1 - word_id_in_seq;
} }
// move on
last_width = width_data[idx_sorted_by_width_cpu_data[k]]; int old_word_id = offset_vec[old_batch_id] + inner_word_id_in_seq;
j = k - 1; _map_vec[old_word_id] = target_word_id;
length_vec_cnt[old_batch_id]--;
target_word_id++;
} else {
last_batch_size--;
break; break;
} }
} }
} }
// copying to the reorganized buffer TargetWrapperCuda::MemcpyAsync(_dev_map_vec,
auto* _layout_input = new Tensor(); _map_vec.data(),
auto* _layout_input_gpu = param.layout_input; sizeof(int) * word_sum,
if (_input->dims().size() == 1) { IoDirection::HtoD,
// _layout_input.reshape_batch_sequence({dim0}, new_offset); stream_id);
LOG(FATAL) << "_input->dims().size() = 1, error."; _emit_offset_vec[max_len] = word_sum;
} else { _emit_length = emit_length;
// _layout_input.reshape_batch_sequence({dim0, dim1}, new_offset); return true;
LoD new_lod; }
new_lod.push_back(new_offset);
_layout_input->set_lod(new_lod);
_layout_input->Resize({dim0, dim1});
_layout_input_gpu->set_lod(new_lod);
_layout_input_gpu->Resize({dim0, dim1});
}
auto* new_emb = _layout_input->mutable_data<float>(); template <typename Dtype>
auto* input_cpu = new Tensor(); __global__ void transpose_2d(Dtype* output, const Dtype* input, int m, int n) {
input_cpu->Resize(_input->dims()); int tid = blockIdx.x * blockDim.x + threadIdx.x;
auto* input_cpu_data = input_cpu->mutable_data<float>(); if (tid < m * n) {
TargetW::MemcpyAsync(input_cpu_data, int i = tid / n;
_input->data<float>(), int j = tid % m;
_input->numel() * sizeof(float), output[tid] = input[j * n + i];
IoDirection::DtoH,
cuda_stream);
for (int i = 0; i < max_width; i++) {
int w = new_offset[i + 1] - new_offset[i];
auto* emb_start = new_emb + dim1 * new_offset[i];
for (int j = 0; j < w; ++j) {
memcpy(emb_start + dim1 * j,
input_cpu_data + dim1 * offset[idx_sorted_by_width_cpu_data[j]] +
dim1 * i,
dim1 * sizeof(float));
}
} }
}
void SearchGrnnCompute::WeightsPreprocess() {
auto& param = this->Param<param_t>();
auto& context = this->ctx_->template As<CUDAContext>();
auto stream = context.exec_stream();
auto* _layout_input_gpu_data = DDim idims = param.wi->dims();
_layout_input_gpu->mutable_data<float>(TARGET(kCUDA)); DDim hdims = param.wh->dims();
TargetW::MemcpyAsync(_layout_input_gpu_data, _wi.Resize({idims[2], idims[0], idims[1]});
new_emb, _wh.Resize({hdims[2], hdims[0], hdims[1]});
_layout_input->numel() * sizeof(float), lite::cuda::math::Transpose<float> trans;
IoDirection::HtoD, trans.transpose(_wi.mutable_data<float>(TARGET(kCUDA)),
cuda_stream); param.wi->data<float>(),
delete _layout_input; idims.Vectorize(),
delete input_cpu; {2, 0, 1},
&stream);
trans.transpose(_wh.mutable_data<float>(TARGET(kCUDA)) + hdims[1] * hdims[2],
param.wh->data<float>() + hdims[1] * hdims[2],
{hdims[0] - 1, hdims[1], hdims[2]},
{2, 0, 1},
&stream);
trans.transpose(_wh.mutable_data<float>(TARGET(kCUDA)),
param.wh->data<float>(),
{hdims[1], hdims[2]},
{1, 0},
&stream);
// int thread_num = 512;
// int block_num = (hdims[1] * hdims[2] + thread_num - 1) / thread_num;
// transpose_2d<<<block_num, thread_num, 0, stream>>>(
// _wh.mutable_data<float>(TARGET(kCUDA)),
// param.wh->data<float>(),
// hdims[1],
// hdims[2]);
} }
void SearchGrnnCompute::CopyBack(float* from, float* to, int step) { void SearchGrnnCompute::PrepareForRun() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
auto& context = this->ctx_->template As<CUDAContext>(); auto& context = this->ctx_->template As<CUDAContext>();
auto stream = context.exec_stream(); auto stream = context.exec_stream();
auto* _input = param.x; gemm_impl_.reset(new lite::cuda::math::Gemm<float, float>);
auto* _layout_input = param.layout_input; _seq_util = SeqSortedseqTranseUtil();
const auto& offset = _input->lod()[0]; WeightsPreprocess();
const auto& new_offset = _layout_input->lod()[0];
const auto* idx_sorted_by_width_cpu_data = int hidden_size = param.num_hidden;
idx_sorted_by_width_cpu->data<int>(); int word_size = param.num_input;
for (size_t i = 0; i < _layout_input->lod()[0].size() - 1; ++i) { int weights_h2h_size = hidden_size * hidden_size * 3;
int w = new_offset[i + 1] - new_offset[i]; int weights_i2h_size = hidden_size * word_size * 3;
for (int j = 0; j < w; j++) {
TargetW::MemcpyAsync( lite::Tensor temp_weights_h2h_ori;
to + step * (offset[idx_sorted_by_width_cpu_data[j]] + i), lite::Tensor temp_weights_h2h_swarp;
from + (new_offset[i] + j) * step, temp_weights_h2h_ori.Resize({weights_h2h_size});
step * sizeof(float), temp_weights_h2h_swarp.Resize({weights_h2h_size});
IoDirection::DtoD,
stream); TargetWrapperCuda::MemcpyAsync(temp_weights_h2h_ori.mutable_data<float>(),
_wh.data<float>(),
sizeof(float) * weights_h2h_size,
IoDirection::DtoH,
stream);
cudaStreamSynchronize(stream);
float* temp_tensor_ptr = temp_weights_h2h_swarp.mutable_data<float>();
memcpy(temp_tensor_ptr,
temp_weights_h2h_ori.data<float>(),
sizeof(float) * hidden_size * hidden_size);
float* rz_temp_tensor_ptr = temp_tensor_ptr + hidden_size * hidden_size;
const float* rz_weights_tensor_ptr =
temp_weights_h2h_ori.data<float>() + hidden_size * hidden_size;
for (int row = 0; row < hidden_size; row++) {
for (int block = 0; block < 2; block++) {
int block_offset = block * hidden_size;
for (int cow = 0; cow < hidden_size; cow++) {
rz_temp_tensor_ptr[block * hidden_size * hidden_size +
row * hidden_size + cow] =
rz_weights_tensor_ptr[row * (2 * hidden_size) + cow + block_offset];
}
}
}
float* orz_temp_tensor_ptr = temp_tensor_ptr;
float* orz_weights_tensor_ptr = temp_weights_h2h_ori.mutable_data<float>();
for (int row = 0; row < hidden_size; row++) {
for (int block = 0; block < 3; block++) {
int block_offset = block * hidden_size;
for (int cow = 0; cow < hidden_size; cow++) {
orz_weights_tensor_ptr[row * (3 * hidden_size) + cow + block_offset] =
orz_temp_tensor_ptr[block * hidden_size * hidden_size +
row * hidden_size + cow];
}
} }
} }
_temp_weights_h2h.Resize({weights_h2h_size});
TargetWrapperCuda::MemcpyAsync(
_temp_weights_h2h.mutable_data<float>(TARGET(kCUDA)),
temp_weights_h2h_ori.data<float>(),
sizeof(float) * weights_h2h_size,
IoDirection::HtoD,
stream);
cudaStreamSynchronize(stream);
}
template <typename Dtype>
static inline __device__ Dtype Sigmoid(const Dtype a) {
return static_cast<Dtype>(1.0) / (static_cast<Dtype>(1.0) + expf(-a));
}
template <typename Dtype>
static inline __device__ Dtype Tanh(const Dtype a) {
Dtype tmp = static_cast<Dtype>(-2.0) * a;
return (static_cast<Dtype>(2.0) / (static_cast<Dtype>(1.0) + expf(tmp))) -
static_cast<Dtype>(1.0);
}
template <typename Dtype>
__global__ void cal_cudnn_kernel(const Dtype* w_x_r,
const Dtype* w_x_z,
const Dtype* w_x_o,
const Dtype* w_h_r,
const Dtype* w_h_z,
const Dtype* w_h_o,
int hidden_size,
int batch_size,
Dtype* output,
const Dtype* hidden_pre) {
const int thread_id = blockIdx.x * blockDim.x + threadIdx.x;
const int batch_id = thread_id / hidden_size;
const int index = thread_id % hidden_size;
if (index < hidden_size && batch_id < batch_size) {
int w_base_index = batch_id * hidden_size * 3 + index;
int h_base_index = batch_id * hidden_size + index;
Dtype hidden_pre_value = hidden_pre[h_base_index];
Dtype r = Sigmoid(w_x_r[w_base_index] + w_h_r[w_base_index]);
Dtype z = Sigmoid(w_x_z[w_base_index] + w_h_z[w_base_index]);
Dtype _h = Tanh(w_x_o[w_base_index] + w_h_o[w_base_index] * r);
output[h_base_index] =
(static_cast<Dtype>(1.0) - z) * _h + z * hidden_pre_value;
}
} }
void SearchGrnnCompute::Run() { void SearchGrnnCompute::Run() {
CHECK(ctx_) << "running context should be set first";
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
auto& context = this->ctx_->template As<CUDAContext>(); auto& context = this->ctx_->template As<CUDAContext>();
auto stream = context.exec_stream(); auto stream = context.exec_stream();
auto* bottom = param.x; auto* x = param.x;
auto* wi = param.wi; LoD offset_vec_vec = x->lod();
auto* wh = param.wh; std::vector<int> offset(offset_vec_vec[offset_vec_vec.size() - 1].size());
auto* top = param.out; for (size_t i = 0; i < offset_vec_vec[offset_vec_vec.size() - 1].size();
auto* _buffer = param.tmp_buffer; ++i) {
int _cap_h = param.num_hidden; offset[i] = static_cast<int>(offset_vec_vec[offset_vec_vec.size() - 1][i]);
int _cap_e = param.num_input; }
const float* x_data = x->data<float>();
int _cap_l = bottom->dims()[0]; auto* dout = param.out;
int batch = bottom->lod()[0].size() - 1; std::vector<int64_t> out_dims_vec{x->dims()[0], param.num_hidden};
dout->Resize(out_dims_vec);
const auto& offset = bottom->lod()[0]; float* dout_data = dout->mutable_data<float>(TARGET(kCUDA));
LoD top_lod; auto* wi = &_wi;
top_lod.push_back(offset); auto* wh = &_wh;
top->set_lod(top_lod);
std::vector<int64_t> top_dims_vec{_cap_l, _cap_h}; const float* weights_i2h = wi->data<float>();
top->Resize(top_dims_vec); const float* weights_h2h = wh->data<float>();
auto* top_hidden = top->mutable_data<float>(TARGET(kCUDA));
int batch_size = offset.size() - 1;
const auto* dense_e2h = wi->data<float>(); int seq_sum = x->dims()[0];
const auto* dense_h2h = wh->data<float>(); bool is_batched = offset.size() > 2;
int hidden_size = param.num_hidden;
const auto* e2h = dense_e2h; int word_size = param.num_input;
const auto* e2hr = dense_e2h + 1 * _cap_e * _cap_h; int o_offset = 0;
const auto* e2hz = dense_e2h + 2 * _cap_e * _cap_h; int r_offset = 1;
const auto* h2h = dense_h2h; int z_offset = 2;
const auto* h2hr = dense_h2h + 1 * _cap_h * _cap_h;
const auto* h2hz = dense_h2h + 2 * _cap_h * _cap_h; is_batched = _seq_util.get_sorted_map(offset, stream);
std::vector<int> emit_offset_vec = _seq_util.get_emit_offset_vec();
PrepareLayout(bottom); int emit_length = emit_offset_vec.size() - 1;
auto* _layout_input = param.layout_input; if (is_batched) {
auto* new_emb = _layout_input->data<float>(); std::vector<int64_t> seq_shape{1, 1, seq_sum, word_size};
const auto& new_offset = _layout_input->lod()[0]; _temp_tensor_in.Resize(seq_shape);
int max_width = _layout_input->lod()[0].size() - 1; std::vector<int64_t> seq_out_shape{1, 1, seq_sum, hidden_size};
_temp_tensor_out.Resize(seq_out_shape);
// this buffer is used for book keeping info which will be used in bp _seq_util.seq_2_sorted_seq(
// buffer also needed in bp, so make it larger x_data,
_buffer->Resize({20, _cap_l, _cap_h}); _temp_tensor_in.mutable_data<float>(TARGET(kCUDA)),
auto* buffer_data = _buffer->mutable_data<float>(TARGET(kCUDA)); word_size,
auto* w_x_e = buffer_data + 0 * _cap_l * _cap_h; stream);
auto* wr_x_e = buffer_data + 1 * _cap_l * _cap_h; x_data = _temp_tensor_in.data<float>();
auto* wz_x_e = buffer_data + 2 * _cap_l * _cap_h; dout_data = _temp_tensor_out.mutable_data<float>(TARGET(kCUDA));
auto* u_x_h = buffer_data + 3 * _cap_l * _cap_h; }
auto* ur_x_h = buffer_data + 4 * _cap_l * _cap_h;
auto* uz_x_h = buffer_data + 5 * _cap_l * _cap_h; std::vector<int64_t> shape_wx({seq_sum, 1, 3, hidden_size});
auto* r = buffer_data + 6 * _cap_l * _cap_h; _temp_wx.Resize(shape_wx);
auto* z = buffer_data + 7 * _cap_l * _cap_h;
auto* tilde = buffer_data + 8 * _cap_l * _cap_h; std::vector<int64_t> shape_wh({1, batch_size, 3, hidden_size});
// the internal hidden _temp_wh.Resize(shape_wh);
auto* hidden = buffer_data + 19 * _cap_l * _cap_h;
gemm_impl_->init(false, false, seq_sum, 3 * hidden_size, word_size, &context);
gemm_impl_->init(false, true, _cap_l, _cap_h, _cap_e, &context); gemm_impl_->run(1.0f,
gemm_impl_->run(1.0f, 0.0f, new_emb, e2h, w_x_e, &context); 0.0f,
gemm_impl_->init(false, true, _cap_l, _cap_h, _cap_e, &context); x_data,
gemm_impl_->run(1.0f, 0.0f, new_emb, e2hr, wr_x_e, &context); weights_i2h,
gemm_impl_->init(false, true, _cap_l, _cap_h, _cap_e, &context); _temp_wx.mutable_data<float>(TARGET(kCUDA)),
gemm_impl_->run(1.0f, 0.0f, new_emb, e2hz, wz_x_e, &context); &context);
// precompute hidden0 std::vector<int64_t> shape_zero({batch_size * hidden_size});
int num = batch * _cap_h; _temp_zero.Resize(shape_zero);
int threads = 512;
int blocks = (num + threads - 1) / threads; TargetWrapperCuda::MemsetAsync(_temp_zero.mutable_data<float>(TARGET(kCUDA)),
PreComputeKernel<<<blocks, threads, 0, stream>>>( 0,
num, w_x_e, wz_x_e, tilde, z, hidden); sizeof(float) * batch_size * hidden_size,
stream);
// recurrence
for (int i = 1; i < max_width; i++) { const float* h = _temp_zero.data<float>();
int w_tm1 = new_offset[i] - new_offset[i - 1]; for (int word_id = 0; word_id < emit_length; word_id++) {
int w = new_offset[i + 1] - new_offset[i]; int real_word_id = word_id;
int last_word_id = word_id - 1;
// precompute hidden i-1 to hidden i int emit_word_id_start = emit_offset_vec[real_word_id];
auto* htm1 = hidden + new_offset[i - 1] * _cap_h; int emit_word_id_end = emit_offset_vec[real_word_id + 1];
int emit_word_length = emit_word_id_end - emit_word_id_start;
gemm_impl_->init(false, true, w, _cap_h, _cap_h, &context);
gemm_impl_->run( const float* hidden_in;
1.0f, 0.0f, htm1, h2h, u_x_h + new_offset[i] * _cap_h, &context); float* hidden_out = dout_data + emit_offset_vec[real_word_id] * hidden_size;
gemm_impl_->init(false, true, w, _cap_h, _cap_h, &context);
gemm_impl_->run( if (word_id == 0) {
1.0f, 0.0f, htm1, h2hr, ur_x_h + new_offset[i] * _cap_h, &context); hidden_in = h;
gemm_impl_->init(false, true, w, _cap_h, _cap_h, &context); } else {
gemm_impl_->run( hidden_in = dout_data + emit_offset_vec[last_word_id] * hidden_size;
1.0f, 0.0f, htm1, h2hz, uz_x_h + new_offset[i] * _cap_h, &context); }
// compute the gate and hidden float* w_x_r = _temp_wx.mutable_data<float>(TARGET(kCUDA)) +
int start = new_offset[i] * _cap_h; r_offset * hidden_size +
int end = (new_offset[i] + w) * _cap_h; emit_word_id_start * hidden_size * 3;
PostComputeKernel<<<blocks, threads, 0, stream>>>(start, float* w_x_z = _temp_wx.mutable_data<float>(TARGET(kCUDA)) +
end, z_offset * hidden_size +
_cap_h, emit_word_id_start * hidden_size * 3;
w_tm1, float* w_x_o = _temp_wx.mutable_data<float>(TARGET(kCUDA)) +
wr_x_e, o_offset * hidden_size +
ur_x_h, emit_word_id_start * hidden_size * 3;
wz_x_e,
uz_x_h, float* w_h_r =
w_x_e, _temp_wh.mutable_data<float>(TARGET(kCUDA)) + r_offset * hidden_size;
u_x_h, float* w_h_z =
r, _temp_wh.mutable_data<float>(TARGET(kCUDA)) + z_offset * hidden_size;
z, float* w_h_o =
tilde, _temp_wh.mutable_data<float>(TARGET(kCUDA)) + o_offset * hidden_size;
hidden); gemm_impl_->init(
false, false, emit_word_length, 3 * hidden_size, hidden_size, &context);
gemm_impl_->run(1.0f,
0.0f,
hidden_in,
_temp_weights_h2h.data<float>(),
_temp_wh.mutable_data<float>(TARGET(kCUDA)),
&context);
const float* w_o = weights_h2h;
const int block_dim = 512;
const int grid_dim =
(emit_word_length * hidden_size + block_dim - 1) / block_dim;
cal_cudnn_kernel<<<grid_dim, block_dim, 0, stream>>>(w_x_r,
w_x_z,
w_x_o,
w_h_r,
w_h_z,
w_h_o,
hidden_size,
emit_word_length,
hidden_out,
hidden_in);
}
if (is_batched) {
_seq_util.sorted_seq_2_seq(_temp_tensor_out.data<float>(),
dout->mutable_data<float>(TARGET(kCUDA)),
hidden_size,
stream);
} }
CopyBack(hidden, top_hidden, _cap_h); dout->set_lod(x->lod());
} }
} // namespace cuda } // namespace cuda
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <vector>
#include "lite/backends/cuda/blas.h" #include "lite/backends/cuda/blas.h"
#include "lite/backends/cuda/math/gemm.h" #include "lite/backends/cuda/math/gemm.h"
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
...@@ -23,6 +24,53 @@ namespace lite { ...@@ -23,6 +24,53 @@ namespace lite {
namespace kernels { namespace kernels {
namespace cuda { namespace cuda {
class SeqSortedseqTranseUtil {
public:
explicit SeqSortedseqTranseUtil(bool is_reverse = false, bool is_bi = false)
: _is_reverse(is_reverse),
_is_bi(is_bi),
_dev_map_vec(nullptr),
_dev_map_vec_length(0) {}
~SeqSortedseqTranseUtil() {
if (_dev_map_vec != nullptr) {
TargetWrapperCuda::Free(static_cast<void*>(_dev_map_vec));
}
}
std::vector<int>& get_length_index() { return _length_index; }
std::vector<int>& get_emit_offset_vec() { return _emit_offset_vec; }
std::vector<int>& get_map_vec() { return _map_vec; }
int* get_dev_map_vec() { return _dev_map_vec; }
int get_emit_length() { return _emit_length; }
template <typename Dtype>
void seq_2_sorted_seq(const Dtype* input,
Dtype* output,
int word_size,
cudaStream_t stream);
template <typename Dtype>
void sorted_seq_2_seq(const Dtype* input,
Dtype* output,
int hidden_size,
cudaStream_t stream);
bool get_sorted_map(const std::vector<int>& offset_vec,
cudaStream_t stream_id);
private:
std::vector<int> _length_index;
std::vector<int> _emit_offset_vec;
std::vector<int> _map_vec;
int _emit_length;
bool _is_reverse;
bool _is_bi;
int* _dev_map_vec;
int _dev_map_vec_length;
};
class SearchGrnnCompute class SearchGrnnCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW)> { : public KernelLite<TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW)> {
public: public:
...@@ -34,10 +82,26 @@ class SearchGrnnCompute ...@@ -34,10 +82,26 @@ class SearchGrnnCompute
virtual ~SearchGrnnCompute() = default; virtual ~SearchGrnnCompute() = default;
private: private:
std::shared_ptr<Tensor> idx_sorted_by_width_cpu; // Weights preprocess:
// wi need to be transpose, the axes should be (2, 0, 1)
// wh0 should transpose, {wh1 wh2} need be transpose, the axes should be {2,
// 0, 1}
void WeightsPreprocess();
private:
std::unique_ptr<lite::cuda::math::Gemm<float, float>> gemm_impl_; std::unique_ptr<lite::cuda::math::Gemm<float, float>> gemm_impl_;
void PrepareLayout(const Tensor* input);
void CopyBack(float* from, float* to, int step); lite::Tensor _temp_tensor_in;
lite::Tensor _temp_tensor_out;
lite::Tensor _temp_wx;
lite::Tensor _temp_wh;
lite::Tensor _temp_zero;
lite::Tensor _temp_weights_h2h;
lite::Tensor _wi;
lite::Tensor _wh;
SeqSortedseqTranseUtil _seq_util;
}; };
} // namespace cuda } // namespace cuda
......
...@@ -25,6 +25,7 @@ namespace cuda { ...@@ -25,6 +25,7 @@ namespace cuda {
void TransposeCompute::Run() { void TransposeCompute::Run() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>(); auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
const lite::Tensor* X = param.x; const lite::Tensor* X = param.x;
lite::Tensor* Out = param.output; lite::Tensor* Out = param.output;
...@@ -39,8 +40,7 @@ void TransposeCompute::Run() { ...@@ -39,8 +40,7 @@ void TransposeCompute::Run() {
// NCHW -> NHWC // NCHW -> NHWC
if (axes.size() == 4 && axes[0] == 0 && axes[1] == 2 && axes[2] == 3 && if (axes.size() == 4 && axes[0] == 0 && axes[1] == 2 && axes[2] == 3 &&
axes[3] == 1) { axes[3] == 1) {
lite::cuda::math::NCHW2NHWC( trans.NCHW2NHWC(dims[0], dims[1], dims[2] * dims[3], in, out, &stream);
dims[0], dims[1], dims[2] * dims[3], in, out, &ctx);
cudaError_t error = cudaGetLastError(); cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
return; return;
...@@ -49,14 +49,13 @@ void TransposeCompute::Run() { ...@@ -49,14 +49,13 @@ void TransposeCompute::Run() {
// NHWC -> NCHW // NHWC -> NCHW
if (axes.size() == 4 && axes[0] == 0 && axes[1] == 3 && axes[2] == 1 && if (axes.size() == 4 && axes[0] == 0 && axes[1] == 3 && axes[2] == 1 &&
axes[3] == 2) { axes[3] == 2) {
lite::cuda::math::NHWC2NCHW( trans.NHWC2NCHW(dims[0], dims[3], dims[1] * dims[2], in, out, &stream);
dims[0], dims[3], dims[1] * dims[2], in, out, &ctx);
cudaError_t error = cudaGetLastError(); cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
return; return;
} }
lite::cuda::math::Transpose(dims, axes, in, out, &ctx); trans.transpose(out, in, dims, axes, &stream);
cudaError_t error = cudaGetLastError(); cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
} }
......
...@@ -29,7 +29,7 @@ class TransposeCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> { ...@@ -29,7 +29,7 @@ class TransposeCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
virtual ~TransposeCompute() = default; virtual ~TransposeCompute() = default;
private: private:
lite::Tensor axes_, dims_; lite::cuda::math::Transpose<float> trans;
}; };
} // namespace cuda } // namespace cuda
......
...@@ -238,7 +238,7 @@ TEST(transpose, normal) { ...@@ -238,7 +238,7 @@ TEST(transpose, normal) {
lite::Tensor x, x_cpu, x_ref; lite::Tensor x, x_cpu, x_ref;
lite::Tensor out, out_cpu, out_ref; lite::Tensor out, out_cpu, out_ref;
int C = 6, H = 7, W = 8; int C = 3, H = 128, W = 128;
std::vector<int> axes({2, 0, 1}); std::vector<int> axes({2, 0, 1});
x.Resize({C, H, W}); x.Resize({C, H, W});
out.Resize({W, C, H}); out.Resize({W, C, H});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册