未验证 提交 14e83376 编写于 作者: Y Yancey 提交者: GitHub

expose h0 in dynamic_lstm (#11391)

* expose h0 in dynamic_lstm

* update by comment

* update by comment

* h0 to H0
上级 8453740b
......@@ -261,9 +261,10 @@ def embedding(input,
return tmp
# TODO(qijun): expose H0 and C0
def dynamic_lstm(input,
size,
h_0=None,
c_0=None,
param_attr=None,
bias_attr=None,
use_peepholes=True,
......@@ -324,6 +325,13 @@ def dynamic_lstm(input,
(T X 4D), where T is the total time steps in this
mini-batch, D is the hidden size.
size(int): 4 * hidden size.
h_0(Variable): The initial hidden state is an optional input, default is zero.
This is a tensor with shape (N x D), where N is the
batch size and D is the hidden size.
c_0(Variable): The initial cell state is an optional input, default is zero.
This is a tensor with shape (N x D), where N is the
batch size. `h_0` and `c_0` can be NULL but only at the same time.
param_attr(ParamAttr|None): The parameter attribute for the learnable
hidden-hidden weights.
......@@ -387,12 +395,20 @@ def dynamic_lstm(input,
cell = helper.create_tmp_variable(dtype)
batch_gate = helper.create_tmp_variable(dtype)
batch_cell_pre_act = helper.create_tmp_variable(dtype)
inputs = {'Input': input, 'Weight': weight, 'Bias': bias}
batch_size = input.shape[0]
if h_0:
assert h_0.shape == (batch_size, size), \
'The shape of h0 should be (batch_size, %d)' % size
inputs['H0'] = h_0
if c_0:
assert c_0.shape == (batch_size, size), \
'The shape of c0 should be (batch_size, %d)' % size
inputs['C0'] = c_0
helper.append_op(
type='lstm',
inputs={'Input': input,
'Weight': weight,
'Bias': bias},
inputs=inputs,
outputs={
'Hidden': hidden,
'Cell': cell,
......@@ -677,11 +693,13 @@ def dynamic_gru(input,
attr=helper.param_attr, shape=[size, 3 * size], dtype=dtype)
bias = helper.create_parameter(
attr=helper.bias_attr, shape=[1, 3 * size], dtype=dtype, is_bias=True)
batch_size = input.shape[0]
inputs = {'Input': input, 'Weight': weight, 'Bias': bias}
if h_0 != None:
assert h_0.shape == (
size, size), 'The shape of h0 should be(%d, %d)' % (size, size)
inputs['h0'] = h_0
batch_size, size
), 'The shape of h0 should be(batch_size, %d)' % size
inputs['H0'] = h_0
hidden = helper.create_tmp_variable(dtype)
batch_gate = helper.create_tmp_variable(dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册