提交 bac08c4a 编写于 作者: W WangZhen

Fix some bugs caused by set functions of the Pass class. test=develop

上级 4e91d8d2
......@@ -123,7 +123,7 @@ class TestDistRunnerBase(object):
pass_builder = build_stra._finalize_strategy_and_create_passes()
mypass = pass_builder.insert_pass(
len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass")
mypass.set_int("num_repeats", args.batch_merge_repeat)
mypass.set("num_repeats", args.batch_merge_repeat)
if args.update_method == "nccl2":
build_stra.num_trainers = len(args.endpoints.split(","))
......
......@@ -111,7 +111,7 @@ class TestPassBuilder(unittest.TestCase):
pass_builder.remove_pass(len(pass_builder.all_passes()) - 1)
self.assertEqual(origin_len + 1, len(pass_builder.all_passes()))
viz_pass.set_str("graph_viz_path", "/tmp/test_viz_pass")
viz_pass.set("graph_viz_path", "/tmp/test_viz_pass")
self.check_network_convergence(
use_cuda=core.is_compiled_with_cuda(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册