From e10af895de5d66fd67358d28efb90b6b31307f15 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Fri, 4 Jan 2019 10:15:20 +0800 Subject: [PATCH] update gru grad op test=develop --- paddle/fluid/operators/gru_unit_op.h | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/gru_unit_op.h b/paddle/fluid/operators/gru_unit_op.h index 9f56a0ef731..ed0c689cfe5 100644 --- a/paddle/fluid/operators/gru_unit_op.h +++ b/paddle/fluid/operators/gru_unit_op.h @@ -113,8 +113,7 @@ class GRUUnitKernel : public framework::OpKernel { auto c = g.slice(c_offsets, extents); // output candidate // calculate final output - bool origin_mode = context.Attr("origin_mode"); - if (origin_mode) { + if (context.Attr("origin_mode")) { h.device(place) = c + u * (h_p - c); // (1 - u) * c + u * h_p } else { h.device(place) = u * (c - h_p) + h_p; // u * c + (1 - u) * h_p @@ -218,7 +217,11 @@ class GRUUnitGradKernel : public framework::OpKernel { T* hidden_prev_grad_data = hidden_prev_grad->mutable_data(context.GetPlace()); auto d_h_p = EigenMatrix::From(*hidden_prev_grad); - d_h_p.device(place) = d_r_h_p * r + d_h * (u.constant(T(1)) - u); + if (context.Attr("origin_mode")) { + d_h_p.device(place) = d_r_h_p * (u.constant(T(1)) - u) + d_h * r; + } else { + d_h_p.device(place) = d_r_h_p * r + d_h * (u.constant(T(1)) - u); + } blas.GEMM(false, true, batch_size, frame_size, frame_size * 2, 1, gate_grad_data, frame_size * 3, weight_data, frame_size * 2, 1, hidden_prev_grad_data, frame_size); -- GitLab