diff --git a/paddle/operators/gru_unit_op.h b/paddle/operators/gru_unit_op.h index 81818b0a0aeacc82df3161746362a273b726035f..050430d3252d05236219cd5ced5a792c21413c1f 100644 --- a/paddle/operators/gru_unit_op.h +++ b/paddle/operators/gru_unit_op.h @@ -146,35 +146,27 @@ class GRUUnitGradKernel : public framework::OpKernel { auto* weight_grad = context.Output(framework::GradVarName("Weight")); auto* bias_grad = context.Output(framework::GradVarName("Bias")); - input_grad->mutable_data(context.GetPlace()); - hidden_prev_grad->mutable_data(context.GetPlace()); - weight_grad->mutable_data(context.GetPlace()); Tensor gate_grad; - gate_grad.mutable_data(input->dims(), context.GetPlace()); Tensor reset_hidden_prev_grad; - reset_hidden_prev_grad.mutable_data(reset_hidden_prev->dims(), - context.GetPlace()); - - int batch_size = input->dims()[0]; - int frame_size = hidden_prev->dims()[1]; const T* hidden_prev_data = hidden_prev->data(); - T* hidden_prev_grad_data = hidden_prev_grad->data(); const T* weight_data = weight->data(); - T* weight_grad_data = weight_grad->data(); - T* gate_grad_data = gate_grad.data(); + T* gate_grad_data = + gate_grad.mutable_data(input->dims(), context.GetPlace()); const T* reset_hidden_prev_data = reset_hidden_prev->data(); - T* reset_hidden_prev_grad_data = reset_hidden_prev_grad.data(); + T* reset_hidden_prev_grad_data = reset_hidden_prev_grad.mutable_data( + reset_hidden_prev->dims(), context.GetPlace()); auto h_p = EigenMatrix::From(*hidden_prev); auto g = EigenMatrix::From(*gate); auto d_h = EigenMatrix::From(*hidden_grad); - auto d_x = EigenMatrix::From(*input_grad); - auto d_h_p = EigenMatrix::From(*hidden_prev_grad); auto d_g = EigenMatrix::From(gate_grad); auto d_r_h_p = EigenMatrix::From(reset_hidden_prev_grad); auto place = context.GetEigenDevice(); + int batch_size = input->dims()[0]; + int frame_size = hidden_prev->dims()[1]; + Eigen::array extents({{batch_size, frame_size}}); Eigen::array u_offsets({{0, 0}}); auto u = g.slice(u_offsets, extents); // update gate @@ -195,28 +187,42 @@ class GRUUnitGradKernel : public framework::OpKernel { gate_grad_data + frame_size * 2, frame_size * 3, weight_data + frame_size * frame_size * 2, frame_size, 0, reset_hidden_prev_grad_data, frame_size); - // backward for state_weight - math::gemm( - context.device_context(), true, false, frame_size, frame_size, - batch_size, 1, reset_hidden_prev_data, frame_size, - gate_grad_data + frame_size * 2, frame_size * 3, 0, - weight_grad_data + frame_size * frame_size * 2, frame_size); // backward for unactivated reset gate ActGradCompute(context.Attr("gate_activation"), place, r, r, d_g.slice(r_offsets, extents), d_r_h_p * h_p); - // backward for update_gate_weight and reset_gate_weight - math::gemm(context.device_context(), true, false, frame_size, - frame_size * 2, batch_size, 1, hidden_prev_data, - frame_size, gate_grad_data, frame_size * 3, 0, - weight_grad_data, frame_size * 2); + // backward for weight + if (weight_grad) { + T* weight_grad_data = weight_grad->mutable_data(context.GetPlace()); + // backward for state_weight + math::gemm( + context.device_context(), true, false, frame_size, frame_size, + batch_size, 1, reset_hidden_prev_data, frame_size, + gate_grad_data + frame_size * 2, frame_size * 3, 0, + weight_grad_data + frame_size * frame_size * 2, frame_size); + + // backward for update_gate_weight and reset_gate_weight + math::gemm(context.device_context(), true, false, frame_size, + frame_size * 2, batch_size, 1, hidden_prev_data, + frame_size, gate_grad_data, frame_size * 3, 0, + weight_grad_data, frame_size * 2); + } // backward for hidden_prev - d_h_p.device(place) = d_r_h_p * r + d_h * (u.constant(T(1)) - u); - math::gemm(context.device_context(), false, true, batch_size, - frame_size, frame_size * 2, 1, gate_grad_data, - frame_size * 3, weight_data, frame_size * 2, 1, - hidden_prev_grad_data, frame_size); + if (hidden_prev_grad) { + T* hidden_prev_grad_data = + hidden_prev_grad->mutable_data(context.GetPlace()); + auto d_h_p = EigenMatrix::From(*hidden_prev_grad); + d_h_p.device(place) = d_r_h_p * r + d_h * (u.constant(T(1)) - u); + math::gemm(context.device_context(), false, true, batch_size, + frame_size, frame_size * 2, 1, gate_grad_data, + frame_size * 3, weight_data, frame_size * 2, 1, + hidden_prev_grad_data, frame_size); + } // backward for input - d_x.device(place) = d_g; + if (input_grad) { + input_grad->mutable_data(context.GetPlace()); + auto d_x = EigenMatrix::From(*input_grad); + d_x.device(place) = d_g; + } // backward for bias if (bias_grad) { bias_grad->mutable_data(context.GetPlace()); diff --git a/python/paddle/v2/fluid/tests/test_gru_unit_op.py b/python/paddle/v2/fluid/tests/test_gru_unit_op.py index beedcf7f428752058e58272cb49dc57aeae90c63..501d5aa5797d6def708338692f0861657f951ef7 100644 --- a/python/paddle/v2/fluid/tests/test_gru_unit_op.py +++ b/python/paddle/v2/fluid/tests/test_gru_unit_op.py @@ -28,8 +28,8 @@ def relu(x): class TestGRUUnitOp(OpTest): - batch_size = 3 - frame_size = 5 + batch_size = 5 + frame_size = 10 activate = { GRUActivationType.identity: identity, GRUActivationType.sigmoid: sigmoid, @@ -92,9 +92,7 @@ class TestGRUUnitOp(OpTest): self.check_output() def test_check_grad(self): - self.check_grad( - ['Input', 'HiddenPrev', 'Weight'], ['Hidden'], - max_relative_error=0.007) + self.check_grad(['Input', 'HiddenPrev', 'Weight'], ['Hidden']) class TestGRUUnitOpWithBias(TestGRUUnitOp): @@ -110,9 +108,12 @@ class TestGRUUnitOpWithBias(TestGRUUnitOp): } def test_check_grad(self): + self.check_grad(['Input', 'HiddenPrev', 'Weight', 'Bias'], ['Hidden']) + + def test_check_grad_ingore_input(self): self.check_grad( - ['Input', 'HiddenPrev', 'Weight', 'Bias'], ['Hidden'], - max_relative_error=0.007) + ['HiddenPrev', 'Weight', 'Bias'], ['Hidden'], + no_grad_set=set('Input')) if __name__ == '__main__':