From e83f902b98cee9e0d95c74f388783434b39887ba Mon Sep 17 00:00:00 2001 From: guru4elephant <35550832+guru4elephant@users.noreply.github.com> Date: Fri, 28 Jun 2019 20:59:17 +0800 Subject: [PATCH] =?UTF-8?q?add=20MultiSlotStringDataGenerator=20for=20spee?= =?UTF-8?q?dup=20of=20string=20based=20user=20inp=E2=80=A6=20(#18390)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add MultiSlotStringDataGenerator for speedup of string based user input data --- .../fluid/incubate/data_generator/__init__.py | 47 ++++++++++++++++++- .../data_generator/test_data_generator.py | 10 ++++ 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/incubate/data_generator/__init__.py b/python/paddle/fluid/incubate/data_generator/__init__.py index 0d376ee973f..c5d298f951d 100644 --- a/python/paddle/fluid/incubate/data_generator/__init__.py +++ b/python/paddle/fluid/incubate/data_generator/__init__.py @@ -235,6 +235,50 @@ class DataGenerator(object): return local_iter +# TODO: guru4elephant +# add more generalized DataGenerator that can adapt user-defined slot +# for example, [(name, float_list), (name, str_list), (name, int_list)] +class MultiSlotStringDataGenerator(DataGenerator): + def _gen_str(self, line): + ''' + Further processing the output of the process() function rewritten by + user, outputting data that can be directly read by the MultiSlotDataFeed, + and updating proto_info infomation. + + The input line will be in this format: + >>> [(name, [str(feasign), ...]), ...] + >>> or ((name, [str(feasign), ...]), ...) + The output will be in this format: + >>> [ids_num id1 id2 ...] ... + + For example, if the input is like this: + >>> [("words", ["1926", "08", "17"]), ("label", ["1"])] + >>> or (("words", ["1926", "08", "17"]), ("label", ["1"])) + the output will be: + >>> 3 1234 2345 3456 1 1 + + Args: + line(str): the output of the process() function rewritten by user. + + Returns: + Return a string data that can be read directly by the MultiSlotDataFeed. + ''' + if not isinstance(line, list) and not isinstance(line, tuple): + raise ValueError( + "the output of process() must be in list or tuple type" + "Examples: [('words', ['1926', '08', '17']), ('label', ['1'])]") + output = "" + for index, item in enumerate(line): + name, elements = item + if output: + output += " " + out_str = [] + out_str.append(str(len(elements))) + out_str.extend(elements) + output += " ".join(out_str) + return output + "\n" + + class MultiSlotDataGenerator(DataGenerator): def _gen_str(self, line): ''' @@ -266,7 +310,8 @@ class MultiSlotDataGenerator(DataGenerator): ''' if not isinstance(line, list) and not isinstance(line, tuple): raise ValueError( - "the output of process() must be in list or tuple type") + "the output of process() must be in list or tuple type" + "Example: [('words', [1926, 08, 17]), ('label', [1])]") output = "" if self._proto_info is None: diff --git a/python/paddle/fluid/incubate/data_generator/test_data_generator.py b/python/paddle/fluid/incubate/data_generator/test_data_generator.py index ea42551efb6..dcacd67e92a 100644 --- a/python/paddle/fluid/incubate/data_generator/test_data_generator.py +++ b/python/paddle/fluid/incubate/data_generator/test_data_generator.py @@ -22,5 +22,15 @@ class SyntheticData(MultiSlotDataGenerator): return data_iter +class SyntheticStringData(MultiSlotStringDataGenerator): + def generate_sample(self, line): + def data_iter(): + for i in range(10000): + yield ("words", ["1", "2", "3", "4"], ("label", ["0"])) + + sd = SyntheticData() sd.run_from_memory() + +sd2 = SyntheticStringData() +sd.run_from_memory() -- GitLab