diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 3fcdc57906c214bdc8179c55b576e2e9e8d80973..69a38618cde7f100849a30e1cb1a2f4738e55e0d 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -123,7 +123,7 @@ class TestDistRunnerBase(object): pass_builder = build_stra._finalize_strategy_and_create_passes() mypass = pass_builder.insert_pass( len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass") - mypass.set_int("num_repeats", args.batch_merge_repeat) + mypass.set("num_repeats", args.batch_merge_repeat) if args.update_method == "nccl2": build_stra.num_trainers = len(args.endpoints.split(",")) diff --git a/python/paddle/fluid/tests/unittests/test_pass_builder.py b/python/paddle/fluid/tests/unittests/test_pass_builder.py index 8c9e489e02839e25cfabe14c16bfd91a908bd734..7e1c2572f08598b8b600517e4a82b48ca71cc20d 100644 --- a/python/paddle/fluid/tests/unittests/test_pass_builder.py +++ b/python/paddle/fluid/tests/unittests/test_pass_builder.py @@ -111,7 +111,7 @@ class TestPassBuilder(unittest.TestCase): pass_builder.remove_pass(len(pass_builder.all_passes()) - 1) self.assertEqual(origin_len + 1, len(pass_builder.all_passes())) - viz_pass.set_str("graph_viz_path", "/tmp/test_viz_pass") + viz_pass.set("graph_viz_path", "/tmp/test_viz_pass") self.check_network_convergence( use_cuda=core.is_compiled_with_cuda(),