提交 27447700 编写于 作者: X xjqbest

fix dataset testcase error

test=develop
上级 5e513928
...@@ -107,12 +107,12 @@ class TestDataset(unittest.TestCase): ...@@ -107,12 +107,12 @@ class TestDataset(unittest.TestCase):
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
for i in range(2): for i in range(2):
#try: try:
exe.train_from_dataset(fluid.default_main_program(), dataset) exe.train_from_dataset(fluid.default_main_program(), dataset)
#except ImportError as e: except ImportError as e:
# pass pass
#except Exception as e: except Exception as e:
# self.assertTrue(False) self.assertTrue(False)
os.remove("./test_in_memory_dataset_run_a.txt") os.remove("./test_in_memory_dataset_run_a.txt")
os.remove("./test_in_memory_dataset_run_b.txt") os.remove("./test_in_memory_dataset_run_b.txt")
...@@ -151,12 +151,12 @@ class TestDataset(unittest.TestCase): ...@@ -151,12 +151,12 @@ class TestDataset(unittest.TestCase):
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
for i in range(2): for i in range(2):
#try: try:
exe.train_from_dataset(fluid.default_main_program(), dataset) exe.train_from_dataset(fluid.default_main_program(), dataset)
#except ImportError as e: except ImportError as e:
# pass pass
#except Exception as e: except Exception as e:
# self.assertTrue(False) self.assertTrue(False)
os.remove("./test_queue_dataset_run_a.txt") os.remove("./test_queue_dataset_run_a.txt")
os.remove("./test_queue_dataset_run_b.txt") os.remove("./test_queue_dataset_run_b.txt")
......
...@@ -23,7 +23,7 @@ class TrainerDesc(object): ...@@ -23,7 +23,7 @@ class TrainerDesc(object):
with open(proto_file, 'r') as f: with open(proto_file, 'r') as f:
text_format.Parse(f.read(), self.proto_desc) text_format.Parse(f.read(), self.proto_desc)
''' '''
from .proto import trainer_desc_pb2 from proto import trainer_desc_pb2
self.proto_desc = trainer_desc_pb2.TrainerDesc() self.proto_desc = trainer_desc_pb2.TrainerDesc()
import multiprocessing as mp import multiprocessing as mp
# set default thread num == cpu count # set default thread num == cpu count
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册