From 3d69a82b837e9dbfb907c36cd4a029e10d682db8 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Tue, 26 Jun 2018 16:23:31 +0800 Subject: [PATCH] fix dist train broadcasting bug --- .../framework/details/multi_devices_graph_builder.cc | 3 +++ paddle/fluid/framework/details/ssa_graph_builder.h | 2 +- paddle/fluid/framework/details/ssa_graph_checker.h | 6 ++++++ paddle/fluid/framework/details/ssa_graph_printer.h | 5 +++++ paddle/fluid/framework/parallel_executor.cc | 12 +++++++----- 5 files changed, 22 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 74f3687c0d6..8765af52b09 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -483,6 +483,9 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, } } else if (op.Type() == "concat") { op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]); + for (auto &varname : op.OutputArgumentNames()) { + var_name_on_devices_.emplace(varname, op_dev_id); + } } else { PADDLE_ENFORCE( "the distribute training related op should be in [split_byref, " diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index 9eb23c46264..18612c3c1b6 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -30,7 +30,7 @@ class SSAGraphBuilder { SSAGraphBuilder() {} virtual ~SSAGraphBuilder() {} virtual std::unique_ptr Build(const ProgramDesc &program) const = 0; - virtual int GetVarDeviceID(const std::string &var_name) const { return -1; } + virtual int GetVarDeviceID(const std::string &var_name) const = 0; DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); diff --git a/paddle/fluid/framework/details/ssa_graph_checker.h b/paddle/fluid/framework/details/ssa_graph_checker.h index 304b221e7e4..331aa9d2b58 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.h +++ b/paddle/fluid/framework/details/ssa_graph_checker.h @@ -16,6 +16,8 @@ #include "paddle/fluid/framework/details/ssa_graph_builder.h" +#include + namespace paddle { namespace framework { namespace details { @@ -33,6 +35,10 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder { return graph; } + int GetVarDeviceID(const std::string& var_name) const override { + return builder_->GetVarDeviceID(var_name); + } + bool IsValidGraph(const SSAGraph* graph) const; private: diff --git a/paddle/fluid/framework/details/ssa_graph_printer.h b/paddle/fluid/framework/details/ssa_graph_printer.h index b4c90013789..09b0333ef2c 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.h +++ b/paddle/fluid/framework/details/ssa_graph_printer.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include "paddle/fluid/framework/details/ssa_graph_builder.h" namespace paddle { @@ -55,6 +56,10 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { return graph; } + int GetVarDeviceID(const std::string& var_name) const override { + return builder_->GetVarDeviceID(var_name); + } + private: std::unique_ptr printer_; std::unique_ptr builder_; diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index a6788cb6d5d..d06e4c89ea1 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -133,17 +133,18 @@ ParallelExecutor::ParallelExecutor( void ParallelExecutor::BCastParamsToGPUs( const std::unordered_set &vars) const { - // the the initialize bcast, all vars would be bcast from device(0), otherwise + // the the initializing bcast, all vars would be bcast from device(0), + // otherwise // bcast from the specified device. - bool initialize = builder_.get() == nullptr ? true : false; + bool initializing = builder_.get() == nullptr ? false : true; for (auto &var : vars) { int var_dev_id = builder_.get() == nullptr ? -1 : builder_->GetVarDeviceID(var); - if (!initialize && var_dev_id == -1) continue; + if (!initializing && var_dev_id == -1) continue; framework::Variable *main_var = nullptr; - if (initialize) { + if (initializing) { main_var = member_->local_scopes_[0]->FindVar(var); } else { main_var = member_->local_scopes_[var_dev_id]->FindVar(var); @@ -164,7 +165,8 @@ void ParallelExecutor::BCastParamsToGPUs( auto place = member_->places_[i]; void *buffer; - if ((initialize && i == 0) || (!initialize && i == var_dev_id)) { + if ((initializing && i == 0) || + (!initializing && i == static_cast(var_dev_id))) { buffer = const_cast(main_tensor.data()); } else { auto local_scope = member_->local_scopes_[i]; -- GitLab