提交 c4ded17e 编写于 作者: Q Qiao Longfei

async mode support dist train

上级 84367cf8
...@@ -133,10 +133,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -133,10 +133,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
void AppendMultiDevPass(const BuildStrategy &strategy) { void AppendMultiDevPass(const BuildStrategy &strategy) {
ir::Pass *multi_devices_pass; ir::Pass *multi_devices_pass;
if (strategy_.async_mode_) { if (strategy_.is_distribution_) {
multi_devices_pass = AppendPass("async_multi_devices_pass").get();
} else if (strategy_.is_distribution_) {
multi_devices_pass = AppendPass("dist_multi_devices_pass").get(); multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
} else if (strategy_.async_mode_) {
multi_devices_pass = AppendPass("async_multi_devices_pass").get();
} else { } else {
if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
multi_devices_pass = multi_devices_pass =
......
...@@ -756,6 +756,11 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result, ...@@ -756,6 +756,11 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
insert_op = true; insert_op = true;
need_broadcast_var_ = true; need_broadcast_var_ = true;
} else if (OpHaveRole(*node, OpRole::kDist)) { } else if (OpHaveRole(*node, OpRole::kDist)) {
// in async_mode, each graph will send it's own gradient, do not need to
// merge gradient.
if (strategy_.async_mode_ && node->Op()->Type() != "concat") {
return false;
}
int op_dev_id = CreateDistTrainOp(result, node); int op_dev_id = CreateDistTrainOp(result, node);
if (node->Op()->Type() == "concat") { if (node->Op()->Type() == "concat") {
// the input(block of parameter) of concat is on different device, // the input(block of parameter) of concat is on different device,
...@@ -827,7 +832,7 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const { ...@@ -827,7 +832,7 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const {
} }
auto recv_param_grad = boost::get<std::vector<std::string>>( auto recv_param_grad = boost::get<std::vector<std::string>>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
if (recv_param_grad.size() == 2U) { if (recv_param_grad.size() == 2U && !strategy_.async_mode_) {
op_dev_id = GetVarDeviceID(recv_param_grad[1]); op_dev_id = GetVarDeviceID(recv_param_grad[1]);
VLOG(10) << "recv param " << recv_param_grad[0] VLOG(10) << "recv param " << recv_param_grad[0]
<< " get grad place: " << recv_param_grad[1] << " get grad place: " << recv_param_grad[1]
......
...@@ -283,7 +283,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -283,7 +283,7 @@ ParallelExecutor::ParallelExecutor(
graphs.push_back(std::move(graph)); graphs.push_back(std::move(graph));
} }
#else #else
if (build_strategy.async_mode_) { if (build_strategy.async_mode_ && !build_strategy.is_distribution_) {
for (size_t i = 0; i < member_->places_.size(); ++i) { for (size_t i = 0; i < member_->places_.size(); ++i) {
std::unique_ptr<ir::Graph> graph = build_strategy.Apply( std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
main_program, {member_->places_[i]}, loss_var_name, main_program, {member_->places_[i]}, loss_var_name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册