From 8c11d3fed6a9f7afeca0596cc6759e2ff034ccf8 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Tue, 6 Nov 2018 15:53:01 +0800 Subject: [PATCH] clean up --- .../framework/details/fast_threaded_ssa_graph_executor.cc | 2 +- .../fluid/framework/details/multi_devices_graph_check_pass.cc | 2 +- paddle/fluid/framework/details/multi_devices_graph_pass.cc | 1 + .../fluid/framework/details/multi_devices_graph_print_pass.cc | 2 +- paddle/fluid/framework/details/reference_count_pass.cc | 2 +- paddle/fluid/framework/details/threaded_ssa_graph_executor.cc | 2 +- paddle/fluid/framework/ir/graph_helper.h | 2 +- python/paddle/fluid/tests/unittests/test_reader_reset.py | 4 ---- 8 files changed, 7 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc index 230ad7ac0bf..9a0e84e3fb2 100644 --- a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc @@ -33,7 +33,7 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor( pool_(strategy.num_threads_ + 1), // add one more thread for generate op_deps fetch_ctxs_(places) { - for (auto &op : ir::GetFilteredNodes(*graph_)) { + for (auto &op : ir::FilterByNodeWrapper(*graph_)) { int dep = static_cast(op->NotReadyInputSize()); op_deps_.emplace(op, dep); if (dep == 0) { diff --git a/paddle/fluid/framework/details/multi_devices_graph_check_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_check_pass.cc index 5b03e9f9604..c8ea1880463 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_check_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_check_pass.cc @@ -46,7 +46,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { insert_pending_var(var); } - for (OpHandleBase *op : ir::GetFilteredNodes(*graph)) { + for (OpHandleBase *op : ir::FilterByNodeWrapper(*graph)) { if (op->Inputs().empty()) { ready_ops.insert(op); } else { diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 58b7ea0b9e9..67d29a42d75 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -36,6 +36,7 @@ namespace framework { namespace details { namespace { +// TODO(panyx0718): Clean this up as well. // all operators. NOTE that even we use a vector here, the operators is // unordered. typedef std::vector GraphOps; diff --git a/paddle/fluid/framework/details/multi_devices_graph_print_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_print_pass.cc index ae50905f761..8f92f0948d7 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_print_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_print_pass.cc @@ -63,7 +63,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph, }); size_t op_id = 0; - for (auto &op : ir::GetFilteredNodes(graph)) { + for (auto &op : ir::FilterByNodeWrapper(graph)) { std::string op_name = "op_" + std::to_string(op_id++); sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]" << std::endl; diff --git a/paddle/fluid/framework/details/reference_count_pass.cc b/paddle/fluid/framework/details/reference_count_pass.cc index 42b248650ea..08783fb5f8b 100644 --- a/paddle/fluid/framework/details/reference_count_pass.cc +++ b/paddle/fluid/framework/details/reference_count_pass.cc @@ -157,7 +157,7 @@ std::unique_ptr ReferenceCountPass::ApplyImpl( } }; - auto all_ops = ir::GetFilteredNodes(*graph); + auto all_ops = ir::FilterByNodeWrapper(*graph); for (auto &op : all_ops) { auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs()); auto out_var_names = get_ref_cnts_from_compute_op(op, op->Outputs()); diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 97b6d4a1ac7..39f5eca53c9 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -60,7 +60,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( InsertPendingVar(&pending_vars, ready_vars.get(), var); } - for (auto &op : ir::GetFilteredNodes(*graph_)) { + for (auto &op : ir::FilterByNodeWrapper(*graph_)) { if (op->Inputs().empty()) { // Special case, Op has no input. ready_ops.insert(op); } else { diff --git a/paddle/fluid/framework/ir/graph_helper.h b/paddle/fluid/framework/ir/graph_helper.h index a107aaf7f57..8d92c406689 100644 --- a/paddle/fluid/framework/ir/graph_helper.h +++ b/paddle/fluid/framework/ir/graph_helper.h @@ -38,7 +38,7 @@ std::map> BuildOperationAdjList( const Graph &graph); template -std::vector GetFilteredNodes(const Graph &graph) { +std::vector FilterByNodeWrapper(const Graph &graph) { std::vector ret; for (ir::Node *n : graph.Nodes()) { if (n->IsWrappedBy()) ret.push_back(&n->Wrapper()); diff --git a/python/paddle/fluid/tests/unittests/test_reader_reset.py b/python/paddle/fluid/tests/unittests/test_reader_reset.py index fbf6e12b003..e97a05b6f92 100644 --- a/python/paddle/fluid/tests/unittests/test_reader_reset.py +++ b/python/paddle/fluid/tests/unittests/test_reader_reset.py @@ -14,7 +14,6 @@ from __future__ import print_function import os -import sys import paddle.fluid as fluid import paddle import numpy as np @@ -91,13 +90,11 @@ class TestReaderReset(unittest.TestCase): try: data_val, label_val = parallel_exe.run(fetch_list, return_numpy=True) - sys.stderr.write('fetched %s\n' % label_val) ins_num = data_val.shape[0] broadcasted_label = np.ones((ins_num, ) + tuple( self.ins_shape)) * label_val.reshape((ins_num, 1)) self.assertEqual(data_val.all(), broadcasted_label.all()) for l in label_val: - sys.stderr.write('label_val: %s\n' % l[0]) self.assertFalse(data_appeared[l[0]]) data_appeared[l[0]] = True @@ -107,7 +104,6 @@ class TestReaderReset(unittest.TestCase): data_appeared = data_appeared[:-parallel_exe.device_count * self.batch_size] for i in data_appeared: - sys.stderr.write('appeared %s\n' % i) self.assertTrue(i) if pass_count < self.test_pass_num: data_appeared = [False] * self.total_ins_num -- GitLab