未验证 提交 ca99427a 编写于 作者: Z Zeng Jinle 提交者: GitHub

[Cherry-pick Release/2.0] Correct reader device index (#23818)

* correct reader device index, test=develop

* fix async executor scope var initialization, test=release/2.0
上级 1b5122ba
......@@ -73,7 +73,7 @@ set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto
eager_deletion_pass
buffer_shared_inplace_op_pass
buffer_shared_cross_op_memory_reuse_pass
set_reader_device_info_pass
set_reader_device_info_utils
add_reader_dependency_pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
......
......@@ -126,8 +126,9 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
}
}
for (size_t i = 0; i < local_scopes_.size(); ++i) {
InitVarsInScope(var_infos_, local_scopes_[i], local_exec_scopes_[i]);
for (size_t i = local_scopes_.size(); i >= 1; --i) {
InitVarsInScope(var_infos_, local_scopes_[i - 1],
local_exec_scopes_[i - 1]);
}
ProcessGraph(graphs_, local_scopes_[0]);
}
......
......@@ -64,7 +64,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendAddReaderDependencyPass();
AppendMultiDevPass();
AppendSetReaderDeviceIndexPass();
AppendMultiGraphOptPasses();
AppendPassToSetMkldnnAttr("mkldnn_placement_pass");
......@@ -243,10 +242,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
&strategy_);
}
void AppendSetReaderDeviceIndexPass() {
AppendPass("set_reader_device_index_pass");
}
void AppendPrintGraphPass(const std::string &pass_name,
const std::string &debug_file_suffix) {
if (!strategy_.debug_graphviz_path_.empty()) {
......@@ -403,9 +398,6 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
"GPU, skipped.";
continue;
}
} else if (pass->Type() == "set_reader_device_index_pass") {
pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
}
VLOG(1) << "Start Apply Pass " << pass->Type();
graph = pass->Apply(graph);
......@@ -442,7 +434,6 @@ USE_PASS(fuse_sgd_op_pass);
USE_PASS(fuse_momentum_op_pass);
USE_PASS(fuse_all_reduce_op_pass);
USE_PASS(runtime_context_cache_pass);
USE_PASS(set_reader_device_index_pass);
USE_PASS(add_reader_dependency_pass);
#ifdef PADDLE_WITH_MKLDNN
USE_PASS(mkldnn_placement_pass);
......
......@@ -11,7 +11,7 @@ endif()
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle fetch_barrier_op_handle ${ALL_REDUCE_OP_HANDLES} reduce_op_handle broadcast_op_handle fused_broadcast_op_handle)
cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass)
cc_library(set_reader_device_info_pass SRCS set_reader_device_info_pass.cc DEPS graph graph_helper pass multi_devices_graph_pass)
cc_library(set_reader_device_info_utils SRCS set_reader_device_info_utils.cc DEPS graph graph_helper pass multi_devices_graph_pass)
cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle)
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS all_reduce_op_handle graph graph_helper pass)
......
......@@ -12,6 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/set_reader_device_info_utils.h"
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
......@@ -22,80 +26,62 @@ namespace paddle {
namespace framework {
namespace ir {
static int GetDeviceCountFromPassAttr(const Pass &pass) {
return static_cast<int>(
pass.Get<const std::vector<platform::Place>>(details::kPlaces).size());
}
static std::unordered_set<std::string> ReaderOpSet() {
return {"create_py_reader"};
}
class InitReaderDeviceCountPass : public Pass {
protected:
void ApplyImpl(Graph *graph) const override {
using QueueHolder =
operators::reader::OrderedMultiDeviceLoDTensorBlockingQueueHolder;
void InitReaderQueueDeviceCount(Graph *graph, const Scope &scope,
size_t dev_cnt) {
using QueueHolder =
operators::reader::OrderedMultiDeviceLoDTensorBlockingQueueHolder;
auto reader_ops = ReaderOpSet();
auto dev_cnt = GetDeviceCountFromPassAttr(*this);
const auto &scope = Get<const Scope>(details::kGlobalScope);
for (auto &node : graph->Nodes()) {
if (node->IsOp() && node->Op() &&
reader_ops.count(node->Op()->Type()) != 0) {
auto queue_name = node->Op()->Input("blocking_queue")[0];
auto var = scope.FindVar(queue_name);
if (var && var->IsType<QueueHolder>()) {
VLOG(10) << "Set device count of " << queue_name << " to be "
<< dev_cnt;
var->GetMutable<QueueHolder>()->GetQueue()->SetDeviceCount(dev_cnt);
}
auto reader_ops = ReaderOpSet();
for (auto &node : graph->Nodes()) {
if (node->IsOp() && node->Op() &&
reader_ops.count(node->Op()->Type()) != 0) {
auto queue_name = node->Op()->Input("blocking_queue")[0];
auto var = scope.FindVar(queue_name);
if (var && var->IsType<QueueHolder>()) {
VLOG(10) << "Set device count of " << queue_name << " to be "
<< dev_cnt;
var->GetMutable<QueueHolder>()->GetQueue()->SetDeviceCount(dev_cnt);
}
}
}
};
}
class SetReaderDeviceIndexPass : public Pass {
protected:
void ApplyImpl(Graph *graph) const override {
auto dev_cnt = GetDeviceCountFromPassAttr(*this);
auto reader_ops = ReaderOpSet();
size_t found_op_num = 0;
void SetReaderOpDeviceInfo(Graph *graph, size_t dev_cnt, size_t dev_idx) {
auto reader_ops = ReaderOpSet();
size_t found_op_num = 0;
for (auto &node : graph->Nodes()) {
if (node->IsOp() && node->Op() &&
reader_ops.count(node->Op()->Type()) != 0) {
auto &op_handle = dynamic_cast<details::ComputationOpHandle &>(
node->Wrapper<details::OpHandleBase>());
auto *op_desc = node->Op();
auto &op_base_attrs =
const_cast<framework::AttributeMap &>(op_handle.GetOp()->Attrs());
int dev_idx = static_cast<int>(op_handle.GetScopeIdx());
for (auto &node : graph->Nodes()) {
if (node->IsOp() && node->Op() &&
reader_ops.count(node->Op()->Type()) != 0) {
auto &op_handle = dynamic_cast<details::ComputationOpHandle &>(
node->Wrapper<details::OpHandleBase>());
auto *op_desc = node->Op();
auto &op_base_attrs =
const_cast<framework::AttributeMap &>(op_handle.GetOp()->Attrs());
int actual_dev_idx = static_cast<int>(op_handle.GetScopeIdx());
if (dev_idx != -1UL) {
actual_dev_idx = static_cast<int>(dev_idx);
}
op_desc->SetAttr("device_index", dev_idx);
op_desc->SetAttr("device_count", dev_cnt);
op_desc->SetAttr("device_index", actual_dev_idx);
op_desc->SetAttr("device_count", static_cast<int>(dev_cnt));
op_base_attrs["device_index"] = dev_idx;
op_base_attrs["device_count"] = dev_cnt;
op_base_attrs["device_index"] = actual_dev_idx;
op_base_attrs["device_count"] = static_cast<int>(dev_cnt);
++found_op_num;
VLOG(10) << "Found op " << op_desc->Type() << " on device " << dev_idx;
}
++found_op_num;
VLOG(10) << "Found op " << op_desc->Type() << " on device "
<< actual_dev_idx;
}
VLOG(10) << "Found op number " << found_op_num;
}
};
VLOG(10) << "Found op number " << found_op_num;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(init_reader_device_count_pass,
paddle::framework::ir::InitReaderDeviceCountPass)
.RequirePassAttr(paddle::framework::details::kGlobalScope)
.RequirePassAttr(paddle::framework::details::kPlaces);
REGISTER_PASS(set_reader_device_index_pass,
paddle::framework::ir::SetReaderDeviceIndexPass)
.RequirePassAttr(paddle::framework::details::kPlaces);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace framework {
namespace ir {
void InitReaderQueueDeviceCount(Graph *graph, const Scope &scope,
size_t dev_cnt);
void SetReaderOpDeviceInfo(Graph *graph, size_t dev_cnt, size_t dev_idx = -1UL);
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -30,6 +30,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/set_reader_device_info_utils.h"
#include "paddle/fluid/platform/profiler.h"
DECLARE_double(eager_delete_tensor_gb);
......@@ -81,15 +82,6 @@ class ParallelExecutorPrivate {
}
}
void InitReaderDeviceCount(ir::Graph *graph) const {
auto pass =
ir::PassRegistry::Instance().Get("init_reader_device_count_pass");
pass->SetNotOwned<const Scope>(details::kGlobalScope, global_scope_);
pass->SetNotOwned<const std::vector<platform::Place>>(details::kPlaces,
&places_);
pass->Apply(graph);
}
void SetHasFeed(size_t dev_idx, bool has_feed = true);
bool AllowPartialFeed() const;
......@@ -456,7 +448,8 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
const BuildStrategy &build_strategy,
ir::Graph *graph)
: member_(new ParallelExecutorPrivate(places, scope)) {
member_->InitReaderDeviceCount(graph);
ir::InitReaderQueueDeviceCount(graph, *(member_->global_scope_),
member_->places_.size());
member_->use_cuda_ = exec_strategy.use_cuda_;
member_->build_strategy_ = build_strategy;
member_->use_all_reduce_ = member_->build_strategy_.reduce_ ==
......@@ -733,6 +726,14 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
op->SetLocalExecScopes(scope_map);
}
}
if (final_graphs.size() == 1) {
ir::SetReaderOpDeviceInfo(final_graphs[0], member_->places_.size());
} else {
for (size_t i = 0; i < final_graphs.size(); ++i) {
ir::SetReaderOpDeviceInfo(final_graphs[i], member_->places_.size(), i);
}
}
}
void ParallelExecutor::BCastParamsToDevices(
......@@ -1061,4 +1062,3 @@ USE_PASS(reference_count_pass);
USE_PASS(eager_deletion_pass);
USE_PASS(buffer_shared_inplace_pass);
USE_PASS(buffer_shared_cross_op_memory_reuse_pass);
USE_PASS(init_reader_device_count_pass);
......@@ -32,12 +32,11 @@ def convolutional_neural_network(use_py_reader):
py_reader = None
if use_py_reader:
py_reader = fluid.layers.create_py_reader_by_data(
py_reader = fluid.io.DataLoader.from_generator(
capacity=64,
feed_list=[img, label],
name='py_reader',
iterable=False,
use_double_buffer=False)
img, label = fluid.layers.read_file(py_reader)
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=img,
......@@ -144,7 +143,7 @@ def train(use_cuda, thread_num, cpu_num):
exec_strategy=exec_strategy)
print("declare parallel executor done.")
py_reader.decorate_paddle_reader(train_reader)
py_reader.set_sample_list_generator(train_reader)
for pass_id in range(2):
step = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册