提交 76beff86 编写于 作者: Y Yibing Liu

Make the projection activation configurable

上级 db1f6a59
...@@ -23,27 +23,29 @@ class LSTMPOp : public framework::OperatorWithKernel { ...@@ -23,27 +23,29 @@ class LSTMPOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"), PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of LSTMP should not be null."); "Input(Input) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Weight"), PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(Weight) of LSTMP should not be null."); "Input(Weight) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("ProjWeight"), PADDLE_ENFORCE(ctx->HasInput("ProjWeight"),
"Input(ProjWeight) of LSTMP should not be null."); "Input(ProjWeight) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Bias"), PADDLE_ENFORCE(ctx->HasInput("Bias"),
"Input(Bias) of LSTMP should not be null."); "Input(Bias) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Projection"), PADDLE_ENFORCE(ctx->HasOutput("Projection"),
"Output(Projection) of LSTMP should not be null."); "Output(Projection) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Cell"), PADDLE_ENFORCE(ctx->HasOutput("Cell"),
"Output(Cell) of LSTMP should not be null."); "Output(Cell) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchGate"), PADDLE_ENFORCE(ctx->HasOutput("BatchGate"),
"Output(BatchGate) of LSTMP should not be null."); "Output(BatchGate) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"), PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"),
"Output(BatchGate) of LSTMP should not be null."); "Output(BatchCellPreAct) of LSTMP operator should not be "
"null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchHidden"), PADDLE_ENFORCE(ctx->HasOutput("BatchHidden"),
"Output(BatchHidden) of LSTMP should not be null."); "Output(BatchHidden) of LSTMP operator should not be null.");
auto in_dims = ctx->GetInputDim("Input"); auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_EQ(in_dims.size(), 2, "Input(X)'s rank must be 2."); PADDLE_ENFORCE_EQ(in_dims.size(), 2,
"Input(X)'s rank of LSTMP operator must be 2.");
int frame_size = in_dims[1] / 4; int frame_size = in_dims[1] / 4;
auto w_dims = ctx->GetInputDim("Weight"); auto w_dims = ctx->GetInputDim("Weight");
...@@ -68,8 +70,8 @@ class LSTMPOp : public framework::OperatorWithKernel { ...@@ -68,8 +70,8 @@ class LSTMPOp : public framework::OperatorWithKernel {
if (ctx->HasInput("H0")) { if (ctx->HasInput("H0")) {
PADDLE_ENFORCE(ctx->HasInput("C0"), PADDLE_ENFORCE(ctx->HasInput("C0"),
"Input(C0) and Input(H0) of LSTMP should not " "Input(C0) of LSTMP operator should not be null after "
"be null at the same time."); "Input(H0) provided.");
auto h_dims = ctx->GetInputDim("H0"); auto h_dims = ctx->GetInputDim("H0");
auto c_dims = ctx->GetInputDim("C0"); auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE(h_dims == c_dims, PADDLE_ENFORCE(h_dims == c_dims,
...@@ -132,8 +134,7 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -132,8 +134,7 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("C0", AddInput("C0",
"(Tensor, optional) the initial cell state is an optional " "(Tensor, optional) the initial cell state is an optional "
"input. This is a tensor with shape (N x D), where N is the " "input. This is a tensor with shape (N x D), where N is the "
"batch size. Only one of `H0` and `C0` can be NULL at the same " "batch size. `C0` should not be null if `H0` provided.")
"time.")
.AsDispensable(); .AsDispensable();
AddInput("Weight", AddInput("Weight",
"(Tensor) the learnable hidden-hidden weights." "(Tensor) the learnable hidden-hidden weights."
...@@ -211,13 +212,12 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -211,13 +212,12 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker {
"`tanh` by default.") "`tanh` by default.")
.SetDefault("tanh") .SetDefault("tanh")
.InEnum({"sigmoid", "tanh", "relu", "identity"}); .InEnum({"sigmoid", "tanh", "relu", "identity"});
AddAttr<bool>("share_cell_act", AddAttr<std::string>("proj_activation",
"(bool, defalut: True) " "(string, default: tanh)"
"whether to share the activation of cell output with the " "The activation for projection output, "
"projection layer. When set to `False`, the projection " "`tanh` by defalut.")
"is simple linear, otherwise it will go through an " .SetDefault("tanh")
"activation function same as `cell_activation`.") .InEnum({"sigmoid", "tanh", "relu", "identity"});
.SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
Long-Short Term Memory with recurrent Projection layer (LSTMP) Operator. Long-Short Term Memory with recurrent Projection layer (LSTMP) Operator.
...@@ -226,20 +226,21 @@ original hidden state to a lower-dimensional one, which is proposed to reduce ...@@ -226,20 +226,21 @@ original hidden state to a lower-dimensional one, which is proposed to reduce
the number of total parameters and furthermore computational complexity for the number of total parameters and furthermore computational complexity for
the LSTM, espeacially for the case that the size of output units is relative the LSTM, espeacially for the case that the size of output units is relative
large (https://research.google.com/pubs/archive/43905.pdf). large (https://research.google.com/pubs/archive/43905.pdf).
The formula is as follows: The formula is as follows:
$$ $$
i_t = \sigma(W_{ix}x_{t} + W_{ih}r_{t-1} + W_{ic}c_{t-1} + b_i) \\ i_t = \sigma(W_{ix}x_{t} + W_{ir}r_{t-1} + W_{ic}c_{t-1} + b_i) \\
f_t = \sigma(W_{fx}x_{t} + W_{fh}r_{t-1} + W_{fc}c_{t-1} + b_f) \\ f_t = \sigma(W_{fx}x_{t} + W_{fr}r_{t-1} + W_{fc}c_{t-1} + b_f) \\
\tilde{c_t} = act_g(W_{cx}x_t + W_{ch}r_{t-1} + b_c) \\ \tilde{c_t} = act_g(W_{cx}x_t + W_{cr}r_{t-1} + b_c) \\
o_t = \sigma(W_{ox}x_{t} + W_{oh}r_{t-1} + W_{oc}c_t + b_o) \\ o_t = \sigma(W_{ox}x_{t} + W_{or}r_{t-1} + W_{oc}c_t + b_o) \\
c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c_t} c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c_t} \\
h_t = o_t \odot act_h(c_t) h_t = o_t \odot act_h(c_t) \\
r_t = \overline{act_h}(W_{rh}h_t) r_t = \overline{act_h}(W_{rh}h_t)
$$ $$
...@@ -259,9 +260,8 @@ input and previous hidden state. ...@@ -259,9 +260,8 @@ input and previous hidden state.
The $\odot$ is the element-wise product of the vectors. $act_g$ and $act_h$ The $\odot$ is the element-wise product of the vectors. $act_g$ and $act_h$
are the cell input and cell output activation functions and `tanh` is usually are the cell input and cell output activation functions and `tanh` is usually
used for them. $\overline{act_h}$ is the activation function for the projection used for them. $\overline{act_h}$ is the activation function for the
layer. When `share_cell_act` set to `False`, $\overline{act_h}$ is an projection output, usually using `identity` or same as $act_h$.
identity activation, otherwise it will be same as $act_h$.
Note that these $W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}$ Note that these $W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}$
operations on the input $x_{t}$ are NOT included in this operator. operations on the input $x_{t}$ are NOT included in this operator.
...@@ -277,22 +277,22 @@ class LSTMPGradOp : public framework::OperatorWithKernel { ...@@ -277,22 +277,22 @@ class LSTMPGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"), PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of LSTMP should not be null."); "Input(Input) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Projection"), PADDLE_ENFORCE(ctx->HasInput("Projection"),
"Input(Projection) of LSTMP should not be null."); "Input(Projection) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Cell"), PADDLE_ENFORCE(ctx->HasInput("Cell"),
"Input(Cell) of LSTMP should not be null."); "Input(Cell) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Weight"), PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(Weight) of LSTMP should not be null."); "Input(Weight) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("ProjWeight"), PADDLE_ENFORCE(ctx->HasInput("ProjWeight"),
"Input(ProjWeight) of LSTMP should not be null."); "Input(ProjWeight) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Bias"), PADDLE_ENFORCE(ctx->HasInput("Bias"),
"Input(Bias) of LSTMP should not be null."); "Input(Bias) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("BatchGate"), PADDLE_ENFORCE(ctx->HasInput("BatchGate"),
"Input(BatchGate) of LSTMP should not be null."); "Input(BatchGate) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("BatchCellPreAct"), PADDLE_ENFORCE(ctx->HasInput("BatchCellPreAct"),
"Input(BatchGate) of LSTMP should not be null."); "Input(BatchGate) of LSTMP operator should not be null.");
auto SetOutGradDim = [&ctx](const std::string& name) { auto SetOutGradDim = [&ctx](const std::string& name) {
auto g_name = framework::GradVarName(name); auto g_name = framework::GradVarName(name);
......
...@@ -136,7 +136,8 @@ class LSTMPKernel : public framework::OpKernel<T> { ...@@ -136,7 +136,8 @@ class LSTMPKernel : public framework::OpKernel<T> {
ctx.Attr<std::string>("cell_activation")); ctx.Attr<std::string>("cell_activation"));
auto cand_act = math::detail::GetActivationType( auto cand_act = math::detail::GetActivationType(
ctx.Attr<std::string>("candidate_activation")); ctx.Attr<std::string>("candidate_activation"));
auto share_cell_act = ctx.Attr<bool>("share_cell_act"); auto proj_act = math::detail::GetActivationType(
ctx.Attr<std::string>("proj_activation"));
auto& place = *ctx.template device_context<DeviceContext>().eigen_device(); auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
for (size_t n = 0; n < num_batch; n++) { for (size_t n = 0; n < num_batch; n++) {
...@@ -174,7 +175,7 @@ class LSTMPKernel : public framework::OpKernel<T> { ...@@ -174,7 +175,7 @@ class LSTMPKernel : public framework::OpKernel<T> {
math::matmul<DeviceContext, T>(device_ctx, ordered_h0, false, math::matmul<DeviceContext, T>(device_ctx, ordered_h0, false,
*proj_weight, false, static_cast<T>(1.0), *proj_weight, false, static_cast<T>(1.0),
ordered_proj0, static_cast<T>(0.0)); ordered_proj0, static_cast<T>(0.0));
if (share_cell_act) { if (proj_act != math::detail::ActivationType::kIdentity) {
auto proj0_dev = EigenMatrix<T>::From(*ordered_proj0); auto proj0_dev = EigenMatrix<T>::From(*ordered_proj0);
ActCompute(cell_act, place, proj0_dev, proj0_dev); ActCompute(cell_act, place, proj0_dev, proj0_dev);
} }
...@@ -194,7 +195,7 @@ class LSTMPKernel : public framework::OpKernel<T> { ...@@ -194,7 +195,7 @@ class LSTMPKernel : public framework::OpKernel<T> {
math::matmul<DeviceContext, T>(device_ctx, hidden_t, false, *proj_weight, math::matmul<DeviceContext, T>(device_ctx, hidden_t, false, *proj_weight,
false, static_cast<T>(1.0), &proj_t, false, static_cast<T>(1.0), &proj_t,
static_cast<T>(0.0)); static_cast<T>(0.0));
if (share_cell_act) { if (proj_act != math::detail::ActivationType::kIdentity) {
auto proj_t_dev = EigenMatrix<T>::From(proj_t); auto proj_t_dev = EigenMatrix<T>::From(proj_t);
ActCompute(cell_act, place, proj_t_dev, proj_t_dev); ActCompute(cell_act, place, proj_t_dev, proj_t_dev);
} }
...@@ -348,7 +349,8 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -348,7 +349,8 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
ctx.Attr<std::string>("cell_activation")); ctx.Attr<std::string>("cell_activation"));
auto cand_act = math::detail::GetActivationType( auto cand_act = math::detail::GetActivationType(
ctx.Attr<std::string>("candidate_activation")); ctx.Attr<std::string>("candidate_activation"));
auto share_cell_act = ctx.Attr<bool>("share_cell_act"); auto proj_act = math::detail::GetActivationType(
ctx.Attr<std::string>("proj_activation"));
auto& place = *ctx.template device_context<DeviceContext>().eigen_device(); auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto batch_starts = batch_gate->lod()[0]; auto batch_starts = batch_gate->lod()[0];
...@@ -359,7 +361,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -359,7 +361,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
Tensor cur_proj = batch_proj.Slice(bstart, bend); Tensor cur_proj = batch_proj.Slice(bstart, bend);
Tensor proj_g = batch_proj_g.Slice(bstart, bend); Tensor proj_g = batch_proj_g.Slice(bstart, bend);
if (share_cell_act) { if (proj_act != math::detail::ActivationType::kIdentity) {
auto cur_proj_dev = EigenMatrix<T>::From(cur_proj); auto cur_proj_dev = EigenMatrix<T>::From(cur_proj);
auto proj_g_dev = EigenMatrix<T>::From(proj_g); auto proj_g_dev = EigenMatrix<T>::From(proj_g);
ActGradCompute(cell_act, place, cur_proj_dev, cur_proj_dev, proj_g_dev, ActGradCompute(cell_act, place, cur_proj_dev, cur_proj_dev, proj_g_dev,
...@@ -439,7 +441,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -439,7 +441,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
math::matmul<DeviceContext, T>(device_ctx, gate_g, false, *weight, math::matmul<DeviceContext, T>(device_ctx, gate_g, false, *weight,
true, static_cast<T>(1.0), &proj0_g, true, static_cast<T>(1.0), &proj0_g,
static_cast<T>(0.0)); static_cast<T>(0.0));
if (share_cell_act) { if (proj_act != math::detail::ActivationType::kIdentity) {
auto proj0_dev = EigenMatrix<T>::From(*ordered_proj0); auto proj0_dev = EigenMatrix<T>::From(*ordered_proj0);
auto proj0_g_dev = EigenMatrix<T>::From(proj0_g); auto proj0_g_dev = EigenMatrix<T>::From(proj0_g);
ActGradCompute(cell_act, place, proj0_dev, proj0_dev, proj0_g_dev, ActGradCompute(cell_act, place, proj0_dev, proj0_dev, proj0_g_dev,
......
...@@ -41,7 +41,7 @@ def relu(x): ...@@ -41,7 +41,7 @@ def relu(x):
return np.maximum(x, 0) return np.maximum(x, 0)
ACTVATION = { ACTIVATION = {
'identity': identity, 'identity': identity,
'sigmoid': sigmoid, 'sigmoid': sigmoid,
'tanh': tanh, 'tanh': tanh,
...@@ -63,8 +63,9 @@ def lstmp( ...@@ -63,8 +63,9 @@ def lstmp(
act_gate=None, act_gate=None,
act_cell=None, act_cell=None,
act_cand=None, act_cand=None,
share_cell_act=True): act_proj=None):
def _step(x, w_r, w_rh, w_c, r_pre, c_pre, act_gate, act_cell, act_cand): def _step(x, w_r, w_rh, w_c, r_pre, c_pre, act_gate, act_cell, act_cand,
act_proj):
g = np.dot(r_pre, w_r) # 1 x 4D g = np.dot(r_pre, w_r) # 1 x 4D
g = g + x g = g + x
g = np.reshape(g, (1, g.size)) g = np.reshape(g, (1, g.size))
...@@ -86,8 +87,7 @@ def lstmp( ...@@ -86,8 +87,7 @@ def lstmp(
h = g_o * act_cell(c) h = g_o * act_cell(c)
# projection # projection
r = np.dot(h, w_rh) r = np.dot(h, w_rh)
if share_cell_act: r = act_proj(r)
r = act_cell(r)
return r, c return r, c
def _reverse(x, lod): def _reverse(x, lod):
...@@ -110,13 +110,12 @@ def lstmp( ...@@ -110,13 +110,12 @@ def lstmp(
seq_len = offset[i + 1] - offset[i] seq_len = offset[i + 1] - offset[i]
x = input[offset[i]:offset[i + 1], :] x = input[offset[i]:offset[i + 1], :]
r_pre = np.dot(h0[i], w_rh) # 1 x P r_pre = np.dot(h0[i], w_rh) # 1 x P
if share_cell_act: r_pre = act_proj(r_pre)
r_pre = act_cell(r_pre)
c_pre = c0[i] # 1 x D c_pre = c0[i] # 1 x D
for j in range(seq_len): for j in range(seq_len):
# compute one step # compute one step
r_pre, c_pre = _step(x[j], w_r, w_rh, w_c, r_pre, c_pre, act_gate, r_pre, c_pre = _step(x[j], w_r, w_rh, w_c, r_pre, c_pre, act_gate,
act_cell, act_cand) act_cell, act_cand, act_proj)
projection.append(r_pre.flatten()) projection.append(r_pre.flatten())
cell.append(c_pre.flatten()) cell.append(c_pre.flatten())
...@@ -131,7 +130,7 @@ def lstmp( ...@@ -131,7 +130,7 @@ def lstmp(
return projection, cell return projection, cell
class TestLstmOp(OpTest): class TestLstmpOp(OpTest):
def set_argument(self): def set_argument(self):
self.lod = [[0, 2, 5, 7]] self.lod = [[0, 2, 5, 7]]
# hidden size # hidden size
...@@ -142,8 +141,8 @@ class TestLstmOp(OpTest): ...@@ -142,8 +141,8 @@ class TestLstmOp(OpTest):
self.act_gate = 'sigmoid' self.act_gate = 'sigmoid'
self.act_cell = 'tanh' self.act_cell = 'tanh'
self.act_cand = 'tanh' self.act_cand = 'tanh'
self.act_proj = self.act_cell
self.share_cell_act = True
self.has_initial_state = False self.has_initial_state = False
self.is_reverse = False self.is_reverse = False
self.use_peepholes = True self.use_peepholes = True
...@@ -172,8 +171,8 @@ class TestLstmOp(OpTest): ...@@ -172,8 +171,8 @@ class TestLstmOp(OpTest):
w_c = b[:, 4 * self.D:] if self.use_peepholes else None w_c = b[:, 4 * self.D:] if self.use_peepholes else None
w_rh = np.random.normal(size=(self.D, self.P)).astype('float64') w_rh = np.random.normal(size=(self.D, self.P)).astype('float64')
r, c = lstmp(x, self.lod, h0, c0, w, w_rh, w_b, w_c, self.is_reverse, r, c = lstmp(x, self.lod, h0, c0, w, w_rh, w_b, w_c, self.is_reverse,
ACTVATION[self.act_gate], ACTVATION[self.act_cell], ACTIVATION[self.act_gate], ACTIVATION[self.act_cell],
ACTVATION[self.act_cand], self.share_cell_act) ACTIVATION[self.act_cand], ACTIVATION[self.act_proj])
self.inputs = {'Input': (x, self.lod), 'Weight': w, 'ProjWeight': w_rh} self.inputs = {'Input': (x, self.lod), 'Weight': w, 'ProjWeight': w_rh}
...@@ -193,7 +192,7 @@ class TestLstmOp(OpTest): ...@@ -193,7 +192,7 @@ class TestLstmOp(OpTest):
'gate_activation': self.act_gate, 'gate_activation': self.act_gate,
'cell_activation': self.act_cell, 'cell_activation': self.act_cell,
'candidate_activation': self.act_cand, 'candidate_activation': self.act_cand,
'share_cell_act': self.share_cell_act 'proj_activation': self.act_proj
} }
def test_check_output(self): def test_check_output(self):
...@@ -212,7 +211,7 @@ class TestLstmOp(OpTest): ...@@ -212,7 +211,7 @@ class TestLstmOp(OpTest):
max_relative_error=1e-2) max_relative_error=1e-2)
class TestLstmOpHasInitial(TestLstmOp): class TestLstmpOpHasInitial(TestLstmpOp):
def set_argument(self): def set_argument(self):
self.lod = [[0, 2, 5, 7]] self.lod = [[0, 2, 5, 7]]
self.D = 16 self.D = 16
...@@ -221,8 +220,8 @@ class TestLstmOpHasInitial(TestLstmOp): ...@@ -221,8 +220,8 @@ class TestLstmOpHasInitial(TestLstmOp):
self.act_gate = 'sigmoid' self.act_gate = 'sigmoid'
self.act_cell = 'tanh' self.act_cell = 'tanh'
self.act_cand = 'tanh' self.act_cand = 'tanh'
self.act_proj = self.act_cell
self.share_cell_act = True
self.has_initial_state = True self.has_initial_state = True
self.is_reverse = True self.is_reverse = True
self.use_peepholes = True self.use_peepholes = True
...@@ -313,7 +312,7 @@ class TestLstmOpHasInitial(TestLstmOp): ...@@ -313,7 +312,7 @@ class TestLstmOpHasInitial(TestLstmOp):
no_grad_set=set('C0')) no_grad_set=set('C0'))
class TestLstmOpRerverse(TestLstmOp): class TestLstmpOpRerverse(TestLstmpOp):
def set_argument(self): def set_argument(self):
self.lod = [[0, 2, 5, 7]] self.lod = [[0, 2, 5, 7]]
self.D = 16 self.D = 16
...@@ -322,14 +321,14 @@ class TestLstmOpRerverse(TestLstmOp): ...@@ -322,14 +321,14 @@ class TestLstmOpRerverse(TestLstmOp):
self.act_gate = 'sigmoid' self.act_gate = 'sigmoid'
self.act_cell = 'tanh' self.act_cell = 'tanh'
self.act_cand = 'tanh' self.act_cand = 'tanh'
self.act_proj = self.act_cell
self.share_cell_act = True
self.has_initial_state = False self.has_initial_state = False
self.is_reverse = True self.is_reverse = True
self.use_peepholes = True self.use_peepholes = True
class TestLstmOpNotUsePeepholes(TestLstmOp): class TestLstmpOpNotUsePeepholes(TestLstmpOp):
def set_argument(self): def set_argument(self):
self.lod = [[0, 2, 5, 7]] self.lod = [[0, 2, 5, 7]]
self.D = 16 self.D = 16
...@@ -338,14 +337,14 @@ class TestLstmOpNotUsePeepholes(TestLstmOp): ...@@ -338,14 +337,14 @@ class TestLstmOpNotUsePeepholes(TestLstmOp):
self.act_gate = 'sigmoid' self.act_gate = 'sigmoid'
self.act_cell = 'tanh' self.act_cell = 'tanh'
self.act_cand = 'tanh' self.act_cand = 'tanh'
self.act_proj = self.act_cell
self.share_cell_act = True
self.has_initial_state = False self.has_initial_state = False
self.is_reverse = False self.is_reverse = False
self.use_peepholes = False self.use_peepholes = False
class TestLstmOpNotShareCellAct(TestLstmOp): class TestLstmpOpLinearProjection(TestLstmpOp):
def set_argument(self): def set_argument(self):
self.lod = [[0, 2, 5, 7]] self.lod = [[0, 2, 5, 7]]
self.D = 16 self.D = 16
...@@ -354,8 +353,8 @@ class TestLstmOpNotShareCellAct(TestLstmOp): ...@@ -354,8 +353,8 @@ class TestLstmOpNotShareCellAct(TestLstmOp):
self.act_gate = 'sigmoid' self.act_gate = 'sigmoid'
self.act_cell = 'tanh' self.act_cell = 'tanh'
self.act_cand = 'tanh' self.act_cand = 'tanh'
self.act_proj = 'identity'
self.share_cell_act = False
self.has_initial_state = False self.has_initial_state = False
self.is_reverse = False self.is_reverse = False
self.use_peepholes = True self.use_peepholes = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册