提交 76beff86 编写于 作者: Y Yibing Liu

Make the projection activation configurable

上级 db1f6a59
......@@ -23,27 +23,29 @@ class LSTMPOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of LSTMP should not be null.");
"Input(Input) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(Weight) of LSTMP should not be null.");
"Input(Weight) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("ProjWeight"),
"Input(ProjWeight) of LSTMP should not be null.");
"Input(ProjWeight) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Bias"),
"Input(Bias) of LSTMP should not be null.");
"Input(Bias) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Projection"),
"Output(Projection) of LSTMP should not be null.");
"Output(Projection) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
"Output(Cell) of LSTMP should not be null.");
"Output(Cell) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchGate"),
"Output(BatchGate) of LSTMP should not be null.");
"Output(BatchGate) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"),
"Output(BatchGate) of LSTMP should not be null.");
"Output(BatchCellPreAct) of LSTMP operator should not be "
"null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchHidden"),
"Output(BatchHidden) of LSTMP should not be null.");
"Output(BatchHidden) of LSTMP operator should not be null.");
auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_EQ(in_dims.size(), 2, "Input(X)'s rank must be 2.");
PADDLE_ENFORCE_EQ(in_dims.size(), 2,
"Input(X)'s rank of LSTMP operator must be 2.");
int frame_size = in_dims[1] / 4;
auto w_dims = ctx->GetInputDim("Weight");
......@@ -68,8 +70,8 @@ class LSTMPOp : public framework::OperatorWithKernel {
if (ctx->HasInput("H0")) {
PADDLE_ENFORCE(ctx->HasInput("C0"),
"Input(C0) and Input(H0) of LSTMP should not "
"be null at the same time.");
"Input(C0) of LSTMP operator should not be null after "
"Input(H0) provided.");
auto h_dims = ctx->GetInputDim("H0");
auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE(h_dims == c_dims,
......@@ -132,8 +134,7 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("C0",
"(Tensor, optional) the initial cell state is an optional "
"input. This is a tensor with shape (N x D), where N is the "
"batch size. Only one of `H0` and `C0` can be NULL at the same "
"time.")
"batch size. `C0` should not be null if `H0` provided.")
.AsDispensable();
AddInput("Weight",
"(Tensor) the learnable hidden-hidden weights."
......@@ -211,13 +212,12 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker {
"`tanh` by default.")
.SetDefault("tanh")
.InEnum({"sigmoid", "tanh", "relu", "identity"});
AddAttr<bool>("share_cell_act",
"(bool, defalut: True) "
"whether to share the activation of cell output with the "
"projection layer. When set to `False`, the projection "
"is simple linear, otherwise it will go through an "
"activation function same as `cell_activation`.")
.SetDefault(true);
AddAttr<std::string>("proj_activation",
"(string, default: tanh)"
"The activation for projection output, "
"`tanh` by defalut.")
.SetDefault("tanh")
.InEnum({"sigmoid", "tanh", "relu", "identity"});
AddComment(R"DOC(
Long-Short Term Memory with recurrent Projection layer (LSTMP) Operator.
......@@ -226,20 +226,21 @@ original hidden state to a lower-dimensional one, which is proposed to reduce
the number of total parameters and furthermore computational complexity for
the LSTM, espeacially for the case that the size of output units is relative
large (https://research.google.com/pubs/archive/43905.pdf).
The formula is as follows:
$$
i_t = \sigma(W_{ix}x_{t} + W_{ih}r_{t-1} + W_{ic}c_{t-1} + b_i) \\
i_t = \sigma(W_{ix}x_{t} + W_{ir}r_{t-1} + W_{ic}c_{t-1} + b_i) \\
f_t = \sigma(W_{fx}x_{t} + W_{fh}r_{t-1} + W_{fc}c_{t-1} + b_f) \\
f_t = \sigma(W_{fx}x_{t} + W_{fr}r_{t-1} + W_{fc}c_{t-1} + b_f) \\
\tilde{c_t} = act_g(W_{cx}x_t + W_{ch}r_{t-1} + b_c) \\
\tilde{c_t} = act_g(W_{cx}x_t + W_{cr}r_{t-1} + b_c) \\
o_t = \sigma(W_{ox}x_{t} + W_{oh}r_{t-1} + W_{oc}c_t + b_o) \\
o_t = \sigma(W_{ox}x_{t} + W_{or}r_{t-1} + W_{oc}c_t + b_o) \\
c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c_t}
c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c_t} \\
h_t = o_t \odot act_h(c_t)
h_t = o_t \odot act_h(c_t) \\
r_t = \overline{act_h}(W_{rh}h_t)
$$
......@@ -259,9 +260,8 @@ input and previous hidden state.
The $\odot$ is the element-wise product of the vectors. $act_g$ and $act_h$
are the cell input and cell output activation functions and `tanh` is usually
used for them. $\overline{act_h}$ is the activation function for the projection
layer. When `share_cell_act` set to `False`, $\overline{act_h}$ is an
identity activation, otherwise it will be same as $act_h$.
used for them. $\overline{act_h}$ is the activation function for the
projection output, usually using `identity` or same as $act_h$.
Note that these $W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}$
operations on the input $x_{t}$ are NOT included in this operator.
......@@ -277,22 +277,22 @@ class LSTMPGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of LSTMP should not be null.");
"Input(Input) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Projection"),
"Input(Projection) of LSTMP should not be null.");
"Input(Projection) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Cell"),
"Input(Cell) of LSTMP should not be null.");
"Input(Cell) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(Weight) of LSTMP should not be null.");
"Input(Weight) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("ProjWeight"),
"Input(ProjWeight) of LSTMP should not be null.");
"Input(ProjWeight) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Bias"),
"Input(Bias) of LSTMP should not be null.");
"Input(Bias) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("BatchGate"),
"Input(BatchGate) of LSTMP should not be null.");
"Input(BatchGate) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("BatchCellPreAct"),
"Input(BatchGate) of LSTMP should not be null.");
"Input(BatchGate) of LSTMP operator should not be null.");
auto SetOutGradDim = [&ctx](const std::string& name) {
auto g_name = framework::GradVarName(name);
......
......@@ -136,7 +136,8 @@ class LSTMPKernel : public framework::OpKernel<T> {
ctx.Attr<std::string>("cell_activation"));
auto cand_act = math::detail::GetActivationType(
ctx.Attr<std::string>("candidate_activation"));
auto share_cell_act = ctx.Attr<bool>("share_cell_act");
auto proj_act = math::detail::GetActivationType(
ctx.Attr<std::string>("proj_activation"));
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
for (size_t n = 0; n < num_batch; n++) {
......@@ -174,7 +175,7 @@ class LSTMPKernel : public framework::OpKernel<T> {
math::matmul<DeviceContext, T>(device_ctx, ordered_h0, false,
*proj_weight, false, static_cast<T>(1.0),
ordered_proj0, static_cast<T>(0.0));
if (share_cell_act) {
if (proj_act != math::detail::ActivationType::kIdentity) {
auto proj0_dev = EigenMatrix<T>::From(*ordered_proj0);
ActCompute(cell_act, place, proj0_dev, proj0_dev);
}
......@@ -194,7 +195,7 @@ class LSTMPKernel : public framework::OpKernel<T> {
math::matmul<DeviceContext, T>(device_ctx, hidden_t, false, *proj_weight,
false, static_cast<T>(1.0), &proj_t,
static_cast<T>(0.0));
if (share_cell_act) {
if (proj_act != math::detail::ActivationType::kIdentity) {
auto proj_t_dev = EigenMatrix<T>::From(proj_t);
ActCompute(cell_act, place, proj_t_dev, proj_t_dev);
}
......@@ -348,7 +349,8 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
ctx.Attr<std::string>("cell_activation"));
auto cand_act = math::detail::GetActivationType(
ctx.Attr<std::string>("candidate_activation"));
auto share_cell_act = ctx.Attr<bool>("share_cell_act");
auto proj_act = math::detail::GetActivationType(
ctx.Attr<std::string>("proj_activation"));
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto batch_starts = batch_gate->lod()[0];
......@@ -359,7 +361,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
Tensor cur_proj = batch_proj.Slice(bstart, bend);
Tensor proj_g = batch_proj_g.Slice(bstart, bend);
if (share_cell_act) {
if (proj_act != math::detail::ActivationType::kIdentity) {
auto cur_proj_dev = EigenMatrix<T>::From(cur_proj);
auto proj_g_dev = EigenMatrix<T>::From(proj_g);
ActGradCompute(cell_act, place, cur_proj_dev, cur_proj_dev, proj_g_dev,
......@@ -439,7 +441,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
math::matmul<DeviceContext, T>(device_ctx, gate_g, false, *weight,
true, static_cast<T>(1.0), &proj0_g,
static_cast<T>(0.0));
if (share_cell_act) {
if (proj_act != math::detail::ActivationType::kIdentity) {
auto proj0_dev = EigenMatrix<T>::From(*ordered_proj0);
auto proj0_g_dev = EigenMatrix<T>::From(proj0_g);
ActGradCompute(cell_act, place, proj0_dev, proj0_dev, proj0_g_dev,
......
......@@ -41,7 +41,7 @@ def relu(x):
return np.maximum(x, 0)
ACTVATION = {
ACTIVATION = {
'identity': identity,
'sigmoid': sigmoid,
'tanh': tanh,
......@@ -63,8 +63,9 @@ def lstmp(
act_gate=None,
act_cell=None,
act_cand=None,
share_cell_act=True):
def _step(x, w_r, w_rh, w_c, r_pre, c_pre, act_gate, act_cell, act_cand):
act_proj=None):
def _step(x, w_r, w_rh, w_c, r_pre, c_pre, act_gate, act_cell, act_cand,
act_proj):
g = np.dot(r_pre, w_r) # 1 x 4D
g = g + x
g = np.reshape(g, (1, g.size))
......@@ -86,8 +87,7 @@ def lstmp(
h = g_o * act_cell(c)
# projection
r = np.dot(h, w_rh)
if share_cell_act:
r = act_cell(r)
r = act_proj(r)
return r, c
def _reverse(x, lod):
......@@ -110,13 +110,12 @@ def lstmp(
seq_len = offset[i + 1] - offset[i]
x = input[offset[i]:offset[i + 1], :]
r_pre = np.dot(h0[i], w_rh) # 1 x P
if share_cell_act:
r_pre = act_cell(r_pre)
r_pre = act_proj(r_pre)
c_pre = c0[i] # 1 x D
for j in range(seq_len):
# compute one step
r_pre, c_pre = _step(x[j], w_r, w_rh, w_c, r_pre, c_pre, act_gate,
act_cell, act_cand)
act_cell, act_cand, act_proj)
projection.append(r_pre.flatten())
cell.append(c_pre.flatten())
......@@ -131,7 +130,7 @@ def lstmp(
return projection, cell
class TestLstmOp(OpTest):
class TestLstmpOp(OpTest):
def set_argument(self):
self.lod = [[0, 2, 5, 7]]
# hidden size
......@@ -142,8 +141,8 @@ class TestLstmOp(OpTest):
self.act_gate = 'sigmoid'
self.act_cell = 'tanh'
self.act_cand = 'tanh'
self.act_proj = self.act_cell
self.share_cell_act = True
self.has_initial_state = False
self.is_reverse = False
self.use_peepholes = True
......@@ -172,8 +171,8 @@ class TestLstmOp(OpTest):
w_c = b[:, 4 * self.D:] if self.use_peepholes else None
w_rh = np.random.normal(size=(self.D, self.P)).astype('float64')
r, c = lstmp(x, self.lod, h0, c0, w, w_rh, w_b, w_c, self.is_reverse,
ACTVATION[self.act_gate], ACTVATION[self.act_cell],
ACTVATION[self.act_cand], self.share_cell_act)
ACTIVATION[self.act_gate], ACTIVATION[self.act_cell],
ACTIVATION[self.act_cand], ACTIVATION[self.act_proj])
self.inputs = {'Input': (x, self.lod), 'Weight': w, 'ProjWeight': w_rh}
......@@ -193,7 +192,7 @@ class TestLstmOp(OpTest):
'gate_activation': self.act_gate,
'cell_activation': self.act_cell,
'candidate_activation': self.act_cand,
'share_cell_act': self.share_cell_act
'proj_activation': self.act_proj
}
def test_check_output(self):
......@@ -212,7 +211,7 @@ class TestLstmOp(OpTest):
max_relative_error=1e-2)
class TestLstmOpHasInitial(TestLstmOp):
class TestLstmpOpHasInitial(TestLstmpOp):
def set_argument(self):
self.lod = [[0, 2, 5, 7]]
self.D = 16
......@@ -221,8 +220,8 @@ class TestLstmOpHasInitial(TestLstmOp):
self.act_gate = 'sigmoid'
self.act_cell = 'tanh'
self.act_cand = 'tanh'
self.act_proj = self.act_cell
self.share_cell_act = True
self.has_initial_state = True
self.is_reverse = True
self.use_peepholes = True
......@@ -313,7 +312,7 @@ class TestLstmOpHasInitial(TestLstmOp):
no_grad_set=set('C0'))
class TestLstmOpRerverse(TestLstmOp):
class TestLstmpOpRerverse(TestLstmpOp):
def set_argument(self):
self.lod = [[0, 2, 5, 7]]
self.D = 16
......@@ -322,14 +321,14 @@ class TestLstmOpRerverse(TestLstmOp):
self.act_gate = 'sigmoid'
self.act_cell = 'tanh'
self.act_cand = 'tanh'
self.act_proj = self.act_cell
self.share_cell_act = True
self.has_initial_state = False
self.is_reverse = True
self.use_peepholes = True
class TestLstmOpNotUsePeepholes(TestLstmOp):
class TestLstmpOpNotUsePeepholes(TestLstmpOp):
def set_argument(self):
self.lod = [[0, 2, 5, 7]]
self.D = 16
......@@ -338,14 +337,14 @@ class TestLstmOpNotUsePeepholes(TestLstmOp):
self.act_gate = 'sigmoid'
self.act_cell = 'tanh'
self.act_cand = 'tanh'
self.act_proj = self.act_cell
self.share_cell_act = True
self.has_initial_state = False
self.is_reverse = False
self.use_peepholes = False
class TestLstmOpNotShareCellAct(TestLstmOp):
class TestLstmpOpLinearProjection(TestLstmpOp):
def set_argument(self):
self.lod = [[0, 2, 5, 7]]
self.D = 16
......@@ -354,8 +353,8 @@ class TestLstmOpNotShareCellAct(TestLstmOp):
self.act_gate = 'sigmoid'
self.act_cell = 'tanh'
self.act_cand = 'tanh'
self.act_proj = 'identity'
self.share_cell_act = False
self.has_initial_state = False
self.is_reverse = False
self.use_peepholes = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册