提交 c4ded17e 编写于 作者: Q Qiao Longfei

async mode support dist train

上级 84367cf8
......@@ -133,10 +133,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
void AppendMultiDevPass(const BuildStrategy &strategy) {
ir::Pass *multi_devices_pass;
if (strategy_.async_mode_) {
multi_devices_pass = AppendPass("async_multi_devices_pass").get();
} else if (strategy_.is_distribution_) {
if (strategy_.is_distribution_) {
multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
} else if (strategy_.async_mode_) {
multi_devices_pass = AppendPass("async_multi_devices_pass").get();
} else {
if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
multi_devices_pass =
......
......@@ -756,6 +756,11 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
insert_op = true;
need_broadcast_var_ = true;
} else if (OpHaveRole(*node, OpRole::kDist)) {
// in async_mode, each graph will send it's own gradient, do not need to
// merge gradient.
if (strategy_.async_mode_ && node->Op()->Type() != "concat") {
return false;
}
int op_dev_id = CreateDistTrainOp(result, node);
if (node->Op()->Type() == "concat") {
// the input(block of parameter) of concat is on different device,
......@@ -827,7 +832,7 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const {
}
auto recv_param_grad = boost::get<std::vector<std::string>>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
if (recv_param_grad.size() == 2U) {
if (recv_param_grad.size() == 2U && !strategy_.async_mode_) {
op_dev_id = GetVarDeviceID(recv_param_grad[1]);
VLOG(10) << "recv param " << recv_param_grad[0]
<< " get grad place: " << recv_param_grad[1]
......
......@@ -283,7 +283,7 @@ ParallelExecutor::ParallelExecutor(
graphs.push_back(std::move(graph));
}
#else
if (build_strategy.async_mode_) {
if (build_strategy.async_mode_ && !build_strategy.is_distribution_) {
for (size_t i = 0; i < member_->places_.size(); ++i) {
std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
main_program, {member_->places_[i]}, loss_var_name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册