未验证 提交 f58cb018 编写于 作者: C Chengmo 提交者: GitHub

【Paddle.Fleet】fix dataset zip py3 bug (#31441)

* fix zip py3 bug
上级 bf09dcb3
...@@ -32,11 +32,11 @@ class DataGenerator(object): ...@@ -32,11 +32,11 @@ class DataGenerator(object):
''' '''
Set batch size of current DataGenerator Set batch size of current DataGenerator
This is necessary only if a user wants to define generator_batch This is necessary only if a user wants to define generator_batch
Example: Example:
.. code-block:: python .. code-block:: python
import paddle.distributed.fleet.data_generator as dg import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator): class MyData(dg.DataGenerator):
...@@ -52,7 +52,7 @@ class DataGenerator(object): ...@@ -52,7 +52,7 @@ class DataGenerator(object):
yield ("words", s[1].extend([s[1][0]])) yield ("words", s[1].extend([s[1][0]]))
mydata = MyData() mydata = MyData()
mydata.set_batch(128) mydata.set_batch(128)
''' '''
self.batch_size_ = batch_size self.batch_size_ = batch_size
...@@ -63,7 +63,7 @@ class DataGenerator(object): ...@@ -63,7 +63,7 @@ class DataGenerator(object):
Example: Example:
.. code-block:: python .. code-block:: python
import paddle.distributed.fleet.data_generator as dg import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator): class MyData(dg.DataGenerator):
...@@ -100,9 +100,9 @@ class DataGenerator(object): ...@@ -100,9 +100,9 @@ class DataGenerator(object):
generated. generated.
Example: Example:
.. code-block:: python .. code-block:: python
import paddle.distributed.fleet.data_generator as dg import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator): class MyData(dg.DataGenerator):
...@@ -161,7 +161,7 @@ class DataGenerator(object): ...@@ -161,7 +161,7 @@ class DataGenerator(object):
The data format is list or tuple: The data format is list or tuple:
[(name, [feasign, ...]), ...] [(name, [feasign, ...]), ...]
or ((name, [feasign, ...]), ...) or ((name, [feasign, ...]), ...)
For example: For example:
[("words", [1926, 08, 17]), ("label", [1])] [("words", [1926, 08, 17]), ("label", [1])]
or (("words", [1926, 08, 17]), ("label", [1])) or (("words", [1926, 08, 17]), ("label", [1]))
...@@ -174,7 +174,7 @@ class DataGenerator(object): ...@@ -174,7 +174,7 @@ class DataGenerator(object):
Example: Example:
.. code-block:: python .. code-block:: python
import paddle.distributed.fleet.data_generator as dg import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator): class MyData(dg.DataGenerator):
...@@ -206,7 +206,7 @@ class DataGenerator(object): ...@@ -206,7 +206,7 @@ class DataGenerator(object):
Example: Example:
.. code-block:: python .. code-block:: python
import paddle.distributed.fleet.data_generator as dg import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator): class MyData(dg.DataGenerator):
...@@ -259,6 +259,9 @@ class MultiSlotStringDataGenerator(DataGenerator): ...@@ -259,6 +259,9 @@ class MultiSlotStringDataGenerator(DataGenerator):
Returns: Returns:
Return a string data that can be read directly by the MultiSlotDataFeed. Return a string data that can be read directly by the MultiSlotDataFeed.
''' '''
if sys.version > '3' and isinstance(line, zip):
line = list(line)
if not isinstance(line, list) and not isinstance(line, tuple): if not isinstance(line, list) and not isinstance(line, tuple):
raise ValueError( raise ValueError(
"the output of process() must be in list or tuple type" "the output of process() must be in list or tuple type"
...@@ -289,7 +292,7 @@ class MultiSlotDataGenerator(DataGenerator): ...@@ -289,7 +292,7 @@ class MultiSlotDataGenerator(DataGenerator):
>>> [ids_num id1 id2 ...] ... >>> [ids_num id1 id2 ...] ...
The proto_info will be in this format: The proto_info will be in this format:
>>> [(name, type), ...] >>> [(name, type), ...]
For example, if the input is like this: For example, if the input is like this:
>>> [("words", [1926, 08, 17]), ("label", [1])] >>> [("words", [1926, 08, 17]), ("label", [1])]
>>> or (("words", [1926, 08, 17]), ("label", [1])) >>> or (("words", [1926, 08, 17]), ("label", [1]))
...@@ -304,6 +307,9 @@ class MultiSlotDataGenerator(DataGenerator): ...@@ -304,6 +307,9 @@ class MultiSlotDataGenerator(DataGenerator):
Returns: Returns:
Return a string data that can be read directly by the MultiSlotDataFeed. Return a string data that can be read directly by the MultiSlotDataFeed.
''' '''
if sys.version > '3' and isinstance(line, zip):
line = list(line)
if not isinstance(line, list) and not isinstance(line, tuple): if not isinstance(line, list) and not isinstance(line, tuple):
raise ValueError( raise ValueError(
"the output of process() must be in list or tuple type" "the output of process() must be in list or tuple type"
......
...@@ -95,6 +95,32 @@ class MyMultiSlotDataGenerator_error_5(fleet.MultiSlotDataGenerator): ...@@ -95,6 +95,32 @@ class MyMultiSlotDataGenerator_error_5(fleet.MultiSlotDataGenerator):
return data_iter return data_iter
class MyMultiSlotStringDataGenerator_zip(fleet.MultiSlotStringDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
feature_name = ["words", "label"]
data = [["1", "2", "3", "4"], ["0"]]
yield zip(feature_name, data)
return data_iter
class MyMultiSlotDataGenerator_zip(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
feature_name = ["words", "label"]
data = [[1, 2, 3, 4], [0]]
yield zip(feature_name, data)
return data_iter
class TestMultiSlotDataGenerator(unittest.TestCase): class TestMultiSlotDataGenerator(unittest.TestCase):
def test_MultiSlotDataGenerator_basic(self): def test_MultiSlotDataGenerator_basic(self):
my_ms_dg = MyMultiSlotDataGenerator() my_ms_dg = MyMultiSlotDataGenerator()
...@@ -149,5 +175,19 @@ class TestMultiSlotDataGenerator_error_5(unittest.TestCase): ...@@ -149,5 +175,19 @@ class TestMultiSlotDataGenerator_error_5(unittest.TestCase):
my_ms_dg.run_from_memory() my_ms_dg.run_from_memory()
class TestMultiSlotStringDataGeneratorZip(unittest.TestCase):
def test_MultiSlotStringDataGenerator_zip(self):
my_ms_dg = MyMultiSlotStringDataGenerator_zip()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotDataGeneratorZip(unittest.TestCase):
def test_MultiSlotDataGenerator_zip(self):
my_ms_dg = MyMultiSlotDataGenerator_zip()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册