From 76beff86a0f8e0d6856691b2968bafa52bf3a859 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 24 Jan 2018 01:34:54 -0800 Subject: [PATCH] Make the projection activation configurable --- paddle/operators/lstmp_op.cc | 76 +++++++++---------- paddle/operators/lstmp_op.h | 14 ++-- python/paddle/v2/fluid/tests/test_lstmp_op.py | 41 +++++----- 3 files changed, 66 insertions(+), 65 deletions(-) diff --git a/paddle/operators/lstmp_op.cc b/paddle/operators/lstmp_op.cc index 85be64f44..14469c708 100644 --- a/paddle/operators/lstmp_op.cc +++ b/paddle/operators/lstmp_op.cc @@ -23,27 +23,29 @@ class LSTMPOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input(Input) of LSTMP should not be null."); + "Input(Input) of LSTMP operator should not be null."); PADDLE_ENFORCE(ctx->HasInput("Weight"), - "Input(Weight) of LSTMP should not be null."); + "Input(Weight) of LSTMP operator should not be null."); PADDLE_ENFORCE(ctx->HasInput("ProjWeight"), - "Input(ProjWeight) of LSTMP should not be null."); + "Input(ProjWeight) of LSTMP operator should not be null."); PADDLE_ENFORCE(ctx->HasInput("Bias"), - "Input(Bias) of LSTMP should not be null."); + "Input(Bias) of LSTMP operator should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Projection"), - "Output(Projection) of LSTMP should not be null."); + "Output(Projection) of LSTMP operator should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Cell"), - "Output(Cell) of LSTMP should not be null."); + "Output(Cell) of LSTMP operator should not be null."); PADDLE_ENFORCE(ctx->HasOutput("BatchGate"), - "Output(BatchGate) of LSTMP should not be null."); + "Output(BatchGate) of LSTMP operator should not be null."); PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"), - "Output(BatchGate) of LSTMP should not be null."); + "Output(BatchCellPreAct) of LSTMP operator should not be " + "null."); PADDLE_ENFORCE(ctx->HasOutput("BatchHidden"), - "Output(BatchHidden) of LSTMP should not be null."); + "Output(BatchHidden) of LSTMP operator should not be null."); auto in_dims = ctx->GetInputDim("Input"); - PADDLE_ENFORCE_EQ(in_dims.size(), 2, "Input(X)'s rank must be 2."); + PADDLE_ENFORCE_EQ(in_dims.size(), 2, + "Input(X)'s rank of LSTMP operator must be 2."); int frame_size = in_dims[1] / 4; auto w_dims = ctx->GetInputDim("Weight"); @@ -68,8 +70,8 @@ class LSTMPOp : public framework::OperatorWithKernel { if (ctx->HasInput("H0")) { PADDLE_ENFORCE(ctx->HasInput("C0"), - "Input(C0) and Input(H0) of LSTMP should not " - "be null at the same time."); + "Input(C0) of LSTMP operator should not be null after " + "Input(H0) provided."); auto h_dims = ctx->GetInputDim("H0"); auto c_dims = ctx->GetInputDim("C0"); PADDLE_ENFORCE(h_dims == c_dims, @@ -132,8 +134,7 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("C0", "(Tensor, optional) the initial cell state is an optional " "input. This is a tensor with shape (N x D), where N is the " - "batch size. Only one of `H0` and `C0` can be NULL at the same " - "time.") + "batch size. `C0` should not be null if `H0` provided.") .AsDispensable(); AddInput("Weight", "(Tensor) the learnable hidden-hidden weights." @@ -211,13 +212,12 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker { "`tanh` by default.") .SetDefault("tanh") .InEnum({"sigmoid", "tanh", "relu", "identity"}); - AddAttr("share_cell_act", - "(bool, defalut: True) " - "whether to share the activation of cell output with the " - "projection layer. When set to `False`, the projection " - "is simple linear, otherwise it will go through an " - "activation function same as `cell_activation`.") - .SetDefault(true); + AddAttr("proj_activation", + "(string, default: tanh)" + "The activation for projection output, " + "`tanh` by defalut.") + .SetDefault("tanh") + .InEnum({"sigmoid", "tanh", "relu", "identity"}); AddComment(R"DOC( Long-Short Term Memory with recurrent Projection layer (LSTMP) Operator. @@ -226,20 +226,21 @@ original hidden state to a lower-dimensional one, which is proposed to reduce the number of total parameters and furthermore computational complexity for the LSTM, espeacially for the case that the size of output units is relative large (https://research.google.com/pubs/archive/43905.pdf). + The formula is as follows: $$ -i_t = \sigma(W_{ix}x_{t} + W_{ih}r_{t-1} + W_{ic}c_{t-1} + b_i) \\ +i_t = \sigma(W_{ix}x_{t} + W_{ir}r_{t-1} + W_{ic}c_{t-1} + b_i) \\ -f_t = \sigma(W_{fx}x_{t} + W_{fh}r_{t-1} + W_{fc}c_{t-1} + b_f) \\ +f_t = \sigma(W_{fx}x_{t} + W_{fr}r_{t-1} + W_{fc}c_{t-1} + b_f) \\ -\tilde{c_t} = act_g(W_{cx}x_t + W_{ch}r_{t-1} + b_c) \\ +\tilde{c_t} = act_g(W_{cx}x_t + W_{cr}r_{t-1} + b_c) \\ -o_t = \sigma(W_{ox}x_{t} + W_{oh}r_{t-1} + W_{oc}c_t + b_o) \\ +o_t = \sigma(W_{ox}x_{t} + W_{or}r_{t-1} + W_{oc}c_t + b_o) \\ -c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c_t} +c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c_t} \\ -h_t = o_t \odot act_h(c_t) +h_t = o_t \odot act_h(c_t) \\ r_t = \overline{act_h}(W_{rh}h_t) $$ @@ -259,9 +260,8 @@ input and previous hidden state. The $\odot$ is the element-wise product of the vectors. $act_g$ and $act_h$ are the cell input and cell output activation functions and `tanh` is usually -used for them. $\overline{act_h}$ is the activation function for the projection -layer. When `share_cell_act` set to `False`, $\overline{act_h}$ is an -identity activation, otherwise it will be same as $act_h$. +used for them. $\overline{act_h}$ is the activation function for the +projection output, usually using `identity` or same as $act_h$. Note that these $W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}$ operations on the input $x_{t}$ are NOT included in this operator. @@ -277,22 +277,22 @@ class LSTMPGradOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input(Input) of LSTMP should not be null."); + "Input(Input) of LSTMP operator should not be null."); PADDLE_ENFORCE(ctx->HasInput("Projection"), - "Input(Projection) of LSTMP should not be null."); + "Input(Projection) of LSTMP operator should not be null."); PADDLE_ENFORCE(ctx->HasInput("Cell"), - "Input(Cell) of LSTMP should not be null."); + "Input(Cell) of LSTMP operator should not be null."); PADDLE_ENFORCE(ctx->HasInput("Weight"), - "Input(Weight) of LSTMP should not be null."); + "Input(Weight) of LSTMP operator should not be null."); PADDLE_ENFORCE(ctx->HasInput("ProjWeight"), - "Input(ProjWeight) of LSTMP should not be null."); + "Input(ProjWeight) of LSTMP operator should not be null."); PADDLE_ENFORCE(ctx->HasInput("Bias"), - "Input(Bias) of LSTMP should not be null."); + "Input(Bias) of LSTMP operator should not be null."); PADDLE_ENFORCE(ctx->HasInput("BatchGate"), - "Input(BatchGate) of LSTMP should not be null."); + "Input(BatchGate) of LSTMP operator should not be null."); PADDLE_ENFORCE(ctx->HasInput("BatchCellPreAct"), - "Input(BatchGate) of LSTMP should not be null."); + "Input(BatchGate) of LSTMP operator should not be null."); auto SetOutGradDim = [&ctx](const std::string& name) { auto g_name = framework::GradVarName(name); diff --git a/paddle/operators/lstmp_op.h b/paddle/operators/lstmp_op.h index 0048f7e1c..9dc37615f 100644 --- a/paddle/operators/lstmp_op.h +++ b/paddle/operators/lstmp_op.h @@ -136,7 +136,8 @@ class LSTMPKernel : public framework::OpKernel { ctx.Attr("cell_activation")); auto cand_act = math::detail::GetActivationType( ctx.Attr("candidate_activation")); - auto share_cell_act = ctx.Attr("share_cell_act"); + auto proj_act = math::detail::GetActivationType( + ctx.Attr("proj_activation")); auto& place = *ctx.template device_context().eigen_device(); for (size_t n = 0; n < num_batch; n++) { @@ -174,7 +175,7 @@ class LSTMPKernel : public framework::OpKernel { math::matmul(device_ctx, ordered_h0, false, *proj_weight, false, static_cast(1.0), ordered_proj0, static_cast(0.0)); - if (share_cell_act) { + if (proj_act != math::detail::ActivationType::kIdentity) { auto proj0_dev = EigenMatrix::From(*ordered_proj0); ActCompute(cell_act, place, proj0_dev, proj0_dev); } @@ -194,7 +195,7 @@ class LSTMPKernel : public framework::OpKernel { math::matmul(device_ctx, hidden_t, false, *proj_weight, false, static_cast(1.0), &proj_t, static_cast(0.0)); - if (share_cell_act) { + if (proj_act != math::detail::ActivationType::kIdentity) { auto proj_t_dev = EigenMatrix::From(proj_t); ActCompute(cell_act, place, proj_t_dev, proj_t_dev); } @@ -348,7 +349,8 @@ class LSTMPGradKernel : public framework::OpKernel { ctx.Attr("cell_activation")); auto cand_act = math::detail::GetActivationType( ctx.Attr("candidate_activation")); - auto share_cell_act = ctx.Attr("share_cell_act"); + auto proj_act = math::detail::GetActivationType( + ctx.Attr("proj_activation")); auto& place = *ctx.template device_context().eigen_device(); auto batch_starts = batch_gate->lod()[0]; @@ -359,7 +361,7 @@ class LSTMPGradKernel : public framework::OpKernel { Tensor cur_proj = batch_proj.Slice(bstart, bend); Tensor proj_g = batch_proj_g.Slice(bstart, bend); - if (share_cell_act) { + if (proj_act != math::detail::ActivationType::kIdentity) { auto cur_proj_dev = EigenMatrix::From(cur_proj); auto proj_g_dev = EigenMatrix::From(proj_g); ActGradCompute(cell_act, place, cur_proj_dev, cur_proj_dev, proj_g_dev, @@ -439,7 +441,7 @@ class LSTMPGradKernel : public framework::OpKernel { math::matmul(device_ctx, gate_g, false, *weight, true, static_cast(1.0), &proj0_g, static_cast(0.0)); - if (share_cell_act) { + if (proj_act != math::detail::ActivationType::kIdentity) { auto proj0_dev = EigenMatrix::From(*ordered_proj0); auto proj0_g_dev = EigenMatrix::From(proj0_g); ActGradCompute(cell_act, place, proj0_dev, proj0_dev, proj0_g_dev, diff --git a/python/paddle/v2/fluid/tests/test_lstmp_op.py b/python/paddle/v2/fluid/tests/test_lstmp_op.py index 8835cae50..08fc32e11 100644 --- a/python/paddle/v2/fluid/tests/test_lstmp_op.py +++ b/python/paddle/v2/fluid/tests/test_lstmp_op.py @@ -41,7 +41,7 @@ def relu(x): return np.maximum(x, 0) -ACTVATION = { +ACTIVATION = { 'identity': identity, 'sigmoid': sigmoid, 'tanh': tanh, @@ -63,8 +63,9 @@ def lstmp( act_gate=None, act_cell=None, act_cand=None, - share_cell_act=True): - def _step(x, w_r, w_rh, w_c, r_pre, c_pre, act_gate, act_cell, act_cand): + act_proj=None): + def _step(x, w_r, w_rh, w_c, r_pre, c_pre, act_gate, act_cell, act_cand, + act_proj): g = np.dot(r_pre, w_r) # 1 x 4D g = g + x g = np.reshape(g, (1, g.size)) @@ -86,8 +87,7 @@ def lstmp( h = g_o * act_cell(c) # projection r = np.dot(h, w_rh) - if share_cell_act: - r = act_cell(r) + r = act_proj(r) return r, c def _reverse(x, lod): @@ -110,13 +110,12 @@ def lstmp( seq_len = offset[i + 1] - offset[i] x = input[offset[i]:offset[i + 1], :] r_pre = np.dot(h0[i], w_rh) # 1 x P - if share_cell_act: - r_pre = act_cell(r_pre) + r_pre = act_proj(r_pre) c_pre = c0[i] # 1 x D for j in range(seq_len): # compute one step r_pre, c_pre = _step(x[j], w_r, w_rh, w_c, r_pre, c_pre, act_gate, - act_cell, act_cand) + act_cell, act_cand, act_proj) projection.append(r_pre.flatten()) cell.append(c_pre.flatten()) @@ -131,7 +130,7 @@ def lstmp( return projection, cell -class TestLstmOp(OpTest): +class TestLstmpOp(OpTest): def set_argument(self): self.lod = [[0, 2, 5, 7]] # hidden size @@ -142,8 +141,8 @@ class TestLstmOp(OpTest): self.act_gate = 'sigmoid' self.act_cell = 'tanh' self.act_cand = 'tanh' + self.act_proj = self.act_cell - self.share_cell_act = True self.has_initial_state = False self.is_reverse = False self.use_peepholes = True @@ -172,8 +171,8 @@ class TestLstmOp(OpTest): w_c = b[:, 4 * self.D:] if self.use_peepholes else None w_rh = np.random.normal(size=(self.D, self.P)).astype('float64') r, c = lstmp(x, self.lod, h0, c0, w, w_rh, w_b, w_c, self.is_reverse, - ACTVATION[self.act_gate], ACTVATION[self.act_cell], - ACTVATION[self.act_cand], self.share_cell_act) + ACTIVATION[self.act_gate], ACTIVATION[self.act_cell], + ACTIVATION[self.act_cand], ACTIVATION[self.act_proj]) self.inputs = {'Input': (x, self.lod), 'Weight': w, 'ProjWeight': w_rh} @@ -193,7 +192,7 @@ class TestLstmOp(OpTest): 'gate_activation': self.act_gate, 'cell_activation': self.act_cell, 'candidate_activation': self.act_cand, - 'share_cell_act': self.share_cell_act + 'proj_activation': self.act_proj } def test_check_output(self): @@ -212,7 +211,7 @@ class TestLstmOp(OpTest): max_relative_error=1e-2) -class TestLstmOpHasInitial(TestLstmOp): +class TestLstmpOpHasInitial(TestLstmpOp): def set_argument(self): self.lod = [[0, 2, 5, 7]] self.D = 16 @@ -221,8 +220,8 @@ class TestLstmOpHasInitial(TestLstmOp): self.act_gate = 'sigmoid' self.act_cell = 'tanh' self.act_cand = 'tanh' + self.act_proj = self.act_cell - self.share_cell_act = True self.has_initial_state = True self.is_reverse = True self.use_peepholes = True @@ -313,7 +312,7 @@ class TestLstmOpHasInitial(TestLstmOp): no_grad_set=set('C0')) -class TestLstmOpRerverse(TestLstmOp): +class TestLstmpOpRerverse(TestLstmpOp): def set_argument(self): self.lod = [[0, 2, 5, 7]] self.D = 16 @@ -322,14 +321,14 @@ class TestLstmOpRerverse(TestLstmOp): self.act_gate = 'sigmoid' self.act_cell = 'tanh' self.act_cand = 'tanh' + self.act_proj = self.act_cell - self.share_cell_act = True self.has_initial_state = False self.is_reverse = True self.use_peepholes = True -class TestLstmOpNotUsePeepholes(TestLstmOp): +class TestLstmpOpNotUsePeepholes(TestLstmpOp): def set_argument(self): self.lod = [[0, 2, 5, 7]] self.D = 16 @@ -338,14 +337,14 @@ class TestLstmOpNotUsePeepholes(TestLstmOp): self.act_gate = 'sigmoid' self.act_cell = 'tanh' self.act_cand = 'tanh' + self.act_proj = self.act_cell - self.share_cell_act = True self.has_initial_state = False self.is_reverse = False self.use_peepholes = False -class TestLstmOpNotShareCellAct(TestLstmOp): +class TestLstmpOpLinearProjection(TestLstmpOp): def set_argument(self): self.lod = [[0, 2, 5, 7]] self.D = 16 @@ -354,8 +353,8 @@ class TestLstmOpNotShareCellAct(TestLstmOp): self.act_gate = 'sigmoid' self.act_cell = 'tanh' self.act_cand = 'tanh' + self.act_proj = 'identity' - self.share_cell_act = False self.has_initial_state = False self.is_reverse = False self.use_peepholes = True -- GitLab