提交 a3b27e32 编写于 作者: X Xin Pan

fix

test=develop
上级 f25eb9a7
...@@ -26,6 +26,9 @@ void ClearFetchOp(ir::Graph* graph, std::vector<FetchOpHandle*>* fetch_ops) { ...@@ -26,6 +26,9 @@ void ClearFetchOp(ir::Graph* graph, std::vector<FetchOpHandle*>* fetch_ops) {
for (auto& out_var : op->Node()->outputs) { for (auto& out_var : op->Node()->outputs) {
graph->RemoveNode(out_var); graph->RemoveNode(out_var);
} }
for (auto& in_var : op->Inputs()) {
in_var->RemoveOutput(op, op->Node());
}
graph->RemoveNode(op->Node()); graph->RemoveNode(op->Node());
} }
fetch_ops->clear(); fetch_ops->clear();
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from __future__ import print_function from __future__ import print_function
import os import os
import sys
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle import paddle
import numpy as np import numpy as np
...@@ -90,11 +91,13 @@ class TestReaderReset(unittest.TestCase): ...@@ -90,11 +91,13 @@ class TestReaderReset(unittest.TestCase):
try: try:
data_val, label_val = parallel_exe.run(fetch_list, data_val, label_val = parallel_exe.run(fetch_list,
return_numpy=True) return_numpy=True)
sys.stderr.write('fetched %s\n' % label_val)
ins_num = data_val.shape[0] ins_num = data_val.shape[0]
broadcasted_label = np.ones((ins_num, ) + tuple( broadcasted_label = np.ones((ins_num, ) + tuple(
self.ins_shape)) * label_val.reshape((ins_num, 1)) self.ins_shape)) * label_val.reshape((ins_num, 1))
self.assertEqual(data_val.all(), broadcasted_label.all()) self.assertEqual(data_val.all(), broadcasted_label.all())
for l in label_val: for l in label_val:
sys.stderr.write('label_val: %s\n' % l[0])
self.assertFalse(data_appeared[l[0]]) self.assertFalse(data_appeared[l[0]])
data_appeared[l[0]] = True data_appeared[l[0]] = True
...@@ -104,6 +107,7 @@ class TestReaderReset(unittest.TestCase): ...@@ -104,6 +107,7 @@ class TestReaderReset(unittest.TestCase):
data_appeared = data_appeared[:-parallel_exe.device_count * data_appeared = data_appeared[:-parallel_exe.device_count *
self.batch_size] self.batch_size]
for i in data_appeared: for i in data_appeared:
sys.stderr.write('appeared %s\n' % i)
self.assertTrue(i) self.assertTrue(i)
if pass_count < self.test_pass_num: if pass_count < self.test_pass_num:
data_appeared = [False] * self.total_ins_num data_appeared = [False] * self.total_ins_num
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册