提交 a398e25d 编写于 作者: Y yangyaming

Expose param_attr and bias_attr.

上级 ed56ed9f
...@@ -51,7 +51,10 @@ class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -51,7 +51,10 @@ class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker {
LstmUnitOpMaker(framework::OpProto* proto, LstmUnitOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "FC input before the non-linear activation."); AddInput("X",
"Lstm unit only applies non-linear activations, please make sure"
"that linear tranformation has already been applied to `X`. "
"Linear tranformation can be applied by adding a `fc` layer");
AddInput( AddInput(
"C_prev", "C_prev",
"The cell state tensor of last time-step in the Lstm Unit operator."); "The cell state tensor of last time-step in the Lstm Unit operator.");
......
...@@ -5,6 +5,7 @@ All layers just related to the neural network. ...@@ -5,6 +5,7 @@ All layers just related to the neural network.
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
from ..initializer import Normal, Constant from ..initializer import Normal, Constant
from ..framework import Variable from ..framework import Variable
from ..param_attr import ParamAttr
from tensor import concat from tensor import concat
__all__ = [ __all__ = [
...@@ -796,6 +797,8 @@ def lstm_unit(x_t, ...@@ -796,6 +797,8 @@ def lstm_unit(x_t,
hidden_t_prev, hidden_t_prev,
cell_t_prev, cell_t_prev,
forget_bias=0.0, forget_bias=0.0,
param_attr=None,
bias_attr=ParamAttr(),
main_program=None, main_program=None,
startup_program=None): startup_program=None):
"""Lstm unit layer. The equation of a lstm step is: """Lstm unit layer. The equation of a lstm step is:
...@@ -836,6 +839,10 @@ def lstm_unit(x_t, ...@@ -836,6 +839,10 @@ def lstm_unit(x_t,
hidden_t_prev (Variable): The hidden value of lstm unit. hidden_t_prev (Variable): The hidden value of lstm unit.
cell_t_prev (Variable): The cell value of lstm unit. cell_t_prev (Variable): The cell value of lstm unit.
forget_bias (float): The forget bias of lstm unit. forget_bias (float): The forget bias of lstm unit.
param_attr (ParamAttr): The attributes of parameter weights, used to set
initializer, name etc.
bias_attr (ParamAttr): The attributes of bias weights, used to set
initializer, name etc.
main_program (Program): The main program. main_program (Program): The main program.
startup_program (Program): the startup program. startup_program (Program): the startup program.
...@@ -882,6 +889,8 @@ def lstm_unit(x_t, ...@@ -882,6 +889,8 @@ def lstm_unit(x_t,
startup_program=startup_program) startup_program=startup_program)
fc_out = fc(input=concat_out, fc_out = fc(input=concat_out,
size=4 * size, size=4 * size,
param_attr=param_attr,
bias_attr=bias_attr,
main_program=main_program, main_program=main_program,
startup_program=startup_program) startup_program=startup_program)
dtype = x_t.dtype dtype = x_t.dtype
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册