From a3b27e323759b83d428de7c647baf5ee9822948e Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Tue, 6 Nov 2018 13:54:41 +0800 Subject: [PATCH] fix test=develop --- paddle/fluid/framework/details/ssa_graph_executor.cc | 3 +++ python/paddle/fluid/tests/unittests/test_reader_reset.py | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/paddle/fluid/framework/details/ssa_graph_executor.cc b/paddle/fluid/framework/details/ssa_graph_executor.cc index d283a34ba9..af2cbd5c87 100644 --- a/paddle/fluid/framework/details/ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/ssa_graph_executor.cc @@ -26,6 +26,9 @@ void ClearFetchOp(ir::Graph* graph, std::vector* fetch_ops) { for (auto& out_var : op->Node()->outputs) { graph->RemoveNode(out_var); } + for (auto& in_var : op->Inputs()) { + in_var->RemoveOutput(op, op->Node()); + } graph->RemoveNode(op->Node()); } fetch_ops->clear(); diff --git a/python/paddle/fluid/tests/unittests/test_reader_reset.py b/python/paddle/fluid/tests/unittests/test_reader_reset.py index e97a05b6f9..fbf6e12b00 100644 --- a/python/paddle/fluid/tests/unittests/test_reader_reset.py +++ b/python/paddle/fluid/tests/unittests/test_reader_reset.py @@ -14,6 +14,7 @@ from __future__ import print_function import os +import sys import paddle.fluid as fluid import paddle import numpy as np @@ -90,11 +91,13 @@ class TestReaderReset(unittest.TestCase): try: data_val, label_val = parallel_exe.run(fetch_list, return_numpy=True) + sys.stderr.write('fetched %s\n' % label_val) ins_num = data_val.shape[0] broadcasted_label = np.ones((ins_num, ) + tuple( self.ins_shape)) * label_val.reshape((ins_num, 1)) self.assertEqual(data_val.all(), broadcasted_label.all()) for l in label_val: + sys.stderr.write('label_val: %s\n' % l[0]) self.assertFalse(data_appeared[l[0]]) data_appeared[l[0]] = True @@ -104,6 +107,7 @@ class TestReaderReset(unittest.TestCase): data_appeared = data_appeared[:-parallel_exe.device_count * self.batch_size] for i in data_appeared: + sys.stderr.write('appeared %s\n' % i) self.assertTrue(i) if pass_count < self.test_pass_num: data_appeared = [False] * self.total_ins_num -- GitLab