From d94c936bd5814281582e6e3a7847d73277b438c7 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Tue, 7 Nov 2017 13:43:36 +0800 Subject: [PATCH] Enhance unit testing. 1. user can disable peephole connections. 2. not calculate some gradients. --- paddle/operators/lstm_op.cc | 9 +- paddle/operators/lstm_op.h | 12 +- .../operators/math/detail/lstm_cpu_kernel.h | 40 ++--- .../operators/math/detail/lstm_gpu_kernel.h | 14 +- paddle/operators/math/sequence2batch.h | 11 +- .../paddle/v2/framework/tests/test_lstm_op.py | 142 +++++++++++++++--- python/paddle/v2/optimizer.py | 2 +- 7 files changed, 167 insertions(+), 63 deletions(-) diff --git a/paddle/operators/lstm_op.cc b/paddle/operators/lstm_op.cc index 6c6c3f6e17..dc64b3f2c4 100644 --- a/paddle/operators/lstm_op.cc +++ b/paddle/operators/lstm_op.cc @@ -164,16 +164,19 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { "(string, default: sigmoid)" "The activation for input gate, forget gate and output " "gate, `sigmoid` by default.") - .SetDefault("sigmoid"); + .SetDefault("sigmoid") + .InEnum({"sigmoid", "tanh", "relu", "identity"}); AddAttr("cell_activation", "(string, default: tanh)" "The activation for cell output, `tanh` by defalut.") - .SetDefault("tanh"); + .SetDefault("tanh") + .InEnum({"sigmoid", "tanh", "relu", "identity"}); AddAttr("candidate_activation", "(string, default: tanh)" "The activation for candidate hidden state, " "`tanh` by default.") - .SetDefault("tanh"); + .SetDefault("tanh") + .InEnum({"sigmoid", "tanh", "relu", "identity"}); AddComment(R"DOC( Long-Short Term Memory (LSTM) Operator. diff --git a/paddle/operators/lstm_op.h b/paddle/operators/lstm_op.h index 2e0bbbeca0..26856f4a6e 100644 --- a/paddle/operators/lstm_op.h +++ b/paddle/operators/lstm_op.h @@ -69,7 +69,7 @@ class LSTMKernel : public framework::OpKernel { } math::LstmMetaValue lstm_value; - if (bias) { + if (bias && ctx.Attr("use_peepholes")) { T* bias_data = const_cast(bias->data()); // the code style in LstmMetaValue will be updated later. @@ -85,6 +85,7 @@ class LSTMKernel : public framework::OpKernel { Tensor ordered_c0; if (cell_t0) { math::CopyMatrixRowsFunctor row_shuffle; + ordered_c0.mutable_data(cell_t0->dims(), ctx.GetPlace()); const size_t* order = batch_gate->lod()[2].data(); row_shuffle(device_ctx, *cell_t0, order, ordered_c0, true); lstm_value.prevStateValue = ordered_c0.data(); @@ -124,6 +125,7 @@ class LSTMKernel : public framework::OpKernel { } else if (hidden_t0) { math::CopyMatrixRowsFunctor row_shuffle; Tensor ordered_h0; + ordered_h0.mutable_data(hidden_t0->dims(), ctx.GetPlace()); const size_t* order = batch_gate->lod()[2].data(); row_shuffle(device_ctx, *hidden_t0, order, ordered_h0, true); math::matmul(device_ctx, ordered_h0, false, *weight, false, @@ -199,7 +201,7 @@ class LSTMGradKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ(frame_size, out_dims[1]); math::LstmMetaValue lstm_value; - if (bias) { + if (bias && ctx.Attr("use_peepholes")) { T* bias_data = const_cast(bias->data()); lstm_value.checkIg = bias_data + 4 * frame_size; lstm_value.checkFg = lstm_value.checkIg + frame_size; @@ -211,9 +213,13 @@ class LSTMGradKernel : public framework::OpKernel { } math::LstmMetaGrad lstm_grad; + if (bias && bias_g) { - T* bias_g_data = const_cast(bias_g->mutable_data(ctx.GetPlace())); + bias_g->mutable_data(ctx.GetPlace()); zero(device_ctx, bias_g, static_cast(0.0)); + } + if (bias && bias_g && ctx.Attr("use_peepholes")) { + T* bias_g_data = bias_g->data(); lstm_grad.checkIgGrad = bias_g_data + 4 * frame_size; lstm_grad.checkFgGrad = lstm_grad.checkIgGrad + frame_size; lstm_grad.checkOgGrad = lstm_grad.checkFgGrad + frame_size; diff --git a/paddle/operators/math/detail/lstm_cpu_kernel.h b/paddle/operators/math/detail/lstm_cpu_kernel.h index f5b0dd85c9..fc3ad0ce58 100644 --- a/paddle/operators/math/detail/lstm_cpu_kernel.h +++ b/paddle/operators/math/detail/lstm_cpu_kernel.h @@ -52,9 +52,9 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, rValueIg = valueIg[i]; rValueFg = valueFg[i]; rValueOg = valueOg[i]; - rCheckI = value.checkIg[i]; - rCheckF = value.checkFg[i]; - rCheckO = value.checkOg[i]; + rCheckI = value.checkIg ? value.checkIg[i] : 0; + rCheckF = value.checkFg ? value.checkFg[i] : 0; + rCheckO = value.checkOg ? value.checkOg[i] : 0; if (value.prevStateValue) { rPrevState = value.prevStateValue[i]; @@ -114,9 +114,9 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue value, rValueIg = valueIg[i]; rValueFg = valueFg[i]; rValueOg = valueOg[i]; - rCheckI = value.checkIg[i]; - rCheckF = value.checkFg[i]; - rCheckO = value.checkOg[i]; + rCheckI = value.checkIg ? value.checkIg[i] : 0; + rCheckF = value.checkFg ? value.checkFg[i] : 0; + rCheckO = value.checkOg ? value.checkOg[i] : 0; rState = value.stateValue[i]; rStateAtv = value.stateActiveValue[i]; rOutputGrad = grad.outputGrad[i]; @@ -155,9 +155,9 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue value, int frameSize, __m256 rValueIg; __m256 rValueFg; __m256 rValueOg; - __m256 rCheckI; - __m256 rCheckF; - __m256 rCheckO; + __m256 rCheckI = _mm256_set1_ps(0.0f); + __m256 rCheckF = _mm256_set1_ps(0.0f); + __m256 rCheckO = _mm256_set1_ps(0.0f); __m256 rState; __m256 rPrevState = _mm256_set1_ps(0.0f); __m256 rStateAtv; @@ -173,9 +173,11 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue value, int frameSize, rValueIg = valueIg[i]; rValueFg = valueFg[i]; rValueOg = valueOg[i]; - rCheckI = ((__m256 *)value.checkIg)[i]; - rCheckF = ((__m256 *)value.checkFg)[i]; - rCheckO = ((__m256 *)value.checkOg)[i]; + if (value.checkIg) { + rCheckI = ((__m256 *)value.checkIg)[i]; + rCheckF = ((__m256 *)value.checkFg)[i]; + rCheckO = ((__m256 *)value.checkOg)[i]; + } if (value.prevStateValue) { rPrevState = ((__m256 *)value.prevStateValue)[i]; @@ -216,9 +218,9 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue value, __m256 rState; __m256 rStateAtv; __m256 rOutputGrad; - __m256 rCheckI; - __m256 rCheckF; - __m256 rCheckO; + __m256 rCheckI = _mm256_set1_ps(0.0f); + __m256 rCheckF = _mm256_set1_ps(0.0f); + __m256 rCheckO = _mm256_set1_ps(0.0f); __m256 rCheckIGrad; __m256 rCheckFGrad; __m256 rCheckOGrad; @@ -237,9 +239,11 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue value, rValueIg = valueIg[i]; rValueFg = valueFg[i]; rValueOg = valueOg[i]; - rCheckI = ((__m256 *)value.checkIg)[i]; - rCheckF = ((__m256 *)value.checkFg)[i]; - rCheckO = ((__m256 *)value.checkOg)[i]; + if (value.checkIg) { + rCheckI = ((__m256 *)value.checkIg)[i]; + rCheckF = ((__m256 *)value.checkFg)[i]; + rCheckO = ((__m256 *)value.checkOg)[i]; + } rState = ((__m256 *)value.stateValue)[i]; rStateAtv = ((__m256 *)value.stateActiveValue)[i]; rOutputGrad = ((__m256 *)grad.outputGrad)[i]; diff --git a/paddle/operators/math/detail/lstm_gpu_kernel.h b/paddle/operators/math/detail/lstm_gpu_kernel.h index 41a54a359d..e8ac61e009 100644 --- a/paddle/operators/math/detail/lstm_gpu_kernel.h +++ b/paddle/operators/math/detail/lstm_gpu_kernel.h @@ -55,9 +55,10 @@ __global__ void KeLstmForward(Op op, LstmMetaValue value, int frameSize, T rValueIg; T rValueFg; T rValueOg; - T rCheckI = value.checkIg[frameIdx]; - T rCheckF = value.checkFg[frameIdx]; - T rCheckO = value.checkOg[frameIdx]; + + T rCheckI = value.checkIg ? value.checkIg[frameIdx] : 0; + T rCheckF = value.checkFg ? value.checkFg[frameIdx] : 0; + T rCheckO = value.checkOg ? value.checkOg[frameIdx] : 0; rValueIn = value.gateValue[frameIdx]; rValueIg = value.gateValue[frameIdx + frameSize]; @@ -121,9 +122,10 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue value, T rStateGrad; T rStateAtv; T rOutputGrad; - T rCheckI = value.checkIg[frameIdx]; - T rCheckF = value.checkFg[frameIdx]; - T rCheckO = value.checkOg[frameIdx]; + T rCheckI = value.checkIg ? value.checkIg[frameIdx] : 0; + T rCheckF = value.checkFg ? value.checkFg[frameIdx] : 0; + T rCheckO = value.checkOg ? value.checkOg[frameIdx] : 0; + T rCheckIGrad; T rCheckFGrad; T rCheckOGrad; diff --git a/paddle/operators/math/sequence2batch.h b/paddle/operators/math/sequence2batch.h index 4942b7d9a1..794c7d4397 100644 --- a/paddle/operators/math/sequence2batch.h +++ b/paddle/operators/math/sequence2batch.h @@ -31,7 +31,7 @@ class CopyMatrixRowsFunctor { // The indexed rows are based on the input index. void operator()(const platform::DeviceContext& context, const framework::Tensor& src, const size_t* index, - framework::Tensor* dst, bool is_src_index); + framework::Tensor& dst, bool is_src_index); }; template @@ -57,7 +57,7 @@ class LoDTensor2BatchFunctor { bool is_reverse = false) const { if (!is_cal_batch_lod) { auto lods = batch.lod(); - PADDLE_ENFORCE_LE(lods.size(), 2UL); + PADDLE_ENFORCE_GT(lods.size(), 2UL); PADDLE_ENFORCE_EQ(lods[1].size(), static_cast(lod_tensor.dims()[0])); CopyMatrixRowsFunctor to_batch; @@ -68,8 +68,6 @@ class LoDTensor2BatchFunctor { auto lods = lod_tensor.lod(); auto lod = lods[0]; PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now."); - PADDLE_ENFORCE_EQ(lod_tensor.dims()[0], - static_cast(lod.size() - 1)); std::vector seq_info; for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) { @@ -112,7 +110,7 @@ class LoDTensor2BatchFunctor { int num_batch = seq_info[0].length; batch_lods[0].resize(static_cast(num_batch + 1)); // batch_lods[1] is the raw index in the input LoDTensor - batch_lods[1].resize(static_cast(seq_info.size())); + batch_lods[1].resize(static_cast(lod_tensor.dims()[0])); // batch_lods[2] is the sort order for the input LoDTensor. batch_lods[2].resize(seq_info.size()); @@ -152,8 +150,7 @@ class Batch2LoDTensorFunctor { const framework::LoDTensor& batch, framework::LoDTensor& lod_tensor) const { auto in_lod = batch.lod(); - PADDLE_ENFORCE_LT(in_lod.size(), 2UL, - "The LoD size of input `batch` should be 2."); + PADDLE_ENFORCE_GT(in_lod.size(), 2UL); PADDLE_ENFORCE_EQ(in_lod[1].size(), static_cast(lod_tensor.dims()[0])); CopyMatrixRowsFunctor to_seq; diff --git a/python/paddle/v2/framework/tests/test_lstm_op.py b/python/paddle/v2/framework/tests/test_lstm_op.py index 2b8ba1fcdc..a4bb99cd7d 100644 --- a/python/paddle/v2/framework/tests/test_lstm_op.py +++ b/python/paddle/v2/framework/tests/test_lstm_op.py @@ -117,9 +117,9 @@ class TestLstmOp(OpTest): self.act_cell = 'tanh' self.act_cand = 'tanh' - self.has_initial_state = True - self.has_bias = True + self.has_initial_state = False self.is_reverse = False + self.use_peepholes = True def setUp(self): self.set_argument() @@ -129,21 +129,27 @@ class TestLstmOp(OpTest): N = len(self.lod[0]) - 1 x = np.random.normal(size=(T, 4 * self.D)).astype('float64') - h0 = np.zeros((N, self.D)).astype('float64') - c0 = np.zeros((N, self.D)).astype('float64') + if self.has_initial_state: + h0 = np.random.normal(size=(N, self.D)).astype('float64') + c0 = np.random.normal(size=(N, self.D)).astype('float64') + else: + h0 = np.zeros((N, self.D)).astype('float64') + c0 = np.zeros((N, self.D)).astype('float64') w = np.random.normal(size=(self.D, 4 * self.D)).astype('float64') - b = np.random.normal(size=(1, 7 * self.D)).astype('float64') + if self.use_peepholes: + b = np.random.normal(size=(1, 7 * self.D)).astype('float64') + else: + b = np.random.normal(size=(1, 4 * self.D)).astype('float64') - w_b = b[:, 0:4 * self.D] if self.has_bias else None - w_c = b[:, 4 * self.D:] if self.has_bias else None + w_b = b[:, 0:4 * self.D] + w_c = b[:, 4 * self.D:] if self.use_peepholes else None h, c = lstm(x, self.lod, h0, c0, w, w_b, w_c, self.is_reverse, ACTVATION[self.act_gate], ACTVATION[self.act_cell], ACTVATION[self.act_cand]) self.inputs = {'Input': (x, self.lod), 'Weight': w} - if self.has_bias: - self.inputs['Bias'] = b + self.inputs['Bias'] = b if self.has_initial_state: self.inputs['H0'] = h0 @@ -154,18 +160,17 @@ class TestLstmOp(OpTest): 'Cell': (c, self.lod), } self.attrs = { - 'use_peepholes': True, + 'use_peepholes': self.use_peepholes, 'is_reverse': self.is_reverse, 'gate_activation': self.act_gate, 'cell_activation': self.act_cell, 'candidate_activation': self.act_cand } - def not_test_check_output(self): + def test_check_output(self): self.check_output(atol=1e-8) - #TODO(qingqing) add more unit testing case - def not_test_check_grad(self): + def test_check_grad(self): # TODO(qingqing) remove folowing lines after the check_grad is refined. N = len(self.lod[0]) - 1 self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') @@ -174,8 +179,38 @@ class TestLstmOp(OpTest): self.check_grad( ['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=5e-4) + def test_check_grad_ingore_bias(self): + N = len(self.lod[0]) - 1 + self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') + self.outputs['BatchCellPreAct'] = np.zeros( + (N, self.D)).astype('float64') + self.check_grad( + ['Input', 'Weight'], ['Hidden'], + max_relative_error=5e-4, + no_grad_set=set('Bias')) + + def test_check_grad_ingore_weight(self): + N = len(self.lod[0]) - 1 + self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') + self.outputs['BatchCellPreAct'] = np.zeros( + (N, self.D)).astype('float64') + self.check_grad( + ['Input', 'Bias'], ['Hidden'], + max_relative_error=5e-4, + no_grad_set=set('Weight')) + + def test_check_grad_ingore_input(self): + N = len(self.lod[0]) - 1 + self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') + self.outputs['BatchCellPreAct'] = np.zeros( + (N, self.D)).astype('float64') + self.check_grad( + ['Weight', 'Bias'], ['Hidden'], + max_relative_error=5e-4, + no_grad_set=set('Input')) + -class TestLstmOpHasNoInitial(TestLstmOp): +class TestLstmOpHasInitial(TestLstmOp): def set_argument(self): self.lod = [[0, 2, 5, 7]] self.D = 16 @@ -184,12 +219,52 @@ class TestLstmOpHasNoInitial(TestLstmOp): self.act_cell = 'tanh' self.act_cand = 'tanh' - self.has_initial_state = False + self.has_initial_state = True self.is_reverse = True - self.has_bias = True + self.use_peepholes = True + def test_check_grad(self): + # TODO(qingqing) remove folowing lines after the check_grad is refined. + N = len(self.lod[0]) - 1 + self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') + self.outputs['BatchCellPreAct'] = np.zeros( + (N, self.D)).astype('float64') + self.check_grad( + ['Input', 'Weight', 'Bias', 'H0', 'C0'], ['Hidden'], + max_relative_error=5e-4) -class TestLstmOpHasNoBias(TestLstmOp): + # In order to speed up, skip following testing + def test_check_grad_ingore_bias(self): + return + + def test_check_grad_ingore_weight(self): + return + + def test_check_grad_ingore_input(self): + return + + def test_check_grad_ingore_h0(self): + N = len(self.lod[0]) - 1 + self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') + self.outputs['BatchCellPreAct'] = np.zeros( + (N, self.D)).astype('float64') + self.check_grad( + ['Input', 'Weight', 'Bias', 'C0'], ['Hidden'], + max_relative_error=5e-4, + no_grad_set=set('H0')) + + def test_check_grad_ingore_c0(self): + N = len(self.lod[0]) - 1 + self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') + self.outputs['BatchCellPreAct'] = np.zeros( + (N, self.D)).astype('float64') + self.check_grad( + ['Input', 'Weight', 'Bias', 'H0'], ['Hidden'], + max_relative_error=5e-4, + no_grad_set=set('C0')) + + +class TestLstmOpRerverse(TestLstmOp): def set_argument(self): self.lod = [[0, 2, 5, 7]] self.D = 16 @@ -198,15 +273,22 @@ class TestLstmOpHasNoBias(TestLstmOp): self.act_cell = 'tanh' self.act_cand = 'tanh' - self.has_initial_state = True - self.is_reverse = False - self.has_bias = False + self.has_initial_state = False + self.is_reverse = True + self.use_peepholes = True - def test_check_output(self): - self.check_output(atol=1e-8) + # In order to speed up, skip following testing + def test_check_grad_ingore_bias(self): + return + def test_check_grad_ingore_weight(self): + return -class TestLstmOpRerverse(TestLstmOp): + def test_check_grad_ingore_input(self): + return + + +class TestLstmOpNotUsePeepholes(TestLstmOp): def set_argument(self): self.lod = [[0, 2, 5, 7]] self.D = 16 @@ -215,9 +297,19 @@ class TestLstmOpRerverse(TestLstmOp): self.act_cell = 'tanh' self.act_cand = 'tanh' - self.has_initial_state = True + self.has_initial_state = False self.is_reverse = True - self.has_bias = True + self.use_peepholes = False + + # In order to speed up, skip following testing + def test_check_grad_ingore_bias(self): + return + + def test_check_grad_ingore_weight(self): + return + + def test_check_grad_ingore_input(self): + return if __name__ == '__main__': diff --git a/python/paddle/v2/optimizer.py b/python/paddle/v2/optimizer.py index 94d706b1d6..caef5f484e 100644 --- a/python/paddle/v2/optimizer.py +++ b/python/paddle/v2/optimizer.py @@ -102,7 +102,7 @@ class Momentum(Optimizer): .. math:: - v_{t} &= k * v_{t-1} - \\gamma_t / (g_{t} + \\lambda w_{t-1}) \\\\ + v_{t} &= k * v_{t-1} - \\gamma_t (g_{t} + \\lambda w_{t-1}) \\\\ w_{t} &= w_{t-1} + v_{t} \\\\ where, :math:`k` is momentum, :math:`\\lambda` is decay rate, -- GitLab