From e4d7d7ae8fd7198447df21188d3fd85868c8bafa Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Thu, 26 Jul 2018 11:06:12 +0800 Subject: [PATCH] pass refactoring --- .../details/multi_devices_graph_builder.cc | 46 ++++++--- .../details/multi_devices_graph_builder.h | 6 +- .../scope_buffered_ssa_graph_executor.h | 3 + .../framework/details/ssa_graph_builder.h | 4 +- .../details/ssa_graph_builder_factory.cc | 12 +-- .../details/ssa_graph_builder_factory.h | 14 +-- .../framework/details/ssa_graph_checker.cc | 3 + .../framework/details/ssa_graph_checker.h | 7 +- .../framework/details/ssa_graph_executor.h | 4 +- .../framework/details/ssa_graph_printer.cc | 3 + .../framework/details/ssa_graph_printer.h | 7 +- .../details/threaded_ssa_graph_executor.h | 1 + paddle/fluid/framework/ir/graph.h | 2 + paddle/fluid/framework/ir/pass.h | 2 + paddle/fluid/framework/parallel_executor.cc | 98 ++++++++++++------- paddle/fluid/framework/parallel_executor.h | 1 - .../operators/distributed/send_recv.proto | 97 ++++++++++++++++++ 17 files changed, 229 insertions(+), 81 deletions(-) create mode 100644 paddle/fluid/operators/distributed/send_recv.proto diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 4fad520f405..d211f02689e 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -244,6 +244,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( result.Set("vars", new GraphVars(places_.size())); result.Set("dep_vars", new GraphDepVars); result.Set("ops", new GraphOps); + result.Set("sharded_var_device", new ShardedVarDevice); // find send/recv vars so that we can place the distributed training // realted op in the place 0 @@ -276,11 +277,12 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( // the block. is_forwarding = false; } else { - int op_dev_id = GetOpDeviceID(node); + int op_dev_id = GetOpDeviceID(result, node); if (op_dev_id != -1) { // This op only runs on one specific device. CreateComputationalOp(&result, node, op_dev_id); for (ir::Node *n : node->outputs) { - var_name_on_devices_.emplace(n->Name(), op_dev_id); + graph->Get("sharded_var_device") + .emplace(n->Name(), op_dev_id); } } else { // This op runs on all devices, and its output may have parameter's @@ -317,7 +319,8 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( case BuildStrategy::ReduceStrategy::kReduce: cur_device_id = GetAppropriateDeviceID({g_name}); CreateReduceOp(&result, g_name, cur_device_id); - var_name_on_devices_.emplace(g_name, cur_device_id); + graph->Get("sharded_var_device") + .emplace(g_name, cur_device_id); bcast_var_name_set[cur_device_id].emplace(p_name); break; case BuildStrategy::ReduceStrategy::kAllReduce: @@ -499,7 +502,8 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce( return is_pg_once; } -int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const { +int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph, + ir::Node *node) const { if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) { return -1; } @@ -512,15 +516,17 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const { node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); PADDLE_ENFORCE_EQ(param_grad.size(), 2U); - int dev_id = GetVarDeviceID(param_grad[1]); + int dev_id = GetVarDeviceID(graph, param_grad[1]); PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]", node->Op()->Type(), param_grad[0], param_grad[1]); return dev_id; } -int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const { - auto got = var_name_on_devices_.find(varname); - return got == var_name_on_devices_.end() ? -1 : got->second; +int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph, + const std::string &varname) const { + auto &sharded_var_device = graph.Get("sharded_var_device"); + auto got = sharded_var_device.find(varname); + return got == sharded_var_device.end() ? -1 : got->second; } void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const { @@ -625,20 +631,23 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, if (node->Op()->Type() == "split_byref" || node->Op()->Type() == "split_selected_rows") { // TODO(paddle-dev): getting the first var is not safe. - op_dev_id = GetVarDeviceID(input_var_names[0]); + op_dev_id = GetVarDeviceID(*result, input_var_names[0]); if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { op_dev_id = GetAppropriateDeviceID(input_var_names); for (auto &varname : input_var_names) { - var_name_on_devices_.emplace(varname, op_dev_id); + result->Get("sharded_var_device") + .emplace(varname, op_dev_id); } } for (auto &varname : output_var_names) { - var_name_on_devices_.emplace(varname, op_dev_id); + result->Get("sharded_var_device") + .emplace(varname, op_dev_id); } } else if (node->Op()->Type() == "concat") { - op_dev_id = GetVarDeviceID(input_var_names[0]); + op_dev_id = GetVarDeviceID(*result, input_var_names[0]); for (auto &varname : output_var_names) { - var_name_on_devices_.emplace(varname, op_dev_id); + result->Get("sharded_var_device") + .emplace(varname, op_dev_id); } } else { PADDLE_ENFORCE( @@ -663,7 +672,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, int op_dev_id = -1; if (node->Op()->Type() == "send") { // TODO(paddle-dev): getting the first var is not safe. - op_dev_id = GetVarDeviceID(node->inputs[0]->Name()); + op_dev_id = GetVarDeviceID(*result, node->inputs[0]->Name()); PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]), "This hack no longer holds, please fix."); // the variable name which contains .block means it was splited by @@ -678,7 +687,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, } op_dev_id = GetAppropriateDeviceID(input_var_names); for (auto &varname : input_var_names) { - var_name_on_devices_.emplace(varname, op_dev_id); + result->Get("sharded_var_device") + .emplace(varname, op_dev_id); } } } else if (node->Op()->Type() == "recv") { @@ -688,7 +698,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, } op_dev_id = GetAppropriateDeviceID(output_var_names); for (auto &varname : output_var_names) { - var_name_on_devices_.emplace(varname, op_dev_id); + result->Get("sharded_var_device") + .emplace(varname, op_dev_id); } } else { // send_barrier and fetch_barrier op can be scheduled on device 0 @@ -730,3 +741,6 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const { } // namespace details } // namespace framework } // namespace paddle + +REGISTER_PASS(multi_device_pass, + paddle::framework::details::MultiDevSSAGraphBuilder); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index c8c1b2a4386..baea091af30 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -34,7 +34,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { public: std::unique_ptr Apply( std::unique_ptr graph) const override; - int GetVarDeviceID(const std::string &varname) const override; private: void CreateOpHandleIOs(ir::Graph *result, ir::Node *node, @@ -51,6 +50,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { mutable platform::NCCLContextMap *nccl_ctxs_; #endif + int GetVarDeviceID(const ir::Graph &graph, const std::string &varname) const; + bool IsScaleLossOp(ir::Node *node) const; void CreateRPCOp(ir::Graph *result, ir::Node *node) const; @@ -84,7 +85,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { const std::string &og, std::unordered_set *og_has_been_broadcast) const; - int GetOpDeviceID(ir::Node *node) const; + int GetOpDeviceID(const ir::Graph &graph, ir::Node *node) const; void InsertAllReduceOp(ir::Graph *result, const std::string &og) const; @@ -102,7 +103,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { private: mutable BuildStrategy strategy_; mutable std::unordered_map all_vars_; - mutable std::unordered_map var_name_on_devices_; mutable std::vector balance_vars_; void SetCommunicationContext(OpHandleBase *op_handle, diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h index cbfbcb1c0cd..1b188aec599 100644 --- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h @@ -40,6 +40,9 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor { ExecutionStrategy strategy, std::vector local_scopes, std::vector var_infos, std::vector places, std::unique_ptr&& underlying_executor); + + const ir::Graph& Graph() const { return underlying_executor_->Graph(); } + FeedFetchList Run(const std::vector& fetch_tensors) override; private: diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index 2b4f31f2ff3..e0ad0273150 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -47,13 +47,13 @@ typedef std::unordered_set> GraphDepVars; // unordered. typedef std::vector> GraphOps; +typedef std::unordered_map ShardedVarDevice; + class SSAGraphBuilder : public ir::Pass { public: SSAGraphBuilder() {} virtual ~SSAGraphBuilder() {} - virtual int GetVarDeviceID(const std::string &var_name) const = 0; - DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); protected: diff --git a/paddle/fluid/framework/details/ssa_graph_builder_factory.cc b/paddle/fluid/framework/details/ssa_graph_builder_factory.cc index e8d83943acf..2254a3b41ea 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder_factory.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder_factory.cc @@ -21,8 +21,8 @@ namespace paddle { namespace framework { namespace details { -std::unique_ptr SSAGraphBuilderFactory::Create() { - std::unique_ptr res(new MultiDevSSAGraphBuilder); +std::unique_ptr ParallelExecutorPassManager::Create() { + std::unique_ptr res(new MultiDevSSAGraphBuilder); res->SetNotOwned>("places", &places_); res->SetNotOwned("loss_var_name", &loss_var_name_); res->SetNotOwned>("params", ¶m_names_); @@ -33,18 +33,18 @@ std::unique_ptr SSAGraphBuilderFactory::Create() { #endif if (!strategy_.debug_graphviz_path_.empty()) { - SSAGraphBuilder *previous_pass = res.release(); + ir::Pass *previous_pass = res.release(); res.reset(new SSAGraghBuilderWithPrinter); - res->Set("previous_pass", previous_pass); + res->Set("previous_pass", previous_pass); res->SetNotOwned("debug_graphviz_path", &strategy_.debug_graphviz_path_); res->Set("graph_printer", new GraphvizSSAGraphPrinter); } - SSAGraphBuilder *previous_pass = res.release(); + ir::Pass *previous_pass = res.release(); res.reset(new SSAGraghBuilderWithChecker); - res->Set("previous_pass", previous_pass); + res->Set("previous_pass", previous_pass); return res; } diff --git a/paddle/fluid/framework/details/ssa_graph_builder_factory.h b/paddle/fluid/framework/details/ssa_graph_builder_factory.h index 91a119de83e..1bfc3e71e8c 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder_factory.h +++ b/paddle/fluid/framework/details/ssa_graph_builder_factory.h @@ -29,13 +29,13 @@ namespace framework { class Scope; namespace details { -class SSAGraphBuilderFactory { +class ParallelExecutorPassManager { public: - SSAGraphBuilderFactory(const std::vector& places, - const std::string& loss_var_name, - const std::unordered_set& param_names, - const std::vector& local_scopes, - const BuildStrategy& strategy) + ParallelExecutorPassManager( + const std::vector& places, + const std::string& loss_var_name, + const std::unordered_set& param_names, + const std::vector& local_scopes, const BuildStrategy& strategy) : places_(places), loss_var_name_(loss_var_name), param_names_(param_names), @@ -52,7 +52,7 @@ class SSAGraphBuilderFactory { } #endif - std::unique_ptr Create(); + std::unique_ptr Create(); private: std::vector places_; diff --git a/paddle/fluid/framework/details/ssa_graph_checker.cc b/paddle/fluid/framework/details/ssa_graph_checker.cc index 0438b096109..2994329f485 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.cc +++ b/paddle/fluid/framework/details/ssa_graph_checker.cc @@ -85,3 +85,6 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { } // namespace details } // namespace framework } // namespace paddle + +REGISTER_PASS(multi_device_check_pass, + paddle::framework::details::SSAGraghBuilderWithChecker); diff --git a/paddle/fluid/framework/details/ssa_graph_checker.h b/paddle/fluid/framework/details/ssa_graph_checker.h index ae5ad16b0c9..fb766fb4155 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.h +++ b/paddle/fluid/framework/details/ssa_graph_checker.h @@ -26,16 +26,11 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder { public: std::unique_ptr Apply( std::unique_ptr graph) const override { - auto new_graph = - Get("previous_pass").Apply(std::move(graph)); + auto new_graph = Get("previous_pass").Apply(std::move(graph)); PADDLE_ENFORCE(IsValidGraph(new_graph.get())); return new_graph; } - int GetVarDeviceID(const std::string& var_name) const override { - return Get("previous_pass").GetVarDeviceID(var_name); - } - bool IsValidGraph(const ir::Graph* graph) const; }; diff --git a/paddle/fluid/framework/details/ssa_graph_executor.h b/paddle/fluid/framework/details/ssa_graph_executor.h index 8815ec89b23..96fffb7d943 100644 --- a/paddle/fluid/framework/details/ssa_graph_executor.h +++ b/paddle/fluid/framework/details/ssa_graph_executor.h @@ -32,7 +32,9 @@ class SSAGraphExecutor { virtual ~SSAGraphExecutor(); - virtual FeedFetchList Run(const std::vector &fetch_tensors) = 0; + virtual const ir::Graph& Graph() const = 0; + + virtual FeedFetchList Run(const std::vector& fetch_tensors) = 0; }; } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/ssa_graph_printer.cc b/paddle/fluid/framework/details/ssa_graph_printer.cc index 20aab146440..95d0641d72c 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.cc +++ b/paddle/fluid/framework/details/ssa_graph_printer.cc @@ -81,3 +81,6 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph, } // namespace details } // namespace framework } // namespace paddle + +REGISTER_PASS(multi_device_print_pass, + paddle::framework::details::SSAGraghBuilderWithPrinter); diff --git a/paddle/fluid/framework/details/ssa_graph_printer.h b/paddle/fluid/framework/details/ssa_graph_printer.h index 2a939ef4c9e..b7d20aa9838 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.h +++ b/paddle/fluid/framework/details/ssa_graph_printer.h @@ -39,8 +39,7 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { public: std::unique_ptr Apply( std::unique_ptr graph) const override { - auto new_graph = - Get("previous_pass").Apply(std::move(graph)); + auto new_graph = Get("previous_pass").Apply(std::move(graph)); std::unique_ptr fout( new std::ofstream(Get("debug_graphviz_path"))); @@ -48,10 +47,6 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { Get("graph_printer").Print(*new_graph, *fout); return new_graph; } - - int GetVarDeviceID(const std::string& var_name) const override { - return Get("previous_pass").GetVarDeviceID(var_name); - } }; } // namespace details diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index 3d67daa45e2..82d6b5272ab 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -42,6 +42,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { const std::vector &places, std::unique_ptr &&graph); + const ir::Graph &Graph() const { return *graph_; } // Run a SSAGraph by a thread pool // Use topological sort algorithm FeedFetchList Run(const std::vector &fetch_tensors) override; diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 4f59ec82a7d..49f39df4b96 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -42,6 +42,8 @@ class Graph { template AttrType &Get(const std::string &attr_name) const { + PADDLE_ENFORCE(attrs_.find(attr_name) != attrs_.end(), + "%s attr not registered for graph.", attr_name); return *boost::any_cast(attrs_.at(attr_name)); } diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index 9466924262d..5ab7f9a1e2d 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -44,6 +44,8 @@ class Pass { template AttrType &Get(const std::string &attr_name) const { + PADDLE_ENFORCE(attrs_.find(attr_name) != attrs_.end(), + "%s attr not registered for pass.", attr_name); return *boost::any_cast(attrs_.at(attr_name)); } diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index ff661d0013c..a23fd2a41a8 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -33,6 +33,48 @@ limitations under the License. */ namespace paddle { namespace framework { +std::unique_ptr ApplyParallelExecutorPass( + const ProgramDesc &main_program, const std::vector &places, + const std::string &loss_var_name, + const std::unordered_set ¶m_names, + const std::vector &local_scopes, const bool use_cuda, +#ifdef PADDLE_WITH_CUDA + const BuildStrategy &strategy, platform::NCCLContextMap *nccl_ctxs) { +#else + const BuildStrategy &strategy) { +#endif + details::ParallelExecutorPassManager builder_factory( + places, loss_var_name, param_names, local_scopes, strategy); + if (use_cuda) { +#ifdef PADDLE_WITH_CUDA + builder_factory.SetNCCLContextMap(nccl_ctxs); +#else + PADDLE_THROW("Not compiled with CUDA."); +#endif + } + + std::unique_ptr graph(new ir::Graph(main_program)); + if (!strategy.debug_graphviz_path_.empty()) { + auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass"); + const std::string graph_path = string::Sprintf( + "%s%s", strategy.debug_graphviz_path_.c_str(), "_original_graph"); + viz_pass->Set("graph_viz_path", new std::string(graph_path)); + graph = viz_pass->Apply(std::move(graph)); + } + + auto builder = builder_factory.Create(); + graph = builder->Apply(std::move(graph)); + + if (!strategy.debug_graphviz_path_.empty()) { + auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass"); + const std::string graph_path = string::Sprintf( + "%s%s", strategy.debug_graphviz_path_.c_str(), "_before_exec"); + viz_pass->Set("graph_viz_path", new std::string(graph_path)); + graph = viz_pass->Apply(std::move(graph)); + } + return graph; +} + class ParallelExecutorPrivate { public: explicit ParallelExecutorPrivate(const std::vector &places) @@ -120,38 +162,18 @@ ParallelExecutor::ParallelExecutor( var_infos.back().persistable_ = var->Persistable(); } - // Step 3. Convert main_program to SSA form and dependency graph. Also, insert - // ncclOp - details::SSAGraphBuilderFactory builder_factory( - member_->places_, loss_var_name, params, member_->local_scopes_, - build_strategy); - if (member_->use_cuda_) { +// Step 3. Convert main_program to SSA form and dependency graph. Also, insert +// ncclOp #ifdef PADDLE_WITH_CUDA - builder_factory.SetNCCLContextMap(member_->nccl_ctxs_.get()); + std::unique_ptr graph = ApplyParallelExecutorPass( + main_program, member_->places_, loss_var_name, params, + member_->local_scopes_, member_->use_cuda_, build_strategy, + member_->nccl_ctxs_.get()); #else - PADDLE_THROW("Not compiled with CUDA."); + std::unique_ptr graph = ApplyParallelExecutorPass( + main_program, member_->places_, loss_var_name, params, + member_->local_scopes_, member_->use_cuda_, build_strategy); #endif - } - - std::unique_ptr graph(new ir::Graph(main_program)); - if (!build_strategy.debug_graphviz_path_.empty()) { - auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass"); - const std::string graph_path = string::Sprintf( - "%s%s", build_strategy.debug_graphviz_path_.c_str(), "_original_graph"); - viz_pass->Set("graph_viz_path", new std::string(graph_path)); - graph = viz_pass->Apply(std::move(graph)); - } - - builder_ = builder_factory.Create(); - graph = builder_->Apply(std::move(graph)); - - if (!build_strategy.debug_graphviz_path_.empty()) { - auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass"); - const std::string graph_path = string::Sprintf( - "%s%s", build_strategy.debug_graphviz_path_.c_str(), "_before_exec"); - viz_pass->Set("graph_viz_path", new std::string(graph_path)); - graph = viz_pass->Apply(std::move(graph)); - } member_->executor_.reset(new details::ThreadedSSAGraphExecutor( exec_strategy, member_->local_scopes_, places, std::move(graph))); @@ -165,11 +187,18 @@ void ParallelExecutor::BCastParamsToDevices( // the initializing bcast, all vars would be bcast from device(0), // otherwise // bcast from the specified device. - bool initializing = builder_.get() == nullptr ? true : false; - + bool initializing = member_->executor_ ? false : true; for (auto &var : vars) { - int var_dev_id = - builder_.get() == nullptr ? -1 : builder_->GetVarDeviceID(var); + int var_dev_id = -1; + if (member_->executor_) { + auto &sharded_var_device = + member_->executor_->Graph().Get( + "sharded_var_device"); + if (sharded_var_device.find(var) != sharded_var_device.end()) { + var_dev_id = sharded_var_device.at(var); + } + } + if (!initializing && var_dev_id == -1) continue; framework::Variable *main_var = nullptr; @@ -307,3 +336,6 @@ ParallelExecutor::~ParallelExecutor() { } // namespace paddle USE_PASS(graph_viz_pass); +USE_PASS(multi_device_pass); +USE_PASS(multi_device_check_pass); +USE_PASS(multi_device_print_pass); diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index ffb9934a2d7..d624956acde 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -70,7 +70,6 @@ class ParallelExecutor { private: ParallelExecutorPrivate *member_; - std::unique_ptr builder_; }; } // namespace framework diff --git a/paddle/fluid/operators/distributed/send_recv.proto b/paddle/fluid/operators/distributed/send_recv.proto new file mode 100644 index 00000000000..d0595ef1089 --- /dev/null +++ b/paddle/fluid/operators/distributed/send_recv.proto @@ -0,0 +1,97 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under +the Apache License, Version 2.0 (the "License"); you may not use this file +except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +syntax = "proto3"; +package sendrecv; + +option cc_generic_services = false; + +service SendRecvService { + // For parameter server round-robin like hashing, do not split tensors. + // Send and recv only one tensor + // TODO(typhoonzero): add streaming API + rpc SendVariable(VariableMessage) returns (VoidMessage) {} + // Argument VariableMessage for GetVariable should only contain varname. + rpc GetVariable(VariableMessage) returns (VariableMessage) {} + // pre-fetch variable by given variable name and Ids + rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {} + + rpc CheckpointNotify(VariableMessage) returns (VoidMessage) {} +} + +// VariableMessage is serialized paddle variable message. +// It can be: +// LoDTensor +// SelectedRows +enum VarType { + LOD_TENSOR = 0; + SELECTED_ROWS = 1; + NCCL_ID = 2; +} + +// NOTICE(gongwb):don't modify this proto if you are not +// not familar with how we serialize in sendrecvop_utils.h +// and deserilize it in variable_response.h. +message VariableMessage { + enum Type { + // Pod Types + BOOL = 0; + INT16 = 1; + INT32 = 2; + INT64 = 3; + FP16 = 4; + FP32 = 5; + FP64 = 6; + } + + message LodData { repeated int64 lod_data = 1; } + string varname = 1; + // TODO(Yancey1989): reference framework::proto::VarDesc::VarType + VarType type = 2; + // bool persistable is not needed for sending. + // tensor info: + Type data_type = 3; + repeated int64 dims = 4; + + // lod details: + int64 lod_level = 5; + repeated LodData lod = 6; + // selected_rows height, aka. original dim0 + int64 slr_height = 7; + // tensor data + bytes serialized = 8; + // selected_rows data + bytes rows = 9; + // Look up table block execution output variable name. + string out_varname = 10; + // If 1, the ps server will start profiling, the ps + // server stops profiling and generates a profile to /tmp/profile_ps_* + // when profile switches from 1 to 2. + int64 profile = 11; +} + +message VoidMessage {} -- GitLab