提交 5cf00928 编写于 作者: Q Qiao Longfei

add more log and fix test_dist_base in multi_batch_merge_pass

上级 5c36eb8b
...@@ -177,11 +177,13 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -177,11 +177,13 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
#else #else
const bool use_cuda) const { const bool use_cuda) const {
#endif #endif
VLOG(3) << "apply all passes";
// Create a default one if not finalized by user. // Create a default one if not finalized by user.
CreatePassesFromStrategy(false); CreatePassesFromStrategy(false);
std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program)); std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) { for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) {
VLOG(3) << "apply " << pass->Type();
if (IsMultiDevPass(pass->Type())) { if (IsMultiDevPass(pass->Type())) {
pass->Erase(kPlaces); pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places); pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
......
...@@ -19,6 +19,7 @@ namespace paddle { ...@@ -19,6 +19,7 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
std::unique_ptr<Graph> Pass::Apply(std::unique_ptr<Graph> graph) const { std::unique_ptr<Graph> Pass::Apply(std::unique_ptr<Graph> graph) const {
VLOG(3) << "apply pass -> " << Type();
PADDLE_ENFORCE(graph.get(), "graph passed to Pass::Apply() cannot be empty."); PADDLE_ENFORCE(graph.get(), "graph passed to Pass::Apply() cannot be empty.");
for (const std::string& attr : required_pass_attrs_) { for (const std::string& attr : required_pass_attrs_) {
PADDLE_ENFORCE(attrs_.find(attr) != attrs_.end(), PADDLE_ENFORCE(attrs_.find(attr) != attrs_.end(),
......
...@@ -128,8 +128,7 @@ class TestDistRunnerBase(object): ...@@ -128,8 +128,7 @@ class TestDistRunnerBase(object):
if args.batch_merge_repeat > 1: if args.batch_merge_repeat > 1:
pass_builder = build_stra._finalize_strategy_and_create_passes() pass_builder = build_stra._finalize_strategy_and_create_passes()
mypass = pass_builder.insert_pass( mypass = pass_builder.insert_pass(0, "multi_batch_merge_pass")
len(pass_builder.all_passes()) - 3, "multi_batch_merge_pass")
mypass.set("num_repeats", args.batch_merge_repeat) mypass.set("num_repeats", args.batch_merge_repeat)
if args.update_method == "nccl2": if args.update_method == "nccl2":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册