未验证 提交 f58cb018 编写于 作者: C Chengmo 提交者: GitHub

【Paddle.Fleet】fix dataset zip py3 bug (#31441)

* fix zip py3 bug
上级 bf09dcb3
......@@ -259,6 +259,9 @@ class MultiSlotStringDataGenerator(DataGenerator):
Returns:
Return a string data that can be read directly by the MultiSlotDataFeed.
'''
if sys.version > '3' and isinstance(line, zip):
line = list(line)
if not isinstance(line, list) and not isinstance(line, tuple):
raise ValueError(
"the output of process() must be in list or tuple type"
......@@ -304,6 +307,9 @@ class MultiSlotDataGenerator(DataGenerator):
Returns:
Return a string data that can be read directly by the MultiSlotDataFeed.
'''
if sys.version > '3' and isinstance(line, zip):
line = list(line)
if not isinstance(line, list) and not isinstance(line, tuple):
raise ValueError(
"the output of process() must be in list or tuple type"
......
......@@ -95,6 +95,32 @@ class MyMultiSlotDataGenerator_error_5(fleet.MultiSlotDataGenerator):
return data_iter
class MyMultiSlotStringDataGenerator_zip(fleet.MultiSlotStringDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
feature_name = ["words", "label"]
data = [["1", "2", "3", "4"], ["0"]]
yield zip(feature_name, data)
return data_iter
class MyMultiSlotDataGenerator_zip(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
feature_name = ["words", "label"]
data = [[1, 2, 3, 4], [0]]
yield zip(feature_name, data)
return data_iter
class TestMultiSlotDataGenerator(unittest.TestCase):
def test_MultiSlotDataGenerator_basic(self):
my_ms_dg = MyMultiSlotDataGenerator()
......@@ -149,5 +175,19 @@ class TestMultiSlotDataGenerator_error_5(unittest.TestCase):
my_ms_dg.run_from_memory()
class TestMultiSlotStringDataGeneratorZip(unittest.TestCase):
def test_MultiSlotStringDataGenerator_zip(self):
my_ms_dg = MyMultiSlotStringDataGenerator_zip()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotDataGeneratorZip(unittest.TestCase):
def test_MultiSlotDataGenerator_zip(self):
my_ms_dg = MyMultiSlotDataGenerator_zip()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册