From 274477005e929f82f6d4942605ba823a197e575a Mon Sep 17 00:00:00 2001 From: xjqbest <173596896@qq.com> Date: Thu, 4 Apr 2019 10:07:46 +0800 Subject: [PATCH] fix dataset testcase error test=develop --- .../fluid/tests/unittests/test_dataset.py | 24 +++++++++---------- python/paddle/fluid/trainer_desc.py | 2 +- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_dataset.py b/python/paddle/fluid/tests/unittests/test_dataset.py index 9c557097a..4cfd99150 100644 --- a/python/paddle/fluid/tests/unittests/test_dataset.py +++ b/python/paddle/fluid/tests/unittests/test_dataset.py @@ -107,12 +107,12 @@ class TestDataset(unittest.TestCase): exe = fluid.Executor(fluid.CPUPlace()) exe.run(fluid.default_startup_program()) for i in range(2): - #try: - exe.train_from_dataset(fluid.default_main_program(), dataset) - #except ImportError as e: - # pass - #except Exception as e: - # self.assertTrue(False) + try: + exe.train_from_dataset(fluid.default_main_program(), dataset) + except ImportError as e: + pass + except Exception as e: + self.assertTrue(False) os.remove("./test_in_memory_dataset_run_a.txt") os.remove("./test_in_memory_dataset_run_b.txt") @@ -151,12 +151,12 @@ class TestDataset(unittest.TestCase): exe = fluid.Executor(fluid.CPUPlace()) exe.run(fluid.default_startup_program()) for i in range(2): - #try: - exe.train_from_dataset(fluid.default_main_program(), dataset) - #except ImportError as e: - # pass - #except Exception as e: - # self.assertTrue(False) + try: + exe.train_from_dataset(fluid.default_main_program(), dataset) + except ImportError as e: + pass + except Exception as e: + self.assertTrue(False) os.remove("./test_queue_dataset_run_a.txt") os.remove("./test_queue_dataset_run_b.txt") diff --git a/python/paddle/fluid/trainer_desc.py b/python/paddle/fluid/trainer_desc.py index b91f1d1f3..380c404fb 100644 --- a/python/paddle/fluid/trainer_desc.py +++ b/python/paddle/fluid/trainer_desc.py @@ -23,7 +23,7 @@ class TrainerDesc(object): with open(proto_file, 'r') as f: text_format.Parse(f.read(), self.proto_desc) ''' - from .proto import trainer_desc_pb2 + from proto import trainer_desc_pb2 self.proto_desc = trainer_desc_pb2.TrainerDesc() import multiprocessing as mp # set default thread num == cpu count -- GitLab