提交 3d69a82b 编写于 作者: Y yi.wu

fix dist train broadcasting bug

上级 75618e4b
...@@ -483,6 +483,9 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, ...@@ -483,6 +483,9 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
} }
} else if (op.Type() == "concat") { } else if (op.Type() == "concat") {
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]); op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
for (auto &varname : op.OutputArgumentNames()) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
} else { } else {
PADDLE_ENFORCE( PADDLE_ENFORCE(
"the distribute training related op should be in [split_byref, " "the distribute training related op should be in [split_byref, "
......
...@@ -30,7 +30,7 @@ class SSAGraphBuilder { ...@@ -30,7 +30,7 @@ class SSAGraphBuilder {
SSAGraphBuilder() {} SSAGraphBuilder() {}
virtual ~SSAGraphBuilder() {} virtual ~SSAGraphBuilder() {}
virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0; virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0;
virtual int GetVarDeviceID(const std::string &var_name) const { return -1; } virtual int GetVarDeviceID(const std::string &var_name) const = 0;
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
#include "paddle/fluid/framework/details/ssa_graph_builder.h" #include "paddle/fluid/framework/details/ssa_graph_builder.h"
#include <string>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
...@@ -33,6 +35,10 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder { ...@@ -33,6 +35,10 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
return graph; return graph;
} }
int GetVarDeviceID(const std::string& var_name) const override {
return builder_->GetVarDeviceID(var_name);
}
bool IsValidGraph(const SSAGraph* graph) const; bool IsValidGraph(const SSAGraph* graph) const;
private: private:
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <iosfwd> #include <iosfwd>
#include <string>
#include "paddle/fluid/framework/details/ssa_graph_builder.h" #include "paddle/fluid/framework/details/ssa_graph_builder.h"
namespace paddle { namespace paddle {
...@@ -55,6 +56,10 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { ...@@ -55,6 +56,10 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
return graph; return graph;
} }
int GetVarDeviceID(const std::string& var_name) const override {
return builder_->GetVarDeviceID(var_name);
}
private: private:
std::unique_ptr<SSAGraphPrinter> printer_; std::unique_ptr<SSAGraphPrinter> printer_;
std::unique_ptr<SSAGraphBuilder> builder_; std::unique_ptr<SSAGraphBuilder> builder_;
......
...@@ -133,17 +133,18 @@ ParallelExecutor::ParallelExecutor( ...@@ -133,17 +133,18 @@ ParallelExecutor::ParallelExecutor(
void ParallelExecutor::BCastParamsToGPUs( void ParallelExecutor::BCastParamsToGPUs(
const std::unordered_set<std::string> &vars) const { const std::unordered_set<std::string> &vars) const {
// the the initialize bcast, all vars would be bcast from device(0), otherwise // the the initializing bcast, all vars would be bcast from device(0),
// otherwise
// bcast from the specified device. // bcast from the specified device.
bool initialize = builder_.get() == nullptr ? true : false; bool initializing = builder_.get() == nullptr ? false : true;
for (auto &var : vars) { for (auto &var : vars) {
int var_dev_id = int var_dev_id =
builder_.get() == nullptr ? -1 : builder_->GetVarDeviceID(var); builder_.get() == nullptr ? -1 : builder_->GetVarDeviceID(var);
if (!initialize && var_dev_id == -1) continue; if (!initializing && var_dev_id == -1) continue;
framework::Variable *main_var = nullptr; framework::Variable *main_var = nullptr;
if (initialize) { if (initializing) {
main_var = member_->local_scopes_[0]->FindVar(var); main_var = member_->local_scopes_[0]->FindVar(var);
} else { } else {
main_var = member_->local_scopes_[var_dev_id]->FindVar(var); main_var = member_->local_scopes_[var_dev_id]->FindVar(var);
...@@ -164,7 +165,8 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -164,7 +165,8 @@ void ParallelExecutor::BCastParamsToGPUs(
auto place = member_->places_[i]; auto place = member_->places_[i];
void *buffer; void *buffer;
if ((initialize && i == 0) || (!initialize && i == var_dev_id)) { if ((initializing && i == 0) ||
(!initializing && i == static_cast<size_t>(var_dev_id))) {
buffer = const_cast<void *>(main_tensor.data<void>()); buffer = const_cast<void *>(main_tensor.data<void>());
} else { } else {
auto local_scope = member_->local_scopes_[i]; auto local_scope = member_->local_scopes_[i];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册