未验证 提交 eabb2105 编写于 作者: C chengduo 提交者: GitHub

Refactor MultiDevSSAGraphBuilder (#15090)

* Refactor ParallelExecutor
test=develop

* extract Reduce and AllReduce mode from MultiDevSSAGraphBuilder
test=develop

* Refactor MultiDevSSAGraphBuilder
test=developt

* Remove enable_data_balance
test=develop

* code refine
test=develop

* remove data balance
test=develop

* refine ScaleLossGradOp
test=develop

* remove uncessary file
test=develop

* code refine
test=develop

* modify  function name
test=develop

* follow comments
test=develop

* add is_distribution field
test=develop

* set is_distribution
test=develop

* fix DistSSAGraphBuilder
test=develop
上级 875a07c3
......@@ -18,7 +18,7 @@ limitations under the License. */
#include <memory>
#include "paddle/fluid/framework/details/memory_reuse_types.h"
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
......@@ -86,10 +86,8 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
if (strategy.memory_optimize_) {
auto analysis_var_pass = AppendPass("analysis_var_pass");
}
// Convert graph to run on multi-devices.
auto multi_devices_pass = AppendPass("multi_devices_pass");
multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy",
&strategy_);
AppendMultiDevPass(strategy);
// Add a graph print pass to record a graph with device info.
if (!strategy_.debug_graphviz_path_.empty()) {
......@@ -115,6 +113,25 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
}
}
// Convert graph to run on multi-devices.
void AppendMultiDevPass(const BuildStrategy &strategy) {
ir::Pass *multi_devices_pass;
if (strategy_.is_distribution_) {
multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
} else {
if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
multi_devices_pass =
AppendPass("allreduce_mode_multi_devices_pass").get();
} else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get();
} else {
PADDLE_THROW("Unknown reduce strategy.");
}
}
multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy",
&strategy_);
}
private:
BuildStrategy strategy_;
};
......@@ -131,6 +148,10 @@ std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy(
return pass_builder_;
}
bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const {
return framework::details::MultiDevSSAGraphBuilder().count(pass_name) > 0;
}
std::unique_ptr<ir::Graph> BuildStrategy::Apply(
const ProgramDesc &main_program, const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::vector<Scope *> &local_scopes,
......@@ -145,22 +166,23 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) {
if (pass->Type() == "multi_devices_pass") {
pass->Erase("places");
pass->SetNotOwned<const std::vector<platform::Place>>("places", &places);
pass->Erase("loss_var_name");
pass->SetNotOwned<const std::string>("loss_var_name", &loss_var_name);
pass->Erase("local_scopes");
pass->SetNotOwned<const std::vector<Scope *>>("local_scopes",
if (IsMultiDevPass(pass->Type())) {
pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
pass->Erase(kLossVarName);
pass->SetNotOwned<const std::string>(kLossVarName, &loss_var_name);
pass->Erase(kLocalScopes);
pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
&local_scopes);
pass->Erase("nranks");
pass->Set<size_t>("nranks", new size_t(nranks));
pass->Erase(kNRanks);
pass->Set<size_t>(kNRanks, new size_t(nranks));
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
pass->Erase("nccl_ctxs");
pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
#endif
} else if (pass->Type() == "analysis_var_pass") {
const std::vector<OpDesc *> *all_op_descs =
new std::vector<OpDesc *>(main_program.Block(0).AllOps());
......@@ -201,7 +223,9 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
USE_PASS(fuse_elewise_add_act_pass);
USE_PASS(graph_viz_pass);
USE_PASS(multi_batch_merge_pass);
USE_PASS(multi_devices_pass);
USE_PASS(reduce_mode_multi_devices_pass);
USE_PASS(allreduce_mode_multi_devices_pass);
USE_PASS(dist_multi_devices_pass);
USE_PASS(multi_devices_check_pass);
USE_PASS(multi_devices_print_pass);
USE_PASS(analysis_var_pass);
......
......@@ -74,8 +74,6 @@ struct BuildStrategy {
bool fuse_elewise_add_act_ops_{false};
bool enable_data_balance_{false};
bool memory_optimize_{false};
bool memory_early_delete_{false};
......@@ -84,6 +82,10 @@ struct BuildStrategy {
bool fuse_broadcast_op_{false};
// FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode,
// num_trainers is 1, so the current fields of build_strategy doesn't tell if
// it's distributed model.
bool is_distribution_{false};
int num_trainers_{1};
int trainer_id_{0};
std::vector<std::string> trainers_endpoints_;
......@@ -104,6 +106,8 @@ struct BuildStrategy {
bool IsFinalized() const { return is_finalized_; }
bool IsMultiDevPass(const std::string &pass_name) const;
// Apply the passes built by the pass_builder_. The passes will be
// applied to the Program and output an ir::Graph.
std::unique_ptr<ir::Graph> Apply(const ProgramDesc &main_program,
......
......@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
#include <string>
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
......@@ -21,68 +21,78 @@ namespace paddle {
namespace framework {
namespace details {
bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
std::unordered_map<OpHandleBase *, size_t> pending_ops;
std::unordered_set<VarHandleBase *> pending_vars;
std::unordered_set<VarHandleBase *> ready_vars;
std::unordered_set<OpHandleBase *> ready_ops;
class SSAGraghBuilderWithChecker : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
PADDLE_ENFORCE(IsValidGraph(graph.get()));
return graph;
}
auto insert_pending_var = [&](VarHandleBase *var) {
pending_vars.insert(var);
if (var->GeneratedOp() == nullptr) {
ready_vars.emplace(var);
}
};
bool IsValidGraph(const ir::Graph *graph) const {
std::unordered_map<OpHandleBase *, size_t> pending_ops;
std::unordered_set<VarHandleBase *> pending_vars;
std::unordered_set<VarHandleBase *> ready_vars;
std::unordered_set<OpHandleBase *> ready_ops;
for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) {
insert_pending_var(version_pair);
auto insert_pending_var = [&](VarHandleBase *var) {
pending_vars.insert(var);
if (var->GeneratedOp() == nullptr) {
ready_vars.emplace(var);
}
}
}
};
for (auto &var : graph->Get<GraphDepVars>(kGraphDepVars)) {
insert_pending_var(var);
}
for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) {
insert_pending_var(version_pair);
}
}
}
for (OpHandleBase *op : ir::FilterByNodeWrapper<OpHandleBase>(*graph)) {
if (op->Inputs().empty()) {
ready_ops.insert(op);
} else {
pending_ops.insert({op, op->NoDupInputSize()});
for (auto &var : graph->Get<GraphDepVars>(kGraphDepVars)) {
insert_pending_var(var);
}
}
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
for (auto *op : set) {
for (auto out : op->Outputs()) {
ready_vars.emplace(out);
for (OpHandleBase *op : ir::FilterByNodeWrapper<OpHandleBase>(*graph)) {
if (op->Inputs().empty()) {
ready_ops.insert(op);
} else {
pending_ops.insert({op, op->NoDupInputSize()});
}
}
set.clear();
};
while (!pending_vars.empty()) {
run_all_ops(ready_ops);
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
for (auto *op : set) {
for (auto out : op->Outputs()) {
ready_vars.emplace(out);
}
}
set.clear();
};
if (ready_vars.empty()) {
return false;
}
while (!pending_vars.empty()) {
run_all_ops(ready_ops);
for (auto ready_var : ready_vars) {
pending_vars.erase(ready_var);
for (auto *op : ready_var->PendingOps()) {
auto &deps = --pending_ops[op];
if (deps == 0) {
ready_ops.insert(op);
if (ready_vars.empty()) {
return false;
}
for (auto ready_var : ready_vars) {
pending_vars.erase(ready_var);
for (auto *op : ready_var->PendingOps()) {
auto &deps = --pending_ops[op];
if (deps == 0) {
ready_ops.insert(op);
}
}
}
ready_vars.clear();
}
ready_vars.clear();
return true;
}
return true;
}
};
} // namespace details
} // namespace framework
} // namespace paddle
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include <string>
namespace paddle {
namespace framework {
namespace details {
class SSAGraghBuilderWithChecker : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
PADDLE_ENFORCE(IsValidGraph(graph.get()));
return graph;
}
bool IsValidGraph(const ir::Graph* graph) const;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -134,15 +134,8 @@ void AddOutputToLeafOps(ir::Graph *graph) {
}
} // namespace
static const char kLossVarName[] = "loss_var_name";
static const char kPlaces[] = "places";
static const char kLocalScopes[] = "local_scopes";
static const char kStrategy[] = "strategy";
static const char kNRanks[] = "nranks";
void MultiDevSSAGraphBuilder::Init() const {
void MultiDevSSAGraphBuilderBase::Init() const {
all_vars_.clear();
balance_vars_.clear();
loss_var_name_ = Get<const std::string>(kLossVarName);
places_ = Get<const std::vector<platform::Place>>(kPlaces);
......@@ -151,31 +144,16 @@ void MultiDevSSAGraphBuilder::Init() const {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
nccl_ctxs_ = &Get<platform::NCCLContextMap>("nccl_ctxs");
#endif
balance_vars_.resize(places_.size(), 0);
if (strategy_.enable_data_balance_ && places_.size() == 1) {
LOG(WARNING) << "It is no need to enable data balance when there is only "
"one place. enable_data_balance is set to False.";
strategy_.enable_data_balance_ = false;
}
}
std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
Init();
// Give the topology sort order and rebuild the graph structure.
std::vector<ir::Node *> sorted_ops = ir::TopologySortOperations(*graph);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
sorted_ops = SortForReduceMode(sorted_ops);
}
std::vector<ir::Node *> sorted_ops = SortOperations(*graph);
auto nodes = graph->ReleaseNodes();
ir::Graph &result = *graph;
size_t nranks = Get<size_t>(kNRanks);
for (auto &node : nodes) {
if (node->IsVar() && node->Var()) {
all_vars_.emplace(node->Name(), node->Var());
......@@ -187,146 +165,61 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
result.Set(kGraphDepVars, new GraphDepVars);
result.Set(kGraphOps, new GraphOps);
std::vector<std::unordered_set<std::string>> bcast_var_name_set;
bcast_var_name_set.resize(places_.size());
bool is_forwarding = true;
bool is_dist_train = false;
std::unordered_map<std::string, int> sharded_var_device;
bool insert_collection_ops = NeedCollectiveOps();
for (ir::Node *node : sorted_ops) {
if (OpHaveRole(*node, OpRole::kRPC)) {
int op_dev_id = CreateRPCOp(&result, node, &sharded_var_device);
PADDLE_ENFORCE(op_dev_id != -1,
"Can not schedule the RPC operator to the right place.");
if (node->Op()->Type() == "recv") {
auto recv_vars_attr =
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE(recv_vars_attr.size() == 2UL); // [parameter, gradient]
if (recv_vars_attr[0].find(".block") == std::string::npos) {
bcast_var_name_set[op_dev_id].emplace(recv_vars_attr[0]);
}
}
is_dist_train = true;
} else if (OpHaveRole(*node, OpRole::kDist)) {
int op_dev_id = CreateDistTrainOp(&result, node, &sharded_var_device);
if (node->Op()->Type() == "concat") {
auto origin_param_name = node->Op()->OutputArgumentNames()[0];
bcast_var_name_set[op_dev_id].emplace(origin_param_name);
}
} else if (IsScaleLossOp(node)) {
// user can customize loss@grad if not use_default_grad_scale_
if (strategy_.gradient_scale_ !=
BuildStrategy::GradientScaleStrategy::kCustomized) {
// TODO(paddle-dev): Why is there no input for this op_handle?
auto loss_grad_name = node->Op()->OutputArgumentNames()[0];
auto out_dtype = all_vars_.at(loss_grad_name)->GetDataType();
CreateScaleLossGradOp(&result, loss_grad_name, node->outputs[0],
out_dtype);
}
// This assumes the backward generating code will ensure IsScaleLossOp
// is true only for the op that scale the final scalar loss.
// It also assumes backward op will always follow the forward op in
// the block.
is_forwarding = false;
if (DealWithSpecialOp(&result, node)) {
continue;
} else {
int op_dev_id = GetOpDeviceID(node, sharded_var_device);
if (op_dev_id != -1) { // This op only runs on one specific device.
CreateComputationalOp(&result, node, op_dev_id);
for (ir::Node *n : node->outputs) {
sharded_var_device.emplace(n->Name(), op_dev_id);
}
// This op runs on all devices
if (IsScaleLossOp(node)) {
// user can customize loss@grad if not use_default_grad_scale_
InsertScaleLossGradOp(&result, node);
// This assumes the backward generating code will ensure IsScaleLossOp
// is true only for the op that scale the final scalar loss.
// It also assumes backward op will always follow the forward op in
// the block.
is_forwarding = false;
} else {
// This op runs on all devices, and its output may have parameter's
// gradients.
// TODO(paddle-dev): Why is so special about "read" op?
if (node->Op()->Type() == "read" && strategy_.enable_data_balance_) {
node->Op()->SetAttr("throw_eof_exp", false);
CreateComputationalOps(&result, node, places_.size());
const auto &data_var_names = node->Op()->Output("Out");
InsertDataBalanceOp(&result, data_var_names);
} else {
CreateComputationalOps(&result, node, places_.size());
}
CreateComputationalOps(&result, node, places_.size());
}
if (!is_forwarding && nranks > 1UL) {
// Insert collection ops
if (!is_forwarding && insert_collection_ops) {
try {
bool is_bk_op =
static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward));
if (!is_bk_op) continue;
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
try {
auto backward_vars = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
for (size_t i = 0; i < backward_vars.size(); i += 2) {
auto &p_name = backward_vars[i];
auto &g_name = backward_vars[i + 1];
VLOG(10) << "Bcast " << g_name << " for parameter " << p_name;
size_t cur_device_id = -1;
switch (strategy_.reduce_) {
case BuildStrategy::ReduceStrategy::kReduce:
cur_device_id = GetAppropriateDeviceID({g_name});
CreateReduceOp(&result, g_name, cur_device_id);
sharded_var_device.emplace(g_name, cur_device_id);
if (!is_dist_train) {
bcast_var_name_set[cur_device_id].emplace(p_name);
}
break;
case BuildStrategy::ReduceStrategy::kAllReduce:
if (IsSparseGradient(g_name)) {
CreateReduceOp(&result, g_name, 0);
CreateBroadcastOp(&result, g_name, 0);
} else {
InsertAllReduceOp(&result, g_name);
}
break;
default:
LOG(FATAL) << "Unknown reduce strategy ";
break;
}
}
} catch (boost::bad_get e) {
auto backward_vars =
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
for (size_t i = 0; i < backward_vars.size(); i += 2) {
auto &p_name = backward_vars[i];
auto &g_name = backward_vars[i + 1];
VLOG(10) << "Bcast " << g_name << " for parameter " << p_name;
InsertCollectiveOp(&result, p_name, g_name);
}
} catch (boost::bad_get e) {
}
}
}
}
bool use_gpu = false;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
use_gpu = nccl_ctxs_ != nullptr;
#endif
// Insert broadcast operators principle:
// 1. Broadcast optimized parameters in Reduce strategy;
// 2. No need broadcast optimized parameters in AllReduce strategy because of
// the optimization sub-graph would be run on every GPU;
// 3. Allways broadcast received parameters in Distribute Training.
if ((use_gpu &&
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) ||
is_dist_train) {
if (strategy_.fuse_broadcast_op_) {
CreateFusedBroadcastOp(&result, bcast_var_name_set);
} else {
for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
auto &to_bcast_set = bcast_var_name_set[dev_id];
for (auto &bcast_name : to_bcast_set) {
CreateBroadcastOp(&result, bcast_name, dev_id);
}
}
}
}
InsertPostprocessOps(&result);
/*
Dependency graph has been constructed. However, there are still data
hazards need to be handled.
*/
*/
PolishGraphToSupportDataHazards(&result);
/*
......@@ -337,67 +230,54 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
return graph;
}
std::vector<ir::Node *> MultiDevSSAGraphBuilder::SortForReduceMode(
const std::vector<ir::Node *> &topo_ops) const {
std::unordered_map<std::string, int> sharded_var_device;
std::vector<ir::Node *> sorted_ops;
std::unordered_map<std::string, std::vector<ir::Node *>> delayed_op;
sorted_ops.reserve(topo_ops.size());
auto insert_delayed_op = [&](const std::string &var_name, int dev_id) {
sharded_var_device.emplace(var_name, dev_id);
if (delayed_op.count(var_name)) {
auto &ops = delayed_op.at(var_name);
sorted_ops.insert(sorted_ops.end(), ops.begin(), ops.end());
delayed_op.at(var_name).clear();
}
};
void MultiDevSSAGraphBuilderBase::InsertScaleLossGradOp(
ir::Graph *result, const ir::Node *node) const {
// user can customize loss@grad if not use_default_grad_scale_
size_t loss_scale = 0;
switch (this->strategy_.gradient_scale_) {
case BuildStrategy::GradientScaleStrategy::kOne:
loss_scale = 1;
break;
case BuildStrategy::GradientScaleStrategy::kCoeffNumDevice:
loss_scale = Get<size_t>(kNRanks);
break;
case BuildStrategy::GradientScaleStrategy::kCustomized:
loss_scale = 0;
break;
default:
LOG(FATAL) << "Unknown gradient scale strategy.";
break;
}
if (loss_scale) {
// TODO(paddle-dev): Why is there no input for this op_handle?
auto loss_grad_name = node->Op()->OutputArgumentNames()[0];
auto out_dtype = this->all_vars_.at(loss_grad_name)->GetDataType();
this->CreateScaleLossGradOp(result, loss_grad_name, node->outputs[0],
loss_scale, out_dtype);
}
}
for (ir::Node *node : topo_ops) {
int op_dev_id = GetOpDeviceID(node, sharded_var_device, &delayed_op);
if (op_dev_id > -1) {
// This op only runs on one specific device.
sorted_ops.emplace_back(node);
for (ir::Node *n : node->outputs) {
insert_delayed_op(n->Name(), op_dev_id);
}
} else if (op_dev_id == -1) {
// This op runs on all devices, and its output may have parameter's
// gradients.
sorted_ops.emplace_back(node);
bool is_bk_op =
static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward));
if (!is_bk_op) continue;
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
std::vector<std::string> backward_vars;
try {
backward_vars =
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
} catch (boost::bad_get e) {
}
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
std::vector<ir::Node *> MultiDevSSAGraphBuilderBase::SortOperations(
const ir::Graph &graph) const {
return ir::TopologySortOperations(graph);
}
for (size_t i = 0; i < backward_vars.size(); i += 2) {
auto &g_name = backward_vars[i + 1];
size_t cur_device_id = GetAppropriateDeviceID({g_name});
insert_delayed_op(g_name, static_cast<int>(cur_device_id));
}
} else if (op_dev_id == -2) {
// The Op on which the Op depends has not yet been generated.
}
}
bool MultiDevSSAGraphBuilderBase::UseGPU() const {
bool use_gpu = false;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
use_gpu = nccl_ctxs_ != nullptr;
#endif
return use_gpu;
}
PADDLE_ENFORCE_EQ(sorted_ops.size(), topo_ops.size());
return sorted_ops;
bool MultiDevSSAGraphBuilderBase::NeedCollectiveOps() const {
return Get<size_t>(kNRanks) > 1;
}
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
ir::Node *node,
size_t place_id) const {
void MultiDevSSAGraphBuilderBase::CreateOpHandleIOs(ir::Graph *result,
ir::Node *node,
size_t place_id) const {
auto p = places_[place_id];
auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
op_handle->SetDeviceContext(p,
......@@ -420,28 +300,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
}
}
size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
const std::vector<std::string> &var_names) const {
int64_t numel_sum = 0;
for (auto var_name : var_names) {
if (all_vars_.find(var_name) == all_vars_.end()) continue;
auto var_desc = all_vars_.at(var_name);
PADDLE_ENFORCE_NOT_NULL(var_desc);
auto dim = framework::make_ddim(var_desc->GetShape());
int64_t numel = framework::product(dim);
PADDLE_ENFORCE_GT(numel, 0);
numel_sum += numel;
}
auto smallest =
std::min_element(std::begin(balance_vars_), std::end(balance_vars_));
size_t dev_id =
static_cast<size_t>(std::distance(std::begin(balance_vars_), smallest));
balance_vars_[dev_id] += numel_sum;
return dev_id;
}
void MultiDevSSAGraphBuilder::SetCommunicationContext(
void MultiDevSSAGraphBuilderBase::SetCommunicationContext(
OpHandleBase *op_handle, const platform::Place &p) const {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
if (nccl_ctxs_ == nullptr) {
......@@ -454,9 +313,9 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext(
#endif
}
void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
const std::string &p_name,
size_t src_dev_id) const {
void MultiDevSSAGraphBuilderBase::CreateBroadcastOp(ir::Graph *result,
const std::string &p_name,
size_t src_dev_id) const {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto *op_handle = new BroadcastOpHandle(
result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation),
......@@ -484,7 +343,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
}
}
void MultiDevSSAGraphBuilder::CreateFusedBroadcastOp(
void MultiDevSSAGraphBuilderBase::CreateFusedBroadcastOp(
ir::Graph *result,
const std::vector<std::unordered_set<std::string>> &bcast_varnames) const {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
......@@ -522,17 +381,17 @@ void MultiDevSSAGraphBuilder::CreateFusedBroadcastOp(
}
}
void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
ir::Node *node,
int dev_id) const {
void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result,
ir::Node *node,
int dev_id) const {
result->Get<GraphOps>(kGraphOps).emplace_back(
new ComputationOpHandle(result->CreateOpNode(node->Op()),
local_scopes_[dev_id], places_[dev_id], dev_id));
CreateOpHandleIOs(result, node, dev_id);
}
void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
const std::string &og) const {
void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
ir::Graph *result, const std::string &og) const {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
......@@ -560,102 +419,15 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
}
}
void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
ir::Graph *result, const std::vector<std::string> &datas) const {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result->Get<GraphOps>(kGraphOps).emplace_back(new DataBalanceOpHandle(
result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_));
#else
result->Get<GraphOps>(kGraphOps).emplace_back(new DataBalanceOpHandle(
result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
local_scopes_, places_));
#endif
auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
SetCommunicationContext(op_handle, p);
for (const std::string &d_name : datas) {
auto &vars = result->Get<GraphVars>(kGraphVars)[i][d_name];
PADDLE_ENFORCE(!vars.empty());
op_handle->AddInput(vars.back());
auto var = new VarHandle(
result->CreateEmptyNode(d_name, ir::Node::Type::kVariable),
vars.size(), i, d_name, p);
vars.emplace_back(var);
op_handle->AddOutput(var);
}
}
}
int MultiDevSSAGraphBuilder::GetOpDeviceID(
ir::Node *node,
const std::unordered_map<std::string, int> &sharded_var_device,
std::unordered_map<std::string, std::vector<ir::Node *>> *delay_ops) const {
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
return -1;
}
if (!OpHaveRole(*node, framework::OpRole::kOptimize)) {
return -1;
}
auto param_grad = boost::get<std::vector<std::string>>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
int dev_id = GetVarDeviceID(param_grad[1], sharded_var_device);
if (dev_id == -1) {
(*delay_ops)[param_grad[1]].push_back(node);
return -2;
}
return dev_id;
}
int MultiDevSSAGraphBuilder::GetOpDeviceID(
ir::Node *node,
const std::unordered_map<std::string, int> &sharded_var_device) const {
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
return -1;
}
if (!OpHaveRole(*node, framework::OpRole::kOptimize)) {
return -1;
}
auto param_grad = boost::get<std::vector<std::string>>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
int dev_id = GetVarDeviceID(param_grad[1], sharded_var_device);
PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]",
node->Op()->Type(), param_grad[0], param_grad[1]);
return dev_id;
}
int MultiDevSSAGraphBuilder::GetVarDeviceID(
const std::string &varname,
const std::unordered_map<std::string, int> &sharded_var_device) const {
auto got = sharded_var_device.find(varname);
if (got == sharded_var_device.end()) {
auto pos = varname.find(framework::kNewGradSuffix);
if (pos != std::string::npos) {
got = sharded_var_device.find(varname.substr(0, pos));
}
}
return got == sharded_var_device.end() ? -1 : got->second;
}
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(
void MultiDevSSAGraphBuilderBase::CreateScaleLossGradOp(
ir::Graph *result, const std::string &loss_grad_name,
ir::Node *out_var_node, proto::VarType::Type dtype) const {
size_t nranks = Get<size_t>("nranks");
ir::Node *out_var_node, size_t loss_scale,
proto::VarType::Type dtype) const {
for (size_t i = 0; i < places_.size(); ++i) {
// Insert ScaleCost OpHandle
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(places_[i]);
auto *op_handle = new ScaleLossGradOpHandle(
result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation),
nranks, local_scopes_[i], places_[i], dev_ctx, dtype);
loss_scale, local_scopes_[i], places_[i], dev_ctx, dtype);
result->Get<GraphOps>(kGraphOps).emplace_back(op_handle);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
......@@ -669,9 +441,8 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(
}
}
void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
ir::Node *node,
size_t num_places) const {
void MultiDevSSAGraphBuilderBase::CreateComputationalOps(
ir::Graph *result, ir::Node *node, size_t num_places) const {
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
auto p = places_[scope_idx];
auto s = local_scopes_[scope_idx];
......@@ -681,9 +452,9 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
}
}
VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
const std::string &og,
int dst_dev_id) const {
VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp(ir::Graph *result,
const std::string &og,
int dst_dev_id) const {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result->Get<GraphOps>(kGraphOps).emplace_back(new ReduceOpHandle(
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
......@@ -712,51 +483,273 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
return var;
}
int MultiDevSSAGraphBuilder::CreateDistTrainOp(
ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const {
int op_dev_id = -1;
std::vector<std::string> input_var_names;
std::vector<std::string> output_var_names;
for (ir::Node *input : node->inputs) {
input_var_names.push_back(input->Name());
bool MultiDevSSAGraphBuilderBase::IsScaleLossOp(ir::Node *node) const {
return boost::get<int>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
(static_cast<int>(OpRole::kBackward) |
static_cast<int>(OpRole::kLoss)) &&
!loss_var_name_.empty(); // If loss_var is empty. This is test mode
}
bool MultiDevSSAGraphBuilderBase::IsSparseGradient(
const std::string &og) const {
PADDLE_ENFORCE(all_vars_.count(og) != 0);
if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
return true;
}
for (ir::Node *output : node->outputs) {
output_var_names.push_back(output->Name());
return false;
}
void AllReduceSSAGraphBuilder::InsertCollectiveOp(
ir::Graph *result, const std::string &p_name,
const std::string &g_name) const {
if (IsSparseGradient(g_name)) {
CreateReduceOp(result, g_name, 0);
CreateBroadcastOp(result, g_name, 0);
} else {
CreateAllReduceOp(result, g_name);
}
}
if (node->Op()->Type() == "split_byref" ||
node->Op()->Type() == "split_selected_rows" ||
node->Op()->Type() == "split_ids") {
// TODO(paddle-dev): getting the first var is not safe.
op_dev_id = GetVarDeviceID(input_var_names[0], *sharded_var_device);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
op_dev_id = GetAppropriateDeviceID(input_var_names);
for (auto &varname : input_var_names) {
sharded_var_device->emplace(varname, op_dev_id);
int BalanceVarSSAGraphBuilder::GetVarDeviceID(
const std::string &varname) const {
auto got = sharded_var_device_.find(varname);
if (got == sharded_var_device_.end()) {
auto pos = varname.find(framework::kNewGradSuffix);
if (pos != std::string::npos) {
got = sharded_var_device_.find(varname.substr(0, pos));
}
}
return got == sharded_var_device_.end() ? -1 : got->second;
}
int BalanceVarSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const {
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
return -1;
}
if (!OpHaveRole(*node, framework::OpRole::kOptimize)) {
return -1;
}
auto param_grad = boost::get<std::vector<std::string>>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
int dev_id = GetVarDeviceID(param_grad[1]);
PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]",
node->Op()->Type(), param_grad[0], param_grad[1]);
return dev_id;
}
size_t BalanceVarSSAGraphBuilder::GetAppropriateDeviceID(
const std::vector<std::string> &var_names) const {
int64_t numel_sum = 0;
for (auto var_name : var_names) {
if (all_vars_.find(var_name) == all_vars_.end()) continue;
auto var_desc = all_vars_.at(var_name);
PADDLE_ENFORCE_NOT_NULL(var_desc);
auto dim = framework::make_ddim(var_desc->GetShape());
int64_t numel = framework::product(dim);
PADDLE_ENFORCE_GT(numel, 0);
numel_sum += numel;
}
auto smallest =
std::min_element(std::begin(balance_vars_), std::end(balance_vars_));
size_t dev_id =
static_cast<size_t>(std::distance(std::begin(balance_vars_), smallest));
balance_vars_[dev_id] += numel_sum;
return dev_id;
}
void BalanceVarSSAGraphBuilder::ResetState() const {
balance_vars_.clear();
sharded_var_device_.clear();
balance_vars_.resize(places_.size(), 0);
}
void ReduceSSAGraphBuilder::Init() const {
MultiDevSSAGraphBuilderBase::Init();
ResetState();
}
void ReduceSSAGraphBuilder::ResetState() const {
BalanceVarSSAGraphBuilder::ResetState();
bcast_var_name_set_.clear();
bcast_var_name_set_.resize(places_.size());
}
void ReduceSSAGraphBuilder::InsertCollectiveOp(
ir::Graph *result, const std::string &p_name,
const std::string &g_name) const {
size_t cur_device_id = GetAppropriateDeviceID({g_name});
CreateReduceOp(result, g_name, cur_device_id);
sharded_var_device_.emplace(g_name, cur_device_id);
bcast_var_name_set_[cur_device_id].emplace(p_name);
}
bool ReduceSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
ir::Node *node) const {
int op_dev_id = BalanceVarSSAGraphBuilder::GetOpDeviceID(node);
if (op_dev_id != -1) {
// This op only runs on one specific device.
CreateComputationalOp(result, node, op_dev_id);
for (ir::Node *n : node->outputs) {
sharded_var_device_.emplace(n->Name(), op_dev_id);
}
return true;
}
return false;
}
void ReduceSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const {
if (UseGPU()) {
if (strategy_.fuse_broadcast_op_) {
CreateFusedBroadcastOp(result, bcast_var_name_set_);
} else {
for (size_t dev_id = 0; dev_id < bcast_var_name_set_.size(); ++dev_id) {
auto &to_bcast_set = bcast_var_name_set_[dev_id];
for (auto &bcast_name : to_bcast_set) {
CreateBroadcastOp(result, bcast_name, dev_id);
}
}
}
for (auto &varname : output_var_names) {
sharded_var_device->emplace(varname, op_dev_id);
}
}
int ReduceSSAGraphBuilder::GetOpDeviceID(
ir::Node *node,
std::unordered_map<std::string, std::vector<ir::Node *>> *delay_ops) const {
if (!OpHaveRole(*node, framework::OpRole::kOptimize)) {
return -1;
}
auto param_grad = boost::get<std::vector<std::string>>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
int dev_id = GetVarDeviceID(param_grad[1]);
if (dev_id == -1) {
(*delay_ops)[param_grad[1]].push_back(node);
return -2;
}
return dev_id;
}
std::vector<ir::Node *> ReduceSSAGraphBuilder::SortOperations(
const ir::Graph &graph) const {
std::vector<ir::Node *> sorted_ops = ir::TopologySortOperations(graph);
return SortForReduceMode(sorted_ops);
}
std::vector<ir::Node *> ReduceSSAGraphBuilder::SortForReduceMode(
const std::vector<ir::Node *> &topo_ops) const {
std::vector<ir::Node *> sorted_ops;
std::unordered_map<std::string, std::vector<ir::Node *>> delayed_op;
sorted_ops.reserve(topo_ops.size());
ResetState();
auto insert_delayed_op = [&](const std::string &var_name, int dev_id) {
sharded_var_device_.emplace(var_name, dev_id);
if (delayed_op.count(var_name)) {
auto &ops = delayed_op.at(var_name);
sorted_ops.insert(sorted_ops.end(), ops.begin(), ops.end());
delayed_op.at(var_name).clear();
}
} else if (node->Op()->Type() == "concat") {
op_dev_id = GetVarDeviceID(input_var_names[0], *sharded_var_device);
for (auto &varname : output_var_names) {
sharded_var_device->emplace(varname, op_dev_id);
};
for (ir::Node *node : topo_ops) {
int op_dev_id = GetOpDeviceID(node, &delayed_op);
if (op_dev_id > -1) {
// This op only runs on one specific device.
sorted_ops.emplace_back(node);
for (ir::Node *n : node->outputs) {
insert_delayed_op(n->Name(), op_dev_id);
}
} else if (op_dev_id == -1) {
// This op runs on all devices, and its output may have parameter's
// gradients.
sorted_ops.emplace_back(node);
bool is_bk_op =
static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward));
if (!is_bk_op) continue;
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
std::vector<std::string> backward_vars;
try {
backward_vars =
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
} catch (boost::bad_get e) {
}
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
for (size_t i = 0; i < backward_vars.size(); i += 2) {
auto &g_name = backward_vars[i + 1];
size_t cur_device_id = GetAppropriateDeviceID({g_name});
insert_delayed_op(g_name, static_cast<int>(cur_device_id));
}
} else if (op_dev_id == -2) {
// The Op on which the Op depends has not yet been generated.
}
} else {
LOG(ERROR) << "got unexpected dist op: " << node->Op()->Type();
PADDLE_THROW(
"the distribute training related op should be in [split_byref, "
"concat].");
}
PADDLE_ENFORCE(op_dev_id != -1,
"can not find right place for distributed op: %s",
node->Op()->Type());
PADDLE_ENFORCE_EQ(sorted_ops.size(), topo_ops.size());
CreateComputationalOp(result, node, op_dev_id);
return op_dev_id;
ResetState();
return sorted_ops;
}
void DistSSAGraphBuilder::Init() const {
MultiDevSSAGraphBuilderBase::Init();
ResetState();
}
void DistSSAGraphBuilder::ResetState() const {
BalanceVarSSAGraphBuilder::ResetState();
bcast_var_name_set_.clear();
bcast_var_name_set_.resize(places_.size());
}
bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
ir::Node *node) const {
bool insert_op = false;
if (OpHaveRole(*node, OpRole::kRPC)) {
int op_dev_id = CreateRPCOp(result, node);
PADDLE_ENFORCE(op_dev_id != -1,
"Can not schedule the RPC operator to the right place.");
if (node->Op()->Type() == "recv") {
auto recv_vars_attr =
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE(recv_vars_attr.size() == 2UL); // [parameter, gradient]
if (recv_vars_attr[0].find(".block") == std::string::npos) {
bcast_var_name_set_[op_dev_id].emplace(recv_vars_attr[0]);
}
}
insert_op = true;
need_broadcast_var_ = true;
} else if (OpHaveRole(*node, OpRole::kDist)) {
int op_dev_id = CreateDistTrainOp(result, node);
if (node->Op()->Type() == "concat") {
auto origin_param_name = node->Op()->OutputArgumentNames()[0];
bcast_var_name_set_[op_dev_id].emplace(origin_param_name);
}
insert_op = true;
} else {
int op_dev_id = GetOpDeviceID(node);
if (op_dev_id != -1) { // This op only runs on one specific device.
CreateComputationalOp(result, node, op_dev_id);
for (ir::Node *n : node->outputs) {
sharded_var_device_.emplace(n->Name(), op_dev_id);
}
insert_op = true;
}
}
return insert_op;
}
void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
......@@ -775,13 +768,11 @@ void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
}
// Create RPC related op handles that connects its in ops and out ops.
int MultiDevSSAGraphBuilder::CreateRPCOp(
ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const {
int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const {
int op_dev_id = -1;
if (node->Op()->Type() == "send") {
// TODO(paddle-dev): getting the first var is not safe.
op_dev_id = GetVarDeviceID(node->inputs[0]->Name(), *sharded_var_device);
op_dev_id = GetVarDeviceID(node->inputs[0]->Name());
PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]),
"This hack no longer holds, please fix.");
// the variable name which contains .block means it was splited by
......@@ -799,9 +790,9 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
VLOG(10) << "send grad " << input_var_names[0] << " origin "
<< send_param_grad[1] << " place: " << op_dev_id;
for (auto &varname : input_var_names) {
sharded_var_device->emplace(varname, op_dev_id);
sharded_var_device_.emplace(varname, op_dev_id);
}
sharded_var_device->emplace(send_param_grad[1], op_dev_id);
sharded_var_device_.emplace(send_param_grad[1], op_dev_id);
}
} else if (node->Op()->Type() == "recv") {
std::vector<std::string> output_var_names;
......@@ -811,7 +802,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
auto recv_param_grad = boost::get<std::vector<std::string>>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
if (recv_param_grad.size() == 2U) {
op_dev_id = GetVarDeviceID(recv_param_grad[1], *sharded_var_device);
op_dev_id = GetVarDeviceID(recv_param_grad[1]);
VLOG(10) << "recv param " << recv_param_grad[0]
<< " get grad place: " << recv_param_grad[1]
<< " place: " << op_dev_id;
......@@ -819,7 +810,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
op_dev_id = GetAppropriateDeviceID(output_var_names);
}
for (auto &varname : output_var_names) {
sharded_var_device->emplace(varname, op_dev_id);
sharded_var_device_.emplace(varname, op_dev_id);
}
} else {
// send_barrier, fetch_barrier will run on place 0;
......@@ -846,7 +837,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
for (ir::Node *output : node->outputs) {
int outvar_dev_id = op_dev_id;
if (node->Op()->Type() == "fetch_barrier") {
outvar_dev_id = GetVarDeviceID(output->Name(), *sharded_var_device);
outvar_dev_id = GetVarDeviceID(output->Name());
PADDLE_ENFORCE_NE(outvar_dev_id, -1, "output name %s", output->Name());
}
p = places_[outvar_dev_id];
......@@ -863,29 +854,124 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
return op_dev_id;
}
bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
PADDLE_ENFORCE(all_vars_.count(og) != 0);
if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
return true;
int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
ir::Node *node) const {
int op_dev_id = -1;
std::vector<std::string> input_var_names;
std::vector<std::string> output_var_names;
for (ir::Node *input : node->inputs) {
input_var_names.push_back(input->Name());
}
return false;
for (ir::Node *output : node->outputs) {
output_var_names.push_back(output->Name());
}
if (node->Op()->Type() == "split_byref" ||
node->Op()->Type() == "split_selected_rows" ||
node->Op()->Type() == "split_ids") {
// TODO(paddle-dev): getting the first var is not safe.
op_dev_id = GetVarDeviceID(input_var_names[0]);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
op_dev_id = GetAppropriateDeviceID(input_var_names);
for (auto &varname : input_var_names) {
sharded_var_device_.emplace(varname, op_dev_id);
}
}
for (auto &varname : output_var_names) {
sharded_var_device_.emplace(varname, op_dev_id);
}
} else if (node->Op()->Type() == "concat") {
op_dev_id = GetVarDeviceID(input_var_names[0]);
for (auto &varname : output_var_names) {
sharded_var_device_.emplace(varname, op_dev_id);
}
} else {
LOG(ERROR) << "got unexpected dist op: " << node->Op()->Type();
PADDLE_THROW(
"the distribute training related op should be in [split_byref, "
"concat].");
}
PADDLE_ENFORCE(op_dev_id != -1,
"can not find right place for distributed op: %s",
node->Op()->Type());
CreateComputationalOp(result, node, op_dev_id);
return op_dev_id;
}
bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
return boost::get<int>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
(static_cast<int>(OpRole::kBackward) |
static_cast<int>(OpRole::kLoss)) &&
!loss_var_name_.empty(); // If loss_var is empty. This is test mode
void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
const std::string &p_name,
const std::string &g_name) const {
size_t cur_device_id = 0;
switch (strategy_.reduce_) {
case BuildStrategy::ReduceStrategy::kReduce:
cur_device_id = GetAppropriateDeviceID({g_name});
CreateReduceOp(result, g_name, cur_device_id);
sharded_var_device_.emplace(g_name, cur_device_id);
break;
case BuildStrategy::ReduceStrategy::kAllReduce:
if (IsSparseGradient(g_name)) {
CreateReduceOp(result, g_name, 0);
CreateBroadcastOp(result, g_name, 0);
} else {
CreateAllReduceOp(result, g_name);
}
break;
default:
LOG(FATAL) << "Unknown reduce strategy.";
break;
}
}
void DistSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const {
if (need_broadcast_var_ ||
(UseGPU() &&
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce)) {
if (strategy_.fuse_broadcast_op_) {
CreateFusedBroadcastOp(result, bcast_var_name_set_);
} else {
for (size_t dev_id = 0; dev_id < bcast_var_name_set_.size(); ++dev_id) {
auto &to_bcast_set = bcast_var_name_set_[dev_id];
for (auto &bcast_name : to_bcast_set) {
CreateBroadcastOp(result, bcast_name, dev_id);
}
}
}
}
}
std::unordered_set<std::string> &MultiDevSSAGraphBuilder() {
static std::unordered_set<std::string> regs;
return regs;
}
static int MultiDevSSAGraphBuilderRegister(const std::string &builder_mode) {
MultiDevSSAGraphBuilder().insert(builder_mode);
return 0;
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(multi_devices_pass,
paddle::framework::details::MultiDevSSAGraphBuilder)
.RequirePassAttr(paddle::framework::details::kLossVarName)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes)
.RequirePassAttr(paddle::framework::details::kStrategy)
.RequirePassAttr(paddle::framework::details::kNRanks);
#define REGISTER_MULTI_DEVICES_PASS(pass_name, pass_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
_reg_ssa_graph_builder_##pass_name, \
"REGISTER_MULTI_DEVICES_PASS must be called in global namespace."); \
int _reg_ssa_graph_builder_entry_##pass_name = \
paddle::framework::details::MultiDevSSAGraphBuilderRegister(#pass_name); \
REGISTER_PASS(pass_name, pass_class) \
.RequirePassAttr(paddle::framework::details::kLossVarName) \
.RequirePassAttr(paddle::framework::details::kPlaces) \
.RequirePassAttr(paddle::framework::details::kLocalScopes) \
.RequirePassAttr(paddle::framework::details::kStrategy) \
.RequirePassAttr(paddle::framework::details::kNRanks)
REGISTER_MULTI_DEVICES_PASS(reduce_mode_multi_devices_pass,
paddle::framework::details::ReduceSSAGraphBuilder);
REGISTER_MULTI_DEVICES_PASS(
allreduce_mode_multi_devices_pass,
paddle::framework::details::AllReduceSSAGraphBuilder);
REGISTER_MULTI_DEVICES_PASS(dist_multi_devices_pass,
paddle::framework::details::DistSSAGraphBuilder);
......@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include <string>
#include <utility>
#include <vector>
......@@ -30,78 +31,70 @@ namespace framework {
class Scope;
namespace details {
class MultiDevSSAGraphBuilder : public ir::Pass {
constexpr char kLossVarName[] = "loss_var_name";
constexpr char kPlaces[] = "places";
constexpr char kLocalScopes[] = "local_scopes";
constexpr char kStrategy[] = "strategy";
constexpr char kNRanks[] = "nranks";
class MultiDevSSAGraphBuilderBase : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
private:
void CreateOpHandleIOs(ir::Graph *result, ir::Node *node,
size_t device_id) const;
void Init() const;
virtual void Init() const;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
mutable platform::NCCLContextMap *nccl_ctxs_;
#endif
virtual std::vector<ir::Node *> SortOperations(const ir::Graph &graph) const;
int GetVarDeviceID(
const std::string &varname,
const std::unordered_map<std::string, int> &sharded_var_device) const;
virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
const std::string &g_name) const = 0;
bool IsScaleLossOp(ir::Node *node) const;
virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const = 0;
virtual void InsertPostprocessOps(ir::Graph *result) const = 0;
int CreateRPCOp(
ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const;
int CreateDistTrainOp(
ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const;
bool UseGPU() const;
bool NeedCollectiveOps() const;
bool IsScaleLossOp(ir::Node *node) const;
void CreateComputationalOps(ir::Graph *result, ir::Node *node,
size_t num_places) const;
void CreateScaleLossGradOp(ir::Graph *result,
const std::string &loss_grad_name,
ir::Node *out_var_node,
ir::Node *out_var_node, size_t loss_scale,
proto::VarType::Type dtype) const;
VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
int dst_dev_id) const;
void CreateComputationalOp(ir::Graph *result, ir::Node *node,
int dev_id) const;
int GetOpDeviceID(
ir::Node *node,
const std::unordered_map<std::string, int> &sharded_var_device) const;
void InsertAllReduceOp(ir::Graph *result, const std::string &og) const;
bool IsSparseGradient(const std::string &og) const;
void InsertDataBalanceOp(ir::Graph *result,
const std::vector<std::string> &datas) const;
void CreateAllReduceOp(ir::Graph *result, const std::string &og) const;
void CreateBroadcastOp(ir::Graph *result, const std::string &p_name,
size_t src_dev_id) const;
void InsertScaleLossGradOp(ir::Graph *result, const ir::Node *node) const;
void CreateFusedBroadcastOp(
ir::Graph *result,
const std::vector<std::unordered_set<std::string>> &bcast_varnames) const;
bool IsSparseGradient(const std::string &og) const;
size_t GetAppropriateDeviceID(
const std::vector<std::string> &var_names) const;
void SetCommunicationContext(OpHandleBase *op_handle,
const platform::Place &p) const;
std::vector<ir::Node *> SortForReduceMode(
const std::vector<ir::Node *> &) const;
void CreateOpHandleIOs(ir::Graph *result, ir::Node *node,
size_t device_id) const;
int GetOpDeviceID(
ir::Node *node,
const std::unordered_map<std::string, int> &shared_var_device,
std::unordered_map<std::string, std::vector<ir::Node *>> *delay_ops)
const;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
mutable platform::NCCLContextMap *nccl_ctxs_;
#endif
mutable std::string loss_var_name_;
mutable std::vector<platform::Place> places_;
......@@ -109,8 +102,83 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
mutable BuildStrategy strategy_;
mutable std::unordered_map<std::string, VarDesc *> all_vars_;
};
class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
protected:
virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
const std::string &g_name) const;
virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const {
return false;
}
virtual void InsertPostprocessOps(ir::Graph *result) const {}
};
class BalanceVarSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
protected:
int GetVarDeviceID(const std::string &varname) const;
int GetOpDeviceID(ir::Node *node) const;
size_t GetAppropriateDeviceID(
const std::vector<std::string> &var_names) const;
virtual void ResetState() const;
mutable std::unordered_map<std::string, int> sharded_var_device_;
mutable std::vector<int64_t> balance_vars_;
};
class ReduceSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
protected:
virtual void Init() const;
virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
const std::string &g_name) const;
virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const;
virtual void InsertPostprocessOps(ir::Graph *result) const;
virtual std::vector<ir::Node *> SortOperations(const ir::Graph &graph) const;
virtual void ResetState() const;
int GetOpDeviceID(ir::Node *node,
std::unordered_map<std::string, std::vector<ir::Node *>>
*delay_ops) const;
std::vector<ir::Node *> SortForReduceMode(
const std::vector<ir::Node *> &topo_ops) const;
mutable std::vector<std::unordered_set<std::string>> bcast_var_name_set_;
};
class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
protected:
virtual void Init() const;
virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const;
virtual void InsertPostprocessOps(ir::Graph *result) const;
virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
const std::string &g_name) const;
virtual void ResetState() const;
int CreateRPCOp(ir::Graph *result, ir::Node *node) const;
int CreateDistTrainOp(ir::Graph *result, ir::Node *node) const;
mutable std::vector<std::unordered_set<std::string>> bcast_var_name_set_;
mutable bool need_broadcast_var_{false};
};
std::unordered_set<std::string> &MultiDevSSAGraphBuilder();
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -946,13 +946,6 @@ All parameter, weight, gradient are variables in Paddle.
R"DOC(The type is STR, debug_graphviz_path indicate the path that
writing the SSA Graph to file in the form of graphviz, you.
It is useful for debugging. Default "")DOC")
.def_property(
"enable_data_balance",
[](const BuildStrategy &self) { return self.enable_data_balance_; },
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized.");
self.enable_data_balance_ = b;
}) // FIXME(chengudo): enable_data_balance seems not important
.def_property(
"enable_sequential_execution",
[](const BuildStrategy &self) {
......@@ -1007,6 +1000,10 @@ All parameter, weight, gradient are variables in Paddle.
"memory_optimize",
[](const BuildStrategy &self) { return self.memory_optimize_; },
[](BuildStrategy &self, bool b) { self.memory_optimize_ = b; })
.def_property(
"is_distribution",
[](const BuildStrategy &self) { return self.is_distribution_; },
[](BuildStrategy &self, bool b) { self.is_distribution_ = b; })
.def_property(
"memory_early_delete",
[](const BuildStrategy &self) { return self.memory_early_delete_; },
......
......@@ -29,6 +29,15 @@ ExecutionStrategy = core.ParallelExecutor.ExecutionStrategy
BuildStrategy = core.ParallelExecutor.BuildStrategy
def _is_pserver_mode(main_program):
main = main_program if main_program \
else framework.default_main_program()
for op in main.global_block().ops:
if op.type in ["send", "recv"]:
return True
return False
class ParallelExecutor(object):
"""
ParallelExecutor is designed for data parallelism, which focuses on distributing
......@@ -128,6 +137,11 @@ class ParallelExecutor(object):
build_strategy = BuildStrategy()
build_strategy.num_trainers = num_trainers
build_strategy.trainer_id = trainer_id
# FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode,
# num_trainers is 1, so the current fields of build_strategy doesn't tell if
# it's distributed model.
build_strategy.is_distribution = _is_pserver_mode(
main_program) or num_trainers > 1
# step4: get main_program, scope, local_scopes
main = main_program if main_program \
......
......@@ -75,8 +75,6 @@ class TestReaderReset(unittest.TestCase):
exe.run(startup_prog)
build_strategy = fluid.BuildStrategy()
if with_double_buffer:
build_strategy.enable_data_balance = True
exec_strategy = fluid.ExecutionStrategy()
parallel_exe = fluid.ParallelExecutor(
use_cuda=self.use_cuda,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册