提交 3e1b914f 编写于 作者: Q Qiao Longfei

update gru op forward kernel

上级 7a81ab86
......@@ -137,6 +137,10 @@ class GRUOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, defalut: False) "
"whether to compute reversed GRU.")
.SetDefault(false);
AddAttr<bool>("origin_mode",
"bool"
"use origin mode in article https://arxiv.org/abs/1412.3555")
.SetDefault(false);
AddComment(R"DOC(
GRU Operator implements part calculations of the complete GRU as following:
......@@ -221,6 +225,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
public:
void BatchCompute(const framework::ExecutionContext& context) const {
using DeviceContext = paddle::platform::CPUDeviceContext;
bool origin_mode = context.Attr<bool>("origin_mode");
auto* input = context.Input<LoDTensor>("Input");
auto* h0 = context.Input<Tensor>("H0");
auto* weight = context.Input<Tensor>("Weight");
......@@ -327,7 +332,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
math::detail::forward_final_output(
math::detail::forward::gru_finalOutput<T>(), gru_value, frame_size,
cur_batch_size, active_node);
cur_batch_size, active_node, origin_mode);
gru_value.prev_out_value = gru_value.output_value;
}
......@@ -351,7 +356,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
math::GRUUnitFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
active_gate);
active_gate, origin_mode);
gru_value.prev_out_value = gru_value.output_value;
}
......
......@@ -21,6 +21,7 @@ template <typename DeviceContext, typename T>
class GRUKernel : public framework::OpKernel<T> {
public:
void BatchCompute(const framework::ExecutionContext& context) const {
bool origin_mode = context.Attr<bool>("origin_mode");
auto* input = context.Input<LoDTensor>("Input");
auto* h0 = context.Input<Tensor>("H0");
auto* weight = context.Input<Tensor>("Weight");
......@@ -87,7 +88,7 @@ class GRUKernel : public framework::OpKernel<T> {
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
math::GRUUnitFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
active_gate);
active_gate, origin_mode);
gru_value.prev_out_value = gru_value.output_value;
}
......
......@@ -41,6 +41,7 @@ template <typename DeviceContext, typename T>
class GRUGradKernel : public framework::OpKernel<T> {
public:
void BatchCompute(const framework::ExecutionContext& context) const {
bool origin_mode = context.Attr<bool>("origin_mode");
auto* h0 = context.Input<Tensor>("H0");
auto* weight = context.Input<Tensor>("Weight");
const T* weight_data = weight->data<T>();
......@@ -146,7 +147,7 @@ class GRUGradKernel : public framework::OpKernel<T> {
math::GRUUnitGradFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, gru_grad, frame_size, cur_batch_size, active_node,
active_gate);
active_gate, origin_mode);
}
if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
......
......@@ -56,7 +56,8 @@ template <class OpFinalOutput, typename T>
void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
T *gate_value, T *prev_output_value,
T *output_value, int frame_size,
ActivationType active_node) {
ActivationType active_node,
bool origin_mode) {
T r_value_update_gate;
T r_value_frame_state;
T r_prev_out = 0;
......@@ -72,7 +73,7 @@ void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
}
op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
&r_output, active_node);
&r_output, active_node, origin_mode);
frame_state[i] = r_value_frame_state;
output_value[i] = r_output;
......@@ -146,7 +147,8 @@ template <class OpFinalOutput, typename T>
void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
T *gate_value, T *prev_output_value,
T *output_value, int frame_size,
ActivationType active_node) {
ActivationType active_node,
bool origin_mode) {
#ifdef __AVX__
__m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps(0.0f);
__m256 r_value_frame_state, r_value_frame_state_last = _mm256_set1_ps(0.0f);
......@@ -180,7 +182,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
}
op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
&r_output, active_node);
&r_output, active_node, origin_mode);
_mm256_storeu_ps(reinterpret_cast<float *>(frame_state + i),
r_value_frame_state);
......@@ -190,7 +192,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
if (rest > 0) {
i = n - block;
op_final_output(&r_value_update_gate_last, &r_value_frame_state_last,
&r_prev_out_last, &r_output, active_node);
&r_prev_out_last, &r_output, active_node, origin_mode);
_mm256_storeu_ps(reinterpret_cast<float *>(frame_state + i),
r_value_frame_state_last);
......@@ -227,17 +229,18 @@ inline void forward_reset_output(OpResetOutput op_reset_output,
template <class OpFinalOutput, typename T>
inline void forward_final_output(OpFinalOutput op_final_output,
GRUMetaValue<T> value, int frame_size,
int batch_size, ActivationType active_node) {
int batch_size, ActivationType active_node,
bool origin_mode) {
for (int b = 0; b < batch_size; b++) {
if (OpFinalOutput::avx && (frame_size > static_cast<int>(8 - 1)) &&
(sizeof(T) == 4)) {
hl_avx_gru_forward_final_output(op_final_output, value.gate_value,
value.prev_out_value, value.output_value,
frame_size, active_node);
frame_size, active_node, origin_mode);
} else {
hl_naive_gru_forward_final_output(
op_final_output, value.gate_value, value.prev_out_value,
value.output_value, frame_size, active_node);
value.output_value, frame_size, active_node, origin_mode);
}
value.gate_value += frame_size * 3;
......
......@@ -72,7 +72,8 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
T *gate_value, T *prev_output_value,
T *output_value, int frame_size,
int batch_size,
ActivationType active_node) {
ActivationType active_node,
bool origin_mode) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
int batch_idx = 0;
......@@ -94,7 +95,7 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
}
op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
&r_output, active_node);
&r_output, active_node, origin_mode);
gate_value[frame_idx + frame_size * 2] = r_value_frame_state;
output_value[frame_idx] = r_output;
......
......@@ -57,10 +57,16 @@ class gru_finalOutput {
public:
HOSTDEVICE void operator()(T *value_update_gate, T *value_frame_state,
T *prev_out, T *value_output,
ActivationType act_input) {
ActivationType act_input, bool origin_mode) {
*value_frame_state = activation(*value_frame_state, act_input);
*value_output = *prev_out - ((*value_update_gate) * (*prev_out)) +
((*value_update_gate) * (*value_frame_state));
if (origin_mode) {
*value_output = ((*value_update_gate) * (*prev_out)) +
*value_frame_state -
((*value_update_gate) * (*value_frame_state));
} else {
*value_output = *prev_out - ((*value_update_gate) * (*prev_out)) +
((*value_update_gate) * (*value_frame_state));
}
}
#ifndef __NVCC__
#ifndef __AVX__
......@@ -69,11 +75,20 @@ class gru_finalOutput {
static const bool avx = true;
HOSTDEVICE void operator()(__m256 *value_update_gate,
__m256 *value_frame_state, __m256 *prev_out,
__m256 *value_output, ActivationType act_input) {
__m256 *value_output, ActivationType act_input,
bool origin_mode) {
*value_frame_state = activation(*value_frame_state, act_input);
*value_output = _mm256_add_ps(
_mm256_sub_ps(*prev_out, _mm256_mul_ps(*value_update_gate, *prev_out)),
_mm256_mul_ps(*value_update_gate, *value_frame_state));
if (origin_mode) {
*value_output = _mm256_sub_ps(
_mm256_add_ps(_mm256_mul_ps(*value_update_gate, *prev_out),
*value_frame_state),
_mm256_mul_ps(*value_update_gate, *value_frame_state));
} else {
*value_output = _mm256_add_ps(
_mm256_sub_ps(*prev_out,
_mm256_mul_ps(*value_update_gate, *prev_out)),
_mm256_mul_ps(*value_update_gate, *value_frame_state));
}
}
#endif
#endif
......
......@@ -23,7 +23,8 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
static void compute(const platform::CPUDeviceContext &context,
GRUMetaValue<T> value, int frame_size, int batch_size,
const detail::ActivationType active_node,
const detail::ActivationType active_gate) {
const detail::ActivationType active_gate,
bool origin_mode) {
#ifndef __NVCC__
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
if (value.prev_out_value) {
......@@ -43,7 +44,8 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
}
detail::forward_final_output(detail::forward::gru_finalOutput<T>(), value,
frame_size, batch_size, active_node);
frame_size, batch_size, active_node,
origin_mode);
#endif
}
};
......@@ -54,7 +56,8 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> {
GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size, int batch_size,
const detail::ActivationType active_node,
const detail::ActivationType active_gate) {
const detail::ActivationType active_gate,
bool origin_mode) {
#ifndef __NVCC__
detail::backward_state_grad(detail::backward::gru_stateGrad<T>(), value,
grad, frame_size, batch_size, active_node);
......
......@@ -24,7 +24,8 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
static void compute(const platform::CUDADeviceContext &context,
GRUMetaValue<T> value, int frame_size, int batch_size,
const detail::ActivationType active_node,
const detail::ActivationType active_gate) {
const detail::ActivationType active_gate,
bool origin_mode) {
auto stream = context.stream();
dim3 threads;
dim3 grid;
......@@ -73,14 +74,14 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
T><<<grid, threads, 0, stream>>>(
detail::forward::gru_finalOutput<T>(), value.gate_value,
value.prev_out_value, value.output_value, frame_size, batch_size,
active_node);
active_node, origin_mode);
} else {
detail::KeGruForwardFinalOutput<detail::forward::gru_finalOutput<T>,
/* is_batch= */ true,
T><<<grid, threads, 0, stream>>>(
detail::forward::gru_finalOutput<T>(), value.gate_value,
value.prev_out_value, value.output_value, frame_size, batch_size,
active_node);
active_node, origin_mode);
}
}
};
......
......@@ -44,7 +44,8 @@ struct GRUUnitFunctor {
static void compute(const DeviceContext &context, GRUMetaValue<T> value,
int frame_size, int batch_size,
const detail::ActivationType active_node,
const detail::ActivationType active_gate);
const detail::ActivationType active_gate,
bool origin_mode);
};
template <typename DeviceContext, typename T>
......@@ -52,7 +53,8 @@ struct GRUUnitGradFunctor {
static void compute(const DeviceContext &context, GRUMetaValue<T> value,
GRUMetaGrad<T> grad, int frame_size, int batch_size,
const detail::ActivationType active_node,
const detail::ActivationType active_gate);
const detail::ActivationType active_gate,
bool origin_mode);
};
} // namespace math
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册