提交 cb602fce 编写于 作者: Y yaoxuefeng6

add data_generator ut

上级 dfbe4488
......@@ -27,14 +27,6 @@ class DataGenerator(object):
self._proto_info = None
self.batch_size_ = 32
def _set_line_limit(self, line_limit):
if not isinstance(line_limit, int):
raise ValueError("line_limit%s must be in int type" %
type(line_limit))
if line_limit < 1:
raise ValueError("line_limit can not less than 1")
self._line_limit = line_limit
def set_batch(self, batch_size):
'''
Set batch size of current DataGenerator
......
......@@ -13,12 +13,15 @@
import paddle
import unittest
import paddle.distributed.fleet as fleet
import os
class MyMultiSlotDataGenerator(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(100):
for i in range(40):
if i == 1:
yield None
yield ("words", [1, 2, 3, 4]), ("label", [0])
return data_iter
......@@ -27,22 +30,140 @@ class MyMultiSlotDataGenerator(fleet.MultiSlotDataGenerator):
class MyMultiSlotStringDataGenerator(fleet.MultiSlotStringDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(100):
for i in range(40):
if i == 1:
yield None
yield ("words", ["1", "2", "3", "4"]), ("label", ["0"])
return data_iter
class MyMultiSlotDataGenerator_error(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield "words"
return data_iter
class MyMultiSlotDataGenerator_error_2(fleet.MultiSlotStringDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield "words"
return data_iter
class MyMultiSlotDataGenerator_error_3(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield (1, ["1", "2", "3", "4"]), (2, ["0"])
return data_iter
class MyMultiSlotDataGenerator_error_4(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield ("words", "1"), ("label", "0")
return data_iter
class MyMultiSlotDataGenerator_error_5(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield ("words", []), ("label", [])
return data_iter
class TestMultiSlotDataGenerator(unittest.TestCase):
def test_MultiSlotDataGenerator_basic(self):
my_ms_dg = MyMultiSlotDataGenerator()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotStringDataGenerator(unittest.TestCase):
def test_MyMultiSlotStringDataGenerator_basic(self):
my_mss_dg = MyMultiSlotStringDataGenerator()
my_mss_dg.run_from_memory()
my_ms_dg = MyMultiSlotStringDataGenerator()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotStringDataGenerator_2(unittest.TestCase):
def test_MyMultiSlotStringDataGenerator_stdin(self):
with open("test_queue_dataset_run_a.txt", "w") as f:
data = "2 1 2\n"
data += "2 6 2\n"
data += "2 5 2\n"
data += "2 7 2\n"
f.write(data)
tmp = os.popen(
"cat test_queue_dataset_run_a.txt | python my_data_generator.py"
).readlines()
expected_res = [
'1 2 1 1 1 2\n', '1 2 1 6 1 2\n', '1 2 1 5 1 2\n', '1 2 1 7 1 2\n'
]
self.assertEqual(tmp, expected_res)
os.remove("./test_queue_dataset_run_a.txt")
class TestMultiSlotDataGenerator_error(unittest.TestCase):
def test_MultiSlotDataGenerator_error(self):
with self.assertRaises(ValueError):
my_ms_dg = MyMultiSlotDataGenerator_error()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotDataGenerator_error_2(unittest.TestCase):
def test_MultiSlotDataGenerator_error(self):
with self.assertRaises(ValueError):
my_ms_dg = MyMultiSlotDataGenerator_error_2()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotDataGenerator_error_3(unittest.TestCase):
def test_MultiSlotDataGenerator_error(self):
with self.assertRaises(ValueError):
my_ms_dg = MyMultiSlotDataGenerator_error_3()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotDataGenerator_error_4(unittest.TestCase):
def test_MultiSlotDataGenerator_error(self):
with self.assertRaises(ValueError):
my_ms_dg = MyMultiSlotDataGenerator_error_4()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotDataGenerator_error_5(unittest.TestCase):
def test_MultiSlotDataGenerator_error(self):
with self.assertRaises(ValueError):
my_ms_dg = MyMultiSlotDataGenerator_error_5()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册