未验证 提交 e83f902b 编写于 作者: G guru4elephant 提交者: GitHub

add MultiSlotStringDataGenerator for speedup of string based user inp… (#18390)

* add MultiSlotStringDataGenerator for speedup of string based user input data
上级 681d3553
......@@ -235,6 +235,50 @@ class DataGenerator(object):
return local_iter
# TODO: guru4elephant
# add more generalized DataGenerator that can adapt user-defined slot
# for example, [(name, float_list), (name, str_list), (name, int_list)]
class MultiSlotStringDataGenerator(DataGenerator):
def _gen_str(self, line):
'''
Further processing the output of the process() function rewritten by
user, outputting data that can be directly read by the MultiSlotDataFeed,
and updating proto_info infomation.
The input line will be in this format:
>>> [(name, [str(feasign), ...]), ...]
>>> or ((name, [str(feasign), ...]), ...)
The output will be in this format:
>>> [ids_num id1 id2 ...] ...
For example, if the input is like this:
>>> [("words", ["1926", "08", "17"]), ("label", ["1"])]
>>> or (("words", ["1926", "08", "17"]), ("label", ["1"]))
the output will be:
>>> 3 1234 2345 3456 1 1
Args:
line(str): the output of the process() function rewritten by user.
Returns:
Return a string data that can be read directly by the MultiSlotDataFeed.
'''
if not isinstance(line, list) and not isinstance(line, tuple):
raise ValueError(
"the output of process() must be in list or tuple type"
"Examples: [('words', ['1926', '08', '17']), ('label', ['1'])]")
output = ""
for index, item in enumerate(line):
name, elements = item
if output:
output += " "
out_str = []
out_str.append(str(len(elements)))
out_str.extend(elements)
output += " ".join(out_str)
return output + "\n"
class MultiSlotDataGenerator(DataGenerator):
def _gen_str(self, line):
'''
......@@ -266,7 +310,8 @@ class MultiSlotDataGenerator(DataGenerator):
'''
if not isinstance(line, list) and not isinstance(line, tuple):
raise ValueError(
"the output of process() must be in list or tuple type")
"the output of process() must be in list or tuple type"
"Example: [('words', [1926, 08, 17]), ('label', [1])]")
output = ""
if self._proto_info is None:
......
......@@ -22,5 +22,15 @@ class SyntheticData(MultiSlotDataGenerator):
return data_iter
class SyntheticStringData(MultiSlotStringDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(10000):
yield ("words", ["1", "2", "3", "4"], ("label", ["0"]))
sd = SyntheticData()
sd.run_from_memory()
sd2 = SyntheticStringData()
sd.run_from_memory()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册