未验证 提交 acef55df 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix isolated var fetch bug, test=develop (#24070)

上级 3ca700a9
...@@ -210,6 +210,32 @@ std::vector<std::unique_ptr<ir::Graph>> TrySeparateToMultipleSingleDeviceGraphs( ...@@ -210,6 +210,32 @@ std::vector<std::unique_ptr<ir::Graph>> TrySeparateToMultipleSingleDeviceGraphs(
g->Set(kGraphDepVars, new GraphDepVars()); g->Set(kGraphDepVars, new GraphDepVars());
} }
std::vector<VarHandle *> isolated_var_handles;
for (auto *node : graph->Nodes()) {
if (!node->IsWrappedBy<VarHandleBase>()) {
continue;
}
auto &var_handle_base = node->Wrapper<VarHandleBase>();
auto *var_handle = dynamic_cast<VarHandle *>(&var_handle_base);
if (var_handle && var_handle->PendingOps().empty() &&
var_handle->GeneratedOp() == nullptr) {
isolated_var_handles.emplace_back(var_handle);
}
}
for (auto *var_handle : isolated_var_handles) {
auto dev_idx = var_handle->scope_idx();
auto &src_vars = graph->Get<GraphVars>(kGraphVars)[dev_idx];
auto *dst_graph = graphs[dev_idx].get();
auto &dst_vars = dst_graph->Get<GraphVars>(kGraphVars)[0];
VLOG(10) << "Move isolated var " << var_handle->Name() << " at device "
<< dev_idx;
dst_graph->AddNode(graph->RemoveNode(var_handle->Node()).release());
dst_vars[var_handle->Name()].emplace_back(var_handle);
src_vars.erase(var_handle->Name());
}
for (auto &pair : op_to_dev_idx) { for (auto &pair : op_to_dev_idx) {
auto *op = pair.first; auto *op = pair.first;
auto dev_idx = pair.second; auto dev_idx = pair.second;
......
...@@ -44,11 +44,14 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram( ...@@ -44,11 +44,14 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
all_vars.emplace(var->Name(), var); all_vars.emplace(var->Name(), var);
} }
auto not_visited_vars = all_vars;
for (auto *op : program.Block(0).AllOps()) { for (auto *op : program.Block(0).AllOps()) {
ir::Node *node = CreateOpNode(op); ir::Node *node = CreateOpNode(op);
// For input args, reuse the same var name if it was created before. // For input args, reuse the same var name if it was created before.
// Otherwise, create a new one. // Otherwise, create a new one.
for (auto &each_var_name : op->InputArgumentNames()) { for (auto &each_var_name : op->InputArgumentNames()) {
not_visited_vars.erase(each_var_name);
ir::Node *var = nullptr; ir::Node *var = nullptr;
if (var_nodes.find(each_var_name) != var_nodes.end()) { if (var_nodes.find(each_var_name) != var_nodes.end()) {
var = var_nodes.at(each_var_name).back(); var = var_nodes.at(each_var_name).back();
...@@ -68,6 +71,7 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram( ...@@ -68,6 +71,7 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
// For output args, always create a new var. // For output args, always create a new var.
std::unordered_set<std::string> out_arg_set; std::unordered_set<std::string> out_arg_set;
for (auto &each_var_name : op->OutputArgumentNames()) { for (auto &each_var_name : op->OutputArgumentNames()) {
not_visited_vars.erase(each_var_name);
if (each_var_name != kEmptyVarName) { if (each_var_name != kEmptyVarName) {
PADDLE_ENFORCE_EQ(out_arg_set.count(each_var_name), 0, PADDLE_ENFORCE_EQ(out_arg_set.count(each_var_name), 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -91,6 +95,16 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram( ...@@ -91,6 +95,16 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
var->inputs.push_back(node); var->inputs.push_back(node);
} }
} }
for (auto &pair : not_visited_vars) {
const auto &var_name = pair.first;
auto *var_desc = pair.second;
if (var_name != kEmptyVarName) {
VLOG(10) << "Create isolated var node " << var_name;
var_nodes[var_name].push_back(CreateVarNode(var_desc));
}
}
Set<const std::vector<OpDesc *>>( Set<const std::vector<OpDesc *>>(
details::kStaleProgramOpDescs, details::kStaleProgramOpDescs,
new std::vector<OpDesc *>(program.Block(0).AllOps())); new std::vector<OpDesc *>(program.Block(0).AllOps()));
......
...@@ -174,9 +174,15 @@ void MultiDevSSAGraphBuilderBase::ApplyImpl(ir::Graph *graph) const { ...@@ -174,9 +174,15 @@ void MultiDevSSAGraphBuilderBase::ApplyImpl(ir::Graph *graph) const {
auto nodes = graph->ReleaseNodes(); auto nodes = graph->ReleaseNodes();
ir::Graph &result = *graph; ir::Graph &result = *graph;
std::vector<ir::Node *> isolated_vars;
for (auto &node : nodes) { for (auto &node : nodes) {
if (node->IsVar() && node->Var()) { if (node->IsVar() && node->Var()) {
all_vars_.emplace(node->Name(), node->Var()); all_vars_.emplace(node->Name(), node->Var());
if (node->inputs.empty() && node->outputs.empty()) {
isolated_vars.emplace_back(node.get());
}
} }
} }
...@@ -185,6 +191,10 @@ void MultiDevSSAGraphBuilderBase::ApplyImpl(ir::Graph *graph) const { ...@@ -185,6 +191,10 @@ void MultiDevSSAGraphBuilderBase::ApplyImpl(ir::Graph *graph) const {
result.Set(details::kGraphDepVars, new details::GraphDepVars); result.Set(details::kGraphDepVars, new details::GraphDepVars);
result.Set(kGraphOps, new GraphOps); result.Set(kGraphOps, new GraphOps);
for (auto *var_node : isolated_vars) {
CreateIsolatedVarNode(&result, var_node);
}
bool is_forwarding = true; bool is_forwarding = true;
for (ir::Node *node : sorted_ops) { for (ir::Node *node : sorted_ops) {
...@@ -582,6 +592,15 @@ bool MultiDevSSAGraphBuilderBase::IsSparseGradient( ...@@ -582,6 +592,15 @@ bool MultiDevSSAGraphBuilderBase::IsSparseGradient(
return all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS; return all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS;
} }
void MultiDevSSAGraphBuilderBase::CreateIsolatedVarNode(
ir::Graph *graph, ir::Node *var_node) const {
for (size_t i = 0; i < places_.size(); ++i) {
VLOG(10) << "Create isolated var node " << var_node->Name() << " at device "
<< i;
CreateOrGetLatestVarHandle(graph, var_node, places_[i], i);
}
}
void AllReduceSSAGraphBuilder::InsertCollectiveOp( void AllReduceSSAGraphBuilder::InsertCollectiveOp(
ir::Graph *result, const std::string &p_name, ir::Graph *result, const std::string &p_name,
const std::string &g_name) const { const std::string &g_name) const {
......
...@@ -94,6 +94,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass { ...@@ -94,6 +94,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
void CreateOpHandleIOs(ir::Graph *result, ir::Node *node, void CreateOpHandleIOs(ir::Graph *result, ir::Node *node,
size_t device_id) const; size_t device_id) const;
void CreateIsolatedVarNode(ir::Graph *result, ir::Node *var_node) const;
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
mutable platform::NCCLContextMap *nccl_ctxs_{nullptr}; mutable platform::NCCLContextMap *nccl_ctxs_{nullptr};
mutable platform::NCCLCommunicator *multi_nccl_ctxs_{nullptr}; mutable platform::NCCLCommunicator *multi_nccl_ctxs_{nullptr};
......
...@@ -369,6 +369,7 @@ set_tests_properties(test_parallel_executor_test_while_train test_parallel_execu ...@@ -369,6 +369,7 @@ set_tests_properties(test_parallel_executor_test_while_train test_parallel_execu
test_data_norm_op test_imperative_using_non_zero_gpu test_fuse_bn_act_pass test_data_norm_op test_imperative_using_non_zero_gpu test_fuse_bn_act_pass
test_optimizer_in_control_flow test_dataloader_keep_order test_optimizer_in_control_flow test_dataloader_keep_order
test_dataloader_unkeep_order test_dataloader_unkeep_order
test_parallel_executor_fetch_isolated_var
test_parallel_executor_inference_feed_partial_data test_parallel_executor_inference_feed_partial_data
test_parallel_ssa_graph_inference_feed_partial_data test_parallel_ssa_graph_inference_feed_partial_data
test_fetch_unmerged test_fetch_unmerged
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import six
import paddle.fluid as fluid
def enable_parallel_ssa_executor(enabled=True):
if fluid.is_compiled_with_cuda():
fluid.core.globals()['FLAGS_enable_parallel_graph'] = enabled
class TestParallelExecutorFetchIsolatedVarBase(unittest.TestCase):
def build_network(self, is_training):
x = fluid.data(name='x', shape=[-1, 10], dtype='float32')
y = fluid.data(name='y', shape=[-1, 10], dtype='float32')
fc = fluid.layers.fc(x, size=30)
loss = fluid.layers.reduce_mean(fc)
if is_training:
adam = fluid.optimizer.Adam(learning_rate=1e-3)
adam.minimize(loss)
return loss, y
def exec_strategy(self, use_experimental_executor):
strategy = fluid.ExecutionStrategy()
strategy.use_experimental_executor = use_experimental_executor
return strategy
def places(self, use_gpu, dev_cnt):
if use_gpu:
return fluid.cuda_places(list(range(dev_cnt)))
else:
return fluid.cpu_places(dev_cnt)
def test_main(self):
for use_gpu in [False, True]:
for dev_cnt in [1, 2]:
for is_training in [False, True]:
for use_experimental_executor in [False, True]:
for use_parallel_ssa_executor in [False, True]:
func = lambda: self.run_impl(use_gpu, dev_cnt, is_training, use_experimental_executor, use_parallel_ssa_executor)
self.run_func_with_guard(func)
def run_impl(self, use_gpu, dev_cnt, is_training, use_experimental_executor,
use_parallel_ssa_executor):
enable_parallel_ssa_executor(use_parallel_ssa_executor)
if fluid.is_compiled_with_cuda():
if fluid.core.globals()[
'FLAGS_enable_parallel_graph'] and not use_gpu:
return
else:
if use_gpu:
return
loss, isolated_var = self.build_network(is_training)
loss_name = loss.name if is_training else None
places = self.places(use_gpu, dev_cnt)
exe = fluid.Executor(places[0])
exe.run(fluid.default_startup_program())
prog = fluid.CompiledProgram(fluid.default_main_program(
)).with_data_parallel(
loss_name=loss_name,
exec_strategy=self.exec_strategy(use_experimental_executor),
places=places)
BATCH_SIZE = 8 * dev_cnt
for _ in six.moves.range(10):
x_np = np.random.random(size=[BATCH_SIZE, 10]).astype('float32')
y_np = np.random.random(size=[BATCH_SIZE, 10]).astype('float32')
_, y_np_fetch = exe.run(prog,
feed={'x': x_np,
'y': y_np},
fetch_list=[loss, isolated_var])
self.assertTrue(np.array_equal(y_np, y_np_fetch))
enable_parallel_ssa_executor(False)
def run_func_with_guard(self, func):
with fluid.program_guard(fluid.Program(), fluid.Program()):
with fluid.unique_name.guard():
with fluid.scope_guard(fluid.Scope()):
func()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册