提交 d993a4f5 编写于 作者: Y yangyaming

Change default value for bias_attr.

上级 69072ef1
...@@ -866,7 +866,7 @@ def lstm_unit(x_t, ...@@ -866,7 +866,7 @@ def lstm_unit(x_t,
cell_t_prev, cell_t_prev,
forget_bias=0.0, forget_bias=0.0,
param_attr=None, param_attr=None,
bias_attr=ParamAttr(), bias_attr=None,
main_program=None, main_program=None,
startup_program=None): startup_program=None):
"""Lstm unit layer. The equation of a lstm step is: """Lstm unit layer. The equation of a lstm step is:
...@@ -909,8 +909,8 @@ def lstm_unit(x_t, ...@@ -909,8 +909,8 @@ def lstm_unit(x_t,
forget_bias (float): The forget bias of lstm unit. forget_bias (float): The forget bias of lstm unit.
param_attr (ParamAttr): The attributes of parameter weights, used to set param_attr (ParamAttr): The attributes of parameter weights, used to set
initializer, name etc. initializer, name etc.
bias_attr (ParamAttr): The attributes of bias weights, used to set bias_attr (ParamAttr): The attributes of bias weights, if not False,
initializer, name etc. bias weights will be created and be set to default value.
main_program (Program): The main program. main_program (Program): The main program.
startup_program (Program): the startup program. startup_program (Program): the startup program.
...@@ -949,6 +949,9 @@ def lstm_unit(x_t, ...@@ -949,6 +949,9 @@ def lstm_unit(x_t,
raise ValueError("The 1s dimension of x_t, hidden_t_prev and " raise ValueError("The 1s dimension of x_t, hidden_t_prev and "
"cell_t_prev must be the same.") "cell_t_prev must be the same.")
if bias_attr is None:
bias_attr = ParamAttr()
size = cell_t_prev.shape[1] size = cell_t_prev.shape[1]
concat_out = concat( concat_out = concat(
input=[x_t, hidden_t_prev], input=[x_t, hidden_t_prev],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册