未验证 提交 9cc1eb43 编写于 作者: Y Yancey 提交者: GitHub

Merge pull request #11221 from Yancey1989/overlap_memcpy_with_dist

overlap rpc op memcpy in distributed training
...@@ -57,6 +57,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( ...@@ -57,6 +57,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
for (auto &p : params) { for (auto &p : params) {
grad_names_.insert(GradVarName(p)); grad_names_.insert(GradVarName(p));
} }
balance_vars_.resize(places_.size(), 0);
} }
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
...@@ -140,11 +141,30 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp( ...@@ -140,11 +141,30 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
checker(op.InputArgumentNames(), recv_vars); checker(op.InputArgumentNames(), recv_vars);
} }
size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
const std::vector<std::string> &var_names) const {
int64_t numel_sum = 0;
for (auto var_name : var_names) {
auto var_desc = all_vars_.at(var_name);
PADDLE_ENFORCE_NOT_NULL(var_desc);
auto dim = framework::make_ddim(var_desc->GetShape());
int64_t numel = framework::product(dim);
PADDLE_ENFORCE_GT(numel, 0);
numel_sum += numel;
}
auto smallest =
std::min_element(std::begin(balance_vars_), std::end(balance_vars_));
size_t dev_id =
static_cast<size_t>(std::distance(std::begin(balance_vars_), smallest));
balance_vars_[dev_id] += numel_sum;
return dev_id;
}
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const { const ProgramDesc &program) const {
std::unordered_map<std::string, VarDesc *> all_vars;
for (auto *var : program.Block(0).AllVars()) { for (auto *var : program.Block(0).AllVars()) {
all_vars[var->Name()] = var; all_vars_.emplace(var->Name(), var);
} }
auto graph = new SSAGraph(); auto graph = new SSAGraph();
...@@ -161,35 +181,16 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -161,35 +181,16 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
auto send_vars = FindDistTrainSendVars(program); auto send_vars = FindDistTrainSendVars(program);
auto recv_vars = FindDistTrainRecvVars(program); auto recv_vars = FindDistTrainRecvVars(program);
std::vector<std::unordered_set<std::string>> var_name_on_devices;
std::vector<std::unordered_set<std::string>> bcast_var_name_set; std::vector<std::unordered_set<std::string>> bcast_var_name_set;
var_name_on_devices.resize(places_.size());
bcast_var_name_set.resize(places_.size()); bcast_var_name_set.resize(places_.size());
size_t cur_device_id = 0; size_t cur_device_id = 0;
std::vector<int64_t> balance_grads(places_.size(), 0);
auto get_appropriate_dev = [&](std::string &g_name) -> size_t {
auto var_desc = all_vars.at(g_name);
PADDLE_ENFORCE_NOT_NULL(var_desc);
auto dim = framework::make_ddim(var_desc->GetShape());
int64_t numel = framework::product(dim);
PADDLE_ENFORCE_GE(numel, 0);
auto smallest =
std::min_element(std::begin(balance_grads), std::end(balance_grads));
size_t dev_id =
static_cast<size_t>(std::distance(std::begin(balance_grads), smallest));
balance_grads[dev_id] += numel;
return dev_id;
};
bool is_forwarding = true; bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) { for (auto *op : program.Block(0).AllOps()) {
if (boost::get<int>( if (boost::get<int>(
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kRPC)) { static_cast<int>(OpRole::kRPC)) {
// append rpc op if program is distributed trainer main program.
// always use the first device
CreateRPCOp(&result, *op); CreateRPCOp(&result, *op);
} else if (IsDistTrainOp(*op, send_vars, recv_vars)) { } else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
CreateDistTrainOp(&result, *op); CreateDistTrainOp(&result, *op);
...@@ -201,13 +202,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -201,13 +202,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
} }
is_forwarding = false; is_forwarding = false;
} else { } else {
int op_dev_id = GetOpDeviceID(var_name_on_devices, *op); int op_dev_id = GetOpDeviceID(*op);
if (op_dev_id == -1) { // var on all device if (op_dev_id == -1) { // var on all device
CreateComputationalOps(&result, *op, places_.size()); CreateComputationalOps(&result, *op, places_.size());
} else { } else {
CreateComputationalOp(&result, *op, op_dev_id); CreateComputationalOp(&result, *op, op_dev_id);
for (auto &var_name : op->OutputArgumentNames()) { for (auto &var_name : op->OutputArgumentNames()) {
var_name_on_devices[op_dev_id].emplace(var_name); var_name_on_devices_.emplace(var_name, op_dev_id);
} }
} }
if (!is_forwarding && places_.size() > 1) { if (!is_forwarding && places_.size() > 1) {
...@@ -230,13 +231,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -230,13 +231,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
switch (strategy_.reduce_) { switch (strategy_.reduce_) {
case BuildStrategy::ReduceStrategy::kReduce: case BuildStrategy::ReduceStrategy::kReduce:
cur_device_id = get_appropriate_dev(g_name); cur_device_id = GetAppropriateDeviceID({g_name});
CreateReduceOp(&result, g_name, cur_device_id); CreateReduceOp(&result, g_name, cur_device_id);
var_name_on_devices[cur_device_id].emplace(g_name); var_name_on_devices_.emplace(g_name, cur_device_id);
bcast_var_name_set[cur_device_id].emplace(p_name); bcast_var_name_set[cur_device_id].emplace(p_name);
break; break;
case BuildStrategy::ReduceStrategy::kAllReduce: case BuildStrategy::ReduceStrategy::kAllReduce:
if (IsSparseGradient(all_vars, g_name)) { if (IsSparseGradient(g_name)) {
CreateReduceOp(&result, g_name, 0); CreateReduceOp(&result, g_name, 0);
CreateBroadcastOp(&result, g_name, 0); CreateBroadcastOp(&result, g_name, 0);
} else { } else {
...@@ -273,11 +274,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -273,11 +274,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
return std::unique_ptr<SSAGraph>(graph); return std::unique_ptr<SSAGraph>(graph);
} }
bool MultiDevSSAGraphBuilder::IsSparseGradient( bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
const std::unordered_map<std::string, VarDesc *> &all_vars, PADDLE_ENFORCE(all_vars_.count(og) != 0);
const std::string &og) const { if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
PADDLE_ENFORCE(all_vars.count(og) != 0);
if (all_vars.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
return true; return true;
} }
return false; return false;
...@@ -363,24 +362,23 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce( ...@@ -363,24 +362,23 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
return is_pg_once; return is_pg_once;
} }
int MultiDevSSAGraphBuilder::GetOpDeviceID( int MultiDevSSAGraphBuilder::GetOpDeviceID(const OpDesc &op) const {
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
const OpDesc &op) const {
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) { if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
return -1; return -1;
} }
int var_dev_id = -1; for (auto &varname : op.InputArgumentNames()) {
for (auto &var_name : op.InputArgumentNames()) { int dev_id = GetVarDeviceID(varname);
if (var_dev_id != -1) break; if (dev_id != -1) {
for (size_t i = 0; i < var_name_on_devices.size(); ++i) { return dev_id;
if (var_name_on_devices[i].count(var_name)) {
var_dev_id = static_cast<int>(i);
break;
}
} }
} }
return var_dev_id; return -1;
}
int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const {
auto got = var_name_on_devices_.find(varname);
return got == var_name_on_devices_.end() ? -1 : got->second;
} }
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
...@@ -463,7 +461,30 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op, ...@@ -463,7 +461,30 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
const OpDesc &op) const { const OpDesc &op) const {
CreateComputationalOp(result, op, 0); int op_dev_id = -1;
if (op.Type() == "split_byref") {
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames());
for (auto &varname : op.InputArgumentNames()) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
}
for (auto &varname : op.OutputArgumentNames()) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
} else if (op.Type() == "concat") {
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
} else {
PADDLE_ENFORCE(
"the distribute training related op should be in [split_byref, "
"concat].");
}
PADDLE_ENFORCE(op_dev_id != -1,
"can not find right place for distributed op: %s", op.Type());
CreateComputationalOp(result, op, op_dev_id);
if (op.Type() == "concat") { if (op.Type() == "concat") {
ConnectOp(result, result->ops_.back().get(), "fetch_barrier"); ConnectOp(result, result->ops_.back().get(), "fetch_barrier");
} }
...@@ -471,8 +492,34 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, ...@@ -471,8 +492,34 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
const OpDesc &op) const { const OpDesc &op) const {
result->ops_.emplace_back( int op_dev_id = -1;
new RPCOpHandle(op, local_scopes_[0], op.Type(), places_[0])); if (op.Type() == "send") {
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
// the variable name which contains .block means it was splited by
// split_byref op
// so that we can balance the variable blocks to all the pserver instances.
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce &&
op.InputArgumentNames()[0].find(".block") == std::string::npos) {
op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames());
for (auto &varname : op.InputArgumentNames()) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
}
} else if (op.Type() == "recv") {
op_dev_id = GetAppropriateDeviceID(op.OutputArgumentNames());
for (auto &varname : op.OutputArgumentNames()) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
} else {
// send_barrier and fetch_barrier op can be scheduled on device 0
op_dev_id = 0;
}
PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
op.Type());
result->ops_.emplace_back(new RPCOpHandle(op, local_scopes_[op_dev_id],
op.Type(), places_[op_dev_id]));
if (op.Type() == "send_barrier") { if (op.Type() == "send_barrier") {
ConnectOp(result, result->ops_.back().get(), "send"); ConnectOp(result, result->ops_.back().get(), "send");
...@@ -488,9 +535,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, ...@@ -488,9 +535,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
"send, send_barrier. recv, fetch_barrier]"); "send, send_barrier. recv, fetch_barrier]");
} }
// TODO(Yancey1989): schedule rpc op on different place may CreateOpHandleIOs(result, op, op_dev_id);
// increate throughput
CreateOpHandleIOs(result, op, 0);
} }
bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const { bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
......
...@@ -47,10 +47,11 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -47,10 +47,11 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
#endif #endif
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override; std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
int GetVarDeviceID(const std::string &varname) const;
private: private:
void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op, void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op,
size_t place_id) const; size_t device_id) const;
private: private:
std::string loss_var_name_; std::string loss_var_name_;
...@@ -96,21 +97,23 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -96,21 +97,23 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::string &og, const std::string &og,
std::unordered_set<std::string> *og_has_been_broadcast) const; std::unordered_set<std::string> *og_has_been_broadcast) const;
int GetOpDeviceID( int GetOpDeviceID(const OpDesc &op) const;
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
const OpDesc &op) const;
void InsertAllReduceOp(SSAGraph *result, const std::string &og) const; void InsertAllReduceOp(SSAGraph *result, const std::string &og) const;
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name, void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,
size_t src_dev_id) const; size_t src_dev_id) const;
bool IsSparseGradient( bool IsSparseGradient(const std::string &og) const;
const std::unordered_map<std::string, VarDesc *> &all_vars,
const std::string &og) const; size_t GetAppropriateDeviceID(
const std::vector<std::string> &var_names) const;
private: private:
BuildStrategy strategy_; BuildStrategy strategy_;
mutable std::unordered_map<std::string, VarDesc *> all_vars_;
mutable std::unordered_map<std::string, int> var_name_on_devices_;
mutable std::vector<int64_t> balance_vars_;
void SetCommunicationContext(OpHandleBase *op_handle, void SetCommunicationContext(OpHandleBase *op_handle,
const platform::Place &p) const; const platform::Place &p) const;
......
...@@ -30,6 +30,7 @@ class SSAGraphBuilder { ...@@ -30,6 +30,7 @@ class SSAGraphBuilder {
SSAGraphBuilder() {} SSAGraphBuilder() {}
virtual ~SSAGraphBuilder() {} virtual ~SSAGraphBuilder() {}
virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0; virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0;
virtual int GetVarDeviceID(const std::string &var_name) const { return -1; }
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
......
...@@ -110,7 +110,6 @@ ParallelExecutor::ParallelExecutor( ...@@ -110,7 +110,6 @@ ParallelExecutor::ParallelExecutor(
// Step 3. Convert main_program to SSA form and dependency graph. Also, insert // Step 3. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp // ncclOp
details::SSAGraphBuilderFactory builder_factory( details::SSAGraphBuilderFactory builder_factory(
member_->places_, loss_var_name, params, member_->local_scopes_, member_->places_, loss_var_name, params, member_->local_scopes_,
build_strategy); build_strategy);
...@@ -122,9 +121,10 @@ ParallelExecutor::ParallelExecutor( ...@@ -122,9 +121,10 @@ ParallelExecutor::ParallelExecutor(
#endif #endif
} }
builder_ = std::move(builder_factory.Create());
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, exec_strategy, member_->local_scopes_, places,
builder_factory.Create()->Build(main_program))); builder_->Build(main_program)));
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, std::move(var_infos), exec_strategy, member_->local_scopes_, std::move(var_infos),
...@@ -133,10 +133,22 @@ ParallelExecutor::ParallelExecutor( ...@@ -133,10 +133,22 @@ ParallelExecutor::ParallelExecutor(
void ParallelExecutor::BCastParamsToGPUs( void ParallelExecutor::BCastParamsToGPUs(
const std::unordered_set<std::string> &vars) const { const std::unordered_set<std::string> &vars) const {
auto *main_scope = member_->local_scopes_[0]; // the the initialize bcast, all vars would be bcast from device(0), otherwise
// bcast from the specified device.
bool initialize = builder_.get() == nullptr ? true : false;
for (auto &var : vars) { for (auto &var : vars) {
auto *main_var = main_scope->FindVar(var); int var_dev_id =
builder_.get() == nullptr ? -1 : builder_->GetVarDeviceID(var);
if (!initialize && var_dev_id == -1) continue;
framework::Variable *main_var = nullptr;
if (initialize) {
main_var = member_->local_scopes_[0]->FindVar(var);
} else {
main_var = member_->local_scopes_[var_dev_id]->FindVar(var);
}
if (main_var == nullptr || !main_var->IsType<LoDTensor>()) { if (main_var == nullptr || !main_var->IsType<LoDTensor>()) {
continue; continue;
} }
...@@ -151,7 +163,8 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -151,7 +163,8 @@ void ParallelExecutor::BCastParamsToGPUs(
for (size_t i = 0; i < member_->places_.size(); ++i) { for (size_t i = 0; i < member_->places_.size(); ++i) {
auto place = member_->places_[i]; auto place = member_->places_[i];
void *buffer; void *buffer;
if (i == 0) {
if ((initialize && i == 0) || (!initialize && i == var_dev_id)) {
buffer = const_cast<void *>(main_tensor.data<void>()); buffer = const_cast<void *>(main_tensor.data<void>());
} else { } else {
auto local_scope = member_->local_scopes_[i]; auto local_scope = member_->local_scopes_[i];
......
...@@ -19,12 +19,14 @@ limitations under the License. */ ...@@ -19,12 +19,14 @@ limitations under the License. */
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/execution_strategy.h" #include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -68,6 +70,7 @@ class ParallelExecutor { ...@@ -68,6 +70,7 @@ class ParallelExecutor {
private: private:
ParallelExecutorPrivate *member_; ParallelExecutorPrivate *member_;
std::unique_ptr<details::SSAGraphBuilder> builder_;
}; };
} // namespace framework } // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册