未验证 提交 eabb2105 编写于 作者: C chengduo 提交者: GitHub

Refactor MultiDevSSAGraphBuilder (#15090)

* Refactor ParallelExecutor
test=develop

* extract Reduce and AllReduce mode from MultiDevSSAGraphBuilder
test=develop

* Refactor MultiDevSSAGraphBuilder
test=developt

* Remove enable_data_balance
test=develop

* code refine
test=develop

* remove data balance
test=develop

* refine ScaleLossGradOp
test=develop

* remove uncessary file
test=develop

* code refine
test=develop

* modify  function name
test=develop

* follow comments
test=develop

* add is_distribution field
test=develop

* set is_distribution
test=develop

* fix DistSSAGraphBuilder
test=develop
上级 875a07c3
......@@ -18,7 +18,7 @@ limitations under the License. */
#include <memory>
#include "paddle/fluid/framework/details/memory_reuse_types.h"
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
......@@ -86,10 +86,8 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
if (strategy.memory_optimize_) {
auto analysis_var_pass = AppendPass("analysis_var_pass");
}
// Convert graph to run on multi-devices.
auto multi_devices_pass = AppendPass("multi_devices_pass");
multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy",
&strategy_);
AppendMultiDevPass(strategy);
// Add a graph print pass to record a graph with device info.
if (!strategy_.debug_graphviz_path_.empty()) {
......@@ -115,6 +113,25 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
}
}
// Convert graph to run on multi-devices.
void AppendMultiDevPass(const BuildStrategy &strategy) {
ir::Pass *multi_devices_pass;
if (strategy_.is_distribution_) {
multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
} else {
if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
multi_devices_pass =
AppendPass("allreduce_mode_multi_devices_pass").get();
} else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get();
} else {
PADDLE_THROW("Unknown reduce strategy.");
}
}
multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy",
&strategy_);
}
private:
BuildStrategy strategy_;
};
......@@ -131,6 +148,10 @@ std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy(
return pass_builder_;
}
bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const {
return framework::details::MultiDevSSAGraphBuilder().count(pass_name) > 0;
}
std::unique_ptr<ir::Graph> BuildStrategy::Apply(
const ProgramDesc &main_program, const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::vector<Scope *> &local_scopes,
......@@ -145,22 +166,23 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) {
if (pass->Type() == "multi_devices_pass") {
pass->Erase("places");
pass->SetNotOwned<const std::vector<platform::Place>>("places", &places);
pass->Erase("loss_var_name");
pass->SetNotOwned<const std::string>("loss_var_name", &loss_var_name);
pass->Erase("local_scopes");
pass->SetNotOwned<const std::vector<Scope *>>("local_scopes",
if (IsMultiDevPass(pass->Type())) {
pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
pass->Erase(kLossVarName);
pass->SetNotOwned<const std::string>(kLossVarName, &loss_var_name);
pass->Erase(kLocalScopes);
pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
&local_scopes);
pass->Erase("nranks");
pass->Set<size_t>("nranks", new size_t(nranks));
pass->Erase(kNRanks);
pass->Set<size_t>(kNRanks, new size_t(nranks));
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
pass->Erase("nccl_ctxs");
pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
#endif
} else if (pass->Type() == "analysis_var_pass") {
const std::vector<OpDesc *> *all_op_descs =
new std::vector<OpDesc *>(main_program.Block(0).AllOps());
......@@ -201,7 +223,9 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
USE_PASS(fuse_elewise_add_act_pass);
USE_PASS(graph_viz_pass);
USE_PASS(multi_batch_merge_pass);
USE_PASS(multi_devices_pass);
USE_PASS(reduce_mode_multi_devices_pass);
USE_PASS(allreduce_mode_multi_devices_pass);
USE_PASS(dist_multi_devices_pass);
USE_PASS(multi_devices_check_pass);
USE_PASS(multi_devices_print_pass);
USE_PASS(analysis_var_pass);
......
......@@ -74,8 +74,6 @@ struct BuildStrategy {
bool fuse_elewise_add_act_ops_{false};
bool enable_data_balance_{false};
bool memory_optimize_{false};
bool memory_early_delete_{false};
......@@ -84,6 +82,10 @@ struct BuildStrategy {
bool fuse_broadcast_op_{false};
// FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode,
// num_trainers is 1, so the current fields of build_strategy doesn't tell if
// it's distributed model.
bool is_distribution_{false};
int num_trainers_{1};
int trainer_id_{0};
std::vector<std::string> trainers_endpoints_;
......@@ -104,6 +106,8 @@ struct BuildStrategy {
bool IsFinalized() const { return is_finalized_; }
bool IsMultiDevPass(const std::string &pass_name) const;
// Apply the passes built by the pass_builder_. The passes will be
// applied to the Program and output an ir::Graph.
std::unique_ptr<ir::Graph> Apply(const ProgramDesc &main_program,
......
......@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
#include <string>
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
......@@ -21,7 +21,15 @@ namespace paddle {
namespace framework {
namespace details {
bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
class SSAGraghBuilderWithChecker : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
PADDLE_ENFORCE(IsValidGraph(graph.get()));
return graph;
}
bool IsValidGraph(const ir::Graph *graph) const {
std::unordered_map<OpHandleBase *, size_t> pending_ops;
std::unordered_set<VarHandleBase *> pending_vars;
std::unordered_set<VarHandleBase *> ready_vars;
......@@ -82,7 +90,9 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
ready_vars.clear();
}
return true;
}
}
};
} // namespace details
} // namespace framework
} // namespace paddle
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include <string>
namespace paddle {
namespace framework {
namespace details {
class SSAGraghBuilderWithChecker : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
PADDLE_ENFORCE(IsValidGraph(graph.get()));
return graph;
}
bool IsValidGraph(const ir::Graph* graph) const;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include <string>
#include <utility>
#include <vector>
......@@ -30,78 +31,70 @@ namespace framework {
class Scope;
namespace details {
class MultiDevSSAGraphBuilder : public ir::Pass {
constexpr char kLossVarName[] = "loss_var_name";
constexpr char kPlaces[] = "places";
constexpr char kLocalScopes[] = "local_scopes";
constexpr char kStrategy[] = "strategy";
constexpr char kNRanks[] = "nranks";
class MultiDevSSAGraphBuilderBase : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
private:
void CreateOpHandleIOs(ir::Graph *result, ir::Node *node,
size_t device_id) const;
void Init() const;
virtual void Init() const;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
mutable platform::NCCLContextMap *nccl_ctxs_;
#endif
virtual std::vector<ir::Node *> SortOperations(const ir::Graph &graph) const;
int GetVarDeviceID(
const std::string &varname,
const std::unordered_map<std::string, int> &sharded_var_device) const;
virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
const std::string &g_name) const = 0;
bool IsScaleLossOp(ir::Node *node) const;
virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const = 0;
virtual void InsertPostprocessOps(ir::Graph *result) const = 0;
int CreateRPCOp(
ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const;
int CreateDistTrainOp(
ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const;
bool UseGPU() const;
bool NeedCollectiveOps() const;
bool IsScaleLossOp(ir::Node *node) const;
void CreateComputationalOps(ir::Graph *result, ir::Node *node,
size_t num_places) const;
void CreateScaleLossGradOp(ir::Graph *result,
const std::string &loss_grad_name,
ir::Node *out_var_node,
ir::Node *out_var_node, size_t loss_scale,
proto::VarType::Type dtype) const;
VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
int dst_dev_id) const;
void CreateComputationalOp(ir::Graph *result, ir::Node *node,
int dev_id) const;
int GetOpDeviceID(
ir::Node *node,
const std::unordered_map<std::string, int> &sharded_var_device) const;
void InsertAllReduceOp(ir::Graph *result, const std::string &og) const;
bool IsSparseGradient(const std::string &og) const;
void InsertDataBalanceOp(ir::Graph *result,
const std::vector<std::string> &datas) const;
void CreateAllReduceOp(ir::Graph *result, const std::string &og) const;
void CreateBroadcastOp(ir::Graph *result, const std::string &p_name,
size_t src_dev_id) const;
void InsertScaleLossGradOp(ir::Graph *result, const ir::Node *node) const;
void CreateFusedBroadcastOp(
ir::Graph *result,
const std::vector<std::unordered_set<std::string>> &bcast_varnames) const;
bool IsSparseGradient(const std::string &og) const;
size_t GetAppropriateDeviceID(
const std::vector<std::string> &var_names) const;
void SetCommunicationContext(OpHandleBase *op_handle,
const platform::Place &p) const;
std::vector<ir::Node *> SortForReduceMode(
const std::vector<ir::Node *> &) const;
void CreateOpHandleIOs(ir::Graph *result, ir::Node *node,
size_t device_id) const;
int GetOpDeviceID(
ir::Node *node,
const std::unordered_map<std::string, int> &shared_var_device,
std::unordered_map<std::string, std::vector<ir::Node *>> *delay_ops)
const;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
mutable platform::NCCLContextMap *nccl_ctxs_;
#endif
mutable std::string loss_var_name_;
mutable std::vector<platform::Place> places_;
......@@ -109,8 +102,83 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
mutable BuildStrategy strategy_;
mutable std::unordered_map<std::string, VarDesc *> all_vars_;
};
class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
protected:
virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
const std::string &g_name) const;
virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const {
return false;
}
virtual void InsertPostprocessOps(ir::Graph *result) const {}
};
class BalanceVarSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
protected:
int GetVarDeviceID(const std::string &varname) const;
int GetOpDeviceID(ir::Node *node) const;
size_t GetAppropriateDeviceID(
const std::vector<std::string> &var_names) const;
virtual void ResetState() const;
mutable std::unordered_map<std::string, int> sharded_var_device_;
mutable std::vector<int64_t> balance_vars_;
};
class ReduceSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
protected:
virtual void Init() const;
virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
const std::string &g_name) const;
virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const;
virtual void InsertPostprocessOps(ir::Graph *result) const;
virtual std::vector<ir::Node *> SortOperations(const ir::Graph &graph) const;
virtual void ResetState() const;
int GetOpDeviceID(ir::Node *node,
std::unordered_map<std::string, std::vector<ir::Node *>>
*delay_ops) const;
std::vector<ir::Node *> SortForReduceMode(
const std::vector<ir::Node *> &topo_ops) const;
mutable std::vector<std::unordered_set<std::string>> bcast_var_name_set_;
};
class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
protected:
virtual void Init() const;
virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const;
virtual void InsertPostprocessOps(ir::Graph *result) const;
virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
const std::string &g_name) const;
virtual void ResetState() const;
int CreateRPCOp(ir::Graph *result, ir::Node *node) const;
int CreateDistTrainOp(ir::Graph *result, ir::Node *node) const;
mutable std::vector<std::unordered_set<std::string>> bcast_var_name_set_;
mutable bool need_broadcast_var_{false};
};
std::unordered_set<std::string> &MultiDevSSAGraphBuilder();
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -946,13 +946,6 @@ All parameter, weight, gradient are variables in Paddle.
R"DOC(The type is STR, debug_graphviz_path indicate the path that
writing the SSA Graph to file in the form of graphviz, you.
It is useful for debugging. Default "")DOC")
.def_property(
"enable_data_balance",
[](const BuildStrategy &self) { return self.enable_data_balance_; },
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized.");
self.enable_data_balance_ = b;
}) // FIXME(chengudo): enable_data_balance seems not important
.def_property(
"enable_sequential_execution",
[](const BuildStrategy &self) {
......@@ -1007,6 +1000,10 @@ All parameter, weight, gradient are variables in Paddle.
"memory_optimize",
[](const BuildStrategy &self) { return self.memory_optimize_; },
[](BuildStrategy &self, bool b) { self.memory_optimize_ = b; })
.def_property(
"is_distribution",
[](const BuildStrategy &self) { return self.is_distribution_; },
[](BuildStrategy &self, bool b) { self.is_distribution_ = b; })
.def_property(
"memory_early_delete",
[](const BuildStrategy &self) { return self.memory_early_delete_; },
......
......@@ -29,6 +29,15 @@ ExecutionStrategy = core.ParallelExecutor.ExecutionStrategy
BuildStrategy = core.ParallelExecutor.BuildStrategy
def _is_pserver_mode(main_program):
main = main_program if main_program \
else framework.default_main_program()
for op in main.global_block().ops:
if op.type in ["send", "recv"]:
return True
return False
class ParallelExecutor(object):
"""
ParallelExecutor is designed for data parallelism, which focuses on distributing
......@@ -128,6 +137,11 @@ class ParallelExecutor(object):
build_strategy = BuildStrategy()
build_strategy.num_trainers = num_trainers
build_strategy.trainer_id = trainer_id
# FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode,
# num_trainers is 1, so the current fields of build_strategy doesn't tell if
# it's distributed model.
build_strategy.is_distribution = _is_pserver_mode(
main_program) or num_trainers > 1
# step4: get main_program, scope, local_scopes
main = main_program if main_program \
......
......@@ -75,8 +75,6 @@ class TestReaderReset(unittest.TestCase):
exe.run(startup_prog)
build_strategy = fluid.BuildStrategy()
if with_double_buffer:
build_strategy.enable_data_balance = True
exec_strategy = fluid.ExecutionStrategy()
parallel_exe = fluid.ParallelExecutor(
use_cuda=self.use_cuda,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册