diff --git a/paddle/operators/lstm_unit_op.cc b/paddle/operators/lstm_unit_op.cc index 18b9cdf2a39e8226c634194ff2cc56d169979774..b6eb33bafe50548502a0478d37842fd2dfdebda4 100644 --- a/paddle/operators/lstm_unit_op.cc +++ b/paddle/operators/lstm_unit_op.cc @@ -51,7 +51,10 @@ class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker { LstmUnitOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "FC input before the non-linear activation."); + AddInput("X", + "Lstm unit only applies non-linear activations, please make sure" + "that linear tranformation has already been applied to `X`. " + "Linear tranformation can be applied by adding a `fc` layer"); AddInput( "C_prev", "The cell state tensor of last time-step in the Lstm Unit operator."); diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index 84e62d988ce9dbf35c3cfc6e3abec1fb5c191ec3..1c101c62c2dc4c502d0bb5fa3b3835513db5b090 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -5,6 +5,7 @@ All layers just related to the neural network. from ..layer_helper import LayerHelper from ..initializer import Normal, Constant from ..framework import Variable +from ..param_attr import ParamAttr from tensor import concat __all__ = [ @@ -796,6 +797,8 @@ def lstm_unit(x_t, hidden_t_prev, cell_t_prev, forget_bias=0.0, + param_attr=None, + bias_attr=ParamAttr(), main_program=None, startup_program=None): """Lstm unit layer. The equation of a lstm step is: @@ -836,6 +839,10 @@ def lstm_unit(x_t, hidden_t_prev (Variable): The hidden value of lstm unit. cell_t_prev (Variable): The cell value of lstm unit. forget_bias (float): The forget bias of lstm unit. + param_attr (ParamAttr): The attributes of parameter weights, used to set + initializer, name etc. + bias_attr (ParamAttr): The attributes of bias weights, used to set + initializer, name etc. main_program (Program): The main program. startup_program (Program): the startup program. @@ -882,6 +889,8 @@ def lstm_unit(x_t, startup_program=startup_program) fc_out = fc(input=concat_out, size=4 * size, + param_attr=param_attr, + bias_attr=bias_attr, main_program=main_program, startup_program=startup_program) dtype = x_t.dtype