提交 62f1248f 编写于 作者: Q Qiao Longfei

fix use gpu test=develop

上级 45b19cbc
...@@ -731,6 +731,7 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result, ...@@ -731,6 +731,7 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
} }
} }
insert_op = true; insert_op = true;
need_broadcast_var_ = true;
} else if (OpHaveRole(*node, OpRole::kDist)) { } else if (OpHaveRole(*node, OpRole::kDist)) {
int op_dev_id = CreateDistTrainOp(result, node); int op_dev_id = CreateDistTrainOp(result, node);
if (node->Op()->Type() == "concat") { if (node->Op()->Type() == "concat") {
...@@ -925,7 +926,9 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, ...@@ -925,7 +926,9 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
void DistSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const { void DistSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const {
// only GPU reduce mode need to broadcast parameters to each device. // only GPU reduce mode need to broadcast parameters to each device.
if (UseGPU() && strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) { if (UseGPU()) {
if (need_broadcast_var_ ||
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
if (strategy_.fuse_broadcast_op_) { if (strategy_.fuse_broadcast_op_) {
CreateFusedBroadcastOp(result, bcast_var_name_set_); CreateFusedBroadcastOp(result, bcast_var_name_set_);
} else { } else {
...@@ -937,6 +940,7 @@ void DistSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const { ...@@ -937,6 +940,7 @@ void DistSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const {
} }
} }
} }
}
} }
std::unordered_set<std::string> &MultiDevSSAGraphBuilder() { std::unordered_set<std::string> &MultiDevSSAGraphBuilder() {
......
...@@ -174,6 +174,7 @@ class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder { ...@@ -174,6 +174,7 @@ class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
int CreateDistTrainOp(ir::Graph *result, ir::Node *node) const; int CreateDistTrainOp(ir::Graph *result, ir::Node *node) const;
mutable std::vector<std::unordered_set<std::string>> bcast_var_name_set_; mutable std::vector<std::unordered_set<std::string>> bcast_var_name_set_;
mutable bool need_broadcast_var_{false};
}; };
std::unordered_set<std::string> &MultiDevSSAGraphBuilder(); std::unordered_set<std::string> &MultiDevSSAGraphBuilder();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册