未验证 提交 86d7cc97 编写于 作者: W whs 提交者: GitHub

Add bias for gru_unit_op and fix activation function (#10087)

上级 0d94ab13
......@@ -699,8 +699,8 @@ def dynamic_gru(input,
def gru_unit(input,
hidden,
size,
weight=None,
bias=None,
param_attr=None,
bias_attr=None,
activation='tanh',
gate_activation='sigmoid'):
"""
......@@ -731,8 +731,8 @@ def gru_unit(input,
input (Variable): The fc transformed input value of current step.
hidden (Variable): The hidden value of lstm unit from previous step.
size (integer): The input dimension value.
weight (ParamAttr): The weight parameters for gru unit. Default: None
bias (ParamAttr): The bias parameters for gru unit. Default: None
param_attr (ParamAttr): The weight parameters for gru unit. Default: None
bias_attr (ParamAttr): The bias parameters for gru unit. Default: None
activation (string): The activation type for cell (actNode).
Default: 'tanh'
gate_activation (string): The activation type for gates (actGate).
......@@ -764,34 +764,31 @@ def gru_unit(input,
size = size / 3
# create weight
if weight is None:
weight = helper.create_parameter(
attr=helper.param_attr, shape=[size, 3 * size], dtype=dtype)
weight = helper.create_parameter(
attr=helper.param_attr, shape=[size, 3 * size], dtype=dtype)
gate = helper.create_tmp_variable(dtype)
reset_hidden_pre = helper.create_tmp_variable(dtype)
updated_hidden = helper.create_tmp_variable(dtype)
inputs = {'Input': input, 'HiddenPrev': hidden, 'Weight': weight}
# create bias
if bias is None:
if helper.bias_attr:
bias_size = [1, 3 * size]
bias = helper.create_parameter(
attr=helper.bias_attr, shape=bias_size, dtype=dtype, is_bias=True)
gate = helper.create_tmp_variable(dtype)
reset_hidden_pre = helper.create_tmp_variable(dtype)
updated_hidden = helper.create_tmp_variable(dtype)
inputs['Bias'] = bias
helper.append_op(
type='gru_unit',
inputs={'Input': input,
'HiddenPrev': hidden,
'Weight': weight},
inputs=inputs,
outputs={
'Gate': gate,
'ResetHiddenPrev': reset_hidden_pre,
'Hidden': updated_hidden,
},
attrs={
'activation': 0,
'gate_activation': 1,
'activation': 2, # tanh
'gate_activation': 1, # sigmoid
})
return updated_hidden, reset_hidden_pre, gate
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册