未验证 提交 f58cb018 编写于 作者: C Chengmo 提交者: GitHub

【Paddle.Fleet】fix dataset zip py3 bug (#31441)

* fix zip py3 bug
上级 bf09dcb3
......@@ -32,11 +32,11 @@ class DataGenerator(object):
'''
Set batch size of current DataGenerator
This is necessary only if a user wants to define generator_batch
Example:
.. code-block:: python
import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator):
......@@ -52,7 +52,7 @@ class DataGenerator(object):
yield ("words", s[1].extend([s[1][0]]))
mydata = MyData()
mydata.set_batch(128)
'''
self.batch_size_ = batch_size
......@@ -63,7 +63,7 @@ class DataGenerator(object):
Example:
.. code-block:: python
import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator):
......@@ -100,9 +100,9 @@ class DataGenerator(object):
generated.
Example:
.. code-block:: python
import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator):
......@@ -161,7 +161,7 @@ class DataGenerator(object):
The data format is list or tuple:
[(name, [feasign, ...]), ...]
or ((name, [feasign, ...]), ...)
For example:
[("words", [1926, 08, 17]), ("label", [1])]
or (("words", [1926, 08, 17]), ("label", [1]))
......@@ -174,7 +174,7 @@ class DataGenerator(object):
Example:
.. code-block:: python
import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator):
......@@ -206,7 +206,7 @@ class DataGenerator(object):
Example:
.. code-block:: python
import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator):
......@@ -259,6 +259,9 @@ class MultiSlotStringDataGenerator(DataGenerator):
Returns:
Return a string data that can be read directly by the MultiSlotDataFeed.
'''
if sys.version > '3' and isinstance(line, zip):
line = list(line)
if not isinstance(line, list) and not isinstance(line, tuple):
raise ValueError(
"the output of process() must be in list or tuple type"
......@@ -289,7 +292,7 @@ class MultiSlotDataGenerator(DataGenerator):
>>> [ids_num id1 id2 ...] ...
The proto_info will be in this format:
>>> [(name, type), ...]
For example, if the input is like this:
>>> [("words", [1926, 08, 17]), ("label", [1])]
>>> or (("words", [1926, 08, 17]), ("label", [1]))
......@@ -304,6 +307,9 @@ class MultiSlotDataGenerator(DataGenerator):
Returns:
Return a string data that can be read directly by the MultiSlotDataFeed.
'''
if sys.version > '3' and isinstance(line, zip):
line = list(line)
if not isinstance(line, list) and not isinstance(line, tuple):
raise ValueError(
"the output of process() must be in list or tuple type"
......
......@@ -95,6 +95,32 @@ class MyMultiSlotDataGenerator_error_5(fleet.MultiSlotDataGenerator):
return data_iter
class MyMultiSlotStringDataGenerator_zip(fleet.MultiSlotStringDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
feature_name = ["words", "label"]
data = [["1", "2", "3", "4"], ["0"]]
yield zip(feature_name, data)
return data_iter
class MyMultiSlotDataGenerator_zip(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
feature_name = ["words", "label"]
data = [[1, 2, 3, 4], [0]]
yield zip(feature_name, data)
return data_iter
class TestMultiSlotDataGenerator(unittest.TestCase):
def test_MultiSlotDataGenerator_basic(self):
my_ms_dg = MyMultiSlotDataGenerator()
......@@ -149,5 +175,19 @@ class TestMultiSlotDataGenerator_error_5(unittest.TestCase):
my_ms_dg.run_from_memory()
class TestMultiSlotStringDataGeneratorZip(unittest.TestCase):
def test_MultiSlotStringDataGenerator_zip(self):
my_ms_dg = MyMultiSlotStringDataGenerator_zip()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotDataGeneratorZip(unittest.TestCase):
def test_MultiSlotDataGenerator_zip(self):
my_ms_dg = MyMultiSlotDataGenerator_zip()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册