提交 8c11d3fe 编写于 作者: X Xin Pan

clean up

上级 0a896505
......@@ -33,7 +33,7 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
pool_(strategy.num_threads_ +
1), // add one more thread for generate op_deps
fetch_ctxs_(places) {
for (auto &op : ir::GetFilteredNodes<OpHandleBase>(*graph_)) {
for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
int dep = static_cast<int>(op->NotReadyInputSize());
op_deps_.emplace(op, dep);
if (dep == 0) {
......
......@@ -46,7 +46,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
insert_pending_var(var);
}
for (OpHandleBase *op : ir::GetFilteredNodes<OpHandleBase>(*graph)) {
for (OpHandleBase *op : ir::FilterByNodeWrapper<OpHandleBase>(*graph)) {
if (op->Inputs().empty()) {
ready_ops.insert(op);
} else {
......
......@@ -36,6 +36,7 @@ namespace framework {
namespace details {
namespace {
// TODO(panyx0718): Clean this up as well.
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef std::vector<OpHandleBase *> GraphOps;
......
......@@ -63,7 +63,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
});
size_t op_id = 0;
for (auto &op : ir::GetFilteredNodes<OpHandleBase>(graph)) {
for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(graph)) {
std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
<< std::endl;
......
......@@ -157,7 +157,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
}
};
auto all_ops = ir::GetFilteredNodes<OpHandleBase>(*graph);
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
for (auto &op : all_ops) {
auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs());
auto out_var_names = get_ref_cnts_from_compute_op(op, op->Outputs());
......
......@@ -60,7 +60,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
InsertPendingVar(&pending_vars, ready_vars.get(), var);
}
for (auto &op : ir::GetFilteredNodes<OpHandleBase>(*graph_)) {
for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
if (op->Inputs().empty()) { // Special case, Op has no input.
ready_ops.insert(op);
} else {
......
......@@ -38,7 +38,7 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
const Graph &graph);
template <typename T>
std::vector<T *> GetFilteredNodes(const Graph &graph) {
std::vector<T *> FilterByNodeWrapper(const Graph &graph) {
std::vector<T *> ret;
for (ir::Node *n : graph.Nodes()) {
if (n->IsWrappedBy<T>()) ret.push_back(&n->Wrapper<T>());
......
......@@ -14,7 +14,6 @@
from __future__ import print_function
import os
import sys
import paddle.fluid as fluid
import paddle
import numpy as np
......@@ -91,13 +90,11 @@ class TestReaderReset(unittest.TestCase):
try:
data_val, label_val = parallel_exe.run(fetch_list,
return_numpy=True)
sys.stderr.write('fetched %s\n' % label_val)
ins_num = data_val.shape[0]
broadcasted_label = np.ones((ins_num, ) + tuple(
self.ins_shape)) * label_val.reshape((ins_num, 1))
self.assertEqual(data_val.all(), broadcasted_label.all())
for l in label_val:
sys.stderr.write('label_val: %s\n' % l[0])
self.assertFalse(data_appeared[l[0]])
data_appeared[l[0]] = True
......@@ -107,7 +104,6 @@ class TestReaderReset(unittest.TestCase):
data_appeared = data_appeared[:-parallel_exe.device_count *
self.batch_size]
for i in data_appeared:
sys.stderr.write('appeared %s\n' % i)
self.assertTrue(i)
if pass_count < self.test_pass_num:
data_appeared = [False] * self.total_ins_num
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册