diff --git a/python/paddle/distributed/fleet/data_generator/data_generator.py b/python/paddle/distributed/fleet/data_generator/data_generator.py index 669d2ea24a0c788bbe1c0cff38a843bd96e29016..9d743fc38bf3982454bd0e27779567bfb9f0717e 100644 --- a/python/paddle/distributed/fleet/data_generator/data_generator.py +++ b/python/paddle/distributed/fleet/data_generator/data_generator.py @@ -32,11 +32,11 @@ class DataGenerator(object): ''' Set batch size of current DataGenerator This is necessary only if a user wants to define generator_batch - + Example: .. code-block:: python - + import paddle.distributed.fleet.data_generator as dg class MyData(dg.DataGenerator): @@ -52,7 +52,7 @@ class DataGenerator(object): yield ("words", s[1].extend([s[1][0]])) mydata = MyData() mydata.set_batch(128) - + ''' self.batch_size_ = batch_size @@ -63,7 +63,7 @@ class DataGenerator(object): Example: .. code-block:: python - + import paddle.distributed.fleet.data_generator as dg class MyData(dg.DataGenerator): @@ -100,9 +100,9 @@ class DataGenerator(object): generated. Example: - + .. code-block:: python - + import paddle.distributed.fleet.data_generator as dg class MyData(dg.DataGenerator): @@ -161,7 +161,7 @@ class DataGenerator(object): The data format is list or tuple: [(name, [feasign, ...]), ...] or ((name, [feasign, ...]), ...) - + For example: [("words", [1926, 08, 17]), ("label", [1])] or (("words", [1926, 08, 17]), ("label", [1])) @@ -174,7 +174,7 @@ class DataGenerator(object): Example: .. code-block:: python - + import paddle.distributed.fleet.data_generator as dg class MyData(dg.DataGenerator): @@ -206,7 +206,7 @@ class DataGenerator(object): Example: .. code-block:: python - + import paddle.distributed.fleet.data_generator as dg class MyData(dg.DataGenerator): @@ -259,6 +259,9 @@ class MultiSlotStringDataGenerator(DataGenerator): Returns: Return a string data that can be read directly by the MultiSlotDataFeed. ''' + if sys.version > '3' and isinstance(line, zip): + line = list(line) + if not isinstance(line, list) and not isinstance(line, tuple): raise ValueError( "the output of process() must be in list or tuple type" @@ -289,7 +292,7 @@ class MultiSlotDataGenerator(DataGenerator): >>> [ids_num id1 id2 ...] ... The proto_info will be in this format: >>> [(name, type), ...] - + For example, if the input is like this: >>> [("words", [1926, 08, 17]), ("label", [1])] >>> or (("words", [1926, 08, 17]), ("label", [1])) @@ -304,6 +307,9 @@ class MultiSlotDataGenerator(DataGenerator): Returns: Return a string data that can be read directly by the MultiSlotDataFeed. ''' + if sys.version > '3' and isinstance(line, zip): + line = list(line) + if not isinstance(line, list) and not isinstance(line, tuple): raise ValueError( "the output of process() must be in list or tuple type" diff --git a/python/paddle/fluid/tests/unittests/test_data_generator.py b/python/paddle/fluid/tests/unittests/test_data_generator.py index 6381cb364026369734f1de3747d50cc1ca17d5ef..69d8e01fd464afc724d286740b6c8f42929dd387 100644 --- a/python/paddle/fluid/tests/unittests/test_data_generator.py +++ b/python/paddle/fluid/tests/unittests/test_data_generator.py @@ -95,6 +95,32 @@ class MyMultiSlotDataGenerator_error_5(fleet.MultiSlotDataGenerator): return data_iter +class MyMultiSlotStringDataGenerator_zip(fleet.MultiSlotStringDataGenerator): + def generate_sample(self, line): + def data_iter(): + for i in range(40): + if i == 1: + yield None + feature_name = ["words", "label"] + data = [["1", "2", "3", "4"], ["0"]] + yield zip(feature_name, data) + + return data_iter + + +class MyMultiSlotDataGenerator_zip(fleet.MultiSlotDataGenerator): + def generate_sample(self, line): + def data_iter(): + for i in range(40): + if i == 1: + yield None + feature_name = ["words", "label"] + data = [[1, 2, 3, 4], [0]] + yield zip(feature_name, data) + + return data_iter + + class TestMultiSlotDataGenerator(unittest.TestCase): def test_MultiSlotDataGenerator_basic(self): my_ms_dg = MyMultiSlotDataGenerator() @@ -149,5 +175,19 @@ class TestMultiSlotDataGenerator_error_5(unittest.TestCase): my_ms_dg.run_from_memory() +class TestMultiSlotStringDataGeneratorZip(unittest.TestCase): + def test_MultiSlotStringDataGenerator_zip(self): + my_ms_dg = MyMultiSlotStringDataGenerator_zip() + my_ms_dg.set_batch(1) + my_ms_dg.run_from_memory() + + +class TestMultiSlotDataGeneratorZip(unittest.TestCase): + def test_MultiSlotDataGenerator_zip(self): + my_ms_dg = MyMultiSlotDataGenerator_zip() + my_ms_dg.set_batch(1) + my_ms_dg.run_from_memory() + + if __name__ == '__main__': unittest.main()