提交 642fd68c 编写于 作者: Y Yancey1989

update by comment test=develop

上级 7cd6de37
...@@ -21,8 +21,6 @@ namespace paddle { ...@@ -21,8 +21,6 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
constexpr char kAllOpDescs[] = "all_op_descs";
// TODO(gongwb): overlap allreduce with backward computation. // TODO(gongwb): overlap allreduce with backward computation.
class AllReduceDepsPass : public ir::Pass { class AllReduceDepsPass : public ir::Pass {
protected: protected:
......
...@@ -29,8 +29,6 @@ namespace paddle { ...@@ -29,8 +29,6 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
constexpr char kAllOpDescs[] = "all_op_descs";
std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph); std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph);
// NOTE(dzh): A ordered set for node reuse in memory optimize. // NOTE(dzh): A ordered set for node reuse in memory optimize.
......
...@@ -221,7 +221,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl( ...@@ -221,7 +221,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
* Only variables should be the leaves of graph. * Only variables should be the leaves of graph.
*/ */
AddOutputToLeafOps(&result); AddOutputToLeafOps(&result);
// result.Erase(kGraphOps);
return graph; return graph;
} }
......
...@@ -19,12 +19,12 @@ namespace paddle { ...@@ -19,12 +19,12 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph( std::vector<std::unique_ptr<ir::Graph>>
const std::vector<platform::Place> &places, ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(
std::unique_ptr<ir::Graph> graph) { std::unique_ptr<ir::Graph> &&graph) {
std::vector<std::unique_ptr<ir::Graph>> graphs; std::vector<std::unique_ptr<ir::Graph>> graphs;
graphs.reserve(places.size()); graphs.reserve(places_.size());
for (size_t i = 0; i < places.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
ProgramDesc empty; ProgramDesc empty;
graphs.emplace_back(std::unique_ptr<ir::Graph>(new ir::Graph(empty))); graphs.emplace_back(std::unique_ptr<ir::Graph>(new ir::Graph(empty)));
auto &g = graphs.back(); auto &g = graphs.back();
...@@ -60,7 +60,7 @@ std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph( ...@@ -60,7 +60,7 @@ std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
} }
} }
for (size_t dev_id = 0; dev_id < places.size(); ++dev_id) { for (size_t dev_id = 0; dev_id < places_.size(); ++dev_id) {
auto &dev_vars = graphs[dev_id]->Get<GraphVars>(kGraphVars)[0]; auto &dev_vars = graphs[dev_id]->Get<GraphVars>(kGraphVars)[0];
auto &origin_vars = graph->Get<GraphVars>(kGraphVars)[dev_id]; auto &origin_vars = graph->Get<GraphVars>(kGraphVars)[dev_id];
for (auto &name_pair : origin_vars) { for (auto &name_pair : origin_vars) {
...@@ -80,14 +80,26 @@ std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph( ...@@ -80,14 +80,26 @@ std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
ParallelSSAGraphExecutor::ParallelSSAGraphExecutor( ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes, const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
std::vector<std::unique_ptr<ir::Graph>> &&graphs) const framework::ProgramDesc &main_prog, std::unique_ptr<ir::Graph> &&graph)
: strategy_(std::move(strategy)), : strategy_(std::move(strategy)),
local_scopes_(std::move(local_scopes)), local_scopes_(std::move(local_scopes)),
pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr), pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr),
places_(std::move(places)), places_(std::move(places)),
graphs_(std::move(graphs)) { main_prog_(main_prog),
// TODO(Yancey1989): copy graphs is not safely since it deleted the attrs.
graphs_(SeparateMultiDevicesGraph(std::move(graph))) {
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
auto seq_allreduce_pass =
ir::PassRegistry::Instance().Get("all_reduce_deps_pass");
seq_allreduce_pass->Erase(details::kAllOpDescs);
seq_allreduce_pass->Set<const std::vector<OpDesc *>>(
details::kAllOpDescs,
new std::vector<OpDesc *>(main_prog_.Block(0).AllOps()));
for (size_t i = 0; i < graphs_.size(); ++i) {
graphs_[i] = seq_allreduce_pass->Apply(std::move(graphs_[i]));
}
// set the correct size of thread pool to each device. // set the correct size of thread pool to each device.
strategy_.num_threads_ = strategy_.num_threads_ < places_.size() strategy_.num_threads_ = strategy_.num_threads_ < places_.size()
? 1UL ? 1UL
......
...@@ -28,16 +28,13 @@ namespace paddle { ...@@ -28,16 +28,13 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
const std::vector<platform::Place> &places,
std::unique_ptr<ir::Graph> graph);
class ParallelSSAGraphExecutor : public SSAGraphExecutor { class ParallelSSAGraphExecutor : public SSAGraphExecutor {
public: public:
ParallelSSAGraphExecutor(const ExecutionStrategy &strategy, ParallelSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
std::vector<std::unique_ptr<ir::Graph>> &&graphs); const framework::ProgramDesc &main_prog,
std::unique_ptr<ir::Graph> &&graph);
~ParallelSSAGraphExecutor() final = default; ~ParallelSSAGraphExecutor() final = default;
const ir::Graph &Graph() const override { return *graphs_[0]; } const ir::Graph &Graph() const override { return *graphs_[0]; }
...@@ -45,10 +42,14 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor { ...@@ -45,10 +42,14 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override; FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
private: private:
std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
std::unique_ptr<ir::Graph> &&graph);
ExecutionStrategy strategy_; ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::unique_ptr<::ThreadPool> pool_{nullptr}; std::unique_ptr<::ThreadPool> pool_{nullptr};
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
framework::ProgramDesc main_prog_;
std::vector<std::unique_ptr<ir::Graph>> graphs_; std::vector<std::unique_ptr<ir::Graph>> graphs_;
std::vector<std::unique_ptr<details::ThreadedSSAGraphExecutor>> executors_; std::vector<std::unique_ptr<details::ThreadedSSAGraphExecutor>> executors_;
......
...@@ -26,6 +26,11 @@ limitations under the License. */ ...@@ -26,6 +26,11 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details {
constexpr char kAllOpDescs[] = "all_op_descs";
} // namespace details
namespace ir { namespace ir {
/* /*
......
...@@ -305,21 +305,11 @@ ParallelExecutor::ParallelExecutor( ...@@ -305,21 +305,11 @@ ParallelExecutor::ParallelExecutor(
if (build_strategy.enable_parallel_graph_) { if (build_strategy.enable_parallel_graph_) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
auto parallel_graph = // TODO(Yancey1989): Remove passing in the main_program when
details::SeparateMultiDevicesGraph(member_->places_, std::move(graph)); // allreduce_seq_pass doesn't need it as the attr.
auto seq_allreduce_pass =
ir::PassRegistry::Instance().Get("all_reduce_deps_pass");
seq_allreduce_pass->Erase(details::kAllOpDescs);
seq_allreduce_pass->Set<const std::vector<OpDesc *>>(
details::kAllOpDescs,
new std::vector<OpDesc *>(main_program.Block(0).AllOps()));
for (size_t i = 0; i < parallel_graph.size(); ++i) {
parallel_graph[i] =
seq_allreduce_pass->Apply(std::move(parallel_graph[i]));
}
member_->executor_.reset(new details::ParallelSSAGraphExecutor( member_->executor_.reset(new details::ParallelSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, exec_strategy, member_->local_scopes_, member_->places_, main_program,
std::move(parallel_graph))); std::move(graph)));
#else #else
PADDLE_THROW( PADDLE_THROW(
"Paddle should be compiled with CUDA for ParallelGraph Execution."); "Paddle should be compiled with CUDA for ParallelGraph Execution.");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册