未验证 提交 59bc4c46 编写于 作者: Q QI JUN 提交者: GitHub

fix dynamic rnn bug in GPU (#7480)

上级 5ad1aef0
......@@ -1220,7 +1220,8 @@ class DynamicRNN(object):
self.lod_rank_table = None
self.max_seq_len = None
self.step_idx = None
self.zero_idx = fill_constant(shape=[1], value=0, dtype='int64')
self.zero_idx = fill_constant(
shape=[1], value=0, dtype='int64', force_cpu=True)
self.mem_dict = dict()
self.output_array = []
self.outputs = []
......@@ -1275,7 +1276,8 @@ class DynamicRNN(object):
def block(self):
if self.status != DynamicRNN.BEFORE_RNN:
raise ValueError("rnn.block() can only be invoke once")
self.step_idx = fill_constant(shape=[1], dtype='int64', value=0)
self.step_idx = fill_constant(
shape=[1], dtype='int64', value=0, force_cpu=True)
self.step_idx.stop_gradient = False
self.status = DynamicRNN.IN_RNN
with self.while_op.block():
......
......@@ -180,7 +180,7 @@ def assign(input, output):
return output
def fill_constant(shape, dtype, value, out=None):
def fill_constant(shape, dtype, value, force_cpu=False, out=None):
"""
**fill_constant**
......@@ -211,9 +211,12 @@ def fill_constant(shape, dtype, value, out=None):
type='fill_constant',
inputs={},
outputs={'Out': [out]},
attrs={'shape': shape,
attrs={
'shape': shape,
'dtype': out.dtype,
'value': float(value)})
'value': float(value),
'force_cpu': force_cpu
})
out.stop_gradient = True
return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册