未验证 提交 085260f3 编写于 作者: J Jack Zhou 提交者: GitHub

Add eigen gru and fix the dropout bug in the rnn

Add eigen gru and fix the dropout bug in the rnn 
上级 545df287
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include <type_traits> #include <type_traits>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/gru_compute.h" #include "paddle/fluid/operators/math/gru_compute.h"
...@@ -21,6 +23,10 @@ namespace paddle { ...@@ -21,6 +23,10 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
namespace detail { namespace detail {
using Array1 = Eigen::DSizes<int64_t, 1>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
#ifndef __NVCC__ #ifndef __NVCC__
...@@ -242,23 +248,46 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, ...@@ -242,23 +248,46 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
#endif #endif
} }
template <typename T>
inline void forward_reset_outputV2(const platform::CPUDeviceContext &context,
GRUMetaValue<T> value, int frame_size) {
auto &place = *context.eigen_device();
auto value_reset_gate =
typename EigenVector<T>::Type(value.gate_value, Array1(frame_size));
auto value_update_gate = typename EigenVector<T>::Type(
value.gate_value + frame_size, Array1(frame_size));
auto value_reset_output = typename EigenVector<T>::Type(
value.reset_output_value, Array1(frame_size));
auto value_reset_bias =
typename EigenVector<T>::ConstType(value.reset_bias, Array1(frame_size));
SigmoidFunctor<T>()(place, value_reset_gate, value_reset_gate);
SigmoidFunctor<T>()(place, value_update_gate, value_update_gate);
value_reset_output.device(place) =
(value_reset_output + value_reset_bias) * value_reset_gate;
}
template <class OpResetOutput, typename T> template <class OpResetOutput, typename T>
inline void forward_reset_output(OpResetOutput op_reset_output, inline void forward_reset_output(
GRUMetaValue<T> value, int frame_size, OpResetOutput op_reset_output, GRUMetaValue<T> value, int frame_size,
int batch_size, ActivationType active_gate, int batch_size, ActivationType active_gate, bool old_version = true,
bool old_version = true) { const platform::CPUDeviceContext *context = nullptr) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
if (OpResetOutput::avx && (frame_size > static_cast<int>(8 - 1)) && if (!old_version) {
(sizeof(T) == 4)) { // use eigen
hl_avx_gru_forward_reset_output( forward_reset_outputV2(*context, value, frame_size);
op_reset_output, value.gate_value, value.reset_output_value,
value.prev_out_value, frame_size, active_gate, old_version,
value.reset_bias);
} else { } else {
hl_naive_gru_forward_reset_output( if (OpResetOutput::avx && (frame_size & static_cast<int>(8 - 1)) &&
op_reset_output, value.gate_value, value.reset_output_value, (sizeof(T) == 4)) {
value.prev_out_value, frame_size, active_gate, old_version, hl_avx_gru_forward_reset_output(
value.reset_bias); op_reset_output, value.gate_value, value.reset_output_value,
value.prev_out_value, frame_size, active_gate, old_version,
value.reset_bias);
} else {
hl_naive_gru_forward_reset_output(
op_reset_output, value.gate_value, value.reset_output_value,
value.prev_out_value, frame_size, active_gate, old_version,
value.reset_bias);
}
} }
value.gate_value += frame_size * 3; value.gate_value += frame_size * 3;
value.reset_output_value += frame_size; value.reset_output_value += frame_size;
...@@ -268,25 +297,51 @@ inline void forward_reset_output(OpResetOutput op_reset_output, ...@@ -268,25 +297,51 @@ inline void forward_reset_output(OpResetOutput op_reset_output,
} }
} }
template <typename T>
inline void forward_final_outputV2(const platform::CPUDeviceContext &context,
GRUMetaValue<T> value, int frame_size) {
auto &place = *context.eigen_device();
auto value_update_gate = typename EigenVector<T>::Type(
value.gate_value + frame_size, Array1(frame_size));
auto value_frame_state = typename EigenVector<T>::Type(
value.gate_value + 2 * frame_size, Array1(frame_size));
auto value_output =
typename EigenVector<T>::Type(value.output_value, Array1(frame_size));
TanhFunctor<T>()(place, value_frame_state, value_frame_state);
value_output.device(place) =
(static_cast<T>(1.0) - value_update_gate) * value_frame_state;
if (value.prev_out_value) {
auto value_prev_out = typename EigenVector<T>::ConstType(
value.prev_out_value, Array1(frame_size));
value_output.device(place) =
value_output + value_update_gate * value_prev_out;
}
}
template <class OpFinalOutput, typename T> template <class OpFinalOutput, typename T>
inline void forward_final_output(OpFinalOutput op_final_output, inline void forward_final_output(
GRUMetaValue<T> value, int frame_size, OpFinalOutput op_final_output, GRUMetaValue<T> value, int frame_size,
int batch_size, ActivationType active_node, int batch_size, ActivationType active_node, bool origin_mode,
bool origin_mode, bool old_version = true) { bool old_version = true,
const platform::CPUDeviceContext *context = nullptr) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
if (OpFinalOutput::avx && (frame_size > static_cast<int>(8 - 1)) && if (!old_version) {
(sizeof(T) == 4)) { // eigen
hl_avx_gru_forward_final_output(op_final_output, value.gate_value, forward_final_outputV2(*context, value, frame_size);
value.prev_out_value, value.output_value,
frame_size, active_node, origin_mode,
old_version);
} else { } else {
hl_naive_gru_forward_final_output(op_final_output, value.gate_value, if (OpFinalOutput::avx && (frame_size & static_cast<int>(8 - 1)) &&
(sizeof(T) == 4)) {
hl_avx_gru_forward_final_output(op_final_output, value.gate_value,
value.prev_out_value, value.prev_out_value,
value.output_value, frame_size, value.output_value, frame_size,
active_node, origin_mode, old_version); active_node, origin_mode, old_version);
} else {
hl_naive_gru_forward_final_output(
op_final_output, value.gate_value, value.prev_out_value,
value.output_value, frame_size, active_node, origin_mode,
old_version);
}
} }
value.gate_value += frame_size * 3; value.gate_value += frame_size * 3;
value.output_value += frame_size; value.output_value += frame_size;
if (value.prev_out_value) { if (value.prev_out_value) {
...@@ -664,23 +719,70 @@ inline void backward_reset_grad(OpResetGrad op_reset_grad, ...@@ -664,23 +719,70 @@ inline void backward_reset_grad(OpResetGrad op_reset_grad,
} }
} }
template <typename T>
inline void gru_backward(const platform::CPUDeviceContext &context,
GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size) {
auto &place = *context.eigen_device();
auto value_reset_gate =
typename EigenVector<T>::Type(value.gate_value, Array1(frame_size));
auto grad_reset_gate =
typename EigenVector<T>::Type(grad.gate_grad, Array1(frame_size));
auto value_update_gate = typename EigenVector<T>::Type(
value.gate_value + frame_size, Array1(frame_size));
auto grad_update_gate = typename EigenVector<T>::Type(
grad.gate_grad + frame_size, Array1(frame_size));
auto value_frame_state = typename EigenVector<T>::Type(
value.gate_value + frame_size * 2, Array1(frame_size));
auto grad_frame_state = typename EigenVector<T>::Type(
grad.gate_grad + frame_size * 2, Array1(frame_size));
auto grad_output =
typename EigenVector<T>::Type(grad.output_grad, Array1(frame_size));
auto value_reset_output = typename EigenVector<T>::Type(
value.reset_output_value, Array1(frame_size));
auto grad_reset_output =
typename EigenVector<T>::Type(grad.reset_output_grad, Array1(frame_size));
if (value.prev_out_value) {
auto value_prev_out = typename EigenVector<T>::ConstType(
value.prev_out_value, Array1(frame_size));
SigmoidGradFunctor<T>()(place, 1 /*useless*/, value_update_gate,
(value_prev_out - value_frame_state) * grad_output,
grad_update_gate);
} else {
SigmoidGradFunctor<T>()(
place, 1 /*useless*/, value_update_gate,
static_cast<T>(-1) * value_frame_state * grad_output, grad_update_gate);
}
if (grad.prev_out_grad) {
auto grad_prev_out =
typename EigenVector<T>::Type(grad.prev_out_grad, Array1(frame_size));
grad_prev_out.device(place) =
grad_prev_out + grad_output * value_update_gate;
}
TanhGradFunctor<T>()(place, 1 /*useless*/, value_frame_state,
grad_output * (static_cast<T>(1.0) - value_update_gate),
grad_frame_state);
SigmoidGradFunctor<T>()(
place, 1 /*useless*/, value_reset_gate,
value_reset_output / value_reset_gate * grad_frame_state,
grad_reset_gate);
if (value.prev_out_value && grad.prev_out_grad) {
grad_reset_output.device(place) = value_reset_gate * grad_frame_state;
}
}
template <class OpGruGrad, typename T> template <class OpGruGrad, typename T>
inline void cpu_gru_backward(OpGruGrad op_gru_grad, GRUMetaValue<T> value, inline void cpu_gru_backward(const platform::CPUDeviceContext &context,
OpGruGrad op_gru_grad, GRUMetaValue<T> value,
GRUMetaGrad<T> grad, int frame_size, GRUMetaGrad<T> grad, int frame_size,
int batch_size, ActivationType active_node, int batch_size, ActivationType active_node,
ActivationType active_gate) { ActivationType active_gate) {
for (int b = 0; b < batch_size; ++b) { for (int b = 0; b < batch_size; ++b) {
if (OpGruGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) { // eigen
hl_avx_gru_backward( gru_backward(context, value, grad, frame_size);
op_gru_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
grad.prev_out_grad, value.reset_output_value, grad.reset_output_grad,
grad.output_grad, frame_size, active_node, active_gate);
} else {
hl_naive_gru_backward(
op_gru_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
grad.prev_out_grad, value.reset_output_value, grad.reset_output_grad,
grad.output_grad, frame_size, active_node, active_gate);
}
value.gate_value += frame_size * 3; value.gate_value += frame_size * 3;
value.reset_output_value += frame_size; value.reset_output_value += frame_size;
......
...@@ -42,7 +42,8 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> { ...@@ -42,7 +42,8 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
} }
detail::forward_reset_output(detail::forward::gru_resetOutput<T>(), value, detail::forward_reset_output(detail::forward::gru_resetOutput<T>(), value,
frame_size, batch_size, active_gate); frame_size, batch_size, active_gate, true,
&context);
if (value.prev_out_value) { if (value.prev_out_value) {
blas.GEMM(false, false, batch_size, frame_size, frame_size, 1, blas.GEMM(false, false, batch_size, frame_size, frame_size, 1,
...@@ -53,7 +54,7 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> { ...@@ -53,7 +54,7 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
detail::forward_final_output(detail::forward::gru_finalOutput<T>(), value, detail::forward_final_output(detail::forward::gru_finalOutput<T>(), value,
frame_size, batch_size, active_node, frame_size, batch_size, active_node,
origin_mode); origin_mode, &context);
#endif #endif
} }
}; };
...@@ -116,7 +117,8 @@ struct GRUUnitFunctorV2<platform::CPUDeviceContext, T> { ...@@ -116,7 +117,8 @@ struct GRUUnitFunctorV2<platform::CPUDeviceContext, T> {
value.reset_output_value); value.reset_output_value);
} }
detail::forward_reset_output(detail::forward::gru_resetOutput<T>(), value, detail::forward_reset_output(detail::forward::gru_resetOutput<T>(), value,
frame_size, batch_size, active_gate, false); frame_size, batch_size, active_gate, false,
&context);
T *cell_state_value = value.gate_value + 2 * frame_size; T *cell_state_value = value.gate_value + 2 * frame_size;
T *reset_output_value = value.reset_output_value; T *reset_output_value = value.reset_output_value;
...@@ -129,7 +131,7 @@ struct GRUUnitFunctorV2<platform::CPUDeviceContext, T> { ...@@ -129,7 +131,7 @@ struct GRUUnitFunctorV2<platform::CPUDeviceContext, T> {
detail::forward_final_output(detail::forward::gru_finalOutput<T>(), value, detail::forward_final_output(detail::forward::gru_finalOutput<T>(), value,
frame_size, batch_size, active_node, true, frame_size, batch_size, active_node, true,
false); false, &context);
#endif #endif
} }
}; };
...@@ -144,8 +146,50 @@ struct GRUUnitGradFunctorV2<platform::CPUDeviceContext, T> { ...@@ -144,8 +146,50 @@ struct GRUUnitGradFunctorV2<platform::CPUDeviceContext, T> {
#ifndef __NVCC__ #ifndef __NVCC__
// calculate grad_update_gate, grad_frame_state, // calculate grad_update_gate, grad_frame_state,
// grad_reset_output, grad_reset_gate // grad_reset_output, grad_reset_gate
detail::cpu_gru_backward(detail::backward::gru<T>(), value, grad, detail::cpu_gru_backward(context, detail::backward::gru<T>(), value, grad,
frame_size, batch_size, active_node, active_gate); frame_size, batch_size, active_node, active_gate);
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
if (grad.prev_out_grad && value.prev_out_value) {
// update prev_out_grad
blas.GEMM(false, false, batch_size, frame_size, frame_size, 1,
grad.gate_grad, frame_size * 3, value.gate_weight, frame_size,
1, grad.prev_out_grad, frame_size);
blas.GEMM(false, false, batch_size, frame_size, frame_size, 1,
grad.gate_grad + frame_size, frame_size * 3,
value.gate_weight + frame_size * frame_size, frame_size, 1,
grad.prev_out_grad, frame_size);
blas.GEMM(false, false, batch_size, frame_size, frame_size, 1,
grad.reset_output_grad, frame_size, value.state_weight,
frame_size, 1, grad.prev_out_grad, frame_size);
// update weight_hh_grad
if (grad.gate_weight_grad) {
// reset gate
blas.GEMM(true, false, frame_size, frame_size, batch_size, 1,
grad.gate_grad, frame_size * 3, value.prev_out_value,
frame_size, 1, grad.gate_weight_grad, frame_size);
// update gate
blas.GEMM(true, false, frame_size, frame_size, batch_size, 1,
grad.gate_grad + frame_size, frame_size * 3,
value.prev_out_value, frame_size, 1,
grad.gate_weight_grad + frame_size * frame_size, frame_size);
// cell state
blas.GEMM(true, false, frame_size, frame_size, batch_size, 1,
grad.reset_output_grad, frame_size, value.prev_out_value,
frame_size, 1, grad.state_weight_grad, frame_size);
}
}
// update bias_hh_grad
T *gate_grad = grad.gate_grad;
T *bias_hh_grad = grad.bias_hh_grad;
T *state_bias_grad = grad.bias_hh_grad + 2 * frame_size;
T *reset_output_grad = grad.reset_output_grad;
for (int b = 0; b < batch_size; ++b) {
blas.VADD(2 * frame_size, bias_hh_grad, gate_grad, bias_hh_grad);
blas.VADD(frame_size, state_bias_grad, reset_output_grad,
state_bias_grad);
gate_grad += 3 * frame_size;
reset_output_grad += frame_size;
}
#endif #endif
} }
}; };
......
...@@ -38,7 +38,7 @@ struct GRUMetaGrad { ...@@ -38,7 +38,7 @@ struct GRUMetaGrad {
T *reset_output_grad; T *reset_output_grad;
T *output_grad; T *output_grad;
T *prev_out_grad; T *prev_out_grad;
T *state_bias_grad; T *bias_hh_grad;
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
......
此差异已折叠。
...@@ -294,7 +294,6 @@ def unstack(array, axis=0): ...@@ -294,7 +294,6 @@ def unstack(array, axis=0):
def dropout(array, p=0.5): def dropout(array, p=0.5):
if p == 0.0: if p == 0.0:
return array return array
mask = (np.random.uniform(size=array.shape) < (1 - p)).astype(array.dtype) mask = (np.random.uniform(size=array.shape) < (1 - p)).astype(array.dtype)
return array * (mask / (1 - p)) return array * (mask / (1 - p))
...@@ -390,11 +389,12 @@ class RNNMixin(LayerListMixin): ...@@ -390,11 +389,12 @@ class RNNMixin(LayerListMixin):
states = split_states(initial_states, self.num_directions == 2, states = split_states(initial_states, self.num_directions == 2,
self.state_components) self.state_components)
final_states = [] final_states = []
input_temp = inputs
for i, rnn_layer in enumerate(self): for i, rnn_layer in enumerate(self):
if i > 0: if i > 0:
inputs = dropout(inputs, self.dropout) input_temp = dropout(inputs, self.dropout)
outputs, final_state = rnn_layer(inputs, states[i], sequence_length) outputs, final_state = rnn_layer(input_temp, states[i],
sequence_length)
final_states.append(final_state) final_states.append(final_state)
inputs = outputs inputs = outputs
......
...@@ -53,6 +53,7 @@ class TestRNNOp(OpTest): ...@@ -53,6 +53,7 @@ class TestRNNOp(OpTest):
self.is_bidirec = False self.is_bidirec = False
self.mode = "LSTM" self.mode = "LSTM"
self.is_test = False self.is_test = False
self.dropout = 0.0
self.set_attrs() self.set_attrs()
self.direction_num = 2 if self.is_bidirec else 1 self.direction_num = 2 if self.is_bidirec else 1
...@@ -76,7 +77,8 @@ class TestRNNOp(OpTest): ...@@ -76,7 +77,8 @@ class TestRNNOp(OpTest):
hidden_size, hidden_size,
num_layers=self.num_layers, num_layers=self.num_layers,
time_major=True, time_major=True,
direction=direction) direction=direction,
dropout=self.dropout)
flat_w = get_params_for_net(rnn1) flat_w = get_params_for_net(rnn1)
output, (last_hidden, last_cell) = rnn1( output, (last_hidden, last_cell) = rnn1(
...@@ -101,7 +103,7 @@ class TestRNNOp(OpTest): ...@@ -101,7 +103,7 @@ class TestRNNOp(OpTest):
'PreState': [('init_h', init_h), ('init_c', init_c)], 'PreState': [('init_h', init_h), ('init_c', init_c)],
} }
self.attrs = { self.attrs = {
'dropout_prob': 0.0, 'dropout_prob': self.dropout,
'is_bidirec': self.is_bidirec, 'is_bidirec': self.is_bidirec,
'input_size': input_size, 'input_size': input_size,
'hidden_size': hidden_size, 'hidden_size': hidden_size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册