提交 3e1b914f 编写于 作者: Q Qiao Longfei

update gru op forward kernel

上级 7a81ab86
...@@ -137,6 +137,10 @@ class GRUOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -137,6 +137,10 @@ class GRUOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, defalut: False) " "(bool, defalut: False) "
"whether to compute reversed GRU.") "whether to compute reversed GRU.")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("origin_mode",
"bool"
"use origin mode in article https://arxiv.org/abs/1412.3555")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
GRU Operator implements part calculations of the complete GRU as following: GRU Operator implements part calculations of the complete GRU as following:
...@@ -221,6 +225,7 @@ class GRUCPUKernel : public framework::OpKernel<T> { ...@@ -221,6 +225,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
public: public:
void BatchCompute(const framework::ExecutionContext& context) const { void BatchCompute(const framework::ExecutionContext& context) const {
using DeviceContext = paddle::platform::CPUDeviceContext; using DeviceContext = paddle::platform::CPUDeviceContext;
bool origin_mode = context.Attr<bool>("origin_mode");
auto* input = context.Input<LoDTensor>("Input"); auto* input = context.Input<LoDTensor>("Input");
auto* h0 = context.Input<Tensor>("H0"); auto* h0 = context.Input<Tensor>("H0");
auto* weight = context.Input<Tensor>("Weight"); auto* weight = context.Input<Tensor>("Weight");
...@@ -327,7 +332,7 @@ class GRUCPUKernel : public framework::OpKernel<T> { ...@@ -327,7 +332,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
math::detail::forward_final_output( math::detail::forward_final_output(
math::detail::forward::gru_finalOutput<T>(), gru_value, frame_size, math::detail::forward::gru_finalOutput<T>(), gru_value, frame_size,
cur_batch_size, active_node); cur_batch_size, active_node, origin_mode);
gru_value.prev_out_value = gru_value.output_value; gru_value.prev_out_value = gru_value.output_value;
} }
...@@ -351,7 +356,7 @@ class GRUCPUKernel : public framework::OpKernel<T> { ...@@ -351,7 +356,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
math::GRUUnitFunctor<DeviceContext, T>::compute( math::GRUUnitFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, frame_size, cur_batch_size, active_node, dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
active_gate); active_gate, origin_mode);
gru_value.prev_out_value = gru_value.output_value; gru_value.prev_out_value = gru_value.output_value;
} }
......
...@@ -21,6 +21,7 @@ template <typename DeviceContext, typename T> ...@@ -21,6 +21,7 @@ template <typename DeviceContext, typename T>
class GRUKernel : public framework::OpKernel<T> { class GRUKernel : public framework::OpKernel<T> {
public: public:
void BatchCompute(const framework::ExecutionContext& context) const { void BatchCompute(const framework::ExecutionContext& context) const {
bool origin_mode = context.Attr<bool>("origin_mode");
auto* input = context.Input<LoDTensor>("Input"); auto* input = context.Input<LoDTensor>("Input");
auto* h0 = context.Input<Tensor>("H0"); auto* h0 = context.Input<Tensor>("H0");
auto* weight = context.Input<Tensor>("Weight"); auto* weight = context.Input<Tensor>("Weight");
...@@ -87,7 +88,7 @@ class GRUKernel : public framework::OpKernel<T> { ...@@ -87,7 +88,7 @@ class GRUKernel : public framework::OpKernel<T> {
gru_value.reset_output_value = reset_hidden_prev_t.data<T>(); gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
math::GRUUnitFunctor<DeviceContext, T>::compute( math::GRUUnitFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, frame_size, cur_batch_size, active_node, dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
active_gate); active_gate, origin_mode);
gru_value.prev_out_value = gru_value.output_value; gru_value.prev_out_value = gru_value.output_value;
} }
......
...@@ -41,6 +41,7 @@ template <typename DeviceContext, typename T> ...@@ -41,6 +41,7 @@ template <typename DeviceContext, typename T>
class GRUGradKernel : public framework::OpKernel<T> { class GRUGradKernel : public framework::OpKernel<T> {
public: public:
void BatchCompute(const framework::ExecutionContext& context) const { void BatchCompute(const framework::ExecutionContext& context) const {
bool origin_mode = context.Attr<bool>("origin_mode");
auto* h0 = context.Input<Tensor>("H0"); auto* h0 = context.Input<Tensor>("H0");
auto* weight = context.Input<Tensor>("Weight"); auto* weight = context.Input<Tensor>("Weight");
const T* weight_data = weight->data<T>(); const T* weight_data = weight->data<T>();
...@@ -146,7 +147,7 @@ class GRUGradKernel : public framework::OpKernel<T> { ...@@ -146,7 +147,7 @@ class GRUGradKernel : public framework::OpKernel<T> {
math::GRUUnitGradFunctor<DeviceContext, T>::compute( math::GRUUnitGradFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, gru_grad, frame_size, cur_batch_size, active_node, dev_ctx, gru_value, gru_grad, frame_size, cur_batch_size, active_node,
active_gate); active_gate, origin_mode);
} }
if (input_grad) { if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace()); input_grad->mutable_data<T>(context.GetPlace());
......
...@@ -56,7 +56,8 @@ template <class OpFinalOutput, typename T> ...@@ -56,7 +56,8 @@ template <class OpFinalOutput, typename T>
void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output, void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
T *gate_value, T *prev_output_value, T *gate_value, T *prev_output_value,
T *output_value, int frame_size, T *output_value, int frame_size,
ActivationType active_node) { ActivationType active_node,
bool origin_mode) {
T r_value_update_gate; T r_value_update_gate;
T r_value_frame_state; T r_value_frame_state;
T r_prev_out = 0; T r_prev_out = 0;
...@@ -72,7 +73,7 @@ void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output, ...@@ -72,7 +73,7 @@ void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
} }
op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out, op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
&r_output, active_node); &r_output, active_node, origin_mode);
frame_state[i] = r_value_frame_state; frame_state[i] = r_value_frame_state;
output_value[i] = r_output; output_value[i] = r_output;
...@@ -146,7 +147,8 @@ template <class OpFinalOutput, typename T> ...@@ -146,7 +147,8 @@ template <class OpFinalOutput, typename T>
void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
T *gate_value, T *prev_output_value, T *gate_value, T *prev_output_value,
T *output_value, int frame_size, T *output_value, int frame_size,
ActivationType active_node) { ActivationType active_node,
bool origin_mode) {
#ifdef __AVX__ #ifdef __AVX__
__m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps(0.0f); __m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps(0.0f);
__m256 r_value_frame_state, r_value_frame_state_last = _mm256_set1_ps(0.0f); __m256 r_value_frame_state, r_value_frame_state_last = _mm256_set1_ps(0.0f);
...@@ -180,7 +182,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, ...@@ -180,7 +182,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
} }
op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out, op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
&r_output, active_node); &r_output, active_node, origin_mode);
_mm256_storeu_ps(reinterpret_cast<float *>(frame_state + i), _mm256_storeu_ps(reinterpret_cast<float *>(frame_state + i),
r_value_frame_state); r_value_frame_state);
...@@ -190,7 +192,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, ...@@ -190,7 +192,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
if (rest > 0) { if (rest > 0) {
i = n - block; i = n - block;
op_final_output(&r_value_update_gate_last, &r_value_frame_state_last, op_final_output(&r_value_update_gate_last, &r_value_frame_state_last,
&r_prev_out_last, &r_output, active_node); &r_prev_out_last, &r_output, active_node, origin_mode);
_mm256_storeu_ps(reinterpret_cast<float *>(frame_state + i), _mm256_storeu_ps(reinterpret_cast<float *>(frame_state + i),
r_value_frame_state_last); r_value_frame_state_last);
...@@ -227,17 +229,18 @@ inline void forward_reset_output(OpResetOutput op_reset_output, ...@@ -227,17 +229,18 @@ inline void forward_reset_output(OpResetOutput op_reset_output,
template <class OpFinalOutput, typename T> template <class OpFinalOutput, typename T>
inline void forward_final_output(OpFinalOutput op_final_output, inline void forward_final_output(OpFinalOutput op_final_output,
GRUMetaValue<T> value, int frame_size, GRUMetaValue<T> value, int frame_size,
int batch_size, ActivationType active_node) { int batch_size, ActivationType active_node,
bool origin_mode) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
if (OpFinalOutput::avx && (frame_size > static_cast<int>(8 - 1)) && if (OpFinalOutput::avx && (frame_size > static_cast<int>(8 - 1)) &&
(sizeof(T) == 4)) { (sizeof(T) == 4)) {
hl_avx_gru_forward_final_output(op_final_output, value.gate_value, hl_avx_gru_forward_final_output(op_final_output, value.gate_value,
value.prev_out_value, value.output_value, value.prev_out_value, value.output_value,
frame_size, active_node); frame_size, active_node, origin_mode);
} else { } else {
hl_naive_gru_forward_final_output( hl_naive_gru_forward_final_output(
op_final_output, value.gate_value, value.prev_out_value, op_final_output, value.gate_value, value.prev_out_value,
value.output_value, frame_size, active_node); value.output_value, frame_size, active_node, origin_mode);
} }
value.gate_value += frame_size * 3; value.gate_value += frame_size * 3;
......
...@@ -72,7 +72,8 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output, ...@@ -72,7 +72,8 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
T *gate_value, T *prev_output_value, T *gate_value, T *prev_output_value,
T *output_value, int frame_size, T *output_value, int frame_size,
int batch_size, int batch_size,
ActivationType active_node) { ActivationType active_node,
bool origin_mode) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return; if (frame_idx >= frame_size) return;
int batch_idx = 0; int batch_idx = 0;
...@@ -94,7 +95,7 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output, ...@@ -94,7 +95,7 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
} }
op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out, op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
&r_output, active_node); &r_output, active_node, origin_mode);
gate_value[frame_idx + frame_size * 2] = r_value_frame_state; gate_value[frame_idx + frame_size * 2] = r_value_frame_state;
output_value[frame_idx] = r_output; output_value[frame_idx] = r_output;
......
...@@ -57,11 +57,17 @@ class gru_finalOutput { ...@@ -57,11 +57,17 @@ class gru_finalOutput {
public: public:
HOSTDEVICE void operator()(T *value_update_gate, T *value_frame_state, HOSTDEVICE void operator()(T *value_update_gate, T *value_frame_state,
T *prev_out, T *value_output, T *prev_out, T *value_output,
ActivationType act_input) { ActivationType act_input, bool origin_mode) {
*value_frame_state = activation(*value_frame_state, act_input); *value_frame_state = activation(*value_frame_state, act_input);
if (origin_mode) {
*value_output = ((*value_update_gate) * (*prev_out)) +
*value_frame_state -
((*value_update_gate) * (*value_frame_state));
} else {
*value_output = *prev_out - ((*value_update_gate) * (*prev_out)) + *value_output = *prev_out - ((*value_update_gate) * (*prev_out)) +
((*value_update_gate) * (*value_frame_state)); ((*value_update_gate) * (*value_frame_state));
} }
}
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ #ifndef __AVX__
static const bool avx = false; static const bool avx = false;
...@@ -69,12 +75,21 @@ class gru_finalOutput { ...@@ -69,12 +75,21 @@ class gru_finalOutput {
static const bool avx = true; static const bool avx = true;
HOSTDEVICE void operator()(__m256 *value_update_gate, HOSTDEVICE void operator()(__m256 *value_update_gate,
__m256 *value_frame_state, __m256 *prev_out, __m256 *value_frame_state, __m256 *prev_out,
__m256 *value_output, ActivationType act_input) { __m256 *value_output, ActivationType act_input,
bool origin_mode) {
*value_frame_state = activation(*value_frame_state, act_input); *value_frame_state = activation(*value_frame_state, act_input);
if (origin_mode) {
*value_output = _mm256_sub_ps(
_mm256_add_ps(_mm256_mul_ps(*value_update_gate, *prev_out),
*value_frame_state),
_mm256_mul_ps(*value_update_gate, *value_frame_state));
} else {
*value_output = _mm256_add_ps( *value_output = _mm256_add_ps(
_mm256_sub_ps(*prev_out, _mm256_mul_ps(*value_update_gate, *prev_out)), _mm256_sub_ps(*prev_out,
_mm256_mul_ps(*value_update_gate, *prev_out)),
_mm256_mul_ps(*value_update_gate, *value_frame_state)); _mm256_mul_ps(*value_update_gate, *value_frame_state));
} }
}
#endif #endif
#endif #endif
}; };
......
...@@ -23,7 +23,8 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> { ...@@ -23,7 +23,8 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
static void compute(const platform::CPUDeviceContext &context, static void compute(const platform::CPUDeviceContext &context,
GRUMetaValue<T> value, int frame_size, int batch_size, GRUMetaValue<T> value, int frame_size, int batch_size,
const detail::ActivationType active_node, const detail::ActivationType active_node,
const detail::ActivationType active_gate) { const detail::ActivationType active_gate,
bool origin_mode) {
#ifndef __NVCC__ #ifndef __NVCC__
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context); auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
if (value.prev_out_value) { if (value.prev_out_value) {
...@@ -43,7 +44,8 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> { ...@@ -43,7 +44,8 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
} }
detail::forward_final_output(detail::forward::gru_finalOutput<T>(), value, detail::forward_final_output(detail::forward::gru_finalOutput<T>(), value,
frame_size, batch_size, active_node); frame_size, batch_size, active_node,
origin_mode);
#endif #endif
} }
}; };
...@@ -54,7 +56,8 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> { ...@@ -54,7 +56,8 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> {
GRUMetaValue<T> value, GRUMetaGrad<T> grad, GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size, int batch_size, int frame_size, int batch_size,
const detail::ActivationType active_node, const detail::ActivationType active_node,
const detail::ActivationType active_gate) { const detail::ActivationType active_gate,
bool origin_mode) {
#ifndef __NVCC__ #ifndef __NVCC__
detail::backward_state_grad(detail::backward::gru_stateGrad<T>(), value, detail::backward_state_grad(detail::backward::gru_stateGrad<T>(), value,
grad, frame_size, batch_size, active_node); grad, frame_size, batch_size, active_node);
......
...@@ -24,7 +24,8 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> { ...@@ -24,7 +24,8 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
static void compute(const platform::CUDADeviceContext &context, static void compute(const platform::CUDADeviceContext &context,
GRUMetaValue<T> value, int frame_size, int batch_size, GRUMetaValue<T> value, int frame_size, int batch_size,
const detail::ActivationType active_node, const detail::ActivationType active_node,
const detail::ActivationType active_gate) { const detail::ActivationType active_gate,
bool origin_mode) {
auto stream = context.stream(); auto stream = context.stream();
dim3 threads; dim3 threads;
dim3 grid; dim3 grid;
...@@ -73,14 +74,14 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> { ...@@ -73,14 +74,14 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
T><<<grid, threads, 0, stream>>>( T><<<grid, threads, 0, stream>>>(
detail::forward::gru_finalOutput<T>(), value.gate_value, detail::forward::gru_finalOutput<T>(), value.gate_value,
value.prev_out_value, value.output_value, frame_size, batch_size, value.prev_out_value, value.output_value, frame_size, batch_size,
active_node); active_node, origin_mode);
} else { } else {
detail::KeGruForwardFinalOutput<detail::forward::gru_finalOutput<T>, detail::KeGruForwardFinalOutput<detail::forward::gru_finalOutput<T>,
/* is_batch= */ true, /* is_batch= */ true,
T><<<grid, threads, 0, stream>>>( T><<<grid, threads, 0, stream>>>(
detail::forward::gru_finalOutput<T>(), value.gate_value, detail::forward::gru_finalOutput<T>(), value.gate_value,
value.prev_out_value, value.output_value, frame_size, batch_size, value.prev_out_value, value.output_value, frame_size, batch_size,
active_node); active_node, origin_mode);
} }
} }
}; };
......
...@@ -44,7 +44,8 @@ struct GRUUnitFunctor { ...@@ -44,7 +44,8 @@ struct GRUUnitFunctor {
static void compute(const DeviceContext &context, GRUMetaValue<T> value, static void compute(const DeviceContext &context, GRUMetaValue<T> value,
int frame_size, int batch_size, int frame_size, int batch_size,
const detail::ActivationType active_node, const detail::ActivationType active_node,
const detail::ActivationType active_gate); const detail::ActivationType active_gate,
bool origin_mode);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -52,7 +53,8 @@ struct GRUUnitGradFunctor { ...@@ -52,7 +53,8 @@ struct GRUUnitGradFunctor {
static void compute(const DeviceContext &context, GRUMetaValue<T> value, static void compute(const DeviceContext &context, GRUMetaValue<T> value,
GRUMetaGrad<T> grad, int frame_size, int batch_size, GRUMetaGrad<T> grad, int frame_size, int batch_size,
const detail::ActivationType active_node, const detail::ActivationType active_node,
const detail::ActivationType active_gate); const detail::ActivationType active_gate,
bool origin_mode);
}; };
} // namespace math } // namespace math
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册