未验证 提交 dd8ee695 编写于 作者: W Wu Yi 提交者: GitHub

Merge pull request #11756 from typhoonzero/cherry_pick_bcast_fix

Merge pull request #11728 from typhoonzero/fix_paraexe_bcast
...@@ -483,6 +483,9 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, ...@@ -483,6 +483,9 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
} }
} else if (op.Type() == "concat") { } else if (op.Type() == "concat") {
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]); op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
for (auto &varname : op.OutputArgumentNames()) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
} else { } else {
PADDLE_ENFORCE( PADDLE_ENFORCE(
"the distribute training related op should be in [split_byref, " "the distribute training related op should be in [split_byref, "
......
...@@ -30,7 +30,7 @@ class SSAGraphBuilder { ...@@ -30,7 +30,7 @@ class SSAGraphBuilder {
SSAGraphBuilder() {} SSAGraphBuilder() {}
virtual ~SSAGraphBuilder() {} virtual ~SSAGraphBuilder() {}
virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0; virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0;
virtual int GetVarDeviceID(const std::string &var_name) const { return -1; } virtual int GetVarDeviceID(const std::string &var_name) const = 0;
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
#include "paddle/fluid/framework/details/ssa_graph_builder.h" #include "paddle/fluid/framework/details/ssa_graph_builder.h"
#include <string>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
...@@ -33,6 +35,10 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder { ...@@ -33,6 +35,10 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
return graph; return graph;
} }
int GetVarDeviceID(const std::string& var_name) const override {
return builder_->GetVarDeviceID(var_name);
}
bool IsValidGraph(const SSAGraph* graph) const; bool IsValidGraph(const SSAGraph* graph) const;
private: private:
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <iosfwd> #include <iosfwd>
#include <string>
#include "paddle/fluid/framework/details/ssa_graph_builder.h" #include "paddle/fluid/framework/details/ssa_graph_builder.h"
namespace paddle { namespace paddle {
...@@ -55,6 +56,10 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { ...@@ -55,6 +56,10 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
return graph; return graph;
} }
int GetVarDeviceID(const std::string& var_name) const override {
return builder_->GetVarDeviceID(var_name);
}
private: private:
std::unique_ptr<SSAGraphPrinter> printer_; std::unique_ptr<SSAGraphPrinter> printer_;
std::unique_ptr<SSAGraphBuilder> builder_; std::unique_ptr<SSAGraphBuilder> builder_;
......
...@@ -133,17 +133,18 @@ ParallelExecutor::ParallelExecutor( ...@@ -133,17 +133,18 @@ ParallelExecutor::ParallelExecutor(
void ParallelExecutor::BCastParamsToGPUs( void ParallelExecutor::BCastParamsToGPUs(
const std::unordered_set<std::string> &vars) const { const std::unordered_set<std::string> &vars) const {
// the the initialize bcast, all vars would be bcast from device(0), otherwise // the the initializing bcast, all vars would be bcast from device(0),
// otherwise
// bcast from the specified device. // bcast from the specified device.
bool initialize = builder_.get() == nullptr ? true : false; bool initializing = builder_.get() == nullptr ? true : false;
for (auto &var : vars) { for (auto &var : vars) {
int var_dev_id = int var_dev_id =
builder_.get() == nullptr ? -1 : builder_->GetVarDeviceID(var); builder_.get() == nullptr ? -1 : builder_->GetVarDeviceID(var);
if (!initialize && var_dev_id == -1) continue; if (!initializing && var_dev_id == -1) continue;
framework::Variable *main_var = nullptr; framework::Variable *main_var = nullptr;
if (initialize) { if (initializing) {
main_var = member_->local_scopes_[0]->FindVar(var); main_var = member_->local_scopes_[0]->FindVar(var);
} else { } else {
main_var = member_->local_scopes_[var_dev_id]->FindVar(var); main_var = member_->local_scopes_[var_dev_id]->FindVar(var);
...@@ -164,7 +165,8 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -164,7 +165,8 @@ void ParallelExecutor::BCastParamsToGPUs(
auto place = member_->places_[i]; auto place = member_->places_[i];
void *buffer; void *buffer;
if ((initialize && i == 0) || (!initialize && i == var_dev_id)) { if ((initializing && i == 0) ||
(!initializing && static_cast<int>(i) == var_dev_id)) {
buffer = const_cast<void *>(main_tensor.data<void>()); buffer = const_cast<void *>(main_tensor.data<void>());
} else { } else {
auto local_scope = member_->local_scopes_[i]; auto local_scope = member_->local_scopes_[i];
...@@ -181,8 +183,16 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -181,8 +183,16 @@ void ParallelExecutor::BCastParamsToGPUs(
platform::NCCLGroupGuard guard; platform::NCCLGroupGuard guard;
for (size_t i = 0; i < member_->places_.size(); ++i) { for (size_t i = 0; i < member_->places_.size(); ++i) {
auto &nccl_ctx = member_->nccl_ctxs_->at(member_->places_[i]); auto &nccl_ctx = member_->nccl_ctxs_->at(member_->places_[i]);
platform::dynload::ncclBcast(buffers[i], numel, data_type, 0, if (initializing) {
nccl_ctx.comm_, nccl_ctx.stream()); platform::dynload::ncclBcast(buffers[i], numel, data_type, 0,
nccl_ctx.comm_, nccl_ctx.stream());
} else {
if (var_dev_id >= 0) {
platform::dynload::ncclBcast(buffers[i], numel, data_type,
var_dev_id, nccl_ctx.comm_,
nccl_ctx.stream());
}
}
} }
member_->nccl_ctxs_->WaitAll(); member_->nccl_ctxs_->WaitAll();
} }
......
...@@ -302,7 +302,6 @@ class DistributeTranspiler(object): ...@@ -302,7 +302,6 @@ class DistributeTranspiler(object):
""" """
# remove optimize ops and add a send op to main_program # remove optimize ops and add a send op to main_program
delete_ops(self.origin_program.global_block(), self.optimize_ops) delete_ops(self.origin_program.global_block(), self.optimize_ops)
# FIXME(typhoonzero): serialize once will fix error occurs when clone.
self.origin_program.__str__() self.origin_program.__str__()
return self.origin_program return self.origin_program
...@@ -383,11 +382,12 @@ class DistributeTranspiler(object): ...@@ -383,11 +382,12 @@ class DistributeTranspiler(object):
if self._is_adam_connected_op(op): if self._is_adam_connected_op(op):
global_ops.append(op) global_ops.append(op)
def __append_optimize_op__(op, block, grad_to_block_id, merged_var): def __append_optimize_op__(op, block, grad_to_block_id, merged_var,
lr_ops):
if self._is_optimizer_op(op): if self._is_optimizer_op(op):
self._append_pserver_ops(block, op, endpoint, grad_to_block_id, self._append_pserver_ops(block, op, endpoint, grad_to_block_id,
self.origin_program, merged_var) self.origin_program, merged_var)
else: elif op not in lr_ops:
self._append_pserver_non_opt_ops(block, op) self._append_pserver_non_opt_ops(block, op)
def __op_have_grad_input__(op): def __op_have_grad_input__(op):
...@@ -447,7 +447,7 @@ class DistributeTranspiler(object): ...@@ -447,7 +447,7 @@ class DistributeTranspiler(object):
# optimizer is connected to itself # optimizer is connected to itself
if ufind.is_connected(op, opt_op) and op not in global_ops: if ufind.is_connected(op, opt_op) and op not in global_ops:
__append_optimize_op__(op, per_opt_block, grad_to_block_id, __append_optimize_op__(op, per_opt_block, grad_to_block_id,
merged_var) merged_var, lr_ops)
# append global ops # append global ops
if global_ops: if global_ops:
...@@ -455,7 +455,7 @@ class DistributeTranspiler(object): ...@@ -455,7 +455,7 @@ class DistributeTranspiler(object):
pserver_program.num_blocks - 1) pserver_program.num_blocks - 1)
for glb_op in global_ops: for glb_op in global_ops:
__append_optimize_op__(glb_op, opt_state_block, __append_optimize_op__(glb_op, opt_state_block,
grad_to_block_id, None) grad_to_block_id, None, lr_ops)
# process distributed lookup_table # process distributed lookup_table
prefetch_var_name_to_block_id = [] prefetch_var_name_to_block_id = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册