未验证 提交 e506c99c 编写于 作者: C chengduo 提交者: GitHub

Open fuse broadcast option (#18833)

* fix vlog level and fuse option type
test=develop
上级 47f670d5
...@@ -126,6 +126,9 @@ void BroadcastOpHandle::BroadcastOneVar( ...@@ -126,6 +126,9 @@ void BroadcastOpHandle::BroadcastOneVar(
&VariableVisitor::GetMutableTensor(out_var)); &VariableVisitor::GetMutableTensor(out_var));
} }
}); });
for (auto &p : places_) {
nccl_ctxs_->DevCtx(p)->Wait();
}
#else #else
PADDLE_THROW("CUDA is not enabled."); PADDLE_THROW("CUDA is not enabled.");
#endif #endif
......
...@@ -278,12 +278,12 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -278,12 +278,12 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
#else #else
const bool use_cuda) const { const bool use_cuda) const {
#endif #endif
VLOG(3) << "apply all passes"; VLOG(1) << "apply all passes";
// Create a default one if not finalized by user. // Create a default one if not finalized by user.
CreatePassesFromStrategy(false); CreatePassesFromStrategy(false);
for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) { for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) {
VLOG(3) << "BuildStrategy::Apply pass:" << pass->Type(); VLOG(1) << "BuildStrategy::Apply pass:" << pass->Type();
if (IsMultiDevPass(pass->Type())) { if (IsMultiDevPass(pass->Type())) {
pass->Erase(kPlaces); pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places); pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
...@@ -349,11 +349,11 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -349,11 +349,11 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
continue; continue;
} }
} }
VLOG(3) << "Start Apply Pass " << pass->Type(); VLOG(1) << "Start Apply Pass " << pass->Type();
graph = pass->Apply(graph); graph = pass->Apply(graph);
VLOG(3) << "Finish Apply Pass " << pass->Type(); VLOG(1) << "Finish Apply Pass " << pass->Type();
} }
VLOG(3) << "All Passes Applied"; VLOG(1) << "All Passes Applied";
return graph; return graph;
} }
......
...@@ -98,7 +98,7 @@ struct BuildStrategy { ...@@ -98,7 +98,7 @@ struct BuildStrategy {
// faster. Because fusing broadcast OP equals delaying the execution of all // faster. Because fusing broadcast OP equals delaying the execution of all
// broadcast Ops, in this case, all nccl streams are used only for reduce // broadcast Ops, in this case, all nccl streams are used only for reduce
// operations for a period of time. // operations for a period of time.
bool fuse_broadcast_ops_{false}; bool fuse_broadcast_ops_{true};
// replace batch_norm with sync_batch_norm. // replace batch_norm with sync_batch_norm.
bool sync_batch_norm_{false}; bool sync_batch_norm_{false};
......
...@@ -21,7 +21,7 @@ namespace framework { ...@@ -21,7 +21,7 @@ namespace framework {
namespace ir { namespace ir {
std::shared_ptr<Pass> PassBuilder::AppendPass(const std::string& pass_type) { std::shared_ptr<Pass> PassBuilder::AppendPass(const std::string& pass_type) {
VLOG(3) << "Append " << pass_type; VLOG(1) << "Append " << pass_type;
auto pass = ir::PassRegistry::Instance().Get(pass_type); auto pass = ir::PassRegistry::Instance().Get(pass_type);
passes_.emplace_back(pass.release()); passes_.emplace_back(pass.release());
return passes_.back(); return passes_.back();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册