diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index a6fe64fa80d6bf036893d49de56d7274d49a3b30..e680655f2ab3b695da86d2b4e799a03661a25b94 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -483,6 +483,9 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, } } else if (op.Type() == "concat") { op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]); + for (auto &varname : op.OutputArgumentNames()) { + var_name_on_devices_.emplace(varname, op_dev_id); + } } else { PADDLE_ENFORCE( "the distribute training related op should be in [split_byref, " diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index 9eb23c46264f9036f009b0ae9aeeb34ec70c0e53..18612c3c1b62cf4c2ebdc221c301c59ec81c2da7 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -30,7 +30,7 @@ class SSAGraphBuilder { SSAGraphBuilder() {} virtual ~SSAGraphBuilder() {} virtual std::unique_ptr Build(const ProgramDesc &program) const = 0; - virtual int GetVarDeviceID(const std::string &var_name) const { return -1; } + virtual int GetVarDeviceID(const std::string &var_name) const = 0; DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); diff --git a/paddle/fluid/framework/details/ssa_graph_checker.h b/paddle/fluid/framework/details/ssa_graph_checker.h index 304b221e7e4c414a0ab562a1b99836d3b7c02efb..331aa9d2b5864c470dbd5e29ef6faccffdcf781c 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.h +++ b/paddle/fluid/framework/details/ssa_graph_checker.h @@ -16,6 +16,8 @@ #include "paddle/fluid/framework/details/ssa_graph_builder.h" +#include + namespace paddle { namespace framework { namespace details { @@ -33,6 +35,10 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder { return graph; } + int GetVarDeviceID(const std::string& var_name) const override { + return builder_->GetVarDeviceID(var_name); + } + bool IsValidGraph(const SSAGraph* graph) const; private: diff --git a/paddle/fluid/framework/details/ssa_graph_printer.h b/paddle/fluid/framework/details/ssa_graph_printer.h index b4c90013789759d17646d95efdc81fc6a0a4f3e7..09b0333ef2cb43a306133aa5af98d37c11454d4d 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.h +++ b/paddle/fluid/framework/details/ssa_graph_printer.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include "paddle/fluid/framework/details/ssa_graph_builder.h" namespace paddle { @@ -55,6 +56,10 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { return graph; } + int GetVarDeviceID(const std::string& var_name) const override { + return builder_->GetVarDeviceID(var_name); + } + private: std::unique_ptr printer_; std::unique_ptr builder_; diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index d478865fa8f24c653a4185cabd05747a5410ceaa..1bf9b87aea8c7c13bfda57ada8f7ec674d7d65a2 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -133,17 +133,18 @@ ParallelExecutor::ParallelExecutor( void ParallelExecutor::BCastParamsToGPUs( const std::unordered_set &vars) const { - // the the initialize bcast, all vars would be bcast from device(0), otherwise + // the the initializing bcast, all vars would be bcast from device(0), + // otherwise // bcast from the specified device. - bool initialize = builder_.get() == nullptr ? true : false; + bool initializing = builder_.get() == nullptr ? true : false; for (auto &var : vars) { int var_dev_id = builder_.get() == nullptr ? -1 : builder_->GetVarDeviceID(var); - if (!initialize && var_dev_id == -1) continue; + if (!initializing && var_dev_id == -1) continue; framework::Variable *main_var = nullptr; - if (initialize) { + if (initializing) { main_var = member_->local_scopes_[0]->FindVar(var); } else { main_var = member_->local_scopes_[var_dev_id]->FindVar(var); @@ -164,7 +165,8 @@ void ParallelExecutor::BCastParamsToGPUs( auto place = member_->places_[i]; void *buffer; - if ((initialize && i == 0) || (!initialize && i == var_dev_id)) { + if ((initializing && i == 0) || + (!initializing && static_cast(i) == var_dev_id)) { buffer = const_cast(main_tensor.data()); } else { auto local_scope = member_->local_scopes_[i]; @@ -181,8 +183,16 @@ void ParallelExecutor::BCastParamsToGPUs( platform::NCCLGroupGuard guard; for (size_t i = 0; i < member_->places_.size(); ++i) { auto &nccl_ctx = member_->nccl_ctxs_->at(member_->places_[i]); - platform::dynload::ncclBcast(buffers[i], numel, data_type, 0, - nccl_ctx.comm_, nccl_ctx.stream()); + if (initializing) { + platform::dynload::ncclBcast(buffers[i], numel, data_type, 0, + nccl_ctx.comm_, nccl_ctx.stream()); + } else { + if (var_dev_id >= 0) { + platform::dynload::ncclBcast(buffers[i], numel, data_type, + var_dev_id, nccl_ctx.comm_, + nccl_ctx.stream()); + } + } } member_->nccl_ctxs_->WaitAll(); } diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index d8d6a7e9418e1c2a9f82d58b5c9650d58604d46e..10bbec0979b77a7aa581cc01c6903e1648866a86 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -302,7 +302,6 @@ class DistributeTranspiler(object): """ # remove optimize ops and add a send op to main_program delete_ops(self.origin_program.global_block(), self.optimize_ops) - # FIXME(typhoonzero): serialize once will fix error occurs when clone. self.origin_program.__str__() return self.origin_program @@ -383,11 +382,12 @@ class DistributeTranspiler(object): if self._is_adam_connected_op(op): global_ops.append(op) - def __append_optimize_op__(op, block, grad_to_block_id, merged_var): + def __append_optimize_op__(op, block, grad_to_block_id, merged_var, + lr_ops): if self._is_optimizer_op(op): self._append_pserver_ops(block, op, endpoint, grad_to_block_id, self.origin_program, merged_var) - else: + elif op not in lr_ops: self._append_pserver_non_opt_ops(block, op) def __op_have_grad_input__(op): @@ -447,7 +447,7 @@ class DistributeTranspiler(object): # optimizer is connected to itself if ufind.is_connected(op, opt_op) and op not in global_ops: __append_optimize_op__(op, per_opt_block, grad_to_block_id, - merged_var) + merged_var, lr_ops) # append global ops if global_ops: @@ -455,7 +455,7 @@ class DistributeTranspiler(object): pserver_program.num_blocks - 1) for glb_op in global_ops: __append_optimize_op__(glb_op, opt_state_block, - grad_to_block_id, None) + grad_to_block_id, None, lr_ops) # process distributed lookup_table prefetch_var_name_to_block_id = []