提交 8c11d3fe 编写于 作者: X Xin Pan

clean up

上级 0a896505
...@@ -33,7 +33,7 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor( ...@@ -33,7 +33,7 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
pool_(strategy.num_threads_ + pool_(strategy.num_threads_ +
1), // add one more thread for generate op_deps 1), // add one more thread for generate op_deps
fetch_ctxs_(places) { fetch_ctxs_(places) {
for (auto &op : ir::GetFilteredNodes<OpHandleBase>(*graph_)) { for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
int dep = static_cast<int>(op->NotReadyInputSize()); int dep = static_cast<int>(op->NotReadyInputSize());
op_deps_.emplace(op, dep); op_deps_.emplace(op, dep);
if (dep == 0) { if (dep == 0) {
......
...@@ -46,7 +46,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { ...@@ -46,7 +46,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
insert_pending_var(var); insert_pending_var(var);
} }
for (OpHandleBase *op : ir::GetFilteredNodes<OpHandleBase>(*graph)) { for (OpHandleBase *op : ir::FilterByNodeWrapper<OpHandleBase>(*graph)) {
if (op->Inputs().empty()) { if (op->Inputs().empty()) {
ready_ops.insert(op); ready_ops.insert(op);
} else { } else {
......
...@@ -36,6 +36,7 @@ namespace framework { ...@@ -36,6 +36,7 @@ namespace framework {
namespace details { namespace details {
namespace { namespace {
// TODO(panyx0718): Clean this up as well.
// all operators. NOTE that even we use a vector here, the operators is // all operators. NOTE that even we use a vector here, the operators is
// unordered. // unordered.
typedef std::vector<OpHandleBase *> GraphOps; typedef std::vector<OpHandleBase *> GraphOps;
......
...@@ -63,7 +63,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph, ...@@ -63,7 +63,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
}); });
size_t op_id = 0; size_t op_id = 0;
for (auto &op : ir::GetFilteredNodes<OpHandleBase>(graph)) { for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(graph)) {
std::string op_name = "op_" + std::to_string(op_id++); std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]" sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
<< std::endl; << std::endl;
......
...@@ -157,7 +157,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -157,7 +157,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
} }
}; };
auto all_ops = ir::GetFilteredNodes<OpHandleBase>(*graph); auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
for (auto &op : all_ops) { for (auto &op : all_ops) {
auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs()); auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs());
auto out_var_names = get_ref_cnts_from_compute_op(op, op->Outputs()); auto out_var_names = get_ref_cnts_from_compute_op(op, op->Outputs());
......
...@@ -60,7 +60,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -60,7 +60,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
InsertPendingVar(&pending_vars, ready_vars.get(), var); InsertPendingVar(&pending_vars, ready_vars.get(), var);
} }
for (auto &op : ir::GetFilteredNodes<OpHandleBase>(*graph_)) { for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
if (op->Inputs().empty()) { // Special case, Op has no input. if (op->Inputs().empty()) { // Special case, Op has no input.
ready_ops.insert(op); ready_ops.insert(op);
} else { } else {
......
...@@ -38,7 +38,7 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList( ...@@ -38,7 +38,7 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
const Graph &graph); const Graph &graph);
template <typename T> template <typename T>
std::vector<T *> GetFilteredNodes(const Graph &graph) { std::vector<T *> FilterByNodeWrapper(const Graph &graph) {
std::vector<T *> ret; std::vector<T *> ret;
for (ir::Node *n : graph.Nodes()) { for (ir::Node *n : graph.Nodes()) {
if (n->IsWrappedBy<T>()) ret.push_back(&n->Wrapper<T>()); if (n->IsWrappedBy<T>()) ret.push_back(&n->Wrapper<T>());
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
from __future__ import print_function from __future__ import print_function
import os import os
import sys
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle import paddle
import numpy as np import numpy as np
...@@ -91,13 +90,11 @@ class TestReaderReset(unittest.TestCase): ...@@ -91,13 +90,11 @@ class TestReaderReset(unittest.TestCase):
try: try:
data_val, label_val = parallel_exe.run(fetch_list, data_val, label_val = parallel_exe.run(fetch_list,
return_numpy=True) return_numpy=True)
sys.stderr.write('fetched %s\n' % label_val)
ins_num = data_val.shape[0] ins_num = data_val.shape[0]
broadcasted_label = np.ones((ins_num, ) + tuple( broadcasted_label = np.ones((ins_num, ) + tuple(
self.ins_shape)) * label_val.reshape((ins_num, 1)) self.ins_shape)) * label_val.reshape((ins_num, 1))
self.assertEqual(data_val.all(), broadcasted_label.all()) self.assertEqual(data_val.all(), broadcasted_label.all())
for l in label_val: for l in label_val:
sys.stderr.write('label_val: %s\n' % l[0])
self.assertFalse(data_appeared[l[0]]) self.assertFalse(data_appeared[l[0]])
data_appeared[l[0]] = True data_appeared[l[0]] = True
...@@ -107,7 +104,6 @@ class TestReaderReset(unittest.TestCase): ...@@ -107,7 +104,6 @@ class TestReaderReset(unittest.TestCase):
data_appeared = data_appeared[:-parallel_exe.device_count * data_appeared = data_appeared[:-parallel_exe.device_count *
self.batch_size] self.batch_size]
for i in data_appeared: for i in data_appeared:
sys.stderr.write('appeared %s\n' % i)
self.assertTrue(i) self.assertTrue(i)
if pass_count < self.test_pass_num: if pass_count < self.test_pass_num:
data_appeared = [False] * self.total_ins_num data_appeared = [False] * self.total_ins_num
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册