提交 e4d7d7ae 编写于 作者: X Xin Pan

pass refactoring

上级 142e832d
...@@ -244,6 +244,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply( ...@@ -244,6 +244,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
result.Set("vars", new GraphVars(places_.size())); result.Set("vars", new GraphVars(places_.size()));
result.Set("dep_vars", new GraphDepVars); result.Set("dep_vars", new GraphDepVars);
result.Set("ops", new GraphOps); result.Set("ops", new GraphOps);
result.Set("sharded_var_device", new ShardedVarDevice);
// find send/recv vars so that we can place the distributed training // find send/recv vars so that we can place the distributed training
// realted op in the place 0 // realted op in the place 0
...@@ -276,11 +277,12 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply( ...@@ -276,11 +277,12 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
// the block. // the block.
is_forwarding = false; is_forwarding = false;
} else { } else {
int op_dev_id = GetOpDeviceID(node); int op_dev_id = GetOpDeviceID(result, node);
if (op_dev_id != -1) { // This op only runs on one specific device. if (op_dev_id != -1) { // This op only runs on one specific device.
CreateComputationalOp(&result, node, op_dev_id); CreateComputationalOp(&result, node, op_dev_id);
for (ir::Node *n : node->outputs) { for (ir::Node *n : node->outputs) {
var_name_on_devices_.emplace(n->Name(), op_dev_id); graph->Get<ShardedVarDevice>("sharded_var_device")
.emplace(n->Name(), op_dev_id);
} }
} else { } else {
// This op runs on all devices, and its output may have parameter's // This op runs on all devices, and its output may have parameter's
...@@ -317,7 +319,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply( ...@@ -317,7 +319,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
case BuildStrategy::ReduceStrategy::kReduce: case BuildStrategy::ReduceStrategy::kReduce:
cur_device_id = GetAppropriateDeviceID({g_name}); cur_device_id = GetAppropriateDeviceID({g_name});
CreateReduceOp(&result, g_name, cur_device_id); CreateReduceOp(&result, g_name, cur_device_id);
var_name_on_devices_.emplace(g_name, cur_device_id); graph->Get<ShardedVarDevice>("sharded_var_device")
.emplace(g_name, cur_device_id);
bcast_var_name_set[cur_device_id].emplace(p_name); bcast_var_name_set[cur_device_id].emplace(p_name);
break; break;
case BuildStrategy::ReduceStrategy::kAllReduce: case BuildStrategy::ReduceStrategy::kAllReduce:
...@@ -499,7 +502,8 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce( ...@@ -499,7 +502,8 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
return is_pg_once; return is_pg_once;
} }
int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const { int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph,
ir::Node *node) const {
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) { if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
return -1; return -1;
} }
...@@ -512,15 +516,17 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const { ...@@ -512,15 +516,17 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const {
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(param_grad.size(), 2U); PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
int dev_id = GetVarDeviceID(param_grad[1]); int dev_id = GetVarDeviceID(graph, param_grad[1]);
PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]", PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]",
node->Op()->Type(), param_grad[0], param_grad[1]); node->Op()->Type(), param_grad[0], param_grad[1]);
return dev_id; return dev_id;
} }
int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const { int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph,
auto got = var_name_on_devices_.find(varname); const std::string &varname) const {
return got == var_name_on_devices_.end() ? -1 : got->second; auto &sharded_var_device = graph.Get<ShardedVarDevice>("sharded_var_device");
auto got = sharded_var_device.find(varname);
return got == sharded_var_device.end() ? -1 : got->second;
} }
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const { void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const {
...@@ -625,20 +631,23 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, ...@@ -625,20 +631,23 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
if (node->Op()->Type() == "split_byref" || if (node->Op()->Type() == "split_byref" ||
node->Op()->Type() == "split_selected_rows") { node->Op()->Type() == "split_selected_rows") {
// TODO(paddle-dev): getting the first var is not safe. // TODO(paddle-dev): getting the first var is not safe.
op_dev_id = GetVarDeviceID(input_var_names[0]); op_dev_id = GetVarDeviceID(*result, input_var_names[0]);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
op_dev_id = GetAppropriateDeviceID(input_var_names); op_dev_id = GetAppropriateDeviceID(input_var_names);
for (auto &varname : input_var_names) { for (auto &varname : input_var_names) {
var_name_on_devices_.emplace(varname, op_dev_id); result->Get<ShardedVarDevice>("sharded_var_device")
.emplace(varname, op_dev_id);
} }
} }
for (auto &varname : output_var_names) { for (auto &varname : output_var_names) {
var_name_on_devices_.emplace(varname, op_dev_id); result->Get<ShardedVarDevice>("sharded_var_device")
.emplace(varname, op_dev_id);
} }
} else if (node->Op()->Type() == "concat") { } else if (node->Op()->Type() == "concat") {
op_dev_id = GetVarDeviceID(input_var_names[0]); op_dev_id = GetVarDeviceID(*result, input_var_names[0]);
for (auto &varname : output_var_names) { for (auto &varname : output_var_names) {
var_name_on_devices_.emplace(varname, op_dev_id); result->Get<ShardedVarDevice>("sharded_var_device")
.emplace(varname, op_dev_id);
} }
} else { } else {
PADDLE_ENFORCE( PADDLE_ENFORCE(
...@@ -663,7 +672,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -663,7 +672,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
int op_dev_id = -1; int op_dev_id = -1;
if (node->Op()->Type() == "send") { if (node->Op()->Type() == "send") {
// TODO(paddle-dev): getting the first var is not safe. // TODO(paddle-dev): getting the first var is not safe.
op_dev_id = GetVarDeviceID(node->inputs[0]->Name()); op_dev_id = GetVarDeviceID(*result, node->inputs[0]->Name());
PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]), PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]),
"This hack no longer holds, please fix."); "This hack no longer holds, please fix.");
// the variable name which contains .block means it was splited by // the variable name which contains .block means it was splited by
...@@ -678,7 +687,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -678,7 +687,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
} }
op_dev_id = GetAppropriateDeviceID(input_var_names); op_dev_id = GetAppropriateDeviceID(input_var_names);
for (auto &varname : input_var_names) { for (auto &varname : input_var_names) {
var_name_on_devices_.emplace(varname, op_dev_id); result->Get<ShardedVarDevice>("sharded_var_device")
.emplace(varname, op_dev_id);
} }
} }
} else if (node->Op()->Type() == "recv") { } else if (node->Op()->Type() == "recv") {
...@@ -688,7 +698,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -688,7 +698,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
} }
op_dev_id = GetAppropriateDeviceID(output_var_names); op_dev_id = GetAppropriateDeviceID(output_var_names);
for (auto &varname : output_var_names) { for (auto &varname : output_var_names) {
var_name_on_devices_.emplace(varname, op_dev_id); result->Get<ShardedVarDevice>("sharded_var_device")
.emplace(varname, op_dev_id);
} }
} else { } else {
// send_barrier and fetch_barrier op can be scheduled on device 0 // send_barrier and fetch_barrier op can be scheduled on device 0
...@@ -730,3 +741,6 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const { ...@@ -730,3 +741,6 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(multi_device_pass,
paddle::framework::details::MultiDevSSAGraphBuilder);
...@@ -34,7 +34,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -34,7 +34,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
public: public:
std::unique_ptr<ir::Graph> Apply( std::unique_ptr<ir::Graph> Apply(
std::unique_ptr<ir::Graph> graph) const override; std::unique_ptr<ir::Graph> graph) const override;
int GetVarDeviceID(const std::string &varname) const override;
private: private:
void CreateOpHandleIOs(ir::Graph *result, ir::Node *node, void CreateOpHandleIOs(ir::Graph *result, ir::Node *node,
...@@ -51,6 +50,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -51,6 +50,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
mutable platform::NCCLContextMap *nccl_ctxs_; mutable platform::NCCLContextMap *nccl_ctxs_;
#endif #endif
int GetVarDeviceID(const ir::Graph &graph, const std::string &varname) const;
bool IsScaleLossOp(ir::Node *node) const; bool IsScaleLossOp(ir::Node *node) const;
void CreateRPCOp(ir::Graph *result, ir::Node *node) const; void CreateRPCOp(ir::Graph *result, ir::Node *node) const;
...@@ -84,7 +85,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -84,7 +85,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::string &og, const std::string &og,
std::unordered_set<std::string> *og_has_been_broadcast) const; std::unordered_set<std::string> *og_has_been_broadcast) const;
int GetOpDeviceID(ir::Node *node) const; int GetOpDeviceID(const ir::Graph &graph, ir::Node *node) const;
void InsertAllReduceOp(ir::Graph *result, const std::string &og) const; void InsertAllReduceOp(ir::Graph *result, const std::string &og) const;
...@@ -102,7 +103,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -102,7 +103,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
private: private:
mutable BuildStrategy strategy_; mutable BuildStrategy strategy_;
mutable std::unordered_map<std::string, VarDesc *> all_vars_; mutable std::unordered_map<std::string, VarDesc *> all_vars_;
mutable std::unordered_map<std::string, int> var_name_on_devices_;
mutable std::vector<int64_t> balance_vars_; mutable std::vector<int64_t> balance_vars_;
void SetCommunicationContext(OpHandleBase *op_handle, void SetCommunicationContext(OpHandleBase *op_handle,
......
...@@ -40,6 +40,9 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -40,6 +40,9 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
ExecutionStrategy strategy, std::vector<Scope*> local_scopes, ExecutionStrategy strategy, std::vector<Scope*> local_scopes,
std::vector<VariableInfo> var_infos, std::vector<platform::Place> places, std::vector<VariableInfo> var_infos, std::vector<platform::Place> places,
std::unique_ptr<SSAGraphExecutor>&& underlying_executor); std::unique_ptr<SSAGraphExecutor>&& underlying_executor);
const ir::Graph& Graph() const { return underlying_executor_->Graph(); }
FeedFetchList Run(const std::vector<std::string>& fetch_tensors) override; FeedFetchList Run(const std::vector<std::string>& fetch_tensors) override;
private: private:
......
...@@ -47,13 +47,13 @@ typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars; ...@@ -47,13 +47,13 @@ typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars;
// unordered. // unordered.
typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps; typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps;
typedef std::unordered_map<std::string, int> ShardedVarDevice;
class SSAGraphBuilder : public ir::Pass { class SSAGraphBuilder : public ir::Pass {
public: public:
SSAGraphBuilder() {} SSAGraphBuilder() {}
virtual ~SSAGraphBuilder() {} virtual ~SSAGraphBuilder() {}
virtual int GetVarDeviceID(const std::string &var_name) const = 0;
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
protected: protected:
......
...@@ -21,8 +21,8 @@ ...@@ -21,8 +21,8 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
std::unique_ptr<SSAGraphBuilder> SSAGraphBuilderFactory::Create() { std::unique_ptr<ir::Pass> ParallelExecutorPassManager::Create() {
std::unique_ptr<SSAGraphBuilder> res(new MultiDevSSAGraphBuilder); std::unique_ptr<ir::Pass> res(new MultiDevSSAGraphBuilder);
res->SetNotOwned<std::vector<platform::Place>>("places", &places_); res->SetNotOwned<std::vector<platform::Place>>("places", &places_);
res->SetNotOwned<std::string>("loss_var_name", &loss_var_name_); res->SetNotOwned<std::string>("loss_var_name", &loss_var_name_);
res->SetNotOwned<std::unordered_set<std::string>>("params", &param_names_); res->SetNotOwned<std::unordered_set<std::string>>("params", &param_names_);
...@@ -33,18 +33,18 @@ std::unique_ptr<SSAGraphBuilder> SSAGraphBuilderFactory::Create() { ...@@ -33,18 +33,18 @@ std::unique_ptr<SSAGraphBuilder> SSAGraphBuilderFactory::Create() {
#endif #endif
if (!strategy_.debug_graphviz_path_.empty()) { if (!strategy_.debug_graphviz_path_.empty()) {
SSAGraphBuilder *previous_pass = res.release(); ir::Pass *previous_pass = res.release();
res.reset(new SSAGraghBuilderWithPrinter); res.reset(new SSAGraghBuilderWithPrinter);
res->Set<SSAGraphBuilder>("previous_pass", previous_pass); res->Set<ir::Pass>("previous_pass", previous_pass);
res->SetNotOwned<std::string>("debug_graphviz_path", res->SetNotOwned<std::string>("debug_graphviz_path",
&strategy_.debug_graphviz_path_); &strategy_.debug_graphviz_path_);
res->Set<GraphvizSSAGraphPrinter>("graph_printer", res->Set<GraphvizSSAGraphPrinter>("graph_printer",
new GraphvizSSAGraphPrinter); new GraphvizSSAGraphPrinter);
} }
SSAGraphBuilder *previous_pass = res.release(); ir::Pass *previous_pass = res.release();
res.reset(new SSAGraghBuilderWithChecker); res.reset(new SSAGraghBuilderWithChecker);
res->Set<SSAGraphBuilder>("previous_pass", previous_pass); res->Set<ir::Pass>("previous_pass", previous_pass);
return res; return res;
} }
......
...@@ -29,13 +29,13 @@ namespace framework { ...@@ -29,13 +29,13 @@ namespace framework {
class Scope; class Scope;
namespace details { namespace details {
class SSAGraphBuilderFactory { class ParallelExecutorPassManager {
public: public:
SSAGraphBuilderFactory(const std::vector<platform::Place>& places, ParallelExecutorPassManager(
const std::vector<platform::Place>& places,
const std::string& loss_var_name, const std::string& loss_var_name,
const std::unordered_set<std::string>& param_names, const std::unordered_set<std::string>& param_names,
const std::vector<Scope*>& local_scopes, const std::vector<Scope*>& local_scopes, const BuildStrategy& strategy)
const BuildStrategy& strategy)
: places_(places), : places_(places),
loss_var_name_(loss_var_name), loss_var_name_(loss_var_name),
param_names_(param_names), param_names_(param_names),
...@@ -52,7 +52,7 @@ class SSAGraphBuilderFactory { ...@@ -52,7 +52,7 @@ class SSAGraphBuilderFactory {
} }
#endif #endif
std::unique_ptr<SSAGraphBuilder> Create(); std::unique_ptr<ir::Pass> Create();
private: private:
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
......
...@@ -85,3 +85,6 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { ...@@ -85,3 +85,6 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(multi_device_check_pass,
paddle::framework::details::SSAGraghBuilderWithChecker);
...@@ -26,16 +26,11 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder { ...@@ -26,16 +26,11 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
public: public:
std::unique_ptr<ir::Graph> Apply( std::unique_ptr<ir::Graph> Apply(
std::unique_ptr<ir::Graph> graph) const override { std::unique_ptr<ir::Graph> graph) const override {
auto new_graph = auto new_graph = Get<ir::Pass>("previous_pass").Apply(std::move(graph));
Get<SSAGraphBuilder>("previous_pass").Apply(std::move(graph));
PADDLE_ENFORCE(IsValidGraph(new_graph.get())); PADDLE_ENFORCE(IsValidGraph(new_graph.get()));
return new_graph; return new_graph;
} }
int GetVarDeviceID(const std::string& var_name) const override {
return Get<SSAGraphBuilder>("previous_pass").GetVarDeviceID(var_name);
}
bool IsValidGraph(const ir::Graph* graph) const; bool IsValidGraph(const ir::Graph* graph) const;
}; };
......
...@@ -32,7 +32,9 @@ class SSAGraphExecutor { ...@@ -32,7 +32,9 @@ class SSAGraphExecutor {
virtual ~SSAGraphExecutor(); virtual ~SSAGraphExecutor();
virtual FeedFetchList Run(const std::vector<std::string> &fetch_tensors) = 0; virtual const ir::Graph& Graph() const = 0;
virtual FeedFetchList Run(const std::vector<std::string>& fetch_tensors) = 0;
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -81,3 +81,6 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph, ...@@ -81,3 +81,6 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(multi_device_print_pass,
paddle::framework::details::SSAGraghBuilderWithPrinter);
...@@ -39,8 +39,7 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { ...@@ -39,8 +39,7 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
public: public:
std::unique_ptr<ir::Graph> Apply( std::unique_ptr<ir::Graph> Apply(
std::unique_ptr<ir::Graph> graph) const override { std::unique_ptr<ir::Graph> graph) const override {
auto new_graph = auto new_graph = Get<ir::Pass>("previous_pass").Apply(std::move(graph));
Get<SSAGraphBuilder>("previous_pass").Apply(std::move(graph));
std::unique_ptr<std::ostream> fout( std::unique_ptr<std::ostream> fout(
new std::ofstream(Get<std::string>("debug_graphviz_path"))); new std::ofstream(Get<std::string>("debug_graphviz_path")));
...@@ -48,10 +47,6 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { ...@@ -48,10 +47,6 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*new_graph, *fout); Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*new_graph, *fout);
return new_graph; return new_graph;
} }
int GetVarDeviceID(const std::string& var_name) const override {
return Get<SSAGraphBuilder>("previous_pass").GetVarDeviceID(var_name);
}
}; };
} // namespace details } // namespace details
......
...@@ -42,6 +42,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -42,6 +42,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
std::unique_ptr<ir::Graph> &&graph); std::unique_ptr<ir::Graph> &&graph);
const ir::Graph &Graph() const { return *graph_; }
// Run a SSAGraph by a thread pool // Run a SSAGraph by a thread pool
// Use topological sort algorithm // Use topological sort algorithm
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override; FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
......
...@@ -42,6 +42,8 @@ class Graph { ...@@ -42,6 +42,8 @@ class Graph {
template <typename AttrType> template <typename AttrType>
AttrType &Get(const std::string &attr_name) const { AttrType &Get(const std::string &attr_name) const {
PADDLE_ENFORCE(attrs_.find(attr_name) != attrs_.end(),
"%s attr not registered for graph.", attr_name);
return *boost::any_cast<AttrType *>(attrs_.at(attr_name)); return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
} }
......
...@@ -44,6 +44,8 @@ class Pass { ...@@ -44,6 +44,8 @@ class Pass {
template <typename AttrType> template <typename AttrType>
AttrType &Get(const std::string &attr_name) const { AttrType &Get(const std::string &attr_name) const {
PADDLE_ENFORCE(attrs_.find(attr_name) != attrs_.end(),
"%s attr not registered for pass.", attr_name);
return *boost::any_cast<AttrType *>(attrs_.at(attr_name)); return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
} }
......
...@@ -33,6 +33,48 @@ limitations under the License. */ ...@@ -33,6 +33,48 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
std::unique_ptr<ir::Graph> ApplyParallelExecutorPass(
const ProgramDesc &main_program, const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &param_names,
const std::vector<Scope *> &local_scopes, const bool use_cuda,
#ifdef PADDLE_WITH_CUDA
const BuildStrategy &strategy, platform::NCCLContextMap *nccl_ctxs) {
#else
const BuildStrategy &strategy) {
#endif
details::ParallelExecutorPassManager builder_factory(
places, loss_var_name, param_names, local_scopes, strategy);
if (use_cuda) {
#ifdef PADDLE_WITH_CUDA
builder_factory.SetNCCLContextMap(nccl_ctxs);
#else
PADDLE_THROW("Not compiled with CUDA.");
#endif
}
std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
if (!strategy.debug_graphviz_path_.empty()) {
auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass");
const std::string graph_path = string::Sprintf(
"%s%s", strategy.debug_graphviz_path_.c_str(), "_original_graph");
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
graph = viz_pass->Apply(std::move(graph));
}
auto builder = builder_factory.Create();
graph = builder->Apply(std::move(graph));
if (!strategy.debug_graphviz_path_.empty()) {
auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass");
const std::string graph_path = string::Sprintf(
"%s%s", strategy.debug_graphviz_path_.c_str(), "_before_exec");
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
graph = viz_pass->Apply(std::move(graph));
}
return graph;
}
class ParallelExecutorPrivate { class ParallelExecutorPrivate {
public: public:
explicit ParallelExecutorPrivate(const std::vector<platform::Place> &places) explicit ParallelExecutorPrivate(const std::vector<platform::Place> &places)
...@@ -120,38 +162,18 @@ ParallelExecutor::ParallelExecutor( ...@@ -120,38 +162,18 @@ ParallelExecutor::ParallelExecutor(
var_infos.back().persistable_ = var->Persistable(); var_infos.back().persistable_ = var->Persistable();
} }
// Step 3. Convert main_program to SSA form and dependency graph. Also, insert // Step 3. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp // ncclOp
details::SSAGraphBuilderFactory builder_factory(
member_->places_, loss_var_name, params, member_->local_scopes_,
build_strategy);
if (member_->use_cuda_) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
builder_factory.SetNCCLContextMap(member_->nccl_ctxs_.get()); std::unique_ptr<ir::Graph> graph = ApplyParallelExecutorPass(
main_program, member_->places_, loss_var_name, params,
member_->local_scopes_, member_->use_cuda_, build_strategy,
member_->nccl_ctxs_.get());
#else #else
PADDLE_THROW("Not compiled with CUDA."); std::unique_ptr<ir::Graph> graph = ApplyParallelExecutorPass(
main_program, member_->places_, loss_var_name, params,
member_->local_scopes_, member_->use_cuda_, build_strategy);
#endif #endif
}
std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
if (!build_strategy.debug_graphviz_path_.empty()) {
auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass");
const std::string graph_path = string::Sprintf(
"%s%s", build_strategy.debug_graphviz_path_.c_str(), "_original_graph");
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
graph = viz_pass->Apply(std::move(graph));
}
builder_ = builder_factory.Create();
graph = builder_->Apply(std::move(graph));
if (!build_strategy.debug_graphviz_path_.empty()) {
auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass");
const std::string graph_path = string::Sprintf(
"%s%s", build_strategy.debug_graphviz_path_.c_str(), "_before_exec");
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
graph = viz_pass->Apply(std::move(graph));
}
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, std::move(graph))); exec_strategy, member_->local_scopes_, places, std::move(graph)));
...@@ -165,11 +187,18 @@ void ParallelExecutor::BCastParamsToDevices( ...@@ -165,11 +187,18 @@ void ParallelExecutor::BCastParamsToDevices(
// the initializing bcast, all vars would be bcast from device(0), // the initializing bcast, all vars would be bcast from device(0),
// otherwise // otherwise
// bcast from the specified device. // bcast from the specified device.
bool initializing = builder_.get() == nullptr ? true : false; bool initializing = member_->executor_ ? false : true;
for (auto &var : vars) { for (auto &var : vars) {
int var_dev_id = int var_dev_id = -1;
builder_.get() == nullptr ? -1 : builder_->GetVarDeviceID(var); if (member_->executor_) {
auto &sharded_var_device =
member_->executor_->Graph().Get<details::ShardedVarDevice>(
"sharded_var_device");
if (sharded_var_device.find(var) != sharded_var_device.end()) {
var_dev_id = sharded_var_device.at(var);
}
}
if (!initializing && var_dev_id == -1) continue; if (!initializing && var_dev_id == -1) continue;
framework::Variable *main_var = nullptr; framework::Variable *main_var = nullptr;
...@@ -307,3 +336,6 @@ ParallelExecutor::~ParallelExecutor() { ...@@ -307,3 +336,6 @@ ParallelExecutor::~ParallelExecutor() {
} // namespace paddle } // namespace paddle
USE_PASS(graph_viz_pass); USE_PASS(graph_viz_pass);
USE_PASS(multi_device_pass);
USE_PASS(multi_device_check_pass);
USE_PASS(multi_device_print_pass);
...@@ -70,7 +70,6 @@ class ParallelExecutor { ...@@ -70,7 +70,6 @@ class ParallelExecutor {
private: private:
ParallelExecutorPrivate *member_; ParallelExecutorPrivate *member_;
std::unique_ptr<details::SSAGraphBuilder> builder_;
}; };
} // namespace framework } // namespace framework
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under
the Apache License, Version 2.0 (the "License"); you may not use this file
except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax = "proto3";
package sendrecv;
option cc_generic_services = false;
service SendRecvService {
// For parameter server round-robin like hashing, do not split tensors.
// Send and recv only one tensor
// TODO(typhoonzero): add streaming API
rpc SendVariable(VariableMessage) returns (VoidMessage) {}
// Argument VariableMessage for GetVariable should only contain varname.
rpc GetVariable(VariableMessage) returns (VariableMessage) {}
// pre-fetch variable by given variable name and Ids
rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {}
rpc CheckpointNotify(VariableMessage) returns (VoidMessage) {}
}
// VariableMessage is serialized paddle variable message.
// It can be:
// LoDTensor
// SelectedRows
enum VarType {
LOD_TENSOR = 0;
SELECTED_ROWS = 1;
NCCL_ID = 2;
}
// NOTICE(gongwb):don't modify this proto if you are not
// not familar with how we serialize in sendrecvop_utils.h
// and deserilize it in variable_response.h.
message VariableMessage {
enum Type {
// Pod Types
BOOL = 0;
INT16 = 1;
INT32 = 2;
INT64 = 3;
FP16 = 4;
FP32 = 5;
FP64 = 6;
}
message LodData { repeated int64 lod_data = 1; }
string varname = 1;
// TODO(Yancey1989): reference framework::proto::VarDesc::VarType
VarType type = 2;
// bool persistable is not needed for sending.
// tensor info:
Type data_type = 3;
repeated int64 dims = 4;
// lod details:
int64 lod_level = 5;
repeated LodData lod = 6;
// selected_rows height, aka. original dim0
int64 slr_height = 7;
// tensor data
bytes serialized = 8;
// selected_rows data
bytes rows = 9;
// Look up table block execution output variable name.
string out_varname = 10;
// If 1, the ps server will start profiling, the ps
// server stops profiling and generates a profile to /tmp/profile_ps_*
// when profile switches from 1 to 2.
int64 profile = 11;
}
message VoidMessage {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册