提交 84367cf8 编写于 作者: Q Qiao Longfei

support async mode in dist mode parallel executor

上级 e72637dd
......@@ -167,6 +167,10 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
bool is_forwarding = true;
bool insert_collection_ops = NeedCollectiveOps();
if (strategy_.async_mode_) {
// async mode did not need to merge gradient
insert_collection_ops = false;
}
for (ir::Node *node : sorted_ops) {
if (DealWithSpecialOp(&result, node)) {
......@@ -192,8 +196,22 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward));
// optimize op is already processed in DealWithSpecialOp,
// here we only consider backward op
if (!is_bk_op) continue;
/*
* the op that will generate the gradient of on parameter will have
one attr op_role_var
* to record the parameter and gradient, like:
attrs {
name: "op_role_var"
type: STRINGS
strings: "fc_1.b_0"
strings: "fc_1.b_0@GRAD"
}
*/
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
auto backward_vars =
......@@ -204,7 +222,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
for (size_t i = 0; i < backward_vars.size(); i += 2) {
auto &p_name = backward_vars[i];
auto &g_name = backward_vars[i + 1];
VLOG(10) << "Bcast " << g_name << " for parameter " << p_name;
VLOG(3) << "Bcast " << g_name << " for parameter " << p_name;
InsertCollectiveOp(&result, p_name, g_name);
}
......@@ -385,7 +403,7 @@ void MultiDevSSAGraphBuilderBase::CreateFusedBroadcastOp(
void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result,
ir::Node *node,
int dev_id) const {
size_t dev_id) const {
result->Get<GraphOps>(kGraphOps).emplace_back(
new ComputationOpHandle(result->CreateOpNode(node->Op()),
local_scopes_[dev_id], places_[dev_id], dev_id));
......@@ -454,9 +472,8 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOps(
}
}
VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp(ir::Graph *result,
const std::string &og,
int dst_dev_id) const {
VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp(
ir::Graph *result, const std::string &og, size_t dst_dev_id) const {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result->Get<GraphOps>(kGraphOps).emplace_back(new ReduceOpHandle(
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
......@@ -720,6 +737,10 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
ir::Node *node) const {
bool insert_op = false;
if (OpHaveRole(*node, OpRole::kRPC)) {
// in async_mode, each graph will send it's own gradient.
if (strategy_.async_mode_ && node->Op()->Type() == "send") {
return false;
}
int op_dev_id = CreateRPCOp(result, node);
PADDLE_ENFORCE(op_dev_id != -1,
"Can not schedule the RPC operator to the right place.");
......@@ -737,6 +758,8 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
} else if (OpHaveRole(*node, OpRole::kDist)) {
int op_dev_id = CreateDistTrainOp(result, node);
if (node->Op()->Type() == "concat") {
// the input(block of parameter) of concat is on different device,
// the output(parameter) will on one device.
auto origin_param_name = node->Op()->OutputArgumentNames()[0];
bcast_var_name_set_[op_dev_id].emplace(origin_param_name);
}
......@@ -744,6 +767,7 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
} else {
int op_dev_id = GetOpDeviceID(node);
if (op_dev_id != -1) { // This op only runs on one specific device.
// optimize op will be processed here.
CreateComputationalOp(result, node, op_dev_id);
for (ir::Node *n : node->outputs) {
sharded_var_device_.emplace(n->Name(), op_dev_id);
......@@ -905,6 +929,7 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
const std::string &p_name,
const std::string &g_name) const {
// collective gradient to each device
size_t cur_device_id = 0;
switch (strategy_.reduce_) {
case BuildStrategy::ReduceStrategy::kReduce:
......
......@@ -68,10 +68,10 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
proto::VarType::Type dtype) const;
VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
int dst_dev_id) const;
size_t dst_dev_id) const;
void CreateComputationalOp(ir::Graph *result, ir::Node *node,
int dev_id) const;
size_t dev_id) const;
bool IsSparseGradient(const std::string &og) const;
......@@ -118,16 +118,16 @@ class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
class AsyncSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
protected:
virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
const std::string &g_name) const {}
void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
const std::string &g_name) const override {}
bool NeedCollectiveOps() const override { return false; }
virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const {
bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const override {
return false;
}
virtual void InsertPostprocessOps(ir::Graph *result) const {}
void InsertPostprocessOps(ir::Graph *result) const override {}
};
class BalanceVarSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册