提交 a398e25d 编写于 作者: Y yangyaming

Expose param_attr and bias_attr.

上级 ed56ed9f
......@@ -51,7 +51,10 @@ class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker {
LstmUnitOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "FC input before the non-linear activation.");
AddInput("X",
"Lstm unit only applies non-linear activations, please make sure"
"that linear tranformation has already been applied to `X`. "
"Linear tranformation can be applied by adding a `fc` layer");
AddInput(
"C_prev",
"The cell state tensor of last time-step in the Lstm Unit operator.");
......
......@@ -5,6 +5,7 @@ All layers just related to the neural network.
from ..layer_helper import LayerHelper
from ..initializer import Normal, Constant
from ..framework import Variable
from ..param_attr import ParamAttr
from tensor import concat
__all__ = [
......@@ -796,6 +797,8 @@ def lstm_unit(x_t,
hidden_t_prev,
cell_t_prev,
forget_bias=0.0,
param_attr=None,
bias_attr=ParamAttr(),
main_program=None,
startup_program=None):
"""Lstm unit layer. The equation of a lstm step is:
......@@ -836,6 +839,10 @@ def lstm_unit(x_t,
hidden_t_prev (Variable): The hidden value of lstm unit.
cell_t_prev (Variable): The cell value of lstm unit.
forget_bias (float): The forget bias of lstm unit.
param_attr (ParamAttr): The attributes of parameter weights, used to set
initializer, name etc.
bias_attr (ParamAttr): The attributes of bias weights, used to set
initializer, name etc.
main_program (Program): The main program.
startup_program (Program): the startup program.
......@@ -882,6 +889,8 @@ def lstm_unit(x_t,
startup_program=startup_program)
fc_out = fc(input=concat_out,
size=4 * size,
param_attr=param_attr,
bias_attr=bias_attr,
main_program=main_program,
startup_program=startup_program)
dtype = x_t.dtype
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册