diff --git a/paddle/operators/lstmp_op.cc b/paddle/operators/lstmp_op.cc index 266612294c69ba7e80d7822a26ed377318b678c0..932e76e913756148b9872f139723e824a3421565 100644 --- a/paddle/operators/lstmp_op.cc +++ b/paddle/operators/lstmp_op.cc @@ -217,7 +217,7 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( Long-Short Term Memory with Recurrent Projection (LSTMP) Operator. -LATMP is stand LSTM appended by a recurrent projection layer to reduce the +LSTMP is stand LSTM appended by a recurrent projection layer to reduce the number of parameters, espeacially when the output size is relative large. The formula is as follows: @@ -232,7 +232,7 @@ o_t = \sigma(W_{ox}x_{t} + W_{oh}r_{t-1} + W_{oc}c_t + b_o) \\ h_t = o_t \odot act_h(c_t) -r_t = act_h'(W_{rh}h_t) +r_t = act_{h'}(W_{rh}h_t) $$ where the W terms denote weight matrices (e.g. $W_{xi}$ is the matrix diff --git a/paddle/operators/lstmp_op.h b/paddle/operators/lstmp_op.h index 9467ccdb5a8e9d398daeadec1ee68ac5ef6aa0fc..0048f7e1c6aecb897893d108a897283090105518 100644 --- a/paddle/operators/lstmp_op.h +++ b/paddle/operators/lstmp_op.h @@ -365,10 +365,18 @@ class LSTMPGradKernel : public framework::OpKernel { ActGradCompute(cell_act, place, cur_proj_dev, cur_proj_dev, proj_g_dev, proj_g_dev); } + /* hidden state backwarad */ Tensor out_g = batch_hidden_g.Slice(bstart, bend); math::matmul(device_ctx, proj_g, false, *proj_weight, true, static_cast(1.0), &out_g, static_cast(0.0)); + /* projection weight backward*/ + if (proj_weight_g) { + Tensor hidden_t = batch_hidden->Slice(bstart, bend); + math::matmul(device_ctx, hidden_t, true, proj_g, + false, static_cast(1.0), + proj_weight_g, static_cast(1.0)); + } Tensor gate = batch_gate->Slice(bstart, bend); Tensor cell = batch_cell.Slice(bstart, bend); @@ -407,19 +415,12 @@ class LSTMPGradKernel : public framework::OpKernel { static_cast(1.0), &pre_proj_g, static_cast(1.0)); if (weight_g) { - /* backward weight */ + /* weight backward*/ auto pre_proj = batch_proj.Slice(pre_h_start, pre_h_end); math::matmul(device_ctx, pre_proj, true, gate_g, false, static_cast(1.0), weight_g, static_cast(1.0)); } - if (proj_weight_g) { - /* backward proj weigh */ - Tensor hidden_t = batch_hidden->Slice(bstart, bend); - math::matmul(device_ctx, hidden_t, true, proj_g, - false, static_cast(1.0), - proj_weight_g, static_cast(1.0)); - } } else { if (h0 && weight_g) { ReorderInitState(device_ctx, *h0, order, @@ -444,7 +445,6 @@ class LSTMPGradKernel : public framework::OpKernel { ActGradCompute(cell_act, place, proj0_dev, proj0_dev, proj0_g_dev, proj0_g_dev); } - // Tensor proj0_g = proj_g.Slice(bstart, bend); if (h0_g) { math::matmul( device_ctx, proj0_g, false, *proj_weight, true, diff --git a/python/paddle/v2/fluid/tests/test_lstmp_op.py b/python/paddle/v2/fluid/tests/test_lstmp_op.py index 81e06063fc7d725a64014a2b63887b0b22efdce0..a0f6955d7728770158ae25940844f76911d5594f 100644 --- a/python/paddle/v2/fluid/tests/test_lstmp_op.py +++ b/python/paddle/v2/fluid/tests/test_lstmp_op.py @@ -207,8 +207,8 @@ class TestLstmOp(OpTest): self.outputs['BatchCellPreAct'] = np.zeros( (N, self.D)).astype('float64') self.check_grad( - ['Input', 'Weight', 'Bias'], ['Projection'], - max_relative_error=5e-3) + ['Input', 'Weight', 'ProjWeight', 'Bias'], ['Projection'], + max_relative_error=1e-2) class TestLstmOpHasInitial(TestLstmOp): @@ -235,8 +235,9 @@ class TestLstmOpHasInitial(TestLstmOp): self.outputs['BatchCellPreAct'] = np.zeros( (N, self.D)).astype('float64') self.check_grad( - ['Input', 'Weight', 'Bias', 'H0', 'C0'], ['Projection'], - max_relative_error=5e-3) + ['Input', 'Weight', 'ProjWeight', 'Bias', 'H0', 'C0'], + ['Projection'], + max_relative_error=1e-2) def test_check_grad_ingore_bias(self): N = len(self.lod[0]) - 1 @@ -246,8 +247,8 @@ class TestLstmOpHasInitial(TestLstmOp): self.outputs['BatchCellPreAct'] = np.zeros( (N, self.D)).astype('float64') self.check_grad( - ['Input', 'Weight'], ['Projection'], - max_relative_error=5e-3, + ['Input', 'ProjWeight', 'Weight'], ['Projection'], + max_relative_error=1e-2, no_grad_set=set('Bias')) def test_check_grad_ingore_weight(self): @@ -258,10 +259,22 @@ class TestLstmOpHasInitial(TestLstmOp): self.outputs['BatchCellPreAct'] = np.zeros( (N, self.D)).astype('float64') self.check_grad( - ['Input', 'Bias'], ['Projection'], - max_relative_error=5e-3, + ['Input', 'ProjWeight', 'Bias'], ['Projection'], + max_relative_error=1e-2, no_grad_set=set('Weight')) + def test_check_grad_ingore_proj_weight(self): + N = len(self.lod[0]) - 1 + self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64') + self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') + self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') + self.outputs['BatchCellPreAct'] = np.zeros( + (N, self.D)).astype('float64') + self.check_grad( + ['Input', 'Weight', 'Bias'], ['Projection'], + max_relative_error=1e-2, + no_grad_set=set('ProjWeight')) + def test_check_grad_ingore_input(self): N = len(self.lod[0]) - 1 self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64') @@ -270,8 +283,8 @@ class TestLstmOpHasInitial(TestLstmOp): self.outputs['BatchCellPreAct'] = np.zeros( (N, self.D)).astype('float64') self.check_grad( - ['Weight', 'Bias'], ['Projection'], - max_relative_error=5e-3, + ['Weight', 'ProjWeight', 'Bias'], ['Projection'], + max_relative_error=1e-2, no_grad_set=set('Input')) def test_check_grad_ingore_h0(self): @@ -282,8 +295,8 @@ class TestLstmOpHasInitial(TestLstmOp): self.outputs['BatchCellPreAct'] = np.zeros( (N, self.D)).astype('float64') self.check_grad( - ['Input', 'Weight', 'Bias', 'C0'], ['Projection'], - max_relative_error=5e-3, + ['Input', 'Weight', 'ProjWeight', 'Bias', 'C0'], ['Projection'], + max_relative_error=1e-2, no_grad_set=set('H0')) def test_check_grad_ingore_c0(self): @@ -294,8 +307,8 @@ class TestLstmOpHasInitial(TestLstmOp): self.outputs['BatchCellPreAct'] = np.zeros( (N, self.D)).astype('float64') self.check_grad( - ['Input', 'Weight', 'Bias', 'H0'], ['Projection'], - max_relative_error=5e-3, + ['Input', 'Weight', 'ProjWeight', 'Bias', 'H0'], ['Projection'], + max_relative_error=1e-2, no_grad_set=set('C0'))