diff --git a/paddle/fluid/framework/details/multi_devices_helper.cc b/paddle/fluid/framework/details/multi_devices_helper.cc index 79279f1b1435bd5e89ecf7af68aab25eb8ab5baf..805fd9a969d7a63244e66d54fd79b51633489f24 100644 --- a/paddle/fluid/framework/details/multi_devices_helper.cc +++ b/paddle/fluid/framework/details/multi_devices_helper.cc @@ -105,38 +105,36 @@ static size_t GetUniqueDeviceIdOfOp(const details::OpHandleBase &op) { return dev_idx; } -/** - * This function tries to separate the original graph into multiple graphs, in - * which each graph would only run on single device. This is usually used to - * separate a data-parallel inference graph to multiple graphs on each device. - * - * The graph can be separated into multiple single device graphs if and only if: - * - * - the graph does not contain any ops related to multi-devices communication, - * such as allreduce, send, recv, sync_batch_norm, etc. - * - * - ops on different devices do not depend on each other. That is to say, the - * graph has several disconnected sub-graphs. - */ -std::vector> TrySeparateToMultipleSingleDeviceGraphs( - ir::Graph *graph) { +static bool IsDataParallelInferenceGraphImpl( + const ir::Graph &graph, + std::unordered_map *p_op_to_dev_idx, + size_t *p_place_num) { + auto &place_num = *p_place_num; + auto &op_to_dev_idx = *p_op_to_dev_idx; + + auto clear_result = [&] { + place_num = 0; + op_to_dev_idx.clear(); + return false; + }; + + clear_result(); + // If sub-block contains multi-devices ops, we cannot separate - if (ContainMultiDeviceOp(graph->OriginProgram(), 1)) { - return {}; + if (ContainMultiDeviceOp(graph.OriginProgram(), 1)) { + return clear_result(); } - size_t place_num = 0; - auto op_handles = ir::FilterByNodeWrapper(*graph); + auto op_handles = ir::FilterByNodeWrapper(graph); if (op_handles.empty()) { - return {}; + return clear_result(); } - std::unordered_map op_to_dev_idx; for (auto &op : op_handles) { auto dev_idx = GetUniqueDeviceIdOfOp(*op); if (dev_idx == kUndefinedDevIdx) { VLOG(10) << "Op " << op->Name() << " is not determined"; - return {}; + return clear_result(); } place_num = std::max(place_num, dev_idx + 1); op_to_dev_idx[op] = dev_idx; @@ -148,7 +146,7 @@ std::vector> TrySeparateToMultipleSingleDeviceGraphs( if (in_var->GeneratedOp()) { auto iter = op_to_dev_idx.find(in_var->GeneratedOp()); if (iter == op_to_dev_idx.end() || iter->second != dev_idx) { - return {}; + return clear_result(); } } } @@ -157,7 +155,7 @@ std::vector> TrySeparateToMultipleSingleDeviceGraphs( for (auto &pending_op : out_var->PendingOps()) { auto iter = op_to_dev_idx.find(pending_op); if (iter == op_to_dev_idx.end() || iter->second != dev_idx) { - return {}; + return clear_result(); } } } @@ -171,6 +169,36 @@ std::vector> TrySeparateToMultipleSingleDeviceGraphs( "issue at https://github.com/PaddlePaddle/Paddle/issues/new. And " "we will resolve it with high priority.")); + return true; +} + +bool IsDataParallelInferenceGraph(const ir::Graph &graph) { + size_t place_num; + std::unordered_map op_to_dev_idx; + return IsDataParallelInferenceGraphImpl(graph, &op_to_dev_idx, &place_num); +} + +/** + * This function tries to separate the original graph into multiple graphs, in + * which each graph would only run on single device. This is usually used to + * separate a data-parallel inference graph to multiple graphs on each device. + * + * The graph can be separated into multiple single device graphs if and only if: + * + * - the graph does not contain any ops related to multi-devices communication, + * such as allreduce, send, recv, sync_batch_norm, etc. + * + * - ops on different devices do not depend on each other. That is to say, the + * graph has several disconnected sub-graphs. + */ +std::vector> TrySeparateToMultipleSingleDeviceGraphs( + ir::Graph *graph) { + size_t place_num; + std::unordered_map op_to_dev_idx; + if (!IsDataParallelInferenceGraphImpl(*graph, &op_to_dev_idx, &place_num)) { + return {}; + } + if (place_num == 1) { return {}; } @@ -182,8 +210,10 @@ std::vector> TrySeparateToMultipleSingleDeviceGraphs( g->Set(kGraphDepVars, new GraphDepVars()); } - for (auto &op : op_handles) { - auto dev_idx = op_to_dev_idx.at(op); + for (auto &pair : op_to_dev_idx) { + auto *op = pair.first; + auto dev_idx = pair.second; + auto *ret_graph = graphs[dev_idx].get(); auto &ret_vars = ret_graph->Get(kGraphVars)[0]; auto &ret_dummy_vars = ret_graph->Get(kGraphDepVars); diff --git a/paddle/fluid/framework/details/multi_devices_helper.h b/paddle/fluid/framework/details/multi_devices_helper.h index ab68cf53280c2dd7b7996b9c0839b3bc809860cc..4c344af09fb4f2ab0fcff1e7c6072e74dfc6d7b5 100644 --- a/paddle/fluid/framework/details/multi_devices_helper.h +++ b/paddle/fluid/framework/details/multi_devices_helper.h @@ -101,6 +101,8 @@ inline std::vector GetOpRoleVarsOrEmpty(const OpDesc &op) { return boost::get>(iter->second); } +bool IsDataParallelInferenceGraph(const ir::Graph &graph); + std::vector> TrySeparateToMultipleSingleDeviceGraphs( ir::Graph *graph); diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 4e5e4c138cb21fd3d4dd97eb9f0ef7d6fb669206..f2cc9d12ee313004cdea25538fef19bb171aa208 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -647,11 +647,22 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, #ifdef PADDLE_WITH_CUDA // TODO(Yancey1989): Remove passing in the main_program when // allreduce_seq_pass doesn't need it as the attr. + bool is_inference = details::IsDataParallelInferenceGraph(*graph); + bool has_drop_last_read_op = details::HasDropLastReadOp(*graph); + auto *pg_exe = new details::ParallelSSAGraphExecutor( exec_strategy, member_->local_scopes_, member_->local_exec_scopes_, member_->places_, graph); final_graphs = pg_exe->Graphs(); member_->executor_.reset(pg_exe); + + if (is_inference && member_->places_.size() > 1) { + member_->inference_executor_ = pg_exe; + if (!has_drop_last_read_op) { + VLOG(5) << "Enable partial feed support in inference phase"; + pg_exe->EnablePartialFeedSupport(); + } + } #else PADDLE_THROW( "Paddle should be compiled with CUDA for ParallelGraph Execution."); diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 763b04d795e372b680b9bf9c9c1c86054e3a1060..704a9b4abb880cd011c88e3266bd9adfd1592014 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -357,6 +357,7 @@ set_tests_properties(test_parallel_executor_test_while_train test_parallel_execu test_optimizer_in_control_flow test_dataloader_keep_order test_dataloader_unkeep_order test_parallel_executor_inference_feed_partial_data + test_parallel_ssa_graph_inference_feed_partial_data test_fetch_unmerged test_buffer_shared_memory_reuse_pass PROPERTIES LABELS "RUN_TYPE=DIST") diff --git a/python/paddle/fluid/tests/unittests/test_parallel_ssa_graph_inference_feed_partial_data.py b/python/paddle/fluid/tests/unittests/test_parallel_ssa_graph_inference_feed_partial_data.py new file mode 100644 index 0000000000000000000000000000000000000000..8110c0a03c5990e81426c4d9b72f0436b80628f1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_parallel_ssa_graph_inference_feed_partial_data.py @@ -0,0 +1,23 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import unittest + +fluid.core.globals()['FLAGS_enable_parallel_graph'] = 1 + +from test_parallel_executor_inference_feed_partial_data import * + +if __name__ == '__main__': + unittest.main()