未验证 提交 796e2a57 编写于 作者: W Wilber 提交者: GitHub

[CUDA] [Kernels] Add gru fp16 cuda kernel. (#3956)

上级 14397ca0
......@@ -31,6 +31,17 @@ __global__ void RowwiseAddKernel(
c[i] = a[i] + b[w];
}
}
template <>
__global__ void RowwiseAddKernel(
const half* a, const half* b, half* c, int width, int num) {
CUDA_KERNEL_LOOP(i, num) {
int h = i / width;
int w = i - h * width;
c[i] = __hadd(a[i], b[w]);
}
}
template <typename T>
void RowwiseAdd<T>::operator()(const T* input,
const T* bias,
......@@ -44,6 +55,7 @@ void RowwiseAdd<T>::operator()(const T* input,
}
template struct RowwiseAdd<float>;
template struct RowwiseAdd<half>;
} // namespace math
} // namespace cuda
......
......@@ -22,6 +22,10 @@ namespace lite {
namespace cuda {
namespace math {
/*
* threads(frame_per_block, batch_per_block)
* grid(frame_blocks, batch_blocks)
*/
template <typename T>
__global__ void GruForwardResetOutput(
T* gate_value,
......@@ -33,6 +37,7 @@ __global__ void GruForwardResetOutput(
bool is_batch) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
int batch_idx = 0;
if (is_batch) {
batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
......@@ -44,12 +49,14 @@ __global__ void GruForwardResetOutput(
T reset_out_val;
T update_gate_value = gate_value[frame_idx + frame_size * 0];
T reset_gate_value = gate_value[frame_idx + frame_size * 1];
if (prev_output_value) {
if (is_batch) {
prev_output_value += batch_idx * frame_size;
}
prev_out = prev_output_value[frame_idx];
}
if (active_gate == lite::cuda::math::ActivationType::kSigmoid) {
update_gate_value = Sigmoid(update_gate_value);
reset_gate_value = Sigmoid(reset_gate_value);
......@@ -60,12 +67,71 @@ __global__ void GruForwardResetOutput(
update_gate_value = Tanh(update_gate_value);
reset_gate_value = Tanh(reset_gate_value);
}
reset_out_val = prev_out * reset_gate_value;
gate_value[frame_idx + frame_size * 0] = update_gate_value;
gate_value[frame_idx + frame_size * 1] = reset_gate_value;
reset_output_value[frame_idx] = reset_out_val;
}
template <>
__global__ void GruForwardResetOutput(
half* gate_value,
half* reset_output_value,
half* prev_output_value,
int frame_size,
int batch_size,
lite::cuda::math::ActivationType active_gate,
bool is_batch) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
int batch_idx = 0;
if (is_batch) {
batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
if (batch_idx >= batch_size) return;
gate_value += batch_idx * 3 * frame_size;
reset_output_value += batch_idx * frame_size;
}
half prev_out = 0;
half reset_out_val;
half update_gate_value = gate_value[frame_idx + frame_size * 0];
half reset_gate_value = gate_value[frame_idx + frame_size * 1];
if (prev_output_value) {
if (is_batch) {
prev_output_value += batch_idx * frame_size;
}
prev_out = prev_output_value[frame_idx];
}
if (active_gate == ActivationType::kSigmoid) {
update_gate_value = Sigmoid(update_gate_value);
reset_gate_value = Sigmoid(reset_gate_value);
} else if (active_gate == ActivationType::kReLU) {
update_gate_value = ReLU(update_gate_value);
reset_gate_value = ReLU(reset_gate_value);
} else if (active_gate == ActivationType::kTanh) {
update_gate_value = Tanh(update_gate_value);
reset_gate_value = Tanh(reset_gate_value);
}
#if __CUDA_ARCH__ >= 530
reset_out_val = __hmul(prev_out, reset_gate_value);
#else
reset_out_val =
__float2half(__half2float(prev_out) * __half2float(reset_gate_value));
#endif
gate_value[frame_idx + frame_size * 0] = update_gate_value;
gate_value[frame_idx + frame_size * 1] = reset_gate_value;
reset_output_value[frame_idx] = reset_out_val;
}
/*
* threads(frame_per_block, batch_per_block)
* grid(frame_blocks, batch_blocks)
*/
template <typename T>
__global__ void GruForwardFinalOutput(
T* gate_value,
......@@ -87,14 +153,17 @@ __global__ void GruForwardFinalOutput(
gate_value += batch_idx * 3 * frame_size;
output_value += batch_idx * frame_size;
}
T output;
T prev_out = 0;
T update_gate_value = gate_value[frame_idx + frame_size * 0];
T state_frame_value = gate_value[frame_idx + frame_size * 2];
if (prev_output_value) {
if (is_batch) prev_output_value += batch_idx * frame_size;
prev_out = prev_output_value[frame_idx];
}
if (active_node == lite::cuda::math::ActivationType::kSigmoid) {
state_frame_value = Sigmoid(state_frame_value);
} else if (active_node == lite::cuda::math::ActivationType::kReLU) {
......@@ -102,6 +171,7 @@ __global__ void GruForwardFinalOutput(
} else if (active_node == lite::cuda::math::ActivationType::kTanh) {
state_frame_value = Tanh(state_frame_value);
}
if (origin_mode) {
output = update_gate_value * prev_out + state_frame_value -
update_gate_value * state_frame_value;
......@@ -109,6 +179,76 @@ __global__ void GruForwardFinalOutput(
output = prev_out - update_gate_value * prev_out +
update_gate_value * state_frame_value;
}
gate_value[frame_idx + frame_size * 2] = state_frame_value;
output_value[frame_idx] = output;
}
template <>
__global__ void GruForwardFinalOutput(
half* gate_value,
half* prev_output_value,
half* output_value,
int frame_size,
int batch_size,
lite::cuda::math::ActivationType active_node,
bool origin_mode,
bool is_batch) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
int batch_idx = 0;
if (is_batch) {
batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
if (batch_idx >= batch_size) {
return;
}
gate_value += batch_idx * 3 * frame_size;
output_value += batch_idx * frame_size;
}
half output;
half prev_out = 0;
half update_gate_value = gate_value[frame_idx + frame_size * 0];
half state_frame_value = gate_value[frame_idx + frame_size * 2];
if (prev_output_value) {
if (is_batch) prev_output_value += batch_idx * frame_size;
prev_out = prev_output_value[frame_idx];
}
if (active_node == lite::cuda::math::ActivationType::kSigmoid) {
state_frame_value = Sigmoid(state_frame_value);
} else if (active_node == lite::cuda::math::ActivationType::kReLU) {
state_frame_value = ReLU(state_frame_value);
} else if (active_node == lite::cuda::math::ActivationType::kTanh) {
state_frame_value = Tanh(state_frame_value);
}
if (origin_mode) {
#if __CUDA_ARCH__ >= 530
output =
__hsub(__hadd(__hmul(update_gate_value, prev_out), state_frame_value),
__hmul(update_gate_value, state_frame_value));
#else
output = __float2half(
__half2float(update_gate_value) * __half2float(prev_out) +
__half2float(state_frame_value) -
__half2float(update_gate_value) * __half2float(state_frame_value));
#endif
} else {
#if __CUDA_ARCH__ >= 530
output = prev_out - update_gate_value * prev_out +
update_gate_value * state_frame_value;
output = __hadd(__hsub(prev_out, __hmul(update_gate_value, prev_out)),
__hmul(update_gate_value, state_frame_value));
#else
output = __float2half(
__half2float(prev_out) -
__half2float(update_gate_value) * __half2float(prev_out) +
__half2float(update_gate_value) * __half2float(state_frame_value));
#endif
}
gate_value[frame_idx + frame_size * 2] = state_frame_value;
output_value[frame_idx] = output;
}
......@@ -122,6 +262,7 @@ template __global__ void GruForwardFinalOutput<float>(
lite::cuda::math::ActivationType active_node,
bool origin_mode,
bool is_batch);
template __global__ void GruForwardResetOutput<float>(
float* gate_value,
float* reset_output_value,
......
......@@ -34,10 +34,32 @@ template <typename Dtype>
inline __device__ Dtype Sigmoid(const Dtype a) {
return static_cast<Dtype>(1.0) / (static_cast<Dtype>(1.0) + expf(-a));
}
template <>
inline __device__ half Sigmoid(const half a) {
#if __CUDA_ARCH__ >= 530
const half tmp = __float2half(1.0f);
return __hdiv(tmp, __hadd(tmp, hexp(__hmul(__float2half(-1.f), a))));
#else
return __float2half(1.0f / (expf(__half2float(a) * -1) + 1.0f));
#endif
}
template <typename Dtype>
inline __device__ Dtype ReLU(const Dtype a) {
return a > static_cast<Dtype>(0.f) ? a : static_cast<Dtype>(0.f);
}
template <>
inline __device__ half ReLU(const half a) {
const half tmp = __float2half(0.f);
#if __CUDA_ARCH__ >= 530
return __hgt(a, tmp) ? a : tmp;
#else
return __float2half(__half2float(a) > 0.f ? __half2float(a) : 0.f);
#endif
}
template <typename Dtype>
inline __device__ Dtype Tanh(const Dtype a) {
Dtype tmp = static_cast<Dtype>(-2.0) * a;
......@@ -45,6 +67,18 @@ inline __device__ Dtype Tanh(const Dtype a) {
static_cast<Dtype>(1.0);
}
template <>
inline __device__ half Tanh(const half a) {
#if __CUDA_ARCH__ >= 530
half tmp = __float2half(1.0f);
half numerator = __hmul(__float2half(-2.0f), a);
return __hsub(__hdiv(__float2half(2.0f), __hadd(tmp, hexp(numerator))), tmp);
#else
float tmp = -2.0f * __half2float(a);
return __float2half(2.0f / (1.0f + expf(tmp)) - 1.0f);
#endif
}
template <typename T>
__global__ void GruForwardResetOutput(
T* gate_value,
......@@ -54,6 +88,7 @@ __global__ void GruForwardResetOutput(
int batch_size,
lite::cuda::math::ActivationType active_gate,
bool is_batch);
template <typename T>
__global__ void GruForwardFinalOutput(
T* gate_value,
......@@ -65,6 +100,134 @@ __global__ void GruForwardFinalOutput(
bool origin_mode,
bool is_batch);
/*
* threads(tile_size, 1)
* grids(frame_blocks, 1)
*/
template <class T, int TiledSize>
__global__ void FastCollectiveGruGate(T* gate_value,
T* prev_output_value,
T* gate_weight,
T* reset_output,
int frame_size,
ActivationType active_node) {
T xt_0 = 0.0f;
T a0 = 0.0f;
T c0 = 0.0f;
T b0[TiledSize];
int col = blockIdx.x * blockDim.x + threadIdx.x;
int tiled_mask = ((1 << TiledSize) - 1);
// tiled matrix multiply using register shift, faster than sm.
if (prev_output_value) {
for (int k = 0; k < (((frame_size - 1) / TiledSize) + 1); ++k) {
a0 = 0;
if ((threadIdx.x + k * TiledSize) < frame_size) {
a0 = prev_output_value[threadIdx.x + (k * TiledSize)];
}
for (int i = 0; i < TiledSize; ++i) {
if (col < frame_size * 2 && (i + k * TiledSize) < frame_size) {
b0[i] = gate_weight[(i + k * TiledSize) * frame_size * 2 + col];
}
}
for (int i = 0; i < TiledSize; ++i) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
c0 = c0 + __shfl_sync(tiled_mask, a0, i, TiledSize) * b0[i];
#else
c0 = c0 + __shfl(a0, i, TiledSize) * b0[i];
#endif
}
}
}
__syncthreads();
if (col < frame_size * 2) {
xt_0 = gate_value[col];
c0 += xt_0;
if (active_node == ActivationType::kSigmoid) {
c0 = Sigmoid(c0);
} else if (active_node == ActivationType::kReLU) {
c0 = ReLU(c0);
} else if (active_node == ActivationType::kTanh) {
c0 = Tanh(c0);
}
gate_value[col] = c0;
if (frame_size <= col && col < frame_size * 2) {
T htp_0 = 0.0;
if (prev_output_value) {
htp_0 = prev_output_value[col - frame_size];
}
reset_output[col - frame_size] = c0 * htp_0;
} else if (col < frame_size) {
gate_value[col] = c0;
}
}
}
template <class T, int TiledSize>
__global__ void FastCollectiveGruOut(T* gate_weight,
T* prev_out_value,
T* output_value,
T* gate_value,
T* reset_value,
int frame_size,
ActivationType active_node,
bool origin_mode) {
int col = blockIdx.x * blockDim.x + threadIdx.x;
T a0 = 0.0f;
T b0[TiledSize];
T c0 = 0.0f;
int tiled_mask = ((1 << TiledSize) - 1);
if (prev_out_value) {
for (int k = 0; k < ((frame_size - 1) / TiledSize + 1); ++k) {
a0 = 0;
if ((threadIdx.x + k * TiledSize) < frame_size) {
a0 = reset_value[threadIdx.x + k * TiledSize];
}
for (int i = 0; i < TiledSize; ++i) {
if (col < frame_size && (i + k * TiledSize) < frame_size) {
b0[i] = gate_weight[(i + k * TiledSize) * frame_size + col];
}
}
for (int i = 0; i < TiledSize; ++i) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
c0 = c0 + __shfl_sync(tiled_mask, a0, i, TiledSize) * b0[i];
#else
c0 = c0 + __shfl(a0, i, TiledSize) * b0[i];
#endif
}
}
}
__syncthreads();
if (col < frame_size) {
T xt_0 = gate_value[col + 2 * frame_size];
T gta_0 = gate_value[col];
T htp_0 = 0;
if (prev_out_value) {
htp_0 = prev_out_value[col];
}
c0 += xt_0;
if (active_node == ActivationType::kSigmoid) {
c0 = Sigmoid(c0);
} else if (active_node == ActivationType::kReLU) {
c0 = ReLU(c0);
} else if (active_node == ActivationType::kTanh) {
c0 = Tanh(c0);
}
gate_value[col + 2 * frame_size] = c0;
if (origin_mode) {
output_value[col] = htp_0 * gta_0 + (1 - gta_0) * c0;
} else {
output_value[col] = c0 * gta_0 + (1 - gta_0) * htp_0;
}
}
}
} // namespace math
} // namespace cuda
} // namespace lite
......
......@@ -77,8 +77,13 @@ void CopyMatrixRowsFunctor<T>::operator()(
}
template class CopyMatrixRowsFunctor<float>;
template class CopyMatrixRowsFunctor<half>;
template class LoDTensor2BatchFunctor<float>;
template class LoDTensor2BatchFunctor<half>;
template class Batch2LoDTensorFunctor<float>;
template class Batch2LoDTensorFunctor<half>;
} // namespace math
} // namespace cuda
......
......@@ -32,6 +32,9 @@ namespace math {
template <typename T>
class CopyMatrixRowsFunctor {
public:
// If is_src_index is true, copy the indexed rows of input src to the output
// dst. If is_src_index is false, copy the input src to the indexed of output
// dst. The indexes rows are based on the input index.
void operator()(const lite::Tensor& src,
lite::Tensor* dst,
const std::vector<uint64_t>& index_lod,
......@@ -44,6 +47,11 @@ class CopyMatrixRowsFunctor {
template <typename T>
class LoDTensor2BatchFunctor {
// Calculate the length of each sequence and
// sort sequence index by the length.
// example: sequences = {s0, s1, s2}
// s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
// seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)}
struct SeqInfo {
SeqInfo(size_t start, size_t length, size_t seq_idx)
: start_(start), length_(length), seq_idx_(seq_idx) {}
......@@ -60,21 +68,49 @@ class LoDTensor2BatchFunctor {
auto lods = lod_tensor.lod();
CHECK_EQ(lods.size(), 1UL) << "Only support one level sequence now.";
const auto& lod = lods[0];
std::vector<SeqInfo> seq_info;
for (int seq_id = 0; seq_id < static_cast<int>(lod.size()) - 1; ++seq_id) {
size_t length = lod[seq_id + 1] - lod[seq_id];
seq_info.emplace_back(lod[seq_id], length, seq_id);
}
std::sort(seq_info.begin(), seq_info.end(), [](SeqInfo a, SeqInfo b) {
return a.length_ > b.length_;
});
// Calculate the start position of each batch.
// example: sequences = {s0, s1, s2}
// s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
// max_seqlen = 5,
// batchIndex = {b0, b1, b2, b3, b4}
// b0: 1 0 2, b1: 1 0 2, b2: 1 0 2, b3: 1 0, b4: 1
// batch_start_positions[6] = {0, 3, 6, 9, 11, 12}
// batch_start_positions[0] = 0
// batch_start_positions[1] = len(b0)
// batch_start_positions[2] = len(b0) + len(b1)
// ...
// seq2batch_idx[12] = {4, 0, 9,
// 5, 1, 10,
// 6, 2, 11,
// 7, 3,
// 8}
// seq_order = {1, 0, 2}, the sort order.
// where 1 is the second sequence,
// 0 is the first sequence,
// 2 is the third sequence.
LoD batch_lods;
batch_lods.emplace_back(std::vector<uint64_t>{0});
batch_lods.emplace_back(std::vector<uint64_t>{0});
batch_lods.emplace_back(std::vector<uint64_t>{0});
// batch_lods[0] is the start positions for batch LoDTensor
size_t max_seqlen = seq_info[0].length_;
batch_lods[0].resize(max_seqlen + 1);
// batch_lods[1] is the raw index in the input LoDTensor
batch_lods[1].resize(static_cast<size_t>(lod_tensor.dims()[0]));
// batch_lods[2] is the sort order for the input LoDTensor.
batch_lods[2].resize(seq_info.size());
auto* batch_starts = batch_lods[0].data();
......@@ -101,6 +137,7 @@ class LoDTensor2BatchFunctor {
}
batch_tensor->set_lod(batch_lods);
lite::cuda::math::CopyMatrixRowsFunctor<T> to_batch;
to_batch(lod_tensor, batch_tensor, batch_lods[1], true, stream);
CUDA_POST_KERNEL_CHECK;
......
......@@ -48,10 +48,69 @@ struct GRUUnitFunctor {
CUDAContext* context) {
dim3 threads, grids;
if (batch_size == 1) {
int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
int frame_blocks = (frame_size + 1024 - 1) / 1024;
threads = dim3(frame_per_block, 1);
grids = dim3(frame_blocks, 1);
if (lite::TargetWrapperCuda::GetComputeCapability() >= 70) {
if (frame_size < 16) {
constexpr int tiled_size = 8;
int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size;
threads = dim3(tiled_size, 1);
grids = dim3(frame_blocks, 1);
lite::cuda::math::FastCollectiveGruGate<
T,
tiled_size><<<grids, threads, 0, context->exec_stream()>>>(
value.gate_value,
value.prev_out_value,
value.gate_weight,
value.reset_output_value,
frame_size,
active_gate);
frame_blocks = (frame_size + tiled_size - 1) / tiled_size;
grids = dim3(frame_blocks, 1);
lite::cuda::math::FastCollectiveGruOut<
T,
tiled_size><<<grids, threads, 0, context->exec_stream()>>>(
value.state_weight,
value.prev_out_value,
value.output_value,
value.gate_value,
value.reset_output_value,
frame_size,
active_node,
origin_mode);
} else {
constexpr int tiled_size = 16;
int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size;
threads = dim3(tiled_size, 1);
grids = dim3(frame_blocks, 1);
lite::cuda::math::FastCollectiveGruGate<
T,
tiled_size><<<grids, threads, 0, context->exec_stream()>>>(
value.gate_value,
value.prev_out_value,
value.gate_weight,
value.reset_output_value,
frame_size,
active_gate);
frame_blocks = (frame_size + tiled_size - 1) / tiled_size;
grids = dim3(frame_blocks, 1);
lite::cuda::math::FastCollectiveGruOut<
T,
tiled_size><<<grids, threads, 0, context->exec_stream()>>>(
value.state_weight,
value.prev_out_value,
value.output_value,
value.gate_value,
value.reset_output_value,
frame_size,
active_node,
origin_mode);
}
return;
} else {
int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
int frame_blocks = (frame_size + 1024 - 1) / 1024;
threads = dim3(frame_per_block, 1);
grids = dim3(frame_blocks, 1);
}
} else {
threads = dim3(32, 32);
grids = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);
......@@ -121,6 +180,90 @@ struct GRUUnitFunctor {
template struct GRUUnitFunctor<float>;
template <>
struct GRUUnitFunctor<half> {
static void compute(GRUMetaValue<half> value,
int frame_size,
int batch_size,
const lite::cuda::math::ActivationType& active_node,
const lite::cuda::math::ActivationType& active_gate,
bool origin_mode,
lite::cuda::math::Gemm<half, half>* blas,
CUDAContext* context) {
dim3 threads, grids;
if (batch_size == 1) {
int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
int frame_blocks = (frame_size + 1024 - 1) / 1024;
threads = dim3(frame_per_block, 1);
grids = dim3(frame_blocks, 1);
} else {
threads = dim3(32, 32);
grids = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);
}
if (value.prev_out_value) {
CHECK(blas->init(false,
false,
batch_size,
frame_size * 2,
frame_size,
frame_size,
frame_size * 2,
frame_size * 3,
context));
blas->run(1.0f,
1.0f,
value.prev_out_value,
value.gate_weight,
value.gate_value,
context);
}
CUDA_POST_KERNEL_CHECK;
lite::cuda::math::GruForwardResetOutput<
half><<<grids, threads, 0, context->exec_stream()>>>(
value.gate_value,
value.reset_output_value,
value.prev_out_value,
frame_size,
batch_size,
active_gate,
batch_size == 1);
CUDA_POST_KERNEL_CHECK;
if (value.prev_out_value) {
CHECK(blas->init(false,
false,
batch_size,
frame_size,
frame_size,
frame_size,
frame_size,
frame_size * 3,
context));
blas->run(1.0f,
1.0f,
value.reset_output_value,
value.state_weight,
value.gate_value + frame_size * 2,
context);
}
CUDA_POST_KERNEL_CHECK;
lite::cuda::math::GruForwardFinalOutput<
half><<<grids, threads, 0, context->exec_stream()>>>(
value.gate_value,
value.prev_out_value,
value.output_value,
frame_size,
batch_size,
active_node,
origin_mode,
batch_size == 1);
CUDA_POST_KERNEL_CHECK;
}
};
template <typename T, PrecisionType PType>
void GRUCompute<T, PType>::PrepareForRun() {
gemm_impl_.reset(new lite::cuda::math::Gemm<T, T>);
......@@ -141,18 +284,17 @@ void GRUCompute<T, PType>::Run() {
if (param.bias) {
bias = const_cast<lite::Tensor*>(param.bias);
}
auto* weight = param.weight;
auto* weight_data = const_cast<T*>(weight->template data<T>());
auto* batch_gate = param.batch_gate;
auto* batch_reset_hidden_prev = param.batch_reset_hidden_prev;
auto* batch_hidden = param.batch_hidden;
auto* hidden = param.hidden;
auto* batch_reset_hidden_prev_data =
const lite::Tensor* weight = param.weight;
T* weight_data = const_cast<T*>(weight->template data<T>());
lite::Tensor* batch_gate = param.batch_gate;
lite::Tensor* batch_reset_hidden_prev = param.batch_reset_hidden_prev;
lite::Tensor* batch_hidden = param.batch_hidden;
lite::Tensor* hidden = param.hidden;
T* batch_reset_hidden_prev_data =
batch_reset_hidden_prev->template mutable_data<T>(TARGET(kCUDA));
hidden->template mutable_data<T>(TARGET(kCUDA));
auto* batch_gate_data = batch_gate->template mutable_data<T>(TARGET(kCUDA));
auto* batch_hidden_data =
batch_hidden->template mutable_data<T>(TARGET(kCUDA));
T* batch_gate_data = batch_gate->template mutable_data<T>(TARGET(kCUDA));
T* batch_hidden_data = batch_hidden->template mutable_data<T>(TARGET(kCUDA));
bool is_reverse = param.is_reverse;
auto active_node = lite::cuda::math::GetActiveType(param.activation);
auto active_gate = lite::cuda::math::GetActiveType(param.gate_activation);
......@@ -224,6 +366,8 @@ void GRUCompute<T, PType>::Run() {
using GRUFp32 =
paddle::lite::kernels::cuda::GRUCompute<float, PRECISION(kFloat)>;
using GRUFp16 = paddle::lite::kernels::cuda::GRUCompute<half, PRECISION(kFP16)>;
REGISTER_LITE_KERNEL(gru, kCUDA, kFloat, kNCHW, GRUFp32, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("H0", {LiteType::GetTensorTy(TARGET(kCUDA))})
......@@ -234,3 +378,20 @@ REGISTER_LITE_KERNEL(gru, kCUDA, kFloat, kNCHW, GRUFp32, def)
.BindOutput("BatchHidden", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Hidden", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
REGISTER_LITE_KERNEL(gru, kCUDA, kFP16, kNCHW, GRUFp16, def)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindInput("H0", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindInput("Weight",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindOutput("BatchGate",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindOutput("BatchResetHiddenPrev",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindOutput("BatchHidden",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindOutput("Hidden",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.Finalize();
......@@ -45,10 +45,13 @@ class GRUTest : public ::testing::Test {
x_ref_.Resize(lite::DDim(x_shape_));
x_gpu_.Resize(lite::DDim(x_shape_));
x_ref_.set_lod(lod_);
w_ref_.Resize(lite::DDim(w_shape_));
w_gpu_.Resize(lite::DDim(w_shape_));
auto x_ref_data = x_ref_.mutable_data<float>();
auto w_ref_data = w_ref_.mutable_data<float>();
for (int64_t i = 0; i < x_ref_.numel(); i++) {
x_ref_data[i] = static_cast<float>(i % 10 * 0.2);
}
......@@ -63,6 +66,7 @@ class GRUTest : public ::testing::Test {
batch_hidden_gpu_.Resize(lite::DDim(out_shape_));
batch_reset_hidden_gpu_.Resize(lite::DDim(out_shape_));
RunBaseLine();
InitParamAndContext();
}
......@@ -91,6 +95,22 @@ class GRUTest : public ::testing::Test {
w_gpu_.dims());
}
void InitHalfInput() {
x_half_.Resize(lite::DDim(x_shape_));
auto x_half_data = x_half_.mutable_data<half>();
for (int64_t i = 0; i < x_half_.numel(); i++) {
x_half_data[i] = half(lite::float16(x_ref_.data<float>()[i]));
}
x_gpu_.Assign<half, lite::DDim, TARGET(kCUDA)>(x_half_data, x_gpu_.dims());
x_gpu_.set_lod(x_ref_.lod());
w_half_.Resize(w_ref_.dims());
auto w_half_data = w_half_.mutable_data<half>();
for (int64_t i = 0; i < w_half_.numel(); i++) {
w_half_data[i] = half(lite::float16(w_ref_.data<float>()[i]));
}
w_gpu_.Assign<half, lite::DDim, TARGET(kCUDA)>(w_half_data, w_gpu_.dims());
}
void RunBaseLine() {}
int batch_, frame_size_;
......@@ -134,6 +154,29 @@ TEST_F(GRUTest, TestFP32) {
<< duration / FLAGS_repeats << " ms in average.";
}
TEST_F(GRUTest, TestFP16) {
InitHalfInput();
GRUCompute<half, PRECISION(kFP16)> kernel;
kernel.SetParam(param_);
kernel.SetContext(std::move(ctx_));
for (int i = 0; i < FLAGS_warmup; ++i) {
kernel.Launch();
cudaDeviceSynchronize();
}
auto start = GetCurrentUS();
kernel.PrepareForRun();
for (int i = 0; i < FLAGS_repeats; ++i) {
kernel.Run();
}
cudaDeviceSynchronize();
auto duration = (GetCurrentUS() - start) / 1000.0;
LOG(INFO) << "fp16, warmup: " << FLAGS_warmup
<< ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average.";
}
} // namespace cuda
} // namespace kernels
} // namespace lite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册