未验证 提交 3a21980b 编写于 作者: Z Zeng Jinle 提交者: GitHub

add reader dependency pass, test=develop (#23301)

上级 69e3f993
......@@ -73,7 +73,8 @@ set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto
eager_deletion_pass
buffer_shared_inplace_op_pass
buffer_shared_cross_op_memory_reuse_pass
set_reader_device_info_pass)
set_reader_device_info_pass
add_reader_dependency_pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
......
......@@ -65,6 +65,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendOpFusePasses();
AppendPrintGraphPass("graph_viz_pass", "_fused_graph");
AppendAddReaderDependencyPass();
AppendMultiDevPass();
AppendSetReaderDeviceIndexPass();
AppendMultiGraphOptPasses();
......@@ -203,6 +204,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
VLOG(1) << "CollectiveContext:" << context->String();
}
void AppendAddReaderDependencyPass() {
AppendPass("add_reader_dependency_pass");
}
// Convert graph to run on multi-devices.
void AppendMultiDevPass() {
ir::Pass *multi_devices_pass = nullptr;
......@@ -442,6 +447,7 @@ USE_PASS(fuse_momentum_op_pass);
USE_PASS(fuse_all_reduce_op_pass);
USE_PASS(runtime_context_cache_pass);
USE_PASS(set_reader_device_index_pass);
USE_PASS(add_reader_dependency_pass);
#ifdef PADDLE_WITH_MKLDNN
USE_PASS(mkldnn_placement_pass);
#endif
......
......@@ -16,3 +16,4 @@ cc_library(set_reader_device_info_pass SRCS set_reader_device_info_pass.cc DEPS
cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle)
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS all_reduce_op_handle graph graph_helper pass)
cc_library(backward_optimizer_op_deps_pass SRCS backward_optimizer_op_deps_pass.cc DEPS graph graph_helper pass)
cc_library(add_reader_dependency_pass SRCS add_reader_dependency_pass.cc DEPS graph graph_helper pass)
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <queue>
#include <unordered_set>
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class AddReaderDependencyPass : public Pass {
protected:
void ApplyImpl(Graph *graph) const override;
};
static std::unordered_set<Node *> FindAllPrecedingOpNodes(Node *node) {
std::unordered_set<Node *> result;
std::queue<Node *> q;
q.push(node);
while (!q.empty()) {
auto *cur_node = q.front();
q.pop();
for (auto &in_var : cur_node->inputs) {
for (auto &in_op : in_var->inputs) {
if (result.count(in_op) == 0 && in_op != node) {
result.insert(in_op);
q.push(in_op);
}
}
}
}
return result;
}
void AddReaderDependencyPass::ApplyImpl(Graph *graph) const {
const auto &nodes = graph->Nodes();
std::unordered_set<Node *> ops;
std::unordered_set<Node *> read_ops;
for (auto &n : nodes) {
if (n->IsOp() && n->Op()) {
ops.insert(n);
if (n->Op()->Type() == "read") {
read_ops.insert(n);
}
}
}
VLOG(10) << "Found " << read_ops.size() << " read op(s)";
if (read_ops.empty()) {
return;
}
// Find all startup ops
std::unordered_set<Node *> out_ops;
for (auto &op : ops) {
for (auto &out_var : op->outputs) {
for (auto &out_op : out_var->outputs) {
out_ops.insert(out_op);
}
}
}
for (auto &out_op : out_ops) {
ops.erase(out_op);
}
VLOG(10) << "Found " << ops.size() << " startup ops";
for (auto &read_op : read_ops) {
auto preceding_ops = FindAllPrecedingOpNodes(read_op);
for (auto &startup_op : ops) {
if (read_op == startup_op || preceding_ops.count(startup_op) > 0) {
VLOG(10) << "Startup op " << startup_op->Op()->Type() << " is skipped";
continue;
}
auto *dep_var = graph->CreateControlDepVar();
read_op->outputs.push_back(dep_var);
startup_op->inputs.push_back(dep_var);
dep_var->inputs.push_back(read_op);
dep_var->outputs.push_back(startup_op);
VLOG(10) << "Add dependencies between " << read_op->Op()->Type()
<< " and " << startup_op->Op()->Type();
}
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(add_reader_dependency_pass,
paddle::framework::ir::AddReaderDependencyPass);
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from paddle.fluid.layer_helper import LayerHelper
import unittest
import numpy as np
import time
def inplace_add(x, bias):
helper = LayerHelper('scale', **locals())
helper.append_op(
type='scale',
inputs={'X': [x]},
outputs={'Out': [x]},
attrs={'bias': bias})
return x
class TestAddReaderDependency(unittest.TestCase):
def setUp(self):
self.batch_num = 3
self.sleep_time = 2
self.use_double_buffer = True
def test_main(self):
self.run_main(fluid.CPUPlace())
if fluid.is_compiled_with_cuda():
self.run_main(fluid.CUDAPlace(0))
def run_main(self, place):
with fluid.program_guard(fluid.Program(), fluid.Program()):
with fluid.scope_guard(fluid.Scope()):
tmp_in = fluid.data(name='tmp_in', dtype='float32', shape=[1])
loader = fluid.io.DataLoader.from_generator(
feed_list=[tmp_in],
capacity=16,
iterable=False,
use_double_buffer=self.use_double_buffer)
def data_source():
for _ in range(self.batch_num):
time.sleep(self.sleep_time) # sleep some times
yield np.random.uniform(
low=-1, high=1, size=[1]).astype('float32'),
persistable_in = fluid.data(
name='persistable_in', dtype='float32', shape=[1])
persistable_in.persistable = True
persistable_in = inplace_add(persistable_in, bias=1)
prog = fluid.CompiledProgram(fluid.default_main_program())
exe = fluid.Executor(place)
loader.set_batch_generator(data_source)
loader.start()
batch_id = 0
try:
while True:
if batch_id == 0:
feed = {
persistable_in.name:
np.array([-1]).astype('float32')
}
else:
feed = None
ret, = exe.run(prog,
feed=feed,
fetch_list=[persistable_in])
self.assertEqual(ret.shape, (1, ))
self.assertEqual(ret[0], batch_id)
batch_id += 1
except fluid.core.EOFException:
loader.reset()
self.assertEqual(batch_id, self.batch_num)
t = fluid.global_scope().find_var(
persistable_in.name).get_tensor()
t_val = np.array(t)
self.assertEqual(t_val.shape, (1, ))
self.assertEqual(t_val[0] + 1, batch_id)
class TestAddReaderDependencyWithoutDoubleBuffer(TestAddReaderDependency):
def setUp(self):
self.batch_num = 3
self.sleep_time = 2
self.use_double_buffer = False
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册