From 59bc4c4600067381fd82c866cde53e86b29d6309 Mon Sep 17 00:00:00 2001 From: QI JUN Date: Mon, 15 Jan 2018 10:31:08 +0800 Subject: [PATCH] fix dynamic rnn bug in GPU (#7480) --- python/paddle/v2/fluid/layers/control_flow.py | 6 ++++-- python/paddle/v2/fluid/layers/tensor.py | 11 +++++++---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/python/paddle/v2/fluid/layers/control_flow.py b/python/paddle/v2/fluid/layers/control_flow.py index 4b363ecbe78..bef9602bb7c 100644 --- a/python/paddle/v2/fluid/layers/control_flow.py +++ b/python/paddle/v2/fluid/layers/control_flow.py @@ -1220,7 +1220,8 @@ class DynamicRNN(object): self.lod_rank_table = None self.max_seq_len = None self.step_idx = None - self.zero_idx = fill_constant(shape=[1], value=0, dtype='int64') + self.zero_idx = fill_constant( + shape=[1], value=0, dtype='int64', force_cpu=True) self.mem_dict = dict() self.output_array = [] self.outputs = [] @@ -1275,7 +1276,8 @@ class DynamicRNN(object): def block(self): if self.status != DynamicRNN.BEFORE_RNN: raise ValueError("rnn.block() can only be invoke once") - self.step_idx = fill_constant(shape=[1], dtype='int64', value=0) + self.step_idx = fill_constant( + shape=[1], dtype='int64', value=0, force_cpu=True) self.step_idx.stop_gradient = False self.status = DynamicRNN.IN_RNN with self.while_op.block(): diff --git a/python/paddle/v2/fluid/layers/tensor.py b/python/paddle/v2/fluid/layers/tensor.py index 2608a8d1151..2217c56b62a 100644 --- a/python/paddle/v2/fluid/layers/tensor.py +++ b/python/paddle/v2/fluid/layers/tensor.py @@ -180,7 +180,7 @@ def assign(input, output): return output -def fill_constant(shape, dtype, value, out=None): +def fill_constant(shape, dtype, value, force_cpu=False, out=None): """ **fill_constant** @@ -211,9 +211,12 @@ def fill_constant(shape, dtype, value, out=None): type='fill_constant', inputs={}, outputs={'Out': [out]}, - attrs={'shape': shape, - 'dtype': out.dtype, - 'value': float(value)}) + attrs={ + 'shape': shape, + 'dtype': out.dtype, + 'value': float(value), + 'force_cpu': force_cpu + }) out.stop_gradient = True return out -- GitLab