未验证 提交 ca99427a 编写于 作者: Z Zeng Jinle 提交者: GitHub

[Cherry-pick Release/2.0] Correct reader device index (#23818)

* correct reader device index, test=develop

* fix async executor scope var initialization, test=release/2.0
上级 1b5122ba
...@@ -73,7 +73,7 @@ set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto ...@@ -73,7 +73,7 @@ set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto
eager_deletion_pass eager_deletion_pass
buffer_shared_inplace_op_pass buffer_shared_inplace_op_pass
buffer_shared_cross_op_memory_reuse_pass buffer_shared_cross_op_memory_reuse_pass
set_reader_device_info_pass set_reader_device_info_utils
add_reader_dependency_pass) add_reader_dependency_pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS}) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
......
...@@ -126,8 +126,9 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( ...@@ -126,8 +126,9 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
} }
} }
for (size_t i = 0; i < local_scopes_.size(); ++i) { for (size_t i = local_scopes_.size(); i >= 1; --i) {
InitVarsInScope(var_infos_, local_scopes_[i], local_exec_scopes_[i]); InitVarsInScope(var_infos_, local_scopes_[i - 1],
local_exec_scopes_[i - 1]);
} }
ProcessGraph(graphs_, local_scopes_[0]); ProcessGraph(graphs_, local_scopes_[0]);
} }
......
...@@ -64,7 +64,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -64,7 +64,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendAddReaderDependencyPass(); AppendAddReaderDependencyPass();
AppendMultiDevPass(); AppendMultiDevPass();
AppendSetReaderDeviceIndexPass();
AppendMultiGraphOptPasses(); AppendMultiGraphOptPasses();
AppendPassToSetMkldnnAttr("mkldnn_placement_pass"); AppendPassToSetMkldnnAttr("mkldnn_placement_pass");
...@@ -243,10 +242,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -243,10 +242,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
&strategy_); &strategy_);
} }
void AppendSetReaderDeviceIndexPass() {
AppendPass("set_reader_device_index_pass");
}
void AppendPrintGraphPass(const std::string &pass_name, void AppendPrintGraphPass(const std::string &pass_name,
const std::string &debug_file_suffix) { const std::string &debug_file_suffix) {
if (!strategy_.debug_graphviz_path_.empty()) { if (!strategy_.debug_graphviz_path_.empty()) {
...@@ -403,9 +398,6 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -403,9 +398,6 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
"GPU, skipped."; "GPU, skipped.";
continue; continue;
} }
} else if (pass->Type() == "set_reader_device_index_pass") {
pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
} }
VLOG(1) << "Start Apply Pass " << pass->Type(); VLOG(1) << "Start Apply Pass " << pass->Type();
graph = pass->Apply(graph); graph = pass->Apply(graph);
...@@ -442,7 +434,6 @@ USE_PASS(fuse_sgd_op_pass); ...@@ -442,7 +434,6 @@ USE_PASS(fuse_sgd_op_pass);
USE_PASS(fuse_momentum_op_pass); USE_PASS(fuse_momentum_op_pass);
USE_PASS(fuse_all_reduce_op_pass); USE_PASS(fuse_all_reduce_op_pass);
USE_PASS(runtime_context_cache_pass); USE_PASS(runtime_context_cache_pass);
USE_PASS(set_reader_device_index_pass);
USE_PASS(add_reader_dependency_pass); USE_PASS(add_reader_dependency_pass);
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
USE_PASS(mkldnn_placement_pass); USE_PASS(mkldnn_placement_pass);
......
...@@ -11,7 +11,7 @@ endif() ...@@ -11,7 +11,7 @@ endif()
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle fetch_barrier_op_handle ${ALL_REDUCE_OP_HANDLES} reduce_op_handle broadcast_op_handle fused_broadcast_op_handle) scale_loss_grad_op_handle rpc_op_handle fetch_barrier_op_handle ${ALL_REDUCE_OP_HANDLES} reduce_op_handle broadcast_op_handle fused_broadcast_op_handle)
cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass) cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass)
cc_library(set_reader_device_info_pass SRCS set_reader_device_info_pass.cc DEPS graph graph_helper pass multi_devices_graph_pass) cc_library(set_reader_device_info_utils SRCS set_reader_device_info_utils.cc DEPS graph graph_helper pass multi_devices_graph_pass)
cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle) cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle)
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS all_reduce_op_handle graph graph_helper pass) cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS all_reduce_op_handle graph graph_helper pass)
......
...@@ -12,6 +12,10 @@ ...@@ -12,6 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/set_reader_device_info_utils.h"
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
...@@ -22,80 +26,62 @@ namespace paddle { ...@@ -22,80 +26,62 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
static int GetDeviceCountFromPassAttr(const Pass &pass) {
return static_cast<int>(
pass.Get<const std::vector<platform::Place>>(details::kPlaces).size());
}
static std::unordered_set<std::string> ReaderOpSet() { static std::unordered_set<std::string> ReaderOpSet() {
return {"create_py_reader"}; return {"create_py_reader"};
} }
class InitReaderDeviceCountPass : public Pass { void InitReaderQueueDeviceCount(Graph *graph, const Scope &scope,
protected: size_t dev_cnt) {
void ApplyImpl(Graph *graph) const override { using QueueHolder =
using QueueHolder = operators::reader::OrderedMultiDeviceLoDTensorBlockingQueueHolder;
operators::reader::OrderedMultiDeviceLoDTensorBlockingQueueHolder;
auto reader_ops = ReaderOpSet(); auto reader_ops = ReaderOpSet();
auto dev_cnt = GetDeviceCountFromPassAttr(*this); for (auto &node : graph->Nodes()) {
const auto &scope = Get<const Scope>(details::kGlobalScope); if (node->IsOp() && node->Op() &&
for (auto &node : graph->Nodes()) { reader_ops.count(node->Op()->Type()) != 0) {
if (node->IsOp() && node->Op() && auto queue_name = node->Op()->Input("blocking_queue")[0];
reader_ops.count(node->Op()->Type()) != 0) { auto var = scope.FindVar(queue_name);
auto queue_name = node->Op()->Input("blocking_queue")[0]; if (var && var->IsType<QueueHolder>()) {
auto var = scope.FindVar(queue_name); VLOG(10) << "Set device count of " << queue_name << " to be "
if (var && var->IsType<QueueHolder>()) { << dev_cnt;
VLOG(10) << "Set device count of " << queue_name << " to be " var->GetMutable<QueueHolder>()->GetQueue()->SetDeviceCount(dev_cnt);
<< dev_cnt;
var->GetMutable<QueueHolder>()->GetQueue()->SetDeviceCount(dev_cnt);
}
} }
} }
} }
}; }
class SetReaderDeviceIndexPass : public Pass { void SetReaderOpDeviceInfo(Graph *graph, size_t dev_cnt, size_t dev_idx) {
protected: auto reader_ops = ReaderOpSet();
void ApplyImpl(Graph *graph) const override { size_t found_op_num = 0;
auto dev_cnt = GetDeviceCountFromPassAttr(*this);
auto reader_ops = ReaderOpSet();
size_t found_op_num = 0;
for (auto &node : graph->Nodes()) { for (auto &node : graph->Nodes()) {
if (node->IsOp() && node->Op() && if (node->IsOp() && node->Op() &&
reader_ops.count(node->Op()->Type()) != 0) { reader_ops.count(node->Op()->Type()) != 0) {
auto &op_handle = dynamic_cast<details::ComputationOpHandle &>( auto &op_handle = dynamic_cast<details::ComputationOpHandle &>(
node->Wrapper<details::OpHandleBase>()); node->Wrapper<details::OpHandleBase>());
auto *op_desc = node->Op(); auto *op_desc = node->Op();
auto &op_base_attrs = auto &op_base_attrs =
const_cast<framework::AttributeMap &>(op_handle.GetOp()->Attrs()); const_cast<framework::AttributeMap &>(op_handle.GetOp()->Attrs());
int dev_idx = static_cast<int>(op_handle.GetScopeIdx()); int actual_dev_idx = static_cast<int>(op_handle.GetScopeIdx());
if (dev_idx != -1UL) {
actual_dev_idx = static_cast<int>(dev_idx);
}
op_desc->SetAttr("device_index", dev_idx); op_desc->SetAttr("device_index", actual_dev_idx);
op_desc->SetAttr("device_count", dev_cnt); op_desc->SetAttr("device_count", static_cast<int>(dev_cnt));
op_base_attrs["device_index"] = dev_idx; op_base_attrs["device_index"] = actual_dev_idx;
op_base_attrs["device_count"] = dev_cnt; op_base_attrs["device_count"] = static_cast<int>(dev_cnt);
++found_op_num; ++found_op_num;
VLOG(10) << "Found op " << op_desc->Type() << " on device " << dev_idx; VLOG(10) << "Found op " << op_desc->Type() << " on device "
} << actual_dev_idx;
} }
VLOG(10) << "Found op number " << found_op_num;
} }
};
VLOG(10) << "Found op number " << found_op_num;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(init_reader_device_count_pass,
paddle::framework::ir::InitReaderDeviceCountPass)
.RequirePassAttr(paddle::framework::details::kGlobalScope)
.RequirePassAttr(paddle::framework::details::kPlaces);
REGISTER_PASS(set_reader_device_index_pass,
paddle::framework::ir::SetReaderDeviceIndexPass)
.RequirePassAttr(paddle::framework::details::kPlaces);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace framework {
namespace ir {
void InitReaderQueueDeviceCount(Graph *graph, const Scope &scope,
size_t dev_cnt);
void SetReaderOpDeviceInfo(Graph *graph, size_t dev_cnt, size_t dev_idx = -1UL);
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -30,6 +30,7 @@ limitations under the License. */ ...@@ -30,6 +30,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/set_reader_device_info_utils.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
DECLARE_double(eager_delete_tensor_gb); DECLARE_double(eager_delete_tensor_gb);
...@@ -81,15 +82,6 @@ class ParallelExecutorPrivate { ...@@ -81,15 +82,6 @@ class ParallelExecutorPrivate {
} }
} }
void InitReaderDeviceCount(ir::Graph *graph) const {
auto pass =
ir::PassRegistry::Instance().Get("init_reader_device_count_pass");
pass->SetNotOwned<const Scope>(details::kGlobalScope, global_scope_);
pass->SetNotOwned<const std::vector<platform::Place>>(details::kPlaces,
&places_);
pass->Apply(graph);
}
void SetHasFeed(size_t dev_idx, bool has_feed = true); void SetHasFeed(size_t dev_idx, bool has_feed = true);
bool AllowPartialFeed() const; bool AllowPartialFeed() const;
...@@ -456,7 +448,8 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ...@@ -456,7 +448,8 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
const BuildStrategy &build_strategy, const BuildStrategy &build_strategy,
ir::Graph *graph) ir::Graph *graph)
: member_(new ParallelExecutorPrivate(places, scope)) { : member_(new ParallelExecutorPrivate(places, scope)) {
member_->InitReaderDeviceCount(graph); ir::InitReaderQueueDeviceCount(graph, *(member_->global_scope_),
member_->places_.size());
member_->use_cuda_ = exec_strategy.use_cuda_; member_->use_cuda_ = exec_strategy.use_cuda_;
member_->build_strategy_ = build_strategy; member_->build_strategy_ = build_strategy;
member_->use_all_reduce_ = member_->build_strategy_.reduce_ == member_->use_all_reduce_ = member_->build_strategy_.reduce_ ==
...@@ -733,6 +726,14 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ...@@ -733,6 +726,14 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
op->SetLocalExecScopes(scope_map); op->SetLocalExecScopes(scope_map);
} }
} }
if (final_graphs.size() == 1) {
ir::SetReaderOpDeviceInfo(final_graphs[0], member_->places_.size());
} else {
for (size_t i = 0; i < final_graphs.size(); ++i) {
ir::SetReaderOpDeviceInfo(final_graphs[i], member_->places_.size(), i);
}
}
} }
void ParallelExecutor::BCastParamsToDevices( void ParallelExecutor::BCastParamsToDevices(
...@@ -1061,4 +1062,3 @@ USE_PASS(reference_count_pass); ...@@ -1061,4 +1062,3 @@ USE_PASS(reference_count_pass);
USE_PASS(eager_deletion_pass); USE_PASS(eager_deletion_pass);
USE_PASS(buffer_shared_inplace_pass); USE_PASS(buffer_shared_inplace_pass);
USE_PASS(buffer_shared_cross_op_memory_reuse_pass); USE_PASS(buffer_shared_cross_op_memory_reuse_pass);
USE_PASS(init_reader_device_count_pass);
...@@ -32,12 +32,11 @@ def convolutional_neural_network(use_py_reader): ...@@ -32,12 +32,11 @@ def convolutional_neural_network(use_py_reader):
py_reader = None py_reader = None
if use_py_reader: if use_py_reader:
py_reader = fluid.layers.create_py_reader_by_data( py_reader = fluid.io.DataLoader.from_generator(
capacity=64, capacity=64,
feed_list=[img, label], feed_list=[img, label],
name='py_reader', iterable=False,
use_double_buffer=False) use_double_buffer=False)
img, label = fluid.layers.read_file(py_reader)
conv_pool_1 = fluid.nets.simple_img_conv_pool( conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=img, input=img,
...@@ -144,7 +143,7 @@ def train(use_cuda, thread_num, cpu_num): ...@@ -144,7 +143,7 @@ def train(use_cuda, thread_num, cpu_num):
exec_strategy=exec_strategy) exec_strategy=exec_strategy)
print("declare parallel executor done.") print("declare parallel executor done.")
py_reader.decorate_paddle_reader(train_reader) py_reader.set_sample_list_generator(train_reader)
for pass_id in range(2): for pass_id in range(2):
step = 0 step = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册