diff --git a/python/paddle/distributed/fleet/data_generator/data_generator.py b/python/paddle/distributed/fleet/data_generator/data_generator.py index a10c4b8fe5ce08d6dc2d5ccf096547b9a5232da3..3f0a2ea35f335aa7f9667619c958cba5621d67b1 100644 --- a/python/paddle/distributed/fleet/data_generator/data_generator.py +++ b/python/paddle/distributed/fleet/data_generator/data_generator.py @@ -27,14 +27,6 @@ class DataGenerator(object): self._proto_info = None self.batch_size_ = 32 - def _set_line_limit(self, line_limit): - if not isinstance(line_limit, int): - raise ValueError("line_limit%s must be in int type" % - type(line_limit)) - if line_limit < 1: - raise ValueError("line_limit can not less than 1") - self._line_limit = line_limit - def set_batch(self, batch_size): ''' Set batch size of current DataGenerator diff --git a/python/paddle/fluid/tests/unittests/test_data_generator.py b/python/paddle/fluid/tests/unittests/test_data_generator.py index 08974c2df80fbd72aab48fd86b9040764a58079d..8f2961426a80bd96634a849638b127b4130d255a 100644 --- a/python/paddle/fluid/tests/unittests/test_data_generator.py +++ b/python/paddle/fluid/tests/unittests/test_data_generator.py @@ -13,12 +13,15 @@ import paddle import unittest import paddle.distributed.fleet as fleet +import os class MyMultiSlotDataGenerator(fleet.MultiSlotDataGenerator): def generate_sample(self, line): def data_iter(): - for i in range(100): + for i in range(40): + if i == 1: + yield None yield ("words", [1, 2, 3, 4]), ("label", [0]) return data_iter @@ -27,22 +30,140 @@ class MyMultiSlotDataGenerator(fleet.MultiSlotDataGenerator): class MyMultiSlotStringDataGenerator(fleet.MultiSlotStringDataGenerator): def generate_sample(self, line): def data_iter(): - for i in range(100): + for i in range(40): + if i == 1: + yield None yield ("words", ["1", "2", "3", "4"]), ("label", ["0"]) return data_iter +class MyMultiSlotDataGenerator_error(fleet.MultiSlotDataGenerator): + def generate_sample(self, line): + def data_iter(): + for i in range(40): + if i == 1: + yield None + yield "words" + + return data_iter + + +class MyMultiSlotDataGenerator_error_2(fleet.MultiSlotStringDataGenerator): + def generate_sample(self, line): + def data_iter(): + for i in range(40): + if i == 1: + yield None + yield "words" + + return data_iter + + +class MyMultiSlotDataGenerator_error_3(fleet.MultiSlotDataGenerator): + def generate_sample(self, line): + def data_iter(): + for i in range(40): + if i == 1: + yield None + yield (1, ["1", "2", "3", "4"]), (2, ["0"]) + + return data_iter + + +class MyMultiSlotDataGenerator_error_4(fleet.MultiSlotDataGenerator): + def generate_sample(self, line): + def data_iter(): + for i in range(40): + if i == 1: + yield None + yield ("words", "1"), ("label", "0") + + return data_iter + + +class MyMultiSlotDataGenerator_error_5(fleet.MultiSlotDataGenerator): + def generate_sample(self, line): + def data_iter(): + for i in range(40): + if i == 1: + yield None + yield ("words", []), ("label", []) + + return data_iter + + class TestMultiSlotDataGenerator(unittest.TestCase): def test_MultiSlotDataGenerator_basic(self): my_ms_dg = MyMultiSlotDataGenerator() + my_ms_dg.set_batch(1) my_ms_dg.run_from_memory() class TestMultiSlotStringDataGenerator(unittest.TestCase): def test_MyMultiSlotStringDataGenerator_basic(self): - my_mss_dg = MyMultiSlotStringDataGenerator() - my_mss_dg.run_from_memory() + my_ms_dg = MyMultiSlotStringDataGenerator() + my_ms_dg.set_batch(1) + my_ms_dg.run_from_memory() + + +class TestMultiSlotStringDataGenerator_2(unittest.TestCase): + def test_MyMultiSlotStringDataGenerator_stdin(self): + with open("test_queue_dataset_run_a.txt", "w") as f: + data = "2 1 2\n" + data += "2 6 2\n" + data += "2 5 2\n" + data += "2 7 2\n" + f.write(data) + + tmp = os.popen( + "cat test_queue_dataset_run_a.txt | python my_data_generator.py" + ).readlines() + expected_res = [ + '1 2 1 1 1 2\n', '1 2 1 6 1 2\n', '1 2 1 5 1 2\n', '1 2 1 7 1 2\n' + ] + self.assertEqual(tmp, expected_res) + os.remove("./test_queue_dataset_run_a.txt") + + +class TestMultiSlotDataGenerator_error(unittest.TestCase): + def test_MultiSlotDataGenerator_error(self): + with self.assertRaises(ValueError): + my_ms_dg = MyMultiSlotDataGenerator_error() + my_ms_dg.set_batch(1) + my_ms_dg.run_from_memory() + + +class TestMultiSlotDataGenerator_error_2(unittest.TestCase): + def test_MultiSlotDataGenerator_error(self): + with self.assertRaises(ValueError): + my_ms_dg = MyMultiSlotDataGenerator_error_2() + my_ms_dg.set_batch(1) + my_ms_dg.run_from_memory() + + +class TestMultiSlotDataGenerator_error_3(unittest.TestCase): + def test_MultiSlotDataGenerator_error(self): + with self.assertRaises(ValueError): + my_ms_dg = MyMultiSlotDataGenerator_error_3() + my_ms_dg.set_batch(1) + my_ms_dg.run_from_memory() + + +class TestMultiSlotDataGenerator_error_4(unittest.TestCase): + def test_MultiSlotDataGenerator_error(self): + with self.assertRaises(ValueError): + my_ms_dg = MyMultiSlotDataGenerator_error_4() + my_ms_dg.set_batch(1) + my_ms_dg.run_from_memory() + + +class TestMultiSlotDataGenerator_error_5(unittest.TestCase): + def test_MultiSlotDataGenerator_error(self): + with self.assertRaises(ValueError): + my_ms_dg = MyMultiSlotDataGenerator_error_5() + my_ms_dg.set_batch(1) + my_ms_dg.run_from_memory() if __name__ == '__main__':