From bac08c4a263eeda61cc2f1bcf20d005f51f542ef Mon Sep 17 00:00:00 2001 From: WangZhen Date: Tue, 22 Jan 2019 18:26:00 +0800 Subject: [PATCH] Fix some bugs caused by set functions of the Pass class. test=develop --- python/paddle/fluid/tests/unittests/test_dist_base.py | 2 +- python/paddle/fluid/tests/unittests/test_pass_builder.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 3fcdc5790..69a38618c 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -123,7 +123,7 @@ class TestDistRunnerBase(object): pass_builder = build_stra._finalize_strategy_and_create_passes() mypass = pass_builder.insert_pass( len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass") - mypass.set_int("num_repeats", args.batch_merge_repeat) + mypass.set("num_repeats", args.batch_merge_repeat) if args.update_method == "nccl2": build_stra.num_trainers = len(args.endpoints.split(",")) diff --git a/python/paddle/fluid/tests/unittests/test_pass_builder.py b/python/paddle/fluid/tests/unittests/test_pass_builder.py index 8c9e489e0..7e1c2572f 100644 --- a/python/paddle/fluid/tests/unittests/test_pass_builder.py +++ b/python/paddle/fluid/tests/unittests/test_pass_builder.py @@ -111,7 +111,7 @@ class TestPassBuilder(unittest.TestCase): pass_builder.remove_pass(len(pass_builder.all_passes()) - 1) self.assertEqual(origin_len + 1, len(pass_builder.all_passes())) - viz_pass.set_str("graph_viz_path", "/tmp/test_viz_pass") + viz_pass.set("graph_viz_path", "/tmp/test_viz_pass") self.check_network_convergence( use_cuda=core.is_compiled_with_cuda(), -- GitLab