未验证 提交 611c9b37 编写于 作者: W Wilber 提交者: GitHub

[CUDA] [Kernel] Add gru cuda fp32 kernel. (#3953)

上级 00b0344b
......@@ -11,10 +11,13 @@ nv_library(cuda_transpose SRCS transpose.cu DEPS ${cuda_static_deps})
nv_library(cudnn_conv SRCS cudnn_conv.cc DEPS cuda_activation cuda_scale cuda_type_trans ${cuda_static_deps})
nv_library(cuda_elementwise SRCS elementwise.cu DEPS ${cuda_static_deps})
nv_library(cudnn_pool SRCS cudnn_pool.cc DEPS ${cuda_static_deps})
nv_library(cuda_gru_forward SRCS gru_forward.cu DEPS cuda_activation ${cuda_static_deps})
nv_library(cuda_sequence2batch SRCS sequence2batch.cu DEPS ${cuda_static_deps})
nv_library(cuda_gemm SRCS gemm.cc DEPS ${cuda_static_deps})
nv_library(cuda_batched_gemm SRCS batched_gemm.cc DEPS ${cuda_static_deps})
nv_library(cuda_strided_gemm SRCS strided_gemm.cc DEPS ${cuda_static_deps})
nv_library(cuda_sequence_padding SRCS sequence_padding.cu DEPS ${cuda_static_deps})
nv_library(cuda_bias SRCS bias.cu DEPS ${cuda_static_deps})
set (
math_cuda
......@@ -25,10 +28,13 @@ set (
cuda_transpose
cuda_elementwise
cudnn_pool
cuda_gru_forward
cuda_sequence2batch
cuda_gemm
cuda_batched_gemm
cuda_strided_gemm
cuda_sequence_padding
cuda_bias
)
set(math_cuda "${math_cuda}" CACHE GLOBAL "math cuda")
......@@ -21,6 +21,20 @@ namespace lite {
namespace cuda {
namespace math {
ActivationType GetActiveType(const std::string& act) {
if (act == "sigmoid") {
return kSigmoid;
} else if (act == "relu") {
return kReLU;
} else if (act == "tanh") {
return kTanh;
} else if (act == "identify") {
return kIdentity;
} else {
LOG(FATAL) << "not supported activation: " << act;
}
}
template <typename T>
__global__ void relu_kernel(const int num,
const float alpha,
......
......@@ -17,11 +17,22 @@
#include <cuda_runtime.h>
#include <string>
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
enum ActivationType {
kSigmoid,
kReLU,
kTanh,
kIdentity,
};
ActivationType GetActiveType(const std::string& act);
// fp32 and half
template <typename T>
void relu(int num, const T* din, T* dout, float alpha, cudaStream_t stream);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/cuda/math/bias.h"
#include <iostream>
#include "lite/backends/cuda/cuda_utils.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename T>
__global__ void RowwiseAddKernel(
const T* a, const T* b, T* c, int width, int num) {
CUDA_KERNEL_LOOP(i, num) {
int h = i / width;
int w = i - h * width;
c[i] = a[i] + b[w];
}
}
template <typename T>
void RowwiseAdd<T>::operator()(const T* input,
const T* bias,
T* output,
const int width,
const int count,
const cudaStream_t& stream) {
RowwiseAddKernel<T><<<CUDA_GET_BLOCKS(count), CUDA_NUM_THREADS, 0, stream>>>(
input, bias, output, width, count);
CUDA_POST_KERNEL_CHECK;
}
template struct RowwiseAdd<float>;
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include "lite/backends/cuda/cuda_utils.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename T>
struct RowwiseAdd {
void operator()(const T* input,
const T* bias,
T* output,
const int width,
const int count,
const cudaStream_t& stream);
};
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include "lite/backends/cuda/math/gru_forward.h"
#include "lite/core/device_info.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename T>
__global__ void GruForwardResetOutput(
T* gate_value,
T* reset_output_value,
T* prev_output_value,
int frame_size,
int batch_size,
lite::cuda::math::ActivationType active_gate,
bool is_batch) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
int batch_idx = 0;
if (is_batch) {
batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
if (batch_idx >= batch_size) return;
gate_value += batch_idx * 3 * frame_size;
reset_output_value += batch_idx * frame_size;
}
T prev_out = 0;
T reset_out_val;
T update_gate_value = gate_value[frame_idx + frame_size * 0];
T reset_gate_value = gate_value[frame_idx + frame_size * 1];
if (prev_output_value) {
if (is_batch) {
prev_output_value += batch_idx * frame_size;
}
prev_out = prev_output_value[frame_idx];
}
if (active_gate == lite::cuda::math::ActivationType::kSigmoid) {
update_gate_value = Sigmoid(update_gate_value);
reset_gate_value = Sigmoid(reset_gate_value);
} else if (active_gate == lite::cuda::math::ActivationType::kReLU) {
update_gate_value = ReLU(update_gate_value);
reset_gate_value = ReLU(reset_gate_value);
} else if (active_gate == lite::cuda::math::ActivationType::kTanh) {
update_gate_value = Tanh(update_gate_value);
reset_gate_value = Tanh(reset_gate_value);
}
reset_out_val = prev_out * reset_gate_value;
gate_value[frame_idx + frame_size * 0] = update_gate_value;
gate_value[frame_idx + frame_size * 1] = reset_gate_value;
reset_output_value[frame_idx] = reset_out_val;
}
template <typename T>
__global__ void GruForwardFinalOutput(
T* gate_value,
T* prev_output_value,
T* output_value,
int frame_size,
int batch_size,
lite::cuda::math::ActivationType active_node,
bool origin_mode,
bool is_batch) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
int batch_idx = 0;
if (is_batch) {
batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
if (batch_idx >= batch_size) {
return;
}
gate_value += batch_idx * 3 * frame_size;
output_value += batch_idx * frame_size;
}
T output;
T prev_out = 0;
T update_gate_value = gate_value[frame_idx + frame_size * 0];
T state_frame_value = gate_value[frame_idx + frame_size * 2];
if (prev_output_value) {
if (is_batch) prev_output_value += batch_idx * frame_size;
prev_out = prev_output_value[frame_idx];
}
if (active_node == lite::cuda::math::ActivationType::kSigmoid) {
state_frame_value = Sigmoid(state_frame_value);
} else if (active_node == lite::cuda::math::ActivationType::kReLU) {
state_frame_value = ReLU(state_frame_value);
} else if (active_node == lite::cuda::math::ActivationType::kTanh) {
state_frame_value = Tanh(state_frame_value);
}
if (origin_mode) {
output = update_gate_value * prev_out + state_frame_value -
update_gate_value * state_frame_value;
} else {
output = prev_out - update_gate_value * prev_out +
update_gate_value * state_frame_value;
}
gate_value[frame_idx + frame_size * 2] = state_frame_value;
output_value[frame_idx] = output;
}
template __global__ void GruForwardFinalOutput<float>(
float* gate_value,
float* prev_output_value,
float* output_value,
int frame_size,
int batch_size,
lite::cuda::math::ActivationType active_node,
bool origin_mode,
bool is_batch);
template __global__ void GruForwardResetOutput<float>(
float* gate_value,
float* reset_output_value,
float* prev_output_value,
int frame_size,
int batch_size,
lite::cuda::math::ActivationType active_gate,
bool is_batch);
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cudnn.h>
#include <string>
#include <vector>
#include "lite/api/paddle_place.h"
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/backends/cuda/math/activation.h"
#include "lite/core/context.h"
#include "lite/core/target_wrapper.h"
#include "lite/operators/op_params.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename Dtype>
inline __device__ Dtype Sigmoid(const Dtype a) {
return static_cast<Dtype>(1.0) / (static_cast<Dtype>(1.0) + expf(-a));
}
template <typename Dtype>
inline __device__ Dtype ReLU(const Dtype a) {
return a > static_cast<Dtype>(0.f) ? a : static_cast<Dtype>(0.f);
}
template <typename Dtype>
inline __device__ Dtype Tanh(const Dtype a) {
Dtype tmp = static_cast<Dtype>(-2.0) * a;
return (static_cast<Dtype>(2.0) / (static_cast<Dtype>(1.0) + expf(tmp))) -
static_cast<Dtype>(1.0);
}
template <typename T>
__global__ void GruForwardResetOutput(
T* gate_value,
T* reset_output_value,
T* prev_output_value,
int frame_size,
int batch_size,
lite::cuda::math::ActivationType active_gate,
bool is_batch);
template <typename T>
__global__ void GruForwardFinalOutput(
T* gate_value,
T* prev_output_value,
T* output_value,
int frame_size,
int batch_size,
lite::cuda::math::ActivationType active_node,
bool origin_mode,
bool is_batch);
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/backends/cuda/math/sequence2batch.h"
#include "lite/backends/cuda/math/utils.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename T>
__global__ void CopyMatrixRowsKernel(const T* src,
T* dst,
const uint64_t* index,
int height,
int width,
bool is_src_index) {
int idx = threadIdx.x;
int idy = threadIdx.y;
int row_id = blockDim.y * gridDim.x + idy;
if (row_id < height) {
int src_idx = is_src_index ? index[row_id] : row_id;
int dst_idx = is_src_index ? row_id : index[row_id];
const T* src_data = src + src_idx * width;
T* dst_data = dst + dst_idx * width;
for (int i = idx; i < width; i += blockDim.x) {
dst_data[i] = src_data[i];
}
}
}
template <typename T>
void CopyMatrixRowsFunctor<T>::operator()(
const lite::Tensor& src,
lite::Tensor* dst,
const std::vector<uint64_t>& index_lod,
bool is_src_index,
const cudaStream_t& stream) {
auto src_dims = src.dims();
auto dst_dims = dst->dims();
CHECK_EQ(src_dims.size(), 2) << "The src must be matrix with rank 2.";
CHECK_EQ(dst_dims.size(), 2) << "The dst must be matrix with rank 2.";
CHECK_EQ(src_dims[1], dst_dims[1])
<< "The width of src and dst must be same.";
int height = dst_dims[0];
int width = dst_dims[1];
const auto* src_data = src.data<T>();
auto* dst_data = dst->template mutable_data<T>(TARGET(kCUDA));
index_tensor_.Resize({static_cast<int64_t>(index_lod.size())});
auto* index_tensor_data = index_tensor_.mutable_data<uint64_t>(TARGET(kCUDA));
TargetWrapperCuda::MemcpyAsync(index_tensor_data,
index_lod.data(),
sizeof(uint64_t) * index_lod.size(),
IoDirection::HtoD,
stream);
dim3 threads(128, 8);
dim3 grids((height + threads.y - 1) / threads.y);
CopyMatrixRowsKernel<T><<<grids, threads, 0, stream>>>(
src_data, dst_data, index_tensor_data, height, width, true);
CUDA_POST_KERNEL_CHECK;
}
template class CopyMatrixRowsFunctor<float>;
template class LoDTensor2BatchFunctor<float>;
template class Batch2LoDTensorFunctor<float>;
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include <algorithm>
#include <string>
#include <vector>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/context.h"
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename T>
class CopyMatrixRowsFunctor {
public:
void operator()(const lite::Tensor& src,
lite::Tensor* dst,
const std::vector<uint64_t>& index_lod,
bool is_src_index,
const cudaStream_t& stream);
private:
lite::Tensor index_tensor_;
};
template <typename T>
class LoDTensor2BatchFunctor {
struct SeqInfo {
SeqInfo(size_t start, size_t length, size_t seq_idx)
: start_(start), length_(length), seq_idx_(seq_idx) {}
size_t start_;
size_t length_;
size_t seq_idx_;
};
public:
void operator()(const lite::Tensor& lod_tensor,
lite::Tensor* batch_tensor,
bool is_reverse,
const cudaStream_t& stream) const {
auto lods = lod_tensor.lod();
CHECK_EQ(lods.size(), 1UL) << "Only support one level sequence now.";
const auto& lod = lods[0];
std::vector<SeqInfo> seq_info;
for (int seq_id = 0; seq_id < static_cast<int>(lod.size()) - 1; ++seq_id) {
size_t length = lod[seq_id + 1] - lod[seq_id];
seq_info.emplace_back(lod[seq_id], length, seq_id);
}
std::sort(seq_info.begin(), seq_info.end(), [](SeqInfo a, SeqInfo b) {
return a.length_ > b.length_;
});
LoD batch_lods;
batch_lods.emplace_back(std::vector<uint64_t>{0});
batch_lods.emplace_back(std::vector<uint64_t>{0});
batch_lods.emplace_back(std::vector<uint64_t>{0});
size_t max_seqlen = seq_info[0].length_;
batch_lods[0].resize(max_seqlen + 1);
batch_lods[1].resize(static_cast<size_t>(lod_tensor.dims()[0]));
batch_lods[2].resize(seq_info.size());
auto* batch_starts = batch_lods[0].data();
auto* seq2batch_idx = batch_lods[1].data();
batch_starts[0] = 0;
for (size_t n = 0; n < max_seqlen; ++n) {
size_t batch_id = batch_starts[n];
for (size_t i = 0; i < seq_info.size(); ++i) {
size_t seq_len = seq_info[i].length_;
size_t start = seq_info[i].start_;
if (n < seq_len) {
seq2batch_idx[batch_id] =
is_reverse ? start + seq_len - 1 - n : start + n;
++batch_id;
} else {
break;
}
}
batch_starts[n + 1] = batch_id;
}
auto* seq_order = batch_lods[2].data();
for (size_t i = 0; i < seq_info.size(); ++i) {
seq_order[i] = seq_info[i].seq_idx_;
}
batch_tensor->set_lod(batch_lods);
lite::cuda::math::CopyMatrixRowsFunctor<T> to_batch;
to_batch(lod_tensor, batch_tensor, batch_lods[1], true, stream);
CUDA_POST_KERNEL_CHECK;
}
};
template <typename T>
class Batch2LoDTensorFunctor {
public:
void operator()(const lite::Tensor& batch_tensor,
lite::Tensor* lod_tensor,
const cudaStream_t& stream) {
auto in_lod = batch_tensor.lod();
CHECK_GT(in_lod.size(), 2UL) << "The LoD of LoDTensor should include at "
"least 2-level sequence infomation.";
CHECK_EQ(in_lod[1].size(), static_cast<size_t>(lod_tensor->dims()[0]))
<< "The LoD information should be consistent with the dims.";
lite::cuda::math::CopyMatrixRowsFunctor<T> to_seq;
to_seq(batch_tensor, lod_tensor, in_lod[1], false, stream);
CUDA_POST_KERNEL_CHECK;
}
};
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
......@@ -15,6 +15,7 @@
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/target_wrapper.h"
namespace paddle {
......@@ -31,6 +32,16 @@ class TargetWrapper<TARGET(kCUDA)> {
static size_t num_devices();
static size_t maximum_stream() { return 0; }
static int GetComputeCapability() {
int dev_id = GetCurDevice();
int major, minor;
CUDA_CALL(cudaDeviceGetAttribute(
&major, cudaDevAttrComputeCapabilityMajor, dev_id));
CUDA_CALL(cudaDeviceGetAttribute(
&minor, cudaDevAttrComputeCapabilityMinor, dev_id));
return major * 10 + minor;
}
static size_t GetCurDevice() {
int dev_id;
cudaGetDevice(&dev_id);
......
......@@ -7,6 +7,7 @@ message(STATUS "compile with lite CUDA kernels")
# basic kernels
add_kernel(mul_compute_cuda CUDA basic SRCS mul_compute.cc DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(fc_compute_cuda CUDA basic SRCS fc_compute.cu DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(gru_compute_cuda CUDA basic SRCS gru_compute.cu DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(matmul_compute_cuda CUDA basic SRCS matmul_compute.cc DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(search_group_padding_compute_cuda CUDA basic SRCS search_group_padding_compute.cu DEPS ${lite_kernel_deps})
add_kernel(io_copy_compute_cuda CUDA basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps})
......@@ -69,6 +70,7 @@ nv_test(softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_comp
#nv_test(layout_cuda_test SRCS layout_compute_test.cc DEPS layout_compute_cuda)
nv_test(mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda)
nv_test(fc_compute_cuda_test SRCS fc_compute_test.cc DEPS fc_compute_cuda)
nv_test(gru_compute_cuda_test SRCS gru_compute_test.cc DEPS gru_compute_cuda)
nv_test(matmul_compute_cuda_test SRCS matmul_compute_test.cc DEPS matmul_compute_cuda)
nv_test(dropout_compute_cuda_test SRCS dropout_compute_test.cc DEPS dropout_compute_cuda )
nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc DEPS bilinear_interp_compute_cuda)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/backends/cuda/math/bias.h"
#include "lite/backends/cuda/math/gru_forward.h"
#include "lite/backends/cuda/math/sequence2batch.h"
#include "lite/backends/cuda/target_wrapper.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/gru_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T>
struct GRUMetaValue {
T* gate_weight;
T* state_weight;
T* gate_value;
T* reset_output_value;
T* output_value;
T* prev_out_value;
};
template <typename T>
struct GRUUnitFunctor {
static void compute(GRUMetaValue<T> value,
int frame_size,
int batch_size,
const lite::cuda::math::ActivationType& active_node,
const lite::cuda::math::ActivationType& active_gate,
bool origin_mode,
lite::cuda::math::Gemm<T, T>* blas,
CUDAContext* context) {
dim3 threads, grids;
if (batch_size == 1) {
int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
int frame_blocks = (frame_size + 1024 - 1) / 1024;
threads = dim3(frame_per_block, 1);
grids = dim3(frame_blocks, 1);
} else {
threads = dim3(32, 32);
grids = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);
}
if (value.prev_out_value) {
CHECK(blas->init(false,
false,
batch_size,
frame_size * 2,
frame_size,
frame_size,
frame_size * 2,
frame_size * 3,
context));
blas->run(1.0f,
1.0f,
value.prev_out_value,
value.gate_weight,
value.gate_value,
context);
}
CUDA_POST_KERNEL_CHECK;
lite::cuda::math::GruForwardResetOutput<
T><<<grids, threads, 0, context->exec_stream()>>>(
value.gate_value,
value.reset_output_value,
value.prev_out_value,
frame_size,
batch_size,
active_gate,
batch_size == 1);
CUDA_POST_KERNEL_CHECK;
if (value.prev_out_value) {
CHECK(blas->init(false,
false,
batch_size,
frame_size,
frame_size,
frame_size,
frame_size,
frame_size * 3,
context));
blas->run(1.0f,
1.0f,
value.reset_output_value,
value.state_weight,
value.gate_value + frame_size * 2,
context);
}
CUDA_POST_KERNEL_CHECK;
lite::cuda::math::GruForwardFinalOutput<
T><<<grids, threads, 0, context->exec_stream()>>>(value.gate_value,
value.prev_out_value,
value.output_value,
frame_size,
batch_size,
active_node,
origin_mode,
batch_size == 1);
CUDA_POST_KERNEL_CHECK;
}
};
template struct GRUUnitFunctor<float>;
template <typename T, PrecisionType PType>
void GRUCompute<T, PType>::PrepareForRun() {
gemm_impl_.reset(new lite::cuda::math::Gemm<T, T>);
}
template <typename T, PrecisionType PType>
void GRUCompute<T, PType>::Run() {
auto& context = this->ctx_->template As<CUDAContext>();
auto stream = context.exec_stream();
auto& param = this->template Param<param_t>();
auto* input = param.input;
lite::Tensor* h0{nullptr};
if (param.h0) {
h0 = const_cast<lite::Tensor*>(param.h0);
}
lite::Tensor* bias{nullptr};
if (param.bias) {
bias = const_cast<lite::Tensor*>(param.bias);
}
auto* weight = param.weight;
auto* weight_data = const_cast<T*>(weight->template data<T>());
auto* batch_gate = param.batch_gate;
auto* batch_reset_hidden_prev = param.batch_reset_hidden_prev;
auto* batch_hidden = param.batch_hidden;
auto* hidden = param.hidden;
auto* batch_reset_hidden_prev_data =
batch_reset_hidden_prev->template mutable_data<T>(TARGET(kCUDA));
hidden->template mutable_data<T>(TARGET(kCUDA));
auto* batch_gate_data = batch_gate->template mutable_data<T>(TARGET(kCUDA));
auto* batch_hidden_data =
batch_hidden->template mutable_data<T>(TARGET(kCUDA));
bool is_reverse = param.is_reverse;
auto active_node = lite::cuda::math::GetActiveType(param.activation);
auto active_gate = lite::cuda::math::GetActiveType(param.gate_activation);
bool origin_mode = param.origin_mode;
auto hidden_dims = hidden->dims();
int frame_size = hidden_dims[1];
lite::cuda::math::LoDTensor2BatchFunctor<T> batch_func;
batch_func(*input, batch_gate, is_reverse, stream);
if (bias) {
lite::cuda::math::RowwiseAdd<T> add_bias;
add_bias(batch_gate_data,
bias->template data<T>(),
batch_gate_data,
frame_size,
batch_gate->numel(),
stream);
}
GRUMetaValue<T> gru_value;
gru_value.gate_weight = weight_data;
gru_value.state_weight = weight_data + 2 * frame_size * frame_size;
if (h0) {
// Since the batch computing for GRU reorders the input sequences
// according to their length. The initialized cell state also needs
// to reorder.
ordered_h0_.Resize(h0->dims());
lite::cuda::math::CopyMatrixRowsFunctor<T> row_shuffle;
row_shuffle(*h0, &ordered_h0_, batch_gate->lod()[2], true, stream);
gru_value.prev_out_value = ordered_h0_.mutable_data<T>(TARGET(kCUDA));
} else {
gru_value.prev_out_value = nullptr;
}
auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1;
for (size_t n = 0; n < num_batch; ++n) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
int cur_batch_size = bend - bstart;
gru_value.output_value = batch_hidden_data + bstart * frame_size;
gru_value.gate_value = batch_gate_data + bstart * frame_size * 3;
gru_value.reset_output_value =
batch_reset_hidden_prev_data + bstart * frame_size;
GRUUnitFunctor<T>::compute(gru_value,
frame_size,
cur_batch_size,
active_node,
active_gate,
origin_mode,
gemm_impl_.get(),
&context);
gru_value.prev_out_value = gru_value.output_value;
}
lite::cuda::math::Batch2LoDTensorFunctor<T> to_seq;
batch_hidden->set_lod(batch_gate->lod());
to_seq(*batch_hidden, hidden, stream);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
using GRUFp32 =
paddle::lite::kernels::cuda::GRUCompute<float, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(gru, kCUDA, kFloat, kNCHW, GRUFp32, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("H0", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Weight", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("BatchGate", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("BatchResetHiddenPrev", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("BatchHidden", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Hidden", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "lite/backends/cuda/math/gemm.h"
#include "lite/core/kernel.h"
#include "lite/operators/op_params.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T, PrecisionType PType>
class GRUCompute : public KernelLite<TARGET(kCUDA), PType> {
public:
using param_t = operators::GRUParam;
void PrepareForRun() override;
void Run() override;
virtual ~GRUCompute() = default;
private:
std::unique_ptr<lite::cuda::math::Gemm<T, T>> gemm_impl_{nullptr};
lite::Tensor ordered_h0_;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/gru_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "lite/api/test_helper.h"
#include "lite/utils/float16.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class GRUTest : public ::testing::Test {
protected:
GRUTest()
: batch_(12),
frame_size_(128),
activation_("tanh"),
gate_activation_("sigmoid"),
is_reverse_(false),
origin_mode_(false),
x_shape_({batch_, frame_size_ * 3}),
w_shape_({frame_size_, frame_size_ * 3}),
out_shape_({batch_, frame_size_}),
lod_({{0, 4, 9, 12}}) {
x_ref_.Resize(lite::DDim(x_shape_));
x_gpu_.Resize(lite::DDim(x_shape_));
x_ref_.set_lod(lod_);
w_ref_.Resize(lite::DDim(w_shape_));
w_gpu_.Resize(lite::DDim(w_shape_));
auto x_ref_data = x_ref_.mutable_data<float>();
auto w_ref_data = w_ref_.mutable_data<float>();
for (int64_t i = 0; i < x_ref_.numel(); i++) {
x_ref_data[i] = static_cast<float>(i % 10 * 0.2);
}
for (int64_t i = 0; i < w_ref_.numel(); i++) {
w_ref_data[i] = static_cast<float>(i % 10 * 0.2);
}
out_ref_.Resize(lite::DDim(out_shape_));
out_cpu_.Resize(out_ref_.dims());
out_gpu_.Resize(out_ref_.dims());
batch_gate_gpu_.Resize(lite::DDim(x_shape_));
batch_hidden_gpu_.Resize(lite::DDim(out_shape_));
batch_reset_hidden_gpu_.Resize(lite::DDim(out_shape_));
RunBaseLine();
InitParamAndContext();
}
void InitParamAndContext() {
ctx_.reset(new KernelContext);
cudaStreamCreate(&stream_);
auto& context = ctx_->As<CUDAContext>();
context.SetExecStream(stream_);
param_.input = &x_gpu_;
param_.weight = &w_gpu_;
param_.gate_activation = gate_activation_;
param_.activation = activation_;
param_.is_reverse = is_reverse_;
param_.origin_mode = origin_mode_;
param_.hidden = &out_gpu_;
param_.batch_gate = &batch_gate_gpu_;
param_.batch_reset_hidden_prev = &batch_reset_hidden_gpu_;
param_.batch_hidden = &batch_hidden_gpu_;
}
void InitFloatInput() {
x_gpu_.Assign<float, lite::DDim, TARGET(kCUDA)>(x_ref_.data<float>(),
x_gpu_.dims());
x_gpu_.set_lod(x_ref_.lod());
w_gpu_.Assign<float, lite::DDim, TARGET(kCUDA)>(w_ref_.data<float>(),
w_gpu_.dims());
}
void RunBaseLine() {}
int batch_, frame_size_;
std::string activation_, gate_activation_;
bool is_reverse_, origin_mode_;
std::vector<int64_t> x_shape_, w_shape_, out_shape_;
LoD lod_;
lite::Tensor x_ref_, w_ref_, out_ref_;
lite::Tensor x_gpu_, w_gpu_;
lite::Tensor x_half_, w_half_;
lite::Tensor batch_gate_gpu_;
lite::Tensor batch_hidden_gpu_;
lite::Tensor batch_reset_hidden_gpu_;
lite::Tensor out_cpu_, out_gpu_;
operators::GRUParam param_;
std::unique_ptr<KernelContext> ctx_;
cudaStream_t stream_;
};
TEST_F(GRUTest, TestFP32) {
InitFloatInput();
GRUCompute<float, PRECISION(kFloat)> kernel;
kernel.SetParam(param_);
kernel.SetContext(std::move(ctx_));
for (int i = 0; i < FLAGS_warmup; ++i) {
kernel.Launch();
cudaDeviceSynchronize();
}
auto start = GetCurrentUS();
kernel.PrepareForRun();
for (int i = 0; i < FLAGS_repeats; ++i) {
kernel.Run();
}
cudaDeviceSynchronize();
auto duration = (GetCurrentUS() - start) / 1000.0;
LOG(INFO) << "fp32, warmup: " << FLAGS_warmup
<< ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average.";
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册