未验证 提交 14e83376 编写于 作者: Y Yancey 提交者: GitHub

expose h0 in dynamic_lstm (#11391)

* expose h0 in dynamic_lstm

* update by comment

* update by comment

* h0 to H0
上级 8453740b
...@@ -261,9 +261,10 @@ def embedding(input, ...@@ -261,9 +261,10 @@ def embedding(input,
return tmp return tmp
# TODO(qijun): expose H0 and C0
def dynamic_lstm(input, def dynamic_lstm(input,
size, size,
h_0=None,
c_0=None,
param_attr=None, param_attr=None,
bias_attr=None, bias_attr=None,
use_peepholes=True, use_peepholes=True,
...@@ -324,6 +325,13 @@ def dynamic_lstm(input, ...@@ -324,6 +325,13 @@ def dynamic_lstm(input,
(T X 4D), where T is the total time steps in this (T X 4D), where T is the total time steps in this
mini-batch, D is the hidden size. mini-batch, D is the hidden size.
size(int): 4 * hidden size. size(int): 4 * hidden size.
h_0(Variable): The initial hidden state is an optional input, default is zero.
This is a tensor with shape (N x D), where N is the
batch size and D is the hidden size.
c_0(Variable): The initial cell state is an optional input, default is zero.
This is a tensor with shape (N x D), where N is the
batch size. `h_0` and `c_0` can be NULL but only at the same time.
param_attr(ParamAttr|None): The parameter attribute for the learnable param_attr(ParamAttr|None): The parameter attribute for the learnable
hidden-hidden weights. hidden-hidden weights.
...@@ -387,12 +395,20 @@ def dynamic_lstm(input, ...@@ -387,12 +395,20 @@ def dynamic_lstm(input,
cell = helper.create_tmp_variable(dtype) cell = helper.create_tmp_variable(dtype)
batch_gate = helper.create_tmp_variable(dtype) batch_gate = helper.create_tmp_variable(dtype)
batch_cell_pre_act = helper.create_tmp_variable(dtype) batch_cell_pre_act = helper.create_tmp_variable(dtype)
inputs = {'Input': input, 'Weight': weight, 'Bias': bias}
batch_size = input.shape[0]
if h_0:
assert h_0.shape == (batch_size, size), \
'The shape of h0 should be (batch_size, %d)' % size
inputs['H0'] = h_0
if c_0:
assert c_0.shape == (batch_size, size), \
'The shape of c0 should be (batch_size, %d)' % size
inputs['C0'] = c_0
helper.append_op( helper.append_op(
type='lstm', type='lstm',
inputs={'Input': input, inputs=inputs,
'Weight': weight,
'Bias': bias},
outputs={ outputs={
'Hidden': hidden, 'Hidden': hidden,
'Cell': cell, 'Cell': cell,
...@@ -677,11 +693,13 @@ def dynamic_gru(input, ...@@ -677,11 +693,13 @@ def dynamic_gru(input,
attr=helper.param_attr, shape=[size, 3 * size], dtype=dtype) attr=helper.param_attr, shape=[size, 3 * size], dtype=dtype)
bias = helper.create_parameter( bias = helper.create_parameter(
attr=helper.bias_attr, shape=[1, 3 * size], dtype=dtype, is_bias=True) attr=helper.bias_attr, shape=[1, 3 * size], dtype=dtype, is_bias=True)
batch_size = input.shape[0]
inputs = {'Input': input, 'Weight': weight, 'Bias': bias} inputs = {'Input': input, 'Weight': weight, 'Bias': bias}
if h_0 != None: if h_0 != None:
assert h_0.shape == ( assert h_0.shape == (
size, size), 'The shape of h0 should be(%d, %d)' % (size, size) batch_size, size
inputs['h0'] = h_0 ), 'The shape of h0 should be(batch_size, %d)' % size
inputs['H0'] = h_0
hidden = helper.create_tmp_variable(dtype) hidden = helper.create_tmp_variable(dtype)
batch_gate = helper.create_tmp_variable(dtype) batch_gate = helper.create_tmp_variable(dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册