未验证 提交 59bc4c46 编写于 作者: Q QI JUN 提交者: GitHub

fix dynamic rnn bug in GPU (#7480)

上级 5ad1aef0
...@@ -1220,7 +1220,8 @@ class DynamicRNN(object): ...@@ -1220,7 +1220,8 @@ class DynamicRNN(object):
self.lod_rank_table = None self.lod_rank_table = None
self.max_seq_len = None self.max_seq_len = None
self.step_idx = None self.step_idx = None
self.zero_idx = fill_constant(shape=[1], value=0, dtype='int64') self.zero_idx = fill_constant(
shape=[1], value=0, dtype='int64', force_cpu=True)
self.mem_dict = dict() self.mem_dict = dict()
self.output_array = [] self.output_array = []
self.outputs = [] self.outputs = []
...@@ -1275,7 +1276,8 @@ class DynamicRNN(object): ...@@ -1275,7 +1276,8 @@ class DynamicRNN(object):
def block(self): def block(self):
if self.status != DynamicRNN.BEFORE_RNN: if self.status != DynamicRNN.BEFORE_RNN:
raise ValueError("rnn.block() can only be invoke once") raise ValueError("rnn.block() can only be invoke once")
self.step_idx = fill_constant(shape=[1], dtype='int64', value=0) self.step_idx = fill_constant(
shape=[1], dtype='int64', value=0, force_cpu=True)
self.step_idx.stop_gradient = False self.step_idx.stop_gradient = False
self.status = DynamicRNN.IN_RNN self.status = DynamicRNN.IN_RNN
with self.while_op.block(): with self.while_op.block():
......
...@@ -180,7 +180,7 @@ def assign(input, output): ...@@ -180,7 +180,7 @@ def assign(input, output):
return output return output
def fill_constant(shape, dtype, value, out=None): def fill_constant(shape, dtype, value, force_cpu=False, out=None):
""" """
**fill_constant** **fill_constant**
...@@ -211,9 +211,12 @@ def fill_constant(shape, dtype, value, out=None): ...@@ -211,9 +211,12 @@ def fill_constant(shape, dtype, value, out=None):
type='fill_constant', type='fill_constant',
inputs={}, inputs={},
outputs={'Out': [out]}, outputs={'Out': [out]},
attrs={'shape': shape, attrs={
'dtype': out.dtype, 'shape': shape,
'value': float(value)}) 'dtype': out.dtype,
'value': float(value),
'force_cpu': force_cpu
})
out.stop_gradient = True out.stop_gradient = True
return out return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册