提交 3d69a82b 编写于 作者: Y yi.wu

fix dist train broadcasting bug

上级 75618e4b
......@@ -483,6 +483,9 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
}
} else if (op.Type() == "concat") {
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
for (auto &varname : op.OutputArgumentNames()) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
} else {
PADDLE_ENFORCE(
"the distribute training related op should be in [split_byref, "
......
......@@ -30,7 +30,7 @@ class SSAGraphBuilder {
SSAGraphBuilder() {}
virtual ~SSAGraphBuilder() {}
virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0;
virtual int GetVarDeviceID(const std::string &var_name) const { return -1; }
virtual int GetVarDeviceID(const std::string &var_name) const = 0;
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
......
......@@ -16,6 +16,8 @@
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
#include <string>
namespace paddle {
namespace framework {
namespace details {
......@@ -33,6 +35,10 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
return graph;
}
int GetVarDeviceID(const std::string& var_name) const override {
return builder_->GetVarDeviceID(var_name);
}
bool IsValidGraph(const SSAGraph* graph) const;
private:
......
......@@ -15,6 +15,7 @@
#pragma once
#include <iosfwd>
#include <string>
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
namespace paddle {
......@@ -55,6 +56,10 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
return graph;
}
int GetVarDeviceID(const std::string& var_name) const override {
return builder_->GetVarDeviceID(var_name);
}
private:
std::unique_ptr<SSAGraphPrinter> printer_;
std::unique_ptr<SSAGraphBuilder> builder_;
......
......@@ -133,17 +133,18 @@ ParallelExecutor::ParallelExecutor(
void ParallelExecutor::BCastParamsToGPUs(
const std::unordered_set<std::string> &vars) const {
// the the initialize bcast, all vars would be bcast from device(0), otherwise
// the the initializing bcast, all vars would be bcast from device(0),
// otherwise
// bcast from the specified device.
bool initialize = builder_.get() == nullptr ? true : false;
bool initializing = builder_.get() == nullptr ? false : true;
for (auto &var : vars) {
int var_dev_id =
builder_.get() == nullptr ? -1 : builder_->GetVarDeviceID(var);
if (!initialize && var_dev_id == -1) continue;
if (!initializing && var_dev_id == -1) continue;
framework::Variable *main_var = nullptr;
if (initialize) {
if (initializing) {
main_var = member_->local_scopes_[0]->FindVar(var);
} else {
main_var = member_->local_scopes_[var_dev_id]->FindVar(var);
......@@ -164,7 +165,8 @@ void ParallelExecutor::BCastParamsToGPUs(
auto place = member_->places_[i];
void *buffer;
if ((initialize && i == 0) || (!initialize && i == var_dev_id)) {
if ((initializing && i == 0) ||
(!initializing && i == static_cast<size_t>(var_dev_id))) {
buffer = const_cast<void *>(main_tensor.data<void>());
} else {
auto local_scope = member_->local_scopes_[i];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册