未验证 提交 eabb2105 编写于 作者: C chengduo 提交者: GitHub

Refactor MultiDevSSAGraphBuilder (#15090)

* Refactor ParallelExecutor
test=develop

* extract Reduce and AllReduce mode from MultiDevSSAGraphBuilder
test=develop

* Refactor MultiDevSSAGraphBuilder
test=developt

* Remove enable_data_balance
test=develop

* code refine
test=develop

* remove data balance
test=develop

* refine ScaleLossGradOp
test=develop

* remove uncessary file
test=develop

* code refine
test=develop

* modify  function name
test=develop

* follow comments
test=develop

* add is_distribution field
test=develop

* set is_distribution
test=develop

* fix DistSSAGraphBuilder
test=develop
上级 875a07c3
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
#include <memory> #include <memory>
#include "paddle/fluid/framework/details/memory_reuse_types.h" #include "paddle/fluid/framework/details/memory_reuse_types.h"
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/sequential_execution_pass.h" #include "paddle/fluid/framework/details/sequential_execution_pass.h"
...@@ -86,10 +86,8 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -86,10 +86,8 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
if (strategy.memory_optimize_) { if (strategy.memory_optimize_) {
auto analysis_var_pass = AppendPass("analysis_var_pass"); auto analysis_var_pass = AppendPass("analysis_var_pass");
} }
// Convert graph to run on multi-devices.
auto multi_devices_pass = AppendPass("multi_devices_pass"); AppendMultiDevPass(strategy);
multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy",
&strategy_);
// Add a graph print pass to record a graph with device info. // Add a graph print pass to record a graph with device info.
if (!strategy_.debug_graphviz_path_.empty()) { if (!strategy_.debug_graphviz_path_.empty()) {
...@@ -115,6 +113,25 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -115,6 +113,25 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
} }
} }
// Convert graph to run on multi-devices.
void AppendMultiDevPass(const BuildStrategy &strategy) {
ir::Pass *multi_devices_pass;
if (strategy_.is_distribution_) {
multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
} else {
if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
multi_devices_pass =
AppendPass("allreduce_mode_multi_devices_pass").get();
} else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get();
} else {
PADDLE_THROW("Unknown reduce strategy.");
}
}
multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy",
&strategy_);
}
private: private:
BuildStrategy strategy_; BuildStrategy strategy_;
}; };
...@@ -131,6 +148,10 @@ std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy( ...@@ -131,6 +148,10 @@ std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy(
return pass_builder_; return pass_builder_;
} }
bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const {
return framework::details::MultiDevSSAGraphBuilder().count(pass_name) > 0;
}
std::unique_ptr<ir::Graph> BuildStrategy::Apply( std::unique_ptr<ir::Graph> BuildStrategy::Apply(
const ProgramDesc &main_program, const std::vector<platform::Place> &places, const ProgramDesc &main_program, const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::vector<Scope *> &local_scopes, const std::string &loss_var_name, const std::vector<Scope *> &local_scopes,
...@@ -145,22 +166,23 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -145,22 +166,23 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program)); std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) { for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) {
if (pass->Type() == "multi_devices_pass") { if (IsMultiDevPass(pass->Type())) {
pass->Erase("places"); pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>("places", &places); pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
pass->Erase("loss_var_name"); pass->Erase(kLossVarName);
pass->SetNotOwned<const std::string>("loss_var_name", &loss_var_name); pass->SetNotOwned<const std::string>(kLossVarName, &loss_var_name);
pass->Erase("local_scopes"); pass->Erase(kLocalScopes);
pass->SetNotOwned<const std::vector<Scope *>>("local_scopes", pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
&local_scopes); &local_scopes);
pass->Erase("nranks"); pass->Erase(kNRanks);
pass->Set<size_t>("nranks", new size_t(nranks)); pass->Set<size_t>(kNRanks, new size_t(nranks));
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr; platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
pass->Erase("nccl_ctxs"); pass->Erase("nccl_ctxs");
pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx); pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
#endif #endif
} else if (pass->Type() == "analysis_var_pass") { } else if (pass->Type() == "analysis_var_pass") {
const std::vector<OpDesc *> *all_op_descs = const std::vector<OpDesc *> *all_op_descs =
new std::vector<OpDesc *>(main_program.Block(0).AllOps()); new std::vector<OpDesc *>(main_program.Block(0).AllOps());
...@@ -201,7 +223,9 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -201,7 +223,9 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
USE_PASS(fuse_elewise_add_act_pass); USE_PASS(fuse_elewise_add_act_pass);
USE_PASS(graph_viz_pass); USE_PASS(graph_viz_pass);
USE_PASS(multi_batch_merge_pass); USE_PASS(multi_batch_merge_pass);
USE_PASS(multi_devices_pass); USE_PASS(reduce_mode_multi_devices_pass);
USE_PASS(allreduce_mode_multi_devices_pass);
USE_PASS(dist_multi_devices_pass);
USE_PASS(multi_devices_check_pass); USE_PASS(multi_devices_check_pass);
USE_PASS(multi_devices_print_pass); USE_PASS(multi_devices_print_pass);
USE_PASS(analysis_var_pass); USE_PASS(analysis_var_pass);
......
...@@ -74,8 +74,6 @@ struct BuildStrategy { ...@@ -74,8 +74,6 @@ struct BuildStrategy {
bool fuse_elewise_add_act_ops_{false}; bool fuse_elewise_add_act_ops_{false};
bool enable_data_balance_{false};
bool memory_optimize_{false}; bool memory_optimize_{false};
bool memory_early_delete_{false}; bool memory_early_delete_{false};
...@@ -84,6 +82,10 @@ struct BuildStrategy { ...@@ -84,6 +82,10 @@ struct BuildStrategy {
bool fuse_broadcast_op_{false}; bool fuse_broadcast_op_{false};
// FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode,
// num_trainers is 1, so the current fields of build_strategy doesn't tell if
// it's distributed model.
bool is_distribution_{false};
int num_trainers_{1}; int num_trainers_{1};
int trainer_id_{0}; int trainer_id_{0};
std::vector<std::string> trainers_endpoints_; std::vector<std::string> trainers_endpoints_;
...@@ -104,6 +106,8 @@ struct BuildStrategy { ...@@ -104,6 +106,8 @@ struct BuildStrategy {
bool IsFinalized() const { return is_finalized_; } bool IsFinalized() const { return is_finalized_; }
bool IsMultiDevPass(const std::string &pass_name) const;
// Apply the passes built by the pass_builder_. The passes will be // Apply the passes built by the pass_builder_. The passes will be
// applied to the Program and output an ir::Graph. // applied to the Program and output an ir::Graph.
std::unique_ptr<ir::Graph> Apply(const ProgramDesc &main_program, std::unique_ptr<ir::Graph> Apply(const ProgramDesc &main_program,
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
#include <string> #include <string>
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
...@@ -21,7 +21,15 @@ namespace paddle { ...@@ -21,7 +21,15 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { class SSAGraghBuilderWithChecker : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
PADDLE_ENFORCE(IsValidGraph(graph.get()));
return graph;
}
bool IsValidGraph(const ir::Graph *graph) const {
std::unordered_map<OpHandleBase *, size_t> pending_ops; std::unordered_map<OpHandleBase *, size_t> pending_ops;
std::unordered_set<VarHandleBase *> pending_vars; std::unordered_set<VarHandleBase *> pending_vars;
std::unordered_set<VarHandleBase *> ready_vars; std::unordered_set<VarHandleBase *> ready_vars;
...@@ -82,7 +90,9 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { ...@@ -82,7 +90,9 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
ready_vars.clear(); ready_vars.clear();
} }
return true; return true;
} }
};
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include <string>
namespace paddle {
namespace framework {
namespace details {
class SSAGraghBuilderWithChecker : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
PADDLE_ENFORCE(IsValidGraph(graph.get()));
return graph;
}
bool IsValidGraph(const ir::Graph* graph) const;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -134,15 +134,8 @@ void AddOutputToLeafOps(ir::Graph *graph) { ...@@ -134,15 +134,8 @@ void AddOutputToLeafOps(ir::Graph *graph) {
} }
} // namespace } // namespace
static const char kLossVarName[] = "loss_var_name"; void MultiDevSSAGraphBuilderBase::Init() const {
static const char kPlaces[] = "places";
static const char kLocalScopes[] = "local_scopes";
static const char kStrategy[] = "strategy";
static const char kNRanks[] = "nranks";
void MultiDevSSAGraphBuilder::Init() const {
all_vars_.clear(); all_vars_.clear();
balance_vars_.clear();
loss_var_name_ = Get<const std::string>(kLossVarName); loss_var_name_ = Get<const std::string>(kLossVarName);
places_ = Get<const std::vector<platform::Place>>(kPlaces); places_ = Get<const std::vector<platform::Place>>(kPlaces);
...@@ -151,31 +144,16 @@ void MultiDevSSAGraphBuilder::Init() const { ...@@ -151,31 +144,16 @@ void MultiDevSSAGraphBuilder::Init() const {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
nccl_ctxs_ = &Get<platform::NCCLContextMap>("nccl_ctxs"); nccl_ctxs_ = &Get<platform::NCCLContextMap>("nccl_ctxs");
#endif #endif
balance_vars_.resize(places_.size(), 0);
if (strategy_.enable_data_balance_ && places_.size() == 1) {
LOG(WARNING) << "It is no need to enable data balance when there is only "
"one place. enable_data_balance is set to False.";
strategy_.enable_data_balance_ = false;
}
} }
std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
Init(); Init();
// Give the topology sort order and rebuild the graph structure. std::vector<ir::Node *> sorted_ops = SortOperations(*graph);
std::vector<ir::Node *> sorted_ops = ir::TopologySortOperations(*graph);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
sorted_ops = SortForReduceMode(sorted_ops);
}
auto nodes = graph->ReleaseNodes(); auto nodes = graph->ReleaseNodes();
ir::Graph &result = *graph; ir::Graph &result = *graph;
size_t nranks = Get<size_t>(kNRanks);
for (auto &node : nodes) { for (auto &node : nodes) {
if (node->IsVar() && node->Var()) { if (node->IsVar() && node->Var()) {
all_vars_.emplace(node->Name(), node->Var()); all_vars_.emplace(node->Name(), node->Var());
...@@ -187,142 +165,57 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -187,142 +165,57 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
result.Set(kGraphDepVars, new GraphDepVars); result.Set(kGraphDepVars, new GraphDepVars);
result.Set(kGraphOps, new GraphOps); result.Set(kGraphOps, new GraphOps);
std::vector<std::unordered_set<std::string>> bcast_var_name_set;
bcast_var_name_set.resize(places_.size());
bool is_forwarding = true; bool is_forwarding = true;
bool is_dist_train = false; bool insert_collection_ops = NeedCollectiveOps();
std::unordered_map<std::string, int> sharded_var_device;
for (ir::Node *node : sorted_ops) { for (ir::Node *node : sorted_ops) {
if (OpHaveRole(*node, OpRole::kRPC)) { if (DealWithSpecialOp(&result, node)) {
int op_dev_id = CreateRPCOp(&result, node, &sharded_var_device); continue;
PADDLE_ENFORCE(op_dev_id != -1, } else {
"Can not schedule the RPC operator to the right place."); // This op runs on all devices
if (node->Op()->Type() == "recv") { if (IsScaleLossOp(node)) {
auto recv_vars_attr =
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE(recv_vars_attr.size() == 2UL); // [parameter, gradient]
if (recv_vars_attr[0].find(".block") == std::string::npos) {
bcast_var_name_set[op_dev_id].emplace(recv_vars_attr[0]);
}
}
is_dist_train = true;
} else if (OpHaveRole(*node, OpRole::kDist)) {
int op_dev_id = CreateDistTrainOp(&result, node, &sharded_var_device);
if (node->Op()->Type() == "concat") {
auto origin_param_name = node->Op()->OutputArgumentNames()[0];
bcast_var_name_set[op_dev_id].emplace(origin_param_name);
}
} else if (IsScaleLossOp(node)) {
// user can customize loss@grad if not use_default_grad_scale_ // user can customize loss@grad if not use_default_grad_scale_
if (strategy_.gradient_scale_ != InsertScaleLossGradOp(&result, node);
BuildStrategy::GradientScaleStrategy::kCustomized) {
// TODO(paddle-dev): Why is there no input for this op_handle?
auto loss_grad_name = node->Op()->OutputArgumentNames()[0];
auto out_dtype = all_vars_.at(loss_grad_name)->GetDataType();
CreateScaleLossGradOp(&result, loss_grad_name, node->outputs[0],
out_dtype);
}
// This assumes the backward generating code will ensure IsScaleLossOp // This assumes the backward generating code will ensure IsScaleLossOp
// is true only for the op that scale the final scalar loss. // is true only for the op that scale the final scalar loss.
// It also assumes backward op will always follow the forward op in // It also assumes backward op will always follow the forward op in
// the block. // the block.
is_forwarding = false; is_forwarding = false;
} else {
int op_dev_id = GetOpDeviceID(node, sharded_var_device);
if (op_dev_id != -1) { // This op only runs on one specific device.
CreateComputationalOp(&result, node, op_dev_id);
for (ir::Node *n : node->outputs) {
sharded_var_device.emplace(n->Name(), op_dev_id);
}
} else {
// This op runs on all devices, and its output may have parameter's
// gradients.
// TODO(paddle-dev): Why is so special about "read" op?
if (node->Op()->Type() == "read" && strategy_.enable_data_balance_) {
node->Op()->SetAttr("throw_eof_exp", false);
CreateComputationalOps(&result, node, places_.size());
const auto &data_var_names = node->Op()->Output("Out");
InsertDataBalanceOp(&result, data_var_names);
} else { } else {
CreateComputationalOps(&result, node, places_.size()); CreateComputationalOps(&result, node, places_.size());
} }
if (!is_forwarding && nranks > 1UL) { // Insert collection ops
if (!is_forwarding && insert_collection_ops) {
try {
bool is_bk_op = bool is_bk_op =
static_cast<bool>(boost::get<int>(node->Op()->GetAttr( static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) & OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward)); static_cast<int>(OpRole::kBackward));
if (!is_bk_op) continue; if (!is_bk_op) continue;
// Currently, we assume that once gradient is generated, it can be // Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once. // broadcast, and each gradient is only broadcast once.
try { auto backward_vars =
auto backward_vars = boost::get<std::vector<std::string>>( boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName())); OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0); PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
for (size_t i = 0; i < backward_vars.size(); i += 2) { for (size_t i = 0; i < backward_vars.size(); i += 2) {
auto &p_name = backward_vars[i]; auto &p_name = backward_vars[i];
auto &g_name = backward_vars[i + 1]; auto &g_name = backward_vars[i + 1];
VLOG(10) << "Bcast " << g_name << " for parameter " << p_name; VLOG(10) << "Bcast " << g_name << " for parameter " << p_name;
size_t cur_device_id = -1;
switch (strategy_.reduce_) { InsertCollectiveOp(&result, p_name, g_name);
case BuildStrategy::ReduceStrategy::kReduce:
cur_device_id = GetAppropriateDeviceID({g_name});
CreateReduceOp(&result, g_name, cur_device_id);
sharded_var_device.emplace(g_name, cur_device_id);
if (!is_dist_train) {
bcast_var_name_set[cur_device_id].emplace(p_name);
}
break;
case BuildStrategy::ReduceStrategy::kAllReduce:
if (IsSparseGradient(g_name)) {
CreateReduceOp(&result, g_name, 0);
CreateBroadcastOp(&result, g_name, 0);
} else {
InsertAllReduceOp(&result, g_name);
}
break;
default:
LOG(FATAL) << "Unknown reduce strategy ";
break;
}
} }
} catch (boost::bad_get e) { } catch (boost::bad_get e) {
} }
} }
} }
} }
}
bool use_gpu = false;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
use_gpu = nccl_ctxs_ != nullptr;
#endif
// Insert broadcast operators principle: InsertPostprocessOps(&result);
// 1. Broadcast optimized parameters in Reduce strategy;
// 2. No need broadcast optimized parameters in AllReduce strategy because of
// the optimization sub-graph would be run on every GPU;
// 3. Allways broadcast received parameters in Distribute Training.
if ((use_gpu &&
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) ||
is_dist_train) {
if (strategy_.fuse_broadcast_op_) {
CreateFusedBroadcastOp(&result, bcast_var_name_set);
} else {
for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
auto &to_bcast_set = bcast_var_name_set[dev_id];
for (auto &bcast_name : to_bcast_set) {
CreateBroadcastOp(&result, bcast_name, dev_id);
}
}
}
}
/* /*
Dependency graph has been constructed. However, there are still data Dependency graph has been constructed. However, there are still data
hazards need to be handled. hazards need to be handled.
...@@ -337,65 +230,52 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -337,65 +230,52 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
return graph; return graph;
} }
std::vector<ir::Node *> MultiDevSSAGraphBuilder::SortForReduceMode( void MultiDevSSAGraphBuilderBase::InsertScaleLossGradOp(
const std::vector<ir::Node *> &topo_ops) const { ir::Graph *result, const ir::Node *node) const {
std::unordered_map<std::string, int> sharded_var_device; // user can customize loss@grad if not use_default_grad_scale_
std::vector<ir::Node *> sorted_ops; size_t loss_scale = 0;
std::unordered_map<std::string, std::vector<ir::Node *>> delayed_op; switch (this->strategy_.gradient_scale_) {
sorted_ops.reserve(topo_ops.size()); case BuildStrategy::GradientScaleStrategy::kOne:
loss_scale = 1;
auto insert_delayed_op = [&](const std::string &var_name, int dev_id) { break;
sharded_var_device.emplace(var_name, dev_id); case BuildStrategy::GradientScaleStrategy::kCoeffNumDevice:
if (delayed_op.count(var_name)) { loss_scale = Get<size_t>(kNRanks);
auto &ops = delayed_op.at(var_name); break;
sorted_ops.insert(sorted_ops.end(), ops.begin(), ops.end()); case BuildStrategy::GradientScaleStrategy::kCustomized:
delayed_op.at(var_name).clear(); loss_scale = 0;
break;
default:
LOG(FATAL) << "Unknown gradient scale strategy.";
break;
} }
};
for (ir::Node *node : topo_ops) { if (loss_scale) {
int op_dev_id = GetOpDeviceID(node, sharded_var_device, &delayed_op); // TODO(paddle-dev): Why is there no input for this op_handle?
if (op_dev_id > -1) { auto loss_grad_name = node->Op()->OutputArgumentNames()[0];
// This op only runs on one specific device. auto out_dtype = this->all_vars_.at(loss_grad_name)->GetDataType();
sorted_ops.emplace_back(node); this->CreateScaleLossGradOp(result, loss_grad_name, node->outputs[0],
for (ir::Node *n : node->outputs) { loss_scale, out_dtype);
insert_delayed_op(n->Name(), op_dev_id);
}
} else if (op_dev_id == -1) {
// This op runs on all devices, and its output may have parameter's
// gradients.
sorted_ops.emplace_back(node);
bool is_bk_op =
static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward));
if (!is_bk_op) continue;
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
std::vector<std::string> backward_vars;
try {
backward_vars =
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
} catch (boost::bad_get e) {
} }
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0); }
for (size_t i = 0; i < backward_vars.size(); i += 2) { std::vector<ir::Node *> MultiDevSSAGraphBuilderBase::SortOperations(
auto &g_name = backward_vars[i + 1]; const ir::Graph &graph) const {
size_t cur_device_id = GetAppropriateDeviceID({g_name}); return ir::TopologySortOperations(graph);
insert_delayed_op(g_name, static_cast<int>(cur_device_id)); }
}
} else if (op_dev_id == -2) {
// The Op on which the Op depends has not yet been generated.
}
}
PADDLE_ENFORCE_EQ(sorted_ops.size(), topo_ops.size()); bool MultiDevSSAGraphBuilderBase::UseGPU() const {
return sorted_ops; bool use_gpu = false;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
use_gpu = nccl_ctxs_ != nullptr;
#endif
return use_gpu;
} }
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result, bool MultiDevSSAGraphBuilderBase::NeedCollectiveOps() const {
return Get<size_t>(kNRanks) > 1;
}
void MultiDevSSAGraphBuilderBase::CreateOpHandleIOs(ir::Graph *result,
ir::Node *node, ir::Node *node,
size_t place_id) const { size_t place_id) const {
auto p = places_[place_id]; auto p = places_[place_id];
...@@ -420,28 +300,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result, ...@@ -420,28 +300,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
} }
} }
size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( void MultiDevSSAGraphBuilderBase::SetCommunicationContext(
const std::vector<std::string> &var_names) const {
int64_t numel_sum = 0;
for (auto var_name : var_names) {
if (all_vars_.find(var_name) == all_vars_.end()) continue;
auto var_desc = all_vars_.at(var_name);
PADDLE_ENFORCE_NOT_NULL(var_desc);
auto dim = framework::make_ddim(var_desc->GetShape());
int64_t numel = framework::product(dim);
PADDLE_ENFORCE_GT(numel, 0);
numel_sum += numel;
}
auto smallest =
std::min_element(std::begin(balance_vars_), std::end(balance_vars_));
size_t dev_id =
static_cast<size_t>(std::distance(std::begin(balance_vars_), smallest));
balance_vars_[dev_id] += numel_sum;
return dev_id;
}
void MultiDevSSAGraphBuilder::SetCommunicationContext(
OpHandleBase *op_handle, const platform::Place &p) const { OpHandleBase *op_handle, const platform::Place &p) const {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
if (nccl_ctxs_ == nullptr) { if (nccl_ctxs_ == nullptr) {
...@@ -454,7 +313,7 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext( ...@@ -454,7 +313,7 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext(
#endif #endif
} }
void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result, void MultiDevSSAGraphBuilderBase::CreateBroadcastOp(ir::Graph *result,
const std::string &p_name, const std::string &p_name,
size_t src_dev_id) const { size_t src_dev_id) const {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
...@@ -484,7 +343,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result, ...@@ -484,7 +343,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
} }
} }
void MultiDevSSAGraphBuilder::CreateFusedBroadcastOp( void MultiDevSSAGraphBuilderBase::CreateFusedBroadcastOp(
ir::Graph *result, ir::Graph *result,
const std::vector<std::unordered_set<std::string>> &bcast_varnames) const { const std::vector<std::unordered_set<std::string>> &bcast_varnames) const {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
...@@ -522,7 +381,7 @@ void MultiDevSSAGraphBuilder::CreateFusedBroadcastOp( ...@@ -522,7 +381,7 @@ void MultiDevSSAGraphBuilder::CreateFusedBroadcastOp(
} }
} }
void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result, void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result,
ir::Node *node, ir::Node *node,
int dev_id) const { int dev_id) const {
result->Get<GraphOps>(kGraphOps).emplace_back( result->Get<GraphOps>(kGraphOps).emplace_back(
...@@ -531,8 +390,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result, ...@@ -531,8 +390,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
CreateOpHandleIOs(result, node, dev_id); CreateOpHandleIOs(result, node, dev_id);
} }
void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result, void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
const std::string &og) const { ir::Graph *result, const std::string &og) const {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle( result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
...@@ -560,102 +419,15 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result, ...@@ -560,102 +419,15 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
} }
} }
void MultiDevSSAGraphBuilder::InsertDataBalanceOp( void MultiDevSSAGraphBuilderBase::CreateScaleLossGradOp(
ir::Graph *result, const std::vector<std::string> &datas) const {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result->Get<GraphOps>(kGraphOps).emplace_back(new DataBalanceOpHandle(
result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_));
#else
result->Get<GraphOps>(kGraphOps).emplace_back(new DataBalanceOpHandle(
result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
local_scopes_, places_));
#endif
auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
SetCommunicationContext(op_handle, p);
for (const std::string &d_name : datas) {
auto &vars = result->Get<GraphVars>(kGraphVars)[i][d_name];
PADDLE_ENFORCE(!vars.empty());
op_handle->AddInput(vars.back());
auto var = new VarHandle(
result->CreateEmptyNode(d_name, ir::Node::Type::kVariable),
vars.size(), i, d_name, p);
vars.emplace_back(var);
op_handle->AddOutput(var);
}
}
}
int MultiDevSSAGraphBuilder::GetOpDeviceID(
ir::Node *node,
const std::unordered_map<std::string, int> &sharded_var_device,
std::unordered_map<std::string, std::vector<ir::Node *>> *delay_ops) const {
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
return -1;
}
if (!OpHaveRole(*node, framework::OpRole::kOptimize)) {
return -1;
}
auto param_grad = boost::get<std::vector<std::string>>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
int dev_id = GetVarDeviceID(param_grad[1], sharded_var_device);
if (dev_id == -1) {
(*delay_ops)[param_grad[1]].push_back(node);
return -2;
}
return dev_id;
}
int MultiDevSSAGraphBuilder::GetOpDeviceID(
ir::Node *node,
const std::unordered_map<std::string, int> &sharded_var_device) const {
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
return -1;
}
if (!OpHaveRole(*node, framework::OpRole::kOptimize)) {
return -1;
}
auto param_grad = boost::get<std::vector<std::string>>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
int dev_id = GetVarDeviceID(param_grad[1], sharded_var_device);
PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]",
node->Op()->Type(), param_grad[0], param_grad[1]);
return dev_id;
}
int MultiDevSSAGraphBuilder::GetVarDeviceID(
const std::string &varname,
const std::unordered_map<std::string, int> &sharded_var_device) const {
auto got = sharded_var_device.find(varname);
if (got == sharded_var_device.end()) {
auto pos = varname.find(framework::kNewGradSuffix);
if (pos != std::string::npos) {
got = sharded_var_device.find(varname.substr(0, pos));
}
}
return got == sharded_var_device.end() ? -1 : got->second;
}
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(
ir::Graph *result, const std::string &loss_grad_name, ir::Graph *result, const std::string &loss_grad_name,
ir::Node *out_var_node, proto::VarType::Type dtype) const { ir::Node *out_var_node, size_t loss_scale,
size_t nranks = Get<size_t>("nranks"); proto::VarType::Type dtype) const {
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
// Insert ScaleCost OpHandle
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(places_[i]); auto *dev_ctx = platform::DeviceContextPool::Instance().Get(places_[i]);
auto *op_handle = new ScaleLossGradOpHandle( auto *op_handle = new ScaleLossGradOpHandle(
result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation), result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation),
nranks, local_scopes_[i], places_[i], dev_ctx, dtype); loss_scale, local_scopes_[i], places_[i], dev_ctx, dtype);
result->Get<GraphOps>(kGraphOps).emplace_back(op_handle); result->Get<GraphOps>(kGraphOps).emplace_back(op_handle);
// FIXME: Currently ScaleLossGradOp only use device_count as scale // FIXME: Currently ScaleLossGradOp only use device_count as scale
...@@ -669,9 +441,8 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp( ...@@ -669,9 +441,8 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(
} }
} }
void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result, void MultiDevSSAGraphBuilderBase::CreateComputationalOps(
ir::Node *node, ir::Graph *result, ir::Node *node, size_t num_places) const {
size_t num_places) const {
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
auto p = places_[scope_idx]; auto p = places_[scope_idx];
auto s = local_scopes_[scope_idx]; auto s = local_scopes_[scope_idx];
...@@ -681,7 +452,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result, ...@@ -681,7 +452,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
} }
} }
VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result, VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp(ir::Graph *result,
const std::string &og, const std::string &og,
int dst_dev_id) const { int dst_dev_id) const {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
...@@ -712,51 +483,273 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result, ...@@ -712,51 +483,273 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
return var; return var;
} }
int MultiDevSSAGraphBuilder::CreateDistTrainOp( bool MultiDevSSAGraphBuilderBase::IsScaleLossOp(ir::Node *node) const {
ir::Graph *result, ir::Node *node, return boost::get<int>(
std::unordered_map<std::string, int> *sharded_var_device) const { node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
int op_dev_id = -1; (static_cast<int>(OpRole::kBackward) |
std::vector<std::string> input_var_names; static_cast<int>(OpRole::kLoss)) &&
std::vector<std::string> output_var_names; !loss_var_name_.empty(); // If loss_var is empty. This is test mode
for (ir::Node *input : node->inputs) { }
input_var_names.push_back(input->Name());
bool MultiDevSSAGraphBuilderBase::IsSparseGradient(
const std::string &og) const {
PADDLE_ENFORCE(all_vars_.count(og) != 0);
if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
return true;
} }
for (ir::Node *output : node->outputs) { return false;
output_var_names.push_back(output->Name()); }
void AllReduceSSAGraphBuilder::InsertCollectiveOp(
ir::Graph *result, const std::string &p_name,
const std::string &g_name) const {
if (IsSparseGradient(g_name)) {
CreateReduceOp(result, g_name, 0);
CreateBroadcastOp(result, g_name, 0);
} else {
CreateAllReduceOp(result, g_name);
} }
}
if (node->Op()->Type() == "split_byref" || int BalanceVarSSAGraphBuilder::GetVarDeviceID(
node->Op()->Type() == "split_selected_rows" || const std::string &varname) const {
node->Op()->Type() == "split_ids") { auto got = sharded_var_device_.find(varname);
// TODO(paddle-dev): getting the first var is not safe. if (got == sharded_var_device_.end()) {
op_dev_id = GetVarDeviceID(input_var_names[0], *sharded_var_device); auto pos = varname.find(framework::kNewGradSuffix);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { if (pos != std::string::npos) {
op_dev_id = GetAppropriateDeviceID(input_var_names); got = sharded_var_device_.find(varname.substr(0, pos));
for (auto &varname : input_var_names) {
sharded_var_device->emplace(varname, op_dev_id);
} }
} }
for (auto &varname : output_var_names) { return got == sharded_var_device_.end() ? -1 : got->second;
sharded_var_device->emplace(varname, op_dev_id); }
int BalanceVarSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const {
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
return -1;
} }
} else if (node->Op()->Type() == "concat") { if (!OpHaveRole(*node, framework::OpRole::kOptimize)) {
op_dev_id = GetVarDeviceID(input_var_names[0], *sharded_var_device); return -1;
for (auto &varname : output_var_names) { }
sharded_var_device->emplace(varname, op_dev_id); auto param_grad = boost::get<std::vector<std::string>>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
int dev_id = GetVarDeviceID(param_grad[1]);
PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]",
node->Op()->Type(), param_grad[0], param_grad[1]);
return dev_id;
}
size_t BalanceVarSSAGraphBuilder::GetAppropriateDeviceID(
const std::vector<std::string> &var_names) const {
int64_t numel_sum = 0;
for (auto var_name : var_names) {
if (all_vars_.find(var_name) == all_vars_.end()) continue;
auto var_desc = all_vars_.at(var_name);
PADDLE_ENFORCE_NOT_NULL(var_desc);
auto dim = framework::make_ddim(var_desc->GetShape());
int64_t numel = framework::product(dim);
PADDLE_ENFORCE_GT(numel, 0);
numel_sum += numel;
} }
auto smallest =
std::min_element(std::begin(balance_vars_), std::end(balance_vars_));
size_t dev_id =
static_cast<size_t>(std::distance(std::begin(balance_vars_), smallest));
balance_vars_[dev_id] += numel_sum;
return dev_id;
}
void BalanceVarSSAGraphBuilder::ResetState() const {
balance_vars_.clear();
sharded_var_device_.clear();
balance_vars_.resize(places_.size(), 0);
}
void ReduceSSAGraphBuilder::Init() const {
MultiDevSSAGraphBuilderBase::Init();
ResetState();
}
void ReduceSSAGraphBuilder::ResetState() const {
BalanceVarSSAGraphBuilder::ResetState();
bcast_var_name_set_.clear();
bcast_var_name_set_.resize(places_.size());
}
void ReduceSSAGraphBuilder::InsertCollectiveOp(
ir::Graph *result, const std::string &p_name,
const std::string &g_name) const {
size_t cur_device_id = GetAppropriateDeviceID({g_name});
CreateReduceOp(result, g_name, cur_device_id);
sharded_var_device_.emplace(g_name, cur_device_id);
bcast_var_name_set_[cur_device_id].emplace(p_name);
}
bool ReduceSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
ir::Node *node) const {
int op_dev_id = BalanceVarSSAGraphBuilder::GetOpDeviceID(node);
if (op_dev_id != -1) {
// This op only runs on one specific device.
CreateComputationalOp(result, node, op_dev_id);
for (ir::Node *n : node->outputs) {
sharded_var_device_.emplace(n->Name(), op_dev_id);
}
return true;
}
return false;
}
void ReduceSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const {
if (UseGPU()) {
if (strategy_.fuse_broadcast_op_) {
CreateFusedBroadcastOp(result, bcast_var_name_set_);
} else { } else {
LOG(ERROR) << "got unexpected dist op: " << node->Op()->Type(); for (size_t dev_id = 0; dev_id < bcast_var_name_set_.size(); ++dev_id) {
PADDLE_THROW( auto &to_bcast_set = bcast_var_name_set_[dev_id];
"the distribute training related op should be in [split_byref, " for (auto &bcast_name : to_bcast_set) {
"concat]."); CreateBroadcastOp(result, bcast_name, dev_id);
}
}
} }
}
}
PADDLE_ENFORCE(op_dev_id != -1, int ReduceSSAGraphBuilder::GetOpDeviceID(
"can not find right place for distributed op: %s", ir::Node *node,
node->Op()->Type()); std::unordered_map<std::string, std::vector<ir::Node *>> *delay_ops) const {
if (!OpHaveRole(*node, framework::OpRole::kOptimize)) {
return -1;
}
auto param_grad = boost::get<std::vector<std::string>>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
int dev_id = GetVarDeviceID(param_grad[1]);
if (dev_id == -1) {
(*delay_ops)[param_grad[1]].push_back(node);
return -2;
}
return dev_id;
}
std::vector<ir::Node *> ReduceSSAGraphBuilder::SortOperations(
const ir::Graph &graph) const {
std::vector<ir::Node *> sorted_ops = ir::TopologySortOperations(graph);
return SortForReduceMode(sorted_ops);
}
std::vector<ir::Node *> ReduceSSAGraphBuilder::SortForReduceMode(
const std::vector<ir::Node *> &topo_ops) const {
std::vector<ir::Node *> sorted_ops;
std::unordered_map<std::string, std::vector<ir::Node *>> delayed_op;
sorted_ops.reserve(topo_ops.size());
ResetState();
auto insert_delayed_op = [&](const std::string &var_name, int dev_id) {
sharded_var_device_.emplace(var_name, dev_id);
if (delayed_op.count(var_name)) {
auto &ops = delayed_op.at(var_name);
sorted_ops.insert(sorted_ops.end(), ops.begin(), ops.end());
delayed_op.at(var_name).clear();
}
};
for (ir::Node *node : topo_ops) {
int op_dev_id = GetOpDeviceID(node, &delayed_op);
if (op_dev_id > -1) {
// This op only runs on one specific device.
sorted_ops.emplace_back(node);
for (ir::Node *n : node->outputs) {
insert_delayed_op(n->Name(), op_dev_id);
}
} else if (op_dev_id == -1) {
// This op runs on all devices, and its output may have parameter's
// gradients.
sorted_ops.emplace_back(node);
bool is_bk_op =
static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward));
if (!is_bk_op) continue;
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
std::vector<std::string> backward_vars;
try {
backward_vars =
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
} catch (boost::bad_get e) {
}
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
for (size_t i = 0; i < backward_vars.size(); i += 2) {
auto &g_name = backward_vars[i + 1];
size_t cur_device_id = GetAppropriateDeviceID({g_name});
insert_delayed_op(g_name, static_cast<int>(cur_device_id));
}
} else if (op_dev_id == -2) {
// The Op on which the Op depends has not yet been generated.
}
}
PADDLE_ENFORCE_EQ(sorted_ops.size(), topo_ops.size());
ResetState();
return sorted_ops;
}
void DistSSAGraphBuilder::Init() const {
MultiDevSSAGraphBuilderBase::Init();
ResetState();
}
void DistSSAGraphBuilder::ResetState() const {
BalanceVarSSAGraphBuilder::ResetState();
bcast_var_name_set_.clear();
bcast_var_name_set_.resize(places_.size());
}
bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
ir::Node *node) const {
bool insert_op = false;
if (OpHaveRole(*node, OpRole::kRPC)) {
int op_dev_id = CreateRPCOp(result, node);
PADDLE_ENFORCE(op_dev_id != -1,
"Can not schedule the RPC operator to the right place.");
if (node->Op()->Type() == "recv") {
auto recv_vars_attr =
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE(recv_vars_attr.size() == 2UL); // [parameter, gradient]
if (recv_vars_attr[0].find(".block") == std::string::npos) {
bcast_var_name_set_[op_dev_id].emplace(recv_vars_attr[0]);
}
}
insert_op = true;
need_broadcast_var_ = true;
} else if (OpHaveRole(*node, OpRole::kDist)) {
int op_dev_id = CreateDistTrainOp(result, node);
if (node->Op()->Type() == "concat") {
auto origin_param_name = node->Op()->OutputArgumentNames()[0];
bcast_var_name_set_[op_dev_id].emplace(origin_param_name);
}
insert_op = true;
} else {
int op_dev_id = GetOpDeviceID(node);
if (op_dev_id != -1) { // This op only runs on one specific device.
CreateComputationalOp(result, node, op_dev_id); CreateComputationalOp(result, node, op_dev_id);
return op_dev_id; for (ir::Node *n : node->outputs) {
sharded_var_device_.emplace(n->Name(), op_dev_id);
}
insert_op = true;
}
}
return insert_op;
} }
void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) { void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
...@@ -775,13 +768,11 @@ void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) { ...@@ -775,13 +768,11 @@ void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
} }
// Create RPC related op handles that connects its in ops and out ops. // Create RPC related op handles that connects its in ops and out ops.
int MultiDevSSAGraphBuilder::CreateRPCOp( int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const {
ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const {
int op_dev_id = -1; int op_dev_id = -1;
if (node->Op()->Type() == "send") { if (node->Op()->Type() == "send") {
// TODO(paddle-dev): getting the first var is not safe. // TODO(paddle-dev): getting the first var is not safe.
op_dev_id = GetVarDeviceID(node->inputs[0]->Name(), *sharded_var_device); op_dev_id = GetVarDeviceID(node->inputs[0]->Name());
PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]), PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]),
"This hack no longer holds, please fix."); "This hack no longer holds, please fix.");
// the variable name which contains .block means it was splited by // the variable name which contains .block means it was splited by
...@@ -799,9 +790,9 @@ int MultiDevSSAGraphBuilder::CreateRPCOp( ...@@ -799,9 +790,9 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
VLOG(10) << "send grad " << input_var_names[0] << " origin " VLOG(10) << "send grad " << input_var_names[0] << " origin "
<< send_param_grad[1] << " place: " << op_dev_id; << send_param_grad[1] << " place: " << op_dev_id;
for (auto &varname : input_var_names) { for (auto &varname : input_var_names) {
sharded_var_device->emplace(varname, op_dev_id); sharded_var_device_.emplace(varname, op_dev_id);
} }
sharded_var_device->emplace(send_param_grad[1], op_dev_id); sharded_var_device_.emplace(send_param_grad[1], op_dev_id);
} }
} else if (node->Op()->Type() == "recv") { } else if (node->Op()->Type() == "recv") {
std::vector<std::string> output_var_names; std::vector<std::string> output_var_names;
...@@ -811,7 +802,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp( ...@@ -811,7 +802,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
auto recv_param_grad = boost::get<std::vector<std::string>>( auto recv_param_grad = boost::get<std::vector<std::string>>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
if (recv_param_grad.size() == 2U) { if (recv_param_grad.size() == 2U) {
op_dev_id = GetVarDeviceID(recv_param_grad[1], *sharded_var_device); op_dev_id = GetVarDeviceID(recv_param_grad[1]);
VLOG(10) << "recv param " << recv_param_grad[0] VLOG(10) << "recv param " << recv_param_grad[0]
<< " get grad place: " << recv_param_grad[1] << " get grad place: " << recv_param_grad[1]
<< " place: " << op_dev_id; << " place: " << op_dev_id;
...@@ -819,7 +810,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp( ...@@ -819,7 +810,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
op_dev_id = GetAppropriateDeviceID(output_var_names); op_dev_id = GetAppropriateDeviceID(output_var_names);
} }
for (auto &varname : output_var_names) { for (auto &varname : output_var_names) {
sharded_var_device->emplace(varname, op_dev_id); sharded_var_device_.emplace(varname, op_dev_id);
} }
} else { } else {
// send_barrier, fetch_barrier will run on place 0; // send_barrier, fetch_barrier will run on place 0;
...@@ -846,7 +837,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp( ...@@ -846,7 +837,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
for (ir::Node *output : node->outputs) { for (ir::Node *output : node->outputs) {
int outvar_dev_id = op_dev_id; int outvar_dev_id = op_dev_id;
if (node->Op()->Type() == "fetch_barrier") { if (node->Op()->Type() == "fetch_barrier") {
outvar_dev_id = GetVarDeviceID(output->Name(), *sharded_var_device); outvar_dev_id = GetVarDeviceID(output->Name());
PADDLE_ENFORCE_NE(outvar_dev_id, -1, "output name %s", output->Name()); PADDLE_ENFORCE_NE(outvar_dev_id, -1, "output name %s", output->Name());
} }
p = places_[outvar_dev_id]; p = places_[outvar_dev_id];
...@@ -863,29 +854,124 @@ int MultiDevSSAGraphBuilder::CreateRPCOp( ...@@ -863,29 +854,124 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
return op_dev_id; return op_dev_id;
} }
bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const { int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
PADDLE_ENFORCE(all_vars_.count(og) != 0); ir::Node *node) const {
if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) { int op_dev_id = -1;
return true; std::vector<std::string> input_var_names;
std::vector<std::string> output_var_names;
for (ir::Node *input : node->inputs) {
input_var_names.push_back(input->Name());
} }
return false; for (ir::Node *output : node->outputs) {
output_var_names.push_back(output->Name());
}
if (node->Op()->Type() == "split_byref" ||
node->Op()->Type() == "split_selected_rows" ||
node->Op()->Type() == "split_ids") {
// TODO(paddle-dev): getting the first var is not safe.
op_dev_id = GetVarDeviceID(input_var_names[0]);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
op_dev_id = GetAppropriateDeviceID(input_var_names);
for (auto &varname : input_var_names) {
sharded_var_device_.emplace(varname, op_dev_id);
}
}
for (auto &varname : output_var_names) {
sharded_var_device_.emplace(varname, op_dev_id);
}
} else if (node->Op()->Type() == "concat") {
op_dev_id = GetVarDeviceID(input_var_names[0]);
for (auto &varname : output_var_names) {
sharded_var_device_.emplace(varname, op_dev_id);
}
} else {
LOG(ERROR) << "got unexpected dist op: " << node->Op()->Type();
PADDLE_THROW(
"the distribute training related op should be in [split_byref, "
"concat].");
}
PADDLE_ENFORCE(op_dev_id != -1,
"can not find right place for distributed op: %s",
node->Op()->Type());
CreateComputationalOp(result, node, op_dev_id);
return op_dev_id;
} }
bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const { void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
return boost::get<int>( const std::string &p_name,
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == const std::string &g_name) const {
(static_cast<int>(OpRole::kBackward) | size_t cur_device_id = 0;
static_cast<int>(OpRole::kLoss)) && switch (strategy_.reduce_) {
!loss_var_name_.empty(); // If loss_var is empty. This is test mode case BuildStrategy::ReduceStrategy::kReduce:
cur_device_id = GetAppropriateDeviceID({g_name});
CreateReduceOp(result, g_name, cur_device_id);
sharded_var_device_.emplace(g_name, cur_device_id);
break;
case BuildStrategy::ReduceStrategy::kAllReduce:
if (IsSparseGradient(g_name)) {
CreateReduceOp(result, g_name, 0);
CreateBroadcastOp(result, g_name, 0);
} else {
CreateAllReduceOp(result, g_name);
}
break;
default:
LOG(FATAL) << "Unknown reduce strategy.";
break;
}
} }
void DistSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const {
if (need_broadcast_var_ ||
(UseGPU() &&
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce)) {
if (strategy_.fuse_broadcast_op_) {
CreateFusedBroadcastOp(result, bcast_var_name_set_);
} else {
for (size_t dev_id = 0; dev_id < bcast_var_name_set_.size(); ++dev_id) {
auto &to_bcast_set = bcast_var_name_set_[dev_id];
for (auto &bcast_name : to_bcast_set) {
CreateBroadcastOp(result, bcast_name, dev_id);
}
}
}
}
}
std::unordered_set<std::string> &MultiDevSSAGraphBuilder() {
static std::unordered_set<std::string> regs;
return regs;
}
static int MultiDevSSAGraphBuilderRegister(const std::string &builder_mode) {
MultiDevSSAGraphBuilder().insert(builder_mode);
return 0;
}
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(multi_devices_pass, #define REGISTER_MULTI_DEVICES_PASS(pass_name, pass_class) \
paddle::framework::details::MultiDevSSAGraphBuilder) STATIC_ASSERT_GLOBAL_NAMESPACE( \
.RequirePassAttr(paddle::framework::details::kLossVarName) _reg_ssa_graph_builder_##pass_name, \
.RequirePassAttr(paddle::framework::details::kPlaces) "REGISTER_MULTI_DEVICES_PASS must be called in global namespace."); \
.RequirePassAttr(paddle::framework::details::kLocalScopes) int _reg_ssa_graph_builder_entry_##pass_name = \
.RequirePassAttr(paddle::framework::details::kStrategy) paddle::framework::details::MultiDevSSAGraphBuilderRegister(#pass_name); \
.RequirePassAttr(paddle::framework::details::kNRanks); REGISTER_PASS(pass_name, pass_class) \
.RequirePassAttr(paddle::framework::details::kLossVarName) \
.RequirePassAttr(paddle::framework::details::kPlaces) \
.RequirePassAttr(paddle::framework::details::kLocalScopes) \
.RequirePassAttr(paddle::framework::details::kStrategy) \
.RequirePassAttr(paddle::framework::details::kNRanks)
REGISTER_MULTI_DEVICES_PASS(reduce_mode_multi_devices_pass,
paddle::framework::details::ReduceSSAGraphBuilder);
REGISTER_MULTI_DEVICES_PASS(
allreduce_mode_multi_devices_pass,
paddle::framework::details::AllReduceSSAGraphBuilder);
REGISTER_MULTI_DEVICES_PASS(dist_multi_devices_pass,
paddle::framework::details::DistSSAGraphBuilder);
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -30,78 +31,70 @@ namespace framework { ...@@ -30,78 +31,70 @@ namespace framework {
class Scope; class Scope;
namespace details { namespace details {
class MultiDevSSAGraphBuilder : public ir::Pass { constexpr char kLossVarName[] = "loss_var_name";
constexpr char kPlaces[] = "places";
constexpr char kLocalScopes[] = "local_scopes";
constexpr char kStrategy[] = "strategy";
constexpr char kNRanks[] = "nranks";
class MultiDevSSAGraphBuilderBase : public ir::Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override; std::unique_ptr<ir::Graph> graph) const override;
private: virtual void Init() const;
void CreateOpHandleIOs(ir::Graph *result, ir::Node *node,
size_t device_id) const;
void Init() const;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) virtual std::vector<ir::Node *> SortOperations(const ir::Graph &graph) const;
mutable platform::NCCLContextMap *nccl_ctxs_;
#endif
int GetVarDeviceID( virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
const std::string &varname, const std::string &g_name) const = 0;
const std::unordered_map<std::string, int> &sharded_var_device) const;
bool IsScaleLossOp(ir::Node *node) const; virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const = 0;
virtual void InsertPostprocessOps(ir::Graph *result) const = 0;
int CreateRPCOp( bool UseGPU() const;
ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const; bool NeedCollectiveOps() const;
int CreateDistTrainOp(
ir::Graph *result, ir::Node *node, bool IsScaleLossOp(ir::Node *node) const;
std::unordered_map<std::string, int> *sharded_var_device) const;
void CreateComputationalOps(ir::Graph *result, ir::Node *node, void CreateComputationalOps(ir::Graph *result, ir::Node *node,
size_t num_places) const; size_t num_places) const;
void CreateScaleLossGradOp(ir::Graph *result, void CreateScaleLossGradOp(ir::Graph *result,
const std::string &loss_grad_name, const std::string &loss_grad_name,
ir::Node *out_var_node, ir::Node *out_var_node, size_t loss_scale,
proto::VarType::Type dtype) const; proto::VarType::Type dtype) const;
VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og, VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
int dst_dev_id) const; int dst_dev_id) const;
void CreateComputationalOp(ir::Graph *result, ir::Node *node, void CreateComputationalOp(ir::Graph *result, ir::Node *node,
int dev_id) const; int dev_id) const;
int GetOpDeviceID( bool IsSparseGradient(const std::string &og) const;
ir::Node *node,
const std::unordered_map<std::string, int> &sharded_var_device) const;
void InsertAllReduceOp(ir::Graph *result, const std::string &og) const;
void InsertDataBalanceOp(ir::Graph *result, void CreateAllReduceOp(ir::Graph *result, const std::string &og) const;
const std::vector<std::string> &datas) const;
void CreateBroadcastOp(ir::Graph *result, const std::string &p_name, void CreateBroadcastOp(ir::Graph *result, const std::string &p_name,
size_t src_dev_id) const; size_t src_dev_id) const;
void InsertScaleLossGradOp(ir::Graph *result, const ir::Node *node) const;
void CreateFusedBroadcastOp( void CreateFusedBroadcastOp(
ir::Graph *result, ir::Graph *result,
const std::vector<std::unordered_set<std::string>> &bcast_varnames) const; const std::vector<std::unordered_set<std::string>> &bcast_varnames) const;
bool IsSparseGradient(const std::string &og) const;
size_t GetAppropriateDeviceID(
const std::vector<std::string> &var_names) const;
void SetCommunicationContext(OpHandleBase *op_handle, void SetCommunicationContext(OpHandleBase *op_handle,
const platform::Place &p) const; const platform::Place &p) const;
std::vector<ir::Node *> SortForReduceMode( void CreateOpHandleIOs(ir::Graph *result, ir::Node *node,
const std::vector<ir::Node *> &) const; size_t device_id) const;
int GetOpDeviceID( #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
ir::Node *node, mutable platform::NCCLContextMap *nccl_ctxs_;
const std::unordered_map<std::string, int> &shared_var_device, #endif
std::unordered_map<std::string, std::vector<ir::Node *>> *delay_ops)
const;
mutable std::string loss_var_name_; mutable std::string loss_var_name_;
mutable std::vector<platform::Place> places_; mutable std::vector<platform::Place> places_;
...@@ -109,8 +102,83 @@ class MultiDevSSAGraphBuilder : public ir::Pass { ...@@ -109,8 +102,83 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
mutable BuildStrategy strategy_; mutable BuildStrategy strategy_;
mutable std::unordered_map<std::string, VarDesc *> all_vars_; mutable std::unordered_map<std::string, VarDesc *> all_vars_;
};
class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
protected:
virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
const std::string &g_name) const;
virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const {
return false;
}
virtual void InsertPostprocessOps(ir::Graph *result) const {}
};
class BalanceVarSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
protected:
int GetVarDeviceID(const std::string &varname) const;
int GetOpDeviceID(ir::Node *node) const;
size_t GetAppropriateDeviceID(
const std::vector<std::string> &var_names) const;
virtual void ResetState() const;
mutable std::unordered_map<std::string, int> sharded_var_device_;
mutable std::vector<int64_t> balance_vars_; mutable std::vector<int64_t> balance_vars_;
}; };
class ReduceSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
protected:
virtual void Init() const;
virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
const std::string &g_name) const;
virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const;
virtual void InsertPostprocessOps(ir::Graph *result) const;
virtual std::vector<ir::Node *> SortOperations(const ir::Graph &graph) const;
virtual void ResetState() const;
int GetOpDeviceID(ir::Node *node,
std::unordered_map<std::string, std::vector<ir::Node *>>
*delay_ops) const;
std::vector<ir::Node *> SortForReduceMode(
const std::vector<ir::Node *> &topo_ops) const;
mutable std::vector<std::unordered_set<std::string>> bcast_var_name_set_;
};
class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
protected:
virtual void Init() const;
virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const;
virtual void InsertPostprocessOps(ir::Graph *result) const;
virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
const std::string &g_name) const;
virtual void ResetState() const;
int CreateRPCOp(ir::Graph *result, ir::Node *node) const;
int CreateDistTrainOp(ir::Graph *result, ir::Node *node) const;
mutable std::vector<std::unordered_set<std::string>> bcast_var_name_set_;
mutable bool need_broadcast_var_{false};
};
std::unordered_set<std::string> &MultiDevSSAGraphBuilder();
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -946,13 +946,6 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -946,13 +946,6 @@ All parameter, weight, gradient are variables in Paddle.
R"DOC(The type is STR, debug_graphviz_path indicate the path that R"DOC(The type is STR, debug_graphviz_path indicate the path that
writing the SSA Graph to file in the form of graphviz, you. writing the SSA Graph to file in the form of graphviz, you.
It is useful for debugging. Default "")DOC") It is useful for debugging. Default "")DOC")
.def_property(
"enable_data_balance",
[](const BuildStrategy &self) { return self.enable_data_balance_; },
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized.");
self.enable_data_balance_ = b;
}) // FIXME(chengudo): enable_data_balance seems not important
.def_property( .def_property(
"enable_sequential_execution", "enable_sequential_execution",
[](const BuildStrategy &self) { [](const BuildStrategy &self) {
...@@ -1007,6 +1000,10 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1007,6 +1000,10 @@ All parameter, weight, gradient are variables in Paddle.
"memory_optimize", "memory_optimize",
[](const BuildStrategy &self) { return self.memory_optimize_; }, [](const BuildStrategy &self) { return self.memory_optimize_; },
[](BuildStrategy &self, bool b) { self.memory_optimize_ = b; }) [](BuildStrategy &self, bool b) { self.memory_optimize_ = b; })
.def_property(
"is_distribution",
[](const BuildStrategy &self) { return self.is_distribution_; },
[](BuildStrategy &self, bool b) { self.is_distribution_ = b; })
.def_property( .def_property(
"memory_early_delete", "memory_early_delete",
[](const BuildStrategy &self) { return self.memory_early_delete_; }, [](const BuildStrategy &self) { return self.memory_early_delete_; },
......
...@@ -29,6 +29,15 @@ ExecutionStrategy = core.ParallelExecutor.ExecutionStrategy ...@@ -29,6 +29,15 @@ ExecutionStrategy = core.ParallelExecutor.ExecutionStrategy
BuildStrategy = core.ParallelExecutor.BuildStrategy BuildStrategy = core.ParallelExecutor.BuildStrategy
def _is_pserver_mode(main_program):
main = main_program if main_program \
else framework.default_main_program()
for op in main.global_block().ops:
if op.type in ["send", "recv"]:
return True
return False
class ParallelExecutor(object): class ParallelExecutor(object):
""" """
ParallelExecutor is designed for data parallelism, which focuses on distributing ParallelExecutor is designed for data parallelism, which focuses on distributing
...@@ -128,6 +137,11 @@ class ParallelExecutor(object): ...@@ -128,6 +137,11 @@ class ParallelExecutor(object):
build_strategy = BuildStrategy() build_strategy = BuildStrategy()
build_strategy.num_trainers = num_trainers build_strategy.num_trainers = num_trainers
build_strategy.trainer_id = trainer_id build_strategy.trainer_id = trainer_id
# FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode,
# num_trainers is 1, so the current fields of build_strategy doesn't tell if
# it's distributed model.
build_strategy.is_distribution = _is_pserver_mode(
main_program) or num_trainers > 1
# step4: get main_program, scope, local_scopes # step4: get main_program, scope, local_scopes
main = main_program if main_program \ main = main_program if main_program \
......
...@@ -75,8 +75,6 @@ class TestReaderReset(unittest.TestCase): ...@@ -75,8 +75,6 @@ class TestReaderReset(unittest.TestCase):
exe.run(startup_prog) exe.run(startup_prog)
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
if with_double_buffer:
build_strategy.enable_data_balance = True
exec_strategy = fluid.ExecutionStrategy() exec_strategy = fluid.ExecutionStrategy()
parallel_exe = fluid.ParallelExecutor( parallel_exe = fluid.ParallelExecutor(
use_cuda=self.use_cuda, use_cuda=self.use_cuda,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册