提交 bac08c4a 编写于 作者: W WangZhen

Fix some bugs caused by set functions of the Pass class. test=develop

上级 4e91d8d2
...@@ -123,7 +123,7 @@ class TestDistRunnerBase(object): ...@@ -123,7 +123,7 @@ class TestDistRunnerBase(object):
pass_builder = build_stra._finalize_strategy_and_create_passes() pass_builder = build_stra._finalize_strategy_and_create_passes()
mypass = pass_builder.insert_pass( mypass = pass_builder.insert_pass(
len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass") len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass")
mypass.set_int("num_repeats", args.batch_merge_repeat) mypass.set("num_repeats", args.batch_merge_repeat)
if args.update_method == "nccl2": if args.update_method == "nccl2":
build_stra.num_trainers = len(args.endpoints.split(",")) build_stra.num_trainers = len(args.endpoints.split(","))
......
...@@ -111,7 +111,7 @@ class TestPassBuilder(unittest.TestCase): ...@@ -111,7 +111,7 @@ class TestPassBuilder(unittest.TestCase):
pass_builder.remove_pass(len(pass_builder.all_passes()) - 1) pass_builder.remove_pass(len(pass_builder.all_passes()) - 1)
self.assertEqual(origin_len + 1, len(pass_builder.all_passes())) self.assertEqual(origin_len + 1, len(pass_builder.all_passes()))
viz_pass.set_str("graph_viz_path", "/tmp/test_viz_pass") viz_pass.set("graph_viz_path", "/tmp/test_viz_pass")
self.check_network_convergence( self.check_network_convergence(
use_cuda=core.is_compiled_with_cuda(), use_cuda=core.is_compiled_with_cuda(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册