未验证 提交 adba4384 编写于 作者: 乔龙飞 Qiao Longfei 提交者: GitHub

Merge pull request #15161 from jacquesqiao/gru-add-mode

gru add origin mode
...@@ -70,8 +70,8 @@ paddle.fluid.layers.fc ArgSpec(args=['input', 'size', 'num_flatten_dims', 'param ...@@ -70,8 +70,8 @@ paddle.fluid.layers.fc ArgSpec(args=['input', 'size', 'num_flatten_dims', 'param
paddle.fluid.layers.embedding ArgSpec(args=['input', 'size', 'is_sparse', 'is_distributed', 'padding_idx', 'param_attr', 'dtype'], varargs=None, keywords=None, defaults=(False, False, None, None, 'float32')) paddle.fluid.layers.embedding ArgSpec(args=['input', 'size', 'is_sparse', 'is_distributed', 'padding_idx', 'param_attr', 'dtype'], varargs=None, keywords=None, defaults=(False, False, None, None, 'float32'))
paddle.fluid.layers.dynamic_lstm ArgSpec(args=['input', 'size', 'h_0', 'c_0', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'float32', None)) paddle.fluid.layers.dynamic_lstm ArgSpec(args=['input', 'size', 'h_0', 'c_0', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'float32', None))
paddle.fluid.layers.dynamic_lstmp ArgSpec(args=['input', 'size', 'proj_size', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'proj_activation', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'tanh', 'float32', None)) paddle.fluid.layers.dynamic_lstmp ArgSpec(args=['input', 'size', 'proj_size', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'proj_activation', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'tanh', 'float32', None))
paddle.fluid.layers.dynamic_gru ArgSpec(args=['input', 'size', 'param_attr', 'bias_attr', 'is_reverse', 'gate_activation', 'candidate_activation', 'h_0'], varargs=None, keywords=None, defaults=(None, None, False, 'sigmoid', 'tanh', None)) paddle.fluid.layers.dynamic_gru ArgSpec(args=['input', 'size', 'param_attr', 'bias_attr', 'is_reverse', 'gate_activation', 'candidate_activation', 'h_0', 'origin_mode'], varargs=None, keywords=None, defaults=(None, None, False, 'sigmoid', 'tanh', None, False))
paddle.fluid.layers.gru_unit ArgSpec(args=['input', 'hidden', 'size', 'param_attr', 'bias_attr', 'activation', 'gate_activation'], varargs=None, keywords=None, defaults=(None, None, 'tanh', 'sigmoid')) paddle.fluid.layers.gru_unit ArgSpec(args=['input', 'hidden', 'size', 'param_attr', 'bias_attr', 'activation', 'gate_activation', 'origin_mode'], varargs=None, keywords=None, defaults=(None, None, 'tanh', 'sigmoid', False))
paddle.fluid.layers.linear_chain_crf ArgSpec(args=['input', 'label', 'param_attr'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.linear_chain_crf ArgSpec(args=['input', 'label', 'param_attr'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.crf_decoding ArgSpec(args=['input', 'param_attr', 'label'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.crf_decoding ArgSpec(args=['input', 'param_attr', 'label'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.cos_sim ArgSpec(args=['X', 'Y'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.cos_sim ArgSpec(args=['X', 'Y'], varargs=None, keywords=None, defaults=None)
......
...@@ -137,6 +137,10 @@ class GRUOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -137,6 +137,10 @@ class GRUOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, defalut: False) " "(bool, defalut: False) "
"whether to compute reversed GRU.") "whether to compute reversed GRU.")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("origin_mode",
"bool"
"use origin mode in article https://arxiv.org/abs/1412.3555")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
GRU Operator implements part calculations of the complete GRU as following: GRU Operator implements part calculations of the complete GRU as following:
...@@ -221,6 +225,7 @@ class GRUCPUKernel : public framework::OpKernel<T> { ...@@ -221,6 +225,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
public: public:
void BatchCompute(const framework::ExecutionContext& context) const { void BatchCompute(const framework::ExecutionContext& context) const {
using DeviceContext = paddle::platform::CPUDeviceContext; using DeviceContext = paddle::platform::CPUDeviceContext;
bool origin_mode = context.Attr<bool>("origin_mode");
auto* input = context.Input<LoDTensor>("Input"); auto* input = context.Input<LoDTensor>("Input");
auto* h0 = context.Input<Tensor>("H0"); auto* h0 = context.Input<Tensor>("H0");
auto* weight = context.Input<Tensor>("Weight"); auto* weight = context.Input<Tensor>("Weight");
...@@ -327,7 +332,7 @@ class GRUCPUKernel : public framework::OpKernel<T> { ...@@ -327,7 +332,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
math::detail::forward_final_output( math::detail::forward_final_output(
math::detail::forward::gru_finalOutput<T>(), gru_value, frame_size, math::detail::forward::gru_finalOutput<T>(), gru_value, frame_size,
cur_batch_size, active_node); cur_batch_size, active_node, origin_mode);
gru_value.prev_out_value = gru_value.output_value; gru_value.prev_out_value = gru_value.output_value;
} }
...@@ -351,7 +356,7 @@ class GRUCPUKernel : public framework::OpKernel<T> { ...@@ -351,7 +356,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
math::GRUUnitFunctor<DeviceContext, T>::compute( math::GRUUnitFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, frame_size, cur_batch_size, active_node, dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
active_gate); active_gate, origin_mode);
gru_value.prev_out_value = gru_value.output_value; gru_value.prev_out_value = gru_value.output_value;
} }
......
...@@ -21,6 +21,7 @@ template <typename DeviceContext, typename T> ...@@ -21,6 +21,7 @@ template <typename DeviceContext, typename T>
class GRUKernel : public framework::OpKernel<T> { class GRUKernel : public framework::OpKernel<T> {
public: public:
void BatchCompute(const framework::ExecutionContext& context) const { void BatchCompute(const framework::ExecutionContext& context) const {
bool origin_mode = context.Attr<bool>("origin_mode");
auto* input = context.Input<LoDTensor>("Input"); auto* input = context.Input<LoDTensor>("Input");
auto* h0 = context.Input<Tensor>("H0"); auto* h0 = context.Input<Tensor>("H0");
auto* weight = context.Input<Tensor>("Weight"); auto* weight = context.Input<Tensor>("Weight");
...@@ -87,7 +88,7 @@ class GRUKernel : public framework::OpKernel<T> { ...@@ -87,7 +88,7 @@ class GRUKernel : public framework::OpKernel<T> {
gru_value.reset_output_value = reset_hidden_prev_t.data<T>(); gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
math::GRUUnitFunctor<DeviceContext, T>::compute( math::GRUUnitFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, frame_size, cur_batch_size, active_node, dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
active_gate); active_gate, origin_mode);
gru_value.prev_out_value = gru_value.output_value; gru_value.prev_out_value = gru_value.output_value;
} }
......
...@@ -41,6 +41,7 @@ template <typename DeviceContext, typename T> ...@@ -41,6 +41,7 @@ template <typename DeviceContext, typename T>
class GRUGradKernel : public framework::OpKernel<T> { class GRUGradKernel : public framework::OpKernel<T> {
public: public:
void BatchCompute(const framework::ExecutionContext& context) const { void BatchCompute(const framework::ExecutionContext& context) const {
bool origin_mode = context.Attr<bool>("origin_mode");
auto* h0 = context.Input<Tensor>("H0"); auto* h0 = context.Input<Tensor>("H0");
auto* weight = context.Input<Tensor>("Weight"); auto* weight = context.Input<Tensor>("Weight");
const T* weight_data = weight->data<T>(); const T* weight_data = weight->data<T>();
...@@ -146,7 +147,7 @@ class GRUGradKernel : public framework::OpKernel<T> { ...@@ -146,7 +147,7 @@ class GRUGradKernel : public framework::OpKernel<T> {
math::GRUUnitGradFunctor<DeviceContext, T>::compute( math::GRUUnitGradFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, gru_grad, frame_size, cur_batch_size, active_node, dev_ctx, gru_value, gru_grad, frame_size, cur_batch_size, active_node,
active_gate); active_gate, origin_mode);
} }
if (input_grad) { if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace()); input_grad->mutable_data<T>(context.GetPlace());
......
...@@ -111,6 +111,13 @@ class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -111,6 +111,13 @@ class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker {
"The activation type used in update gate and reset gate.") "The activation type used in update gate and reset gate.")
.SetDefault(sigmoid) .SetDefault(sigmoid)
.InEnum({identity, sigmoid, tanh, relu}); .InEnum({identity, sigmoid, tanh, relu});
AddAttr<bool>("origin_mode",
"bool"
"use origin mode in article <Learning Phrase Representations "
"using RNN Encoder–Decoder\n"
"for Statistical Machine "
"Translation>(https://arxiv.org/pdf/1406.1078.pdf)")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
GRUUnit Operator implements partial calculations of the GRU unit as following: GRUUnit Operator implements partial calculations of the GRU unit as following:
......
...@@ -113,7 +113,11 @@ class GRUUnitKernel : public framework::OpKernel<T> { ...@@ -113,7 +113,11 @@ class GRUUnitKernel : public framework::OpKernel<T> {
auto c = g.slice(c_offsets, extents); // output candidate auto c = g.slice(c_offsets, extents); // output candidate
// calculate final output // calculate final output
h.device(place) = u * (c - h_p) + h_p; if (context.Attr<bool>("origin_mode")) {
h.device(place) = c + u * (h_p - c); // (1 - u) * c + u * h_p
} else {
h.device(place) = u * (c - h_p) + h_p; // u * c + (1 - u) * h_p
}
} }
}; };
...@@ -180,11 +184,19 @@ class GRUUnitGradKernel : public framework::OpKernel<T> { ...@@ -180,11 +184,19 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
auto c = g.slice(c_offsets, extents); // output candidate auto c = g.slice(c_offsets, extents); // output candidate
// backward for unactivated update gate // backward for unactivated update gate
if (context.Attr<bool>("origin_mode")) {
ActGradCompute(context.Attr<int>("gate_activation"), place, u, u,
d_g.slice(u_offsets, extents), d_h * (h_p - c));
// backward for unactivated output candidate
ActGradCompute(context.Attr<int>("activation"), place, c, c,
d_g.slice(c_offsets, extents), d_h * (1 - u));
} else {
ActGradCompute(context.Attr<int>("gate_activation"), place, u, u, ActGradCompute(context.Attr<int>("gate_activation"), place, u, u,
d_g.slice(u_offsets, extents), d_h * (c - h_p)); d_g.slice(u_offsets, extents), d_h * (c - h_p));
// backward for unactivated output candidate // backward for unactivated output candidate
ActGradCompute(context.Attr<int>("activation"), place, c, c, ActGradCompute(context.Attr<int>("activation"), place, c, c,
d_g.slice(c_offsets, extents), d_h * u); d_g.slice(c_offsets, extents), d_h * u);
}
// backward for reset_hidden_prev // backward for reset_hidden_prev
auto blas = math::GetBlas<DeviceContext, T>(context); auto blas = math::GetBlas<DeviceContext, T>(context);
blas.GEMM(false, true, batch_size, frame_size, frame_size, 1, blas.GEMM(false, true, batch_size, frame_size, frame_size, 1,
...@@ -213,7 +225,11 @@ class GRUUnitGradKernel : public framework::OpKernel<T> { ...@@ -213,7 +225,11 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
T* hidden_prev_grad_data = T* hidden_prev_grad_data =
hidden_prev_grad->mutable_data<T>(context.GetPlace()); hidden_prev_grad->mutable_data<T>(context.GetPlace());
auto d_h_p = EigenMatrix<T>::From(*hidden_prev_grad); auto d_h_p = EigenMatrix<T>::From(*hidden_prev_grad);
d_h_p.device(place) = d_r_h_p * r + d_h * (u.constant(T(1)) - u); if (context.Attr<bool>("origin_mode")) {
d_h_p.device(place) = d_r_h_p * r + d_h * u;
} else {
d_h_p.device(place) = d_r_h_p * r + d_h * (1 - u);
}
blas.GEMM(false, true, batch_size, frame_size, frame_size * 2, 1, blas.GEMM(false, true, batch_size, frame_size, frame_size * 2, 1,
gate_grad_data, frame_size * 3, weight_data, frame_size * 2, 1, gate_grad_data, frame_size * 3, weight_data, frame_size * 2, 1,
hidden_prev_grad_data, frame_size); hidden_prev_grad_data, frame_size);
......
...@@ -56,7 +56,8 @@ template <class OpFinalOutput, typename T> ...@@ -56,7 +56,8 @@ template <class OpFinalOutput, typename T>
void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output, void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
T *gate_value, T *prev_output_value, T *gate_value, T *prev_output_value,
T *output_value, int frame_size, T *output_value, int frame_size,
ActivationType active_node) { ActivationType active_node,
bool origin_mode) {
T r_value_update_gate; T r_value_update_gate;
T r_value_frame_state; T r_value_frame_state;
T r_prev_out = 0; T r_prev_out = 0;
...@@ -72,7 +73,7 @@ void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output, ...@@ -72,7 +73,7 @@ void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
} }
op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out, op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
&r_output, active_node); &r_output, active_node, origin_mode);
frame_state[i] = r_value_frame_state; frame_state[i] = r_value_frame_state;
output_value[i] = r_output; output_value[i] = r_output;
...@@ -146,7 +147,8 @@ template <class OpFinalOutput, typename T> ...@@ -146,7 +147,8 @@ template <class OpFinalOutput, typename T>
void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
T *gate_value, T *prev_output_value, T *gate_value, T *prev_output_value,
T *output_value, int frame_size, T *output_value, int frame_size,
ActivationType active_node) { ActivationType active_node,
bool origin_mode) {
#ifdef __AVX__ #ifdef __AVX__
__m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps(0.0f); __m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps(0.0f);
__m256 r_value_frame_state, r_value_frame_state_last = _mm256_set1_ps(0.0f); __m256 r_value_frame_state, r_value_frame_state_last = _mm256_set1_ps(0.0f);
...@@ -180,7 +182,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, ...@@ -180,7 +182,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
} }
op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out, op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
&r_output, active_node); &r_output, active_node, origin_mode);
_mm256_storeu_ps(reinterpret_cast<float *>(frame_state + i), _mm256_storeu_ps(reinterpret_cast<float *>(frame_state + i),
r_value_frame_state); r_value_frame_state);
...@@ -190,7 +192,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, ...@@ -190,7 +192,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
if (rest > 0) { if (rest > 0) {
i = n - block; i = n - block;
op_final_output(&r_value_update_gate_last, &r_value_frame_state_last, op_final_output(&r_value_update_gate_last, &r_value_frame_state_last,
&r_prev_out_last, &r_output, active_node); &r_prev_out_last, &r_output, active_node, origin_mode);
_mm256_storeu_ps(reinterpret_cast<float *>(frame_state + i), _mm256_storeu_ps(reinterpret_cast<float *>(frame_state + i),
r_value_frame_state_last); r_value_frame_state_last);
...@@ -227,17 +229,18 @@ inline void forward_reset_output(OpResetOutput op_reset_output, ...@@ -227,17 +229,18 @@ inline void forward_reset_output(OpResetOutput op_reset_output,
template <class OpFinalOutput, typename T> template <class OpFinalOutput, typename T>
inline void forward_final_output(OpFinalOutput op_final_output, inline void forward_final_output(OpFinalOutput op_final_output,
GRUMetaValue<T> value, int frame_size, GRUMetaValue<T> value, int frame_size,
int batch_size, ActivationType active_node) { int batch_size, ActivationType active_node,
bool origin_mode) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
if (OpFinalOutput::avx && (frame_size > static_cast<int>(8 - 1)) && if (OpFinalOutput::avx && (frame_size > static_cast<int>(8 - 1)) &&
(sizeof(T) == 4)) { (sizeof(T) == 4)) {
hl_avx_gru_forward_final_output(op_final_output, value.gate_value, hl_avx_gru_forward_final_output(op_final_output, value.gate_value,
value.prev_out_value, value.output_value, value.prev_out_value, value.output_value,
frame_size, active_node); frame_size, active_node, origin_mode);
} else { } else {
hl_naive_gru_forward_final_output( hl_naive_gru_forward_final_output(
op_final_output, value.gate_value, value.prev_out_value, op_final_output, value.gate_value, value.prev_out_value,
value.output_value, frame_size, active_node); value.output_value, frame_size, active_node, origin_mode);
} }
value.gate_value += frame_size * 3; value.gate_value += frame_size * 3;
...@@ -253,7 +256,8 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, ...@@ -253,7 +256,8 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
T *gate_grad, T *prev_out_value, T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *output_grad, T *prev_out_grad, T *output_grad,
int frame_size, int frame_size,
ActivationType active_node) { ActivationType active_node,
bool origin_mode) {
T r_update_gate_value; T r_update_gate_value;
T r_update_gate_grad; T r_update_gate_grad;
T r_frame_state_value; T r_frame_state_value;
...@@ -279,7 +283,7 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, ...@@ -279,7 +283,7 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
op_state_grad(&r_update_gate_value, &r_update_gate_grad, op_state_grad(&r_update_gate_value, &r_update_gate_grad,
&r_frame_state_value, &r_frame_state_grad, &r_prev_out_value, &r_frame_state_value, &r_frame_state_grad, &r_prev_out_value,
&r_prev_out_grad, &r_out_grad, active_node); &r_prev_out_grad, &r_out_grad, active_node, origin_mode);
update_gate_grad[i] = r_update_gate_grad; update_gate_grad[i] = r_update_gate_grad;
frame_state_grad[i] = r_frame_state_grad; frame_state_grad[i] = r_frame_state_grad;
...@@ -338,8 +342,8 @@ template <class OpStateGrad, typename T> ...@@ -338,8 +342,8 @@ template <class OpStateGrad, typename T>
void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
T *gate_grad, T *prev_out_value, T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *output_grad, T *prev_out_grad, T *output_grad,
int frame_size, int frame_size, ActivationType active_node,
ActivationType active_node) { bool origin_mode) {
#ifdef __AVX__ #ifdef __AVX__
__m256 r_update_gate_value; __m256 r_update_gate_value;
__m256 r_update_gate_grad; __m256 r_update_gate_grad;
...@@ -368,7 +372,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, ...@@ -368,7 +372,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
op_state_grad(&r_update_gate_value, &r_update_gate_grad, op_state_grad(&r_update_gate_value, &r_update_gate_grad,
&r_frame_state_value, &r_frame_state_grad, &r_prev_out_value, &r_frame_state_value, &r_frame_state_grad, &r_prev_out_value,
&r_prev_out_grad, &r_out_grad, active_node); &r_prev_out_grad, &r_out_grad, active_node, origin_mode);
update_gate_grad[i] = r_update_gate_grad; update_gate_grad[i] = r_update_gate_grad;
frame_state_grad[i] = r_frame_state_grad; frame_state_grad[i] = r_frame_state_grad;
...@@ -431,16 +435,18 @@ template <class OpStateGrad, typename T> ...@@ -431,16 +435,18 @@ template <class OpStateGrad, typename T>
inline void backward_state_grad(OpStateGrad op_state_grad, inline void backward_state_grad(OpStateGrad op_state_grad,
GRUMetaValue<T> value, GRUMetaGrad<T> grad, GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size, int batch_size, int frame_size, int batch_size,
ActivationType active_node) { ActivationType active_node, bool origin_mode) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
if (OpStateGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) { if (OpStateGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_backward_state_grad( hl_avx_gru_backward_state_grad(op_state_grad, value.gate_value,
op_state_grad, value.gate_value, grad.gate_grad, value.prev_out_value, grad.gate_grad, value.prev_out_value,
grad.prev_out_grad, grad.output_grad, frame_size, active_node); grad.prev_out_grad, grad.output_grad,
frame_size, active_node, origin_mode);
} else { } else {
hl_naive_gru_backward_state_grad( hl_naive_gru_backward_state_grad(op_state_grad, value.gate_value,
op_state_grad, value.gate_value, grad.gate_grad, value.prev_out_value, grad.gate_grad, value.prev_out_value,
grad.prev_out_grad, grad.output_grad, frame_size, active_node); grad.prev_out_grad, grad.output_grad,
frame_size, active_node, origin_mode);
} }
value.gate_value += frame_size * 3; value.gate_value += frame_size * 3;
......
...@@ -72,7 +72,8 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output, ...@@ -72,7 +72,8 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
T *gate_value, T *prev_output_value, T *gate_value, T *prev_output_value,
T *output_value, int frame_size, T *output_value, int frame_size,
int batch_size, int batch_size,
ActivationType active_node) { ActivationType active_node,
bool origin_mode) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return; if (frame_idx >= frame_size) return;
int batch_idx = 0; int batch_idx = 0;
...@@ -94,7 +95,7 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output, ...@@ -94,7 +95,7 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
} }
op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out, op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
&r_output, active_node); &r_output, active_node, origin_mode);
gate_value[frame_idx + frame_size * 2] = r_value_frame_state; gate_value[frame_idx + frame_size * 2] = r_value_frame_state;
output_value[frame_idx] = r_output; output_value[frame_idx] = r_output;
...@@ -109,7 +110,8 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value, ...@@ -109,7 +110,8 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value,
T *gate_grad, T *prev_out_value, T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *output_grad, T *prev_out_grad, T *output_grad,
int frame_size, int batch_size, int frame_size, int batch_size,
ActivationType active_node) { ActivationType active_node,
bool origin_mode) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return; if (frame_idx >= frame_size) return;
int batch_idx = 0; int batch_idx = 0;
...@@ -139,7 +141,7 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value, ...@@ -139,7 +141,7 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value,
op_state_grad(&r_update_gate_value, &r_update_gate_grad, &r_frame_state_value, op_state_grad(&r_update_gate_value, &r_update_gate_grad, &r_frame_state_value,
&r_frame_state_grad, &r_prev_out_value, &r_prev_out_grad, &r_frame_state_grad, &r_prev_out_value, &r_prev_out_grad,
&r_out_grad, active_node); &r_out_grad, active_node, origin_mode);
gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad; gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad;
gate_grad[frame_idx + frame_size * 2] = r_frame_state_grad; gate_grad[frame_idx + frame_size * 2] = r_frame_state_grad;
......
...@@ -57,11 +57,17 @@ class gru_finalOutput { ...@@ -57,11 +57,17 @@ class gru_finalOutput {
public: public:
HOSTDEVICE void operator()(T *value_update_gate, T *value_frame_state, HOSTDEVICE void operator()(T *value_update_gate, T *value_frame_state,
T *prev_out, T *value_output, T *prev_out, T *value_output,
ActivationType act_input) { ActivationType act_input, bool origin_mode) {
*value_frame_state = activation(*value_frame_state, act_input); *value_frame_state = activation(*value_frame_state, act_input);
if (origin_mode) {
*value_output = ((*value_update_gate) * (*prev_out)) +
*value_frame_state -
((*value_update_gate) * (*value_frame_state));
} else {
*value_output = *prev_out - ((*value_update_gate) * (*prev_out)) + *value_output = *prev_out - ((*value_update_gate) * (*prev_out)) +
((*value_update_gate) * (*value_frame_state)); ((*value_update_gate) * (*value_frame_state));
} }
}
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ #ifndef __AVX__
static const bool avx = false; static const bool avx = false;
...@@ -69,12 +75,21 @@ class gru_finalOutput { ...@@ -69,12 +75,21 @@ class gru_finalOutput {
static const bool avx = true; static const bool avx = true;
HOSTDEVICE void operator()(__m256 *value_update_gate, HOSTDEVICE void operator()(__m256 *value_update_gate,
__m256 *value_frame_state, __m256 *prev_out, __m256 *value_frame_state, __m256 *prev_out,
__m256 *value_output, ActivationType act_input) { __m256 *value_output, ActivationType act_input,
bool origin_mode) {
*value_frame_state = activation(*value_frame_state, act_input); *value_frame_state = activation(*value_frame_state, act_input);
if (origin_mode) {
*value_output = _mm256_sub_ps(
_mm256_add_ps(_mm256_mul_ps(*value_update_gate, *prev_out),
*value_frame_state),
_mm256_mul_ps(*value_update_gate, *value_frame_state));
} else {
*value_output = _mm256_add_ps( *value_output = _mm256_add_ps(
_mm256_sub_ps(*prev_out, _mm256_mul_ps(*value_update_gate, *prev_out)), _mm256_sub_ps(*prev_out,
_mm256_mul_ps(*value_update_gate, *prev_out)),
_mm256_mul_ps(*value_update_gate, *value_frame_state)); _mm256_mul_ps(*value_update_gate, *value_frame_state));
} }
}
#endif #endif
#endif #endif
}; };
...@@ -88,14 +103,24 @@ class gru_stateGrad { ...@@ -88,14 +103,24 @@ class gru_stateGrad {
HOSTDEVICE void operator()(T *value_update_gate, T *grad_update_gate, HOSTDEVICE void operator()(T *value_update_gate, T *grad_update_gate,
T *value_frame_state, T *grad_frame_state, T *value_frame_state, T *grad_frame_state,
T *value_prev_out, T *grad_prev_out, T *value_prev_out, T *grad_prev_out,
T *grad_output, ActivationType act_input) { T *grad_output, ActivationType act_input,
*grad_update_gate = (*grad_output * (*value_frame_state)); bool origin_mode) {
*grad_update_gate -= (*grad_output * (*value_prev_out)); if (origin_mode) {
*grad_prev_out -= (*grad_output * (*value_update_gate)); *grad_update_gate =
*grad_prev_out += *grad_output; (*grad_output) * ((*value_prev_out) - (*value_frame_state));
*grad_prev_out += (*grad_output * (*value_update_gate));
*grad_frame_state = activation(
*grad_output * (static_cast<T>(1.0) - (*value_update_gate)),
*value_frame_state, act_input);
} else {
*grad_update_gate =
(*grad_output) * ((*value_frame_state) - (*value_prev_out));
*grad_prev_out +=
(*grad_output * (static_cast<T>(1.0) - *value_update_gate));
*grad_frame_state = activation(*grad_output * (*value_update_gate), *grad_frame_state = activation(*grad_output * (*value_update_gate),
*value_frame_state, act_input); *value_frame_state, act_input);
} }
}
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ #ifndef __AVX__
static const bool avx = false; static const bool avx = false;
...@@ -106,18 +131,28 @@ class gru_stateGrad { ...@@ -106,18 +131,28 @@ class gru_stateGrad {
__m256 *value_frame_state, __m256 *value_frame_state,
__m256 *grad_frame_state, __m256 *value_prev_out, __m256 *grad_frame_state, __m256 *value_prev_out,
__m256 *grad_prev_out, __m256 *grad_output, __m256 *grad_prev_out, __m256 *grad_output,
ActivationType act_input) { ActivationType act_input, bool origin_mode) {
*grad_update_gate = _mm256_mul_ps(*grad_output, *value_frame_state); if (origin_mode) {
*grad_update_gate = _mm256_sub_ps( *grad_update_gate = _mm256_mul_ps(
*grad_update_gate, _mm256_mul_ps(*grad_output, *value_prev_out)); *grad_output, _mm256_sub_ps(*value_prev_out, *value_frame_state));
*grad_prev_out = _mm256_add_ps( *grad_prev_out = _mm256_add_ps(
_mm256_sub_ps(*grad_prev_out, *grad_prev_out, _mm256_mul_ps(*grad_output, *value_update_gate));
_mm256_mul_ps(*grad_output, *value_update_gate)), *grad_frame_state = activation(
*grad_output); _mm256_mul_ps(*grad_output, _mm256_sub_ps(_mm256_set1_ps(1.0f),
*value_update_gate)),
*value_frame_state, act_input);
} else {
*grad_update_gate = _mm256_mul_ps(
*grad_output, _mm256_sub_ps(*value_frame_state, *value_prev_out));
*grad_prev_out = _mm256_add_ps(
*grad_prev_out,
_mm256_mul_ps(*grad_output, _mm256_sub_ps(_mm256_set1_ps(1.0f),
*value_update_gate)));
*grad_frame_state = *grad_frame_state =
activation(_mm256_mul_ps(*grad_output, *value_update_gate), activation(_mm256_mul_ps(*grad_output, *value_update_gate),
*value_frame_state, act_input); *value_frame_state, act_input);
} }
}
#endif #endif
#endif #endif
}; };
......
...@@ -23,7 +23,8 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> { ...@@ -23,7 +23,8 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
static void compute(const platform::CPUDeviceContext &context, static void compute(const platform::CPUDeviceContext &context,
GRUMetaValue<T> value, int frame_size, int batch_size, GRUMetaValue<T> value, int frame_size, int batch_size,
const detail::ActivationType active_node, const detail::ActivationType active_node,
const detail::ActivationType active_gate) { const detail::ActivationType active_gate,
bool origin_mode) {
#ifndef __NVCC__ #ifndef __NVCC__
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context); auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
if (value.prev_out_value) { if (value.prev_out_value) {
...@@ -43,7 +44,8 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> { ...@@ -43,7 +44,8 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
} }
detail::forward_final_output(detail::forward::gru_finalOutput<T>(), value, detail::forward_final_output(detail::forward::gru_finalOutput<T>(), value,
frame_size, batch_size, active_node); frame_size, batch_size, active_node,
origin_mode);
#endif #endif
} }
}; };
...@@ -54,10 +56,12 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> { ...@@ -54,10 +56,12 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> {
GRUMetaValue<T> value, GRUMetaGrad<T> grad, GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size, int batch_size, int frame_size, int batch_size,
const detail::ActivationType active_node, const detail::ActivationType active_node,
const detail::ActivationType active_gate) { const detail::ActivationType active_gate,
bool origin_mode) {
#ifndef __NVCC__ #ifndef __NVCC__
detail::backward_state_grad(detail::backward::gru_stateGrad<T>(), value, detail::backward_state_grad(detail::backward::gru_stateGrad<T>(), value,
grad, frame_size, batch_size, active_node); grad, frame_size, batch_size, active_node,
origin_mode);
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context); auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
if (value.prev_out_value && grad.prev_out_grad) { if (value.prev_out_value && grad.prev_out_grad) {
blas.GEMM(false, true, batch_size, frame_size, frame_size, 1, blas.GEMM(false, true, batch_size, frame_size, frame_size, 1,
......
...@@ -24,7 +24,8 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> { ...@@ -24,7 +24,8 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
static void compute(const platform::CUDADeviceContext &context, static void compute(const platform::CUDADeviceContext &context,
GRUMetaValue<T> value, int frame_size, int batch_size, GRUMetaValue<T> value, int frame_size, int batch_size,
const detail::ActivationType active_node, const detail::ActivationType active_node,
const detail::ActivationType active_gate) { const detail::ActivationType active_gate,
bool origin_mode) {
auto stream = context.stream(); auto stream = context.stream();
dim3 threads; dim3 threads;
dim3 grid; dim3 grid;
...@@ -73,14 +74,14 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> { ...@@ -73,14 +74,14 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
T><<<grid, threads, 0, stream>>>( T><<<grid, threads, 0, stream>>>(
detail::forward::gru_finalOutput<T>(), value.gate_value, detail::forward::gru_finalOutput<T>(), value.gate_value,
value.prev_out_value, value.output_value, frame_size, batch_size, value.prev_out_value, value.output_value, frame_size, batch_size,
active_node); active_node, origin_mode);
} else { } else {
detail::KeGruForwardFinalOutput<detail::forward::gru_finalOutput<T>, detail::KeGruForwardFinalOutput<detail::forward::gru_finalOutput<T>,
/* is_batch= */ true, /* is_batch= */ true,
T><<<grid, threads, 0, stream>>>( T><<<grid, threads, 0, stream>>>(
detail::forward::gru_finalOutput<T>(), value.gate_value, detail::forward::gru_finalOutput<T>(), value.gate_value,
value.prev_out_value, value.output_value, frame_size, batch_size, value.prev_out_value, value.output_value, frame_size, batch_size,
active_node); active_node, origin_mode);
} }
} }
}; };
...@@ -91,7 +92,8 @@ struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> { ...@@ -91,7 +92,8 @@ struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> {
GRUMetaValue<T> value, GRUMetaGrad<T> grad, GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size, int batch_size, int frame_size, int batch_size,
const detail::ActivationType active_node, const detail::ActivationType active_node,
const detail::ActivationType active_gate) { const detail::ActivationType active_gate,
bool origin_mode) {
auto stream = context.stream(); auto stream = context.stream();
dim3 threads; dim3 threads;
dim3 grid; dim3 grid;
...@@ -111,14 +113,14 @@ struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> { ...@@ -111,14 +113,14 @@ struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> {
/* is_batch= */ false><<<grid, threads, 0, stream>>>( /* is_batch= */ false><<<grid, threads, 0, stream>>>(
detail::backward::gru_stateGrad<T>(), value.gate_value, detail::backward::gru_stateGrad<T>(), value.gate_value,
grad.gate_grad, value.prev_out_value, grad.prev_out_grad, grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
grad.output_grad, frame_size, batch_size, active_node); grad.output_grad, frame_size, batch_size, active_node, origin_mode);
} else { } else {
detail::KeGruBackwardStateGrad< detail::KeGruBackwardStateGrad<
detail::backward::gru_stateGrad<T>, detail::backward::gru_stateGrad<T>,
/* is_batch= */ true><<<grid, threads, 0, stream>>>( /* is_batch= */ true><<<grid, threads, 0, stream>>>(
detail::backward::gru_stateGrad<T>(), value.gate_value, detail::backward::gru_stateGrad<T>(), value.gate_value,
grad.gate_grad, value.prev_out_value, grad.prev_out_grad, grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
grad.output_grad, frame_size, batch_size, active_node); grad.output_grad, frame_size, batch_size, active_node, origin_mode);
} }
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context); auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context);
......
...@@ -44,7 +44,8 @@ struct GRUUnitFunctor { ...@@ -44,7 +44,8 @@ struct GRUUnitFunctor {
static void compute(const DeviceContext &context, GRUMetaValue<T> value, static void compute(const DeviceContext &context, GRUMetaValue<T> value,
int frame_size, int batch_size, int frame_size, int batch_size,
const detail::ActivationType active_node, const detail::ActivationType active_node,
const detail::ActivationType active_gate); const detail::ActivationType active_gate,
bool origin_mode);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -52,7 +53,8 @@ struct GRUUnitGradFunctor { ...@@ -52,7 +53,8 @@ struct GRUUnitGradFunctor {
static void compute(const DeviceContext &context, GRUMetaValue<T> value, static void compute(const DeviceContext &context, GRUMetaValue<T> value,
GRUMetaGrad<T> grad, int frame_size, int batch_size, GRUMetaGrad<T> grad, int frame_size, int batch_size,
const detail::ActivationType active_node, const detail::ActivationType active_node,
const detail::ActivationType active_gate); const detail::ActivationType active_gate,
bool origin_mode);
}; };
} // namespace math } // namespace math
......
...@@ -864,12 +864,14 @@ def dynamic_gru(input, ...@@ -864,12 +864,14 @@ def dynamic_gru(input,
is_reverse=False, is_reverse=False,
gate_activation='sigmoid', gate_activation='sigmoid',
candidate_activation='tanh', candidate_activation='tanh',
h_0=None): h_0=None,
origin_mode=False):
""" """
**Gated Recurrent Unit (GRU) Layer** **Gated Recurrent Unit (GRU) Layer**
Refer to `Empirical Evaluation of Gated Recurrent Neural Networks on if origin_mode is False, then the equation of a gru step is from paper
Sequence Modeling <https://arxiv.org/abs/1412.3555>`_ . `Empirical Evaluation of Gated Recurrent Neural Networks on Sequence
Modeling <https://arxiv.org/pdf/1412.3555.pdf>`_ .
The formula is as follows: The formula is as follows:
...@@ -883,6 +885,21 @@ def dynamic_gru(input, ...@@ -883,6 +885,21 @@ def dynamic_gru(input,
h_t & = (1-u_t) \odot h_{t-1} + u_t \odot \\tilde{h_t} h_t & = (1-u_t) \odot h_{t-1} + u_t \odot \\tilde{h_t}
if origin_mode is True then the equation is from paper
Learning Phrase Representations using RNN Encoder-Decoder for Statistical
Machine Translation <https://arxiv.org/pdf/1406.1078.pdf>`_
.. math::
u_t & = act_g(W_{ux}x_{t} + W_{uh}h_{t-1} + b_u)
r_t & = act_g(W_{rx}x_{t} + W_{rh}h_{t-1} + b_r)
\\tilde{h_t} & = act_c(W_{cx}x_{t} + W_{ch}(r_t \odot h_{t-1}) + b_c)
h_t & = u_t \odot h_{t-1} + (1-u_t) \odot \\tilde{h_t}
The :math:`\odot` is the element-wise product of the vectors. :math:`act_g` The :math:`\odot` is the element-wise product of the vectors. :math:`act_g`
is the update gate and reset gate activation function and :math:`sigmoid` is the update gate and reset gate activation function and :math:`sigmoid`
is usually used for it. :math:`act_c` is the activation function for is usually used for it. :math:`act_c` is the activation function for
...@@ -980,7 +997,8 @@ def dynamic_gru(input, ...@@ -980,7 +997,8 @@ def dynamic_gru(input,
attrs={ attrs={
'is_reverse': is_reverse, 'is_reverse': is_reverse,
'gate_activation': gate_activation, 'gate_activation': gate_activation,
'activation': candidate_activation 'activation': candidate_activation,
'origin_mode': origin_mode
}) })
return hidden return hidden
...@@ -991,9 +1009,27 @@ def gru_unit(input, ...@@ -991,9 +1009,27 @@ def gru_unit(input,
param_attr=None, param_attr=None,
bias_attr=None, bias_attr=None,
activation='tanh', activation='tanh',
gate_activation='sigmoid'): gate_activation='sigmoid',
origin_mode=False):
""" """
GRU unit layer. The equation of a gru step is: **GRU unit layer**
if origin_mode is True, then the equation of a gru step is from paper
`Learning Phrase Representations using RNN Encoder-Decoder for Statistical
Machine Translation <https://arxiv.org/pdf/1406.1078.pdf>`_
.. math::
u_t & = actGate(xu_{t} + W_u h_{t-1} + b_u)
r_t & = actGate(xr_{t} + W_r h_{t-1} + b_r)
m_t & = actNode(xm_t + W_c dot(r_t, h_{t-1}) + b_m)
h_t & = dot(u_t, h_{t-1}) + dot((1-u_t), m_t)
if origin_mode is False, then the equation of a gru step is from paper
`Empirical Evaluation of Gated Recurrent Neural Networks on Sequence
Modeling <https://arxiv.org/pdf/1412.3555.pdf>`_
.. math:: .. math::
u_t & = actGate(xu_{t} + W_u h_{t-1} + b_u) u_t & = actGate(xu_{t} + W_u h_{t-1} + b_u)
...@@ -1002,7 +1038,8 @@ def gru_unit(input, ...@@ -1002,7 +1038,8 @@ def gru_unit(input,
m_t & = actNode(xm_t + W_c dot(r_t, h_{t-1}) + b_m) m_t & = actNode(xm_t + W_c dot(r_t, h_{t-1}) + b_m)
h_t & = dot((1-u_t), m_t) + dot(u_t, h_{t-1}) h_t & = dot((1-u_t), h_{t-1}) + dot(u_t, m_t)
The inputs of gru unit includes :math:`z_t`, :math:`h_{t-1}`. In terms The inputs of gru unit includes :math:`z_t`, :math:`h_{t-1}`. In terms
of the equation above, the :math:`z_t` is split into 3 parts - of the equation above, the :math:`z_t` is split into 3 parts -
......
...@@ -31,7 +31,8 @@ def gru( ...@@ -31,7 +31,8 @@ def gru(
is_reverse, is_reverse,
act_state, act_state,
act_gate, act_gate,
dtype='float32'): dtype='float32',
origin_mode=False):
def _seq_to_batch(lod, is_reverse): def _seq_to_batch(lod, is_reverse):
idx_in_seq_list = [] idx_in_seq_list = []
seq_lens = lod[0] seq_lens = lod[0]
...@@ -66,6 +67,9 @@ def gru( ...@@ -66,6 +67,9 @@ def gru(
w_c = w.flatten()[D * D * 2:].reshape((D, D)) w_c = w.flatten()[D * D * 2:].reshape((D, D))
c = act_state(np.dot(r_h_p, w_c) + g[:, D * 2:]) c = act_state(np.dot(r_h_p, w_c) + g[:, D * 2:])
g = np.hstack((u_r, c)) g = np.hstack((u_r, c))
if origin_mode:
h = (1 - u) * c + u * h_p
else:
h = u * c + (1 - u) * h_p h = u * c + (1 - u) * h_p
return g, r_h_p, h return g, r_h_p, h
...@@ -110,6 +114,7 @@ class TestGRUOp(OpTest): ...@@ -110,6 +114,7 @@ class TestGRUOp(OpTest):
self.act_state = 'tanh' self.act_state = 'tanh'
self.act_gate = 'sigmoid' self.act_gate = 'sigmoid'
self.dtype = 'float64' self.dtype = 'float64'
self.origin_mode = False
self.set_confs() self.set_confs()
T = sum(self.lod[0]) T = sum(self.lod[0])
...@@ -126,7 +131,8 @@ class TestGRUOp(OpTest): ...@@ -126,7 +131,8 @@ class TestGRUOp(OpTest):
batch_gate, batch_reset_hidden_prev, batch_hidden, hidden = gru( batch_gate, batch_reset_hidden_prev, batch_hidden, hidden = gru(
input, self.lod, h0, weight, bias, self.is_reverse, input, self.lod, h0, weight, bias, self.is_reverse,
ACTIVATION[self.act_state], ACTIVATION[self.act_gate], self.dtype) ACTIVATION[self.act_state], ACTIVATION[self.act_gate], self.dtype,
self.origin_mode)
self.inputs = {'Input': (input, self.lod), 'Weight': weight} self.inputs = {'Input': (input, self.lod), 'Weight': weight}
if self.with_bias: if self.with_bias:
...@@ -145,7 +151,8 @@ class TestGRUOp(OpTest): ...@@ -145,7 +151,8 @@ class TestGRUOp(OpTest):
self.attrs = { self.attrs = {
'activation': self.act_state, 'activation': self.act_state,
'gate_activation': self.act_gate, 'gate_activation': self.act_gate,
'is_reverse': self.is_reverse 'is_reverse': self.is_reverse,
'origin_mode': self.origin_mode
} }
def test_check_output(self): def test_check_output(self):
...@@ -155,12 +162,24 @@ class TestGRUOp(OpTest): ...@@ -155,12 +162,24 @@ class TestGRUOp(OpTest):
self.check_grad(['Input', 'H0', 'Weight', 'Bias'], ['Hidden']) self.check_grad(['Input', 'H0', 'Weight', 'Bias'], ['Hidden'])
class TestGRUOriginMode(TestGRUOp):
def set_confs(self):
self.origin_mode = True
class TestGRUOp2(TestGRUOp): class TestGRUOp2(TestGRUOp):
def set_confs(self): def set_confs(self):
self.D = 19 self.D = 19
self.dtype = 'float32' self.dtype = 'float32'
class TestGRUOp2OriginMode(TestGRUOp):
def set_confs(self):
self.D = 19
self.dtype = 'float32'
self.origin_mode = True
class TestGRUOpNoInitial(TestGRUOp): class TestGRUOpNoInitial(TestGRUOp):
def set_confs(self): def set_confs(self):
self.with_h0 = False self.with_h0 = False
...@@ -182,5 +201,11 @@ class TestGRUOpReverse(TestGRUOp): ...@@ -182,5 +201,11 @@ class TestGRUOpReverse(TestGRUOp):
self.is_reverse = True self.is_reverse = True
class TestGRUOpReverseOriginMode(TestGRUOp):
def set_confs(self):
self.is_reverse = True
self.origin_mode = True
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -53,7 +53,7 @@ class TestGRUUnitOp(OpTest): ...@@ -53,7 +53,7 @@ class TestGRUUnitOp(OpTest):
GRUActivationType.relu: relu, GRUActivationType.relu: relu,
} }
def set_inputs(self): def set_inputs(self, origin_mode=False):
batch_size = self.batch_size batch_size = self.batch_size
frame_size = self.frame_size frame_size = self.frame_size
self.op_type = 'gru_unit' self.op_type = 'gru_unit'
...@@ -68,10 +68,11 @@ class TestGRUUnitOp(OpTest): ...@@ -68,10 +68,11 @@ class TestGRUUnitOp(OpTest):
} }
self.attrs = { self.attrs = {
'activation': GRUActivationType.tanh, 'activation': GRUActivationType.tanh,
'gate_activation': GRUActivationType.sigmoid 'gate_activation': GRUActivationType.sigmoid,
'origin_mode': origin_mode
} }
def set_outputs(self): def set_outputs(self, origin_mode=False):
# GRU calculations # GRU calculations
batch_size = self.batch_size batch_size = self.batch_size
frame_size = self.frame_size frame_size = self.frame_size
...@@ -93,6 +94,9 @@ class TestGRUUnitOp(OpTest): ...@@ -93,6 +94,9 @@ class TestGRUUnitOp(OpTest):
c = self.activate[self.attrs['activation']](np.dot(r_h_p, w_c) + c = self.activate[self.attrs['activation']](np.dot(r_h_p, w_c) +
g[:, frame_size * 2:]) g[:, frame_size * 2:])
g = np.hstack((u_r, c)) g = np.hstack((u_r, c))
if origin_mode:
h = (1 - u) * c + u * h_p
else:
h = u * c + (1 - u) * h_p h = u * c + (1 - u) * h_p
self.outputs = { self.outputs = {
'Gate': g.astype('float64'), 'Gate': g.astype('float64'),
...@@ -111,8 +115,14 @@ class TestGRUUnitOp(OpTest): ...@@ -111,8 +115,14 @@ class TestGRUUnitOp(OpTest):
self.check_grad(['Input', 'HiddenPrev', 'Weight'], ['Hidden']) self.check_grad(['Input', 'HiddenPrev', 'Weight'], ['Hidden'])
class TestGRUUnitOpOriginMode(TestGRUUnitOp):
def setUp(self):
self.set_inputs(origin_mode=True)
self.set_outputs(origin_mode=True)
class TestGRUUnitOpWithBias(TestGRUUnitOp): class TestGRUUnitOpWithBias(TestGRUUnitOp):
def set_inputs(self): def set_inputs(self, origin_mode=False):
batch_size = self.batch_size batch_size = self.batch_size
frame_size = self.frame_size frame_size = self.frame_size
super(TestGRUUnitOpWithBias, self).set_inputs() super(TestGRUUnitOpWithBias, self).set_inputs()
...@@ -120,7 +130,8 @@ class TestGRUUnitOpWithBias(TestGRUUnitOp): ...@@ -120,7 +130,8 @@ class TestGRUUnitOpWithBias(TestGRUUnitOp):
-0.1, 0.1, (1, frame_size * 3)).astype('float64') -0.1, 0.1, (1, frame_size * 3)).astype('float64')
self.attrs = { self.attrs = {
'activation': GRUActivationType.identity, 'activation': GRUActivationType.identity,
'gate_activation': GRUActivationType.sigmoid 'gate_activation': GRUActivationType.sigmoid,
'origin_mode': origin_mode
} }
def test_check_grad(self): def test_check_grad(self):
...@@ -132,5 +143,11 @@ class TestGRUUnitOpWithBias(TestGRUUnitOp): ...@@ -132,5 +143,11 @@ class TestGRUUnitOpWithBias(TestGRUUnitOp):
no_grad_set=set('Input')) no_grad_set=set('Input'))
class TestGRUUnitOpWithBiasOriginMode(TestGRUUnitOpWithBias):
def setUp(self):
self.set_inputs(origin_mode=True)
self.set_outputs(origin_mode=True)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册