From e981c67acd5d29ffbcf906b36ac5133861b8de35 Mon Sep 17 00:00:00 2001 From: hesham Date: Fri, 19 Jun 2020 17:40:02 -0400 Subject: [PATCH] Python Tokenizer !38 Synchronize with latest Ascend software suite 17 Jun 2020 Merge pull request !38 from yanghaoran/master --- mindspore/dataset/text/__init__.py | 5 +- mindspore/dataset/text/transforms.py | 27 +++++++++- mindspore/dataset/text/validators.py | 22 ++++++++ .../python/dataset/test_python_tokenizer.py | 52 +++++++++++++++++++ 4 files changed, 102 insertions(+), 4 deletions(-) create mode 100644 tests/ut/python/dataset/test_python_tokenizer.py diff --git a/mindspore/dataset/text/__init__.py b/mindspore/dataset/text/__init__.py index 4f2ecdc54..7856b55b7 100644 --- a/mindspore/dataset/text/__init__.py +++ b/mindspore/dataset/text/__init__.py @@ -22,12 +22,13 @@ from .utils import to_str, to_bytes, JiebaMode, Vocab, NormalizeForm __all__ = [ "Lookup", "JiebaTokenizer", "UnicodeCharTokenizer", "Ngram", - "to_str", "to_bytes", "JiebaMode", "Vocab", "WordpieceTokenizer", "TruncateSequencePair", "ToNumber" + "to_str", "to_bytes", "JiebaMode", "Vocab", "WordpieceTokenizer", "TruncateSequencePair", "ToNumber", + "PythonTokenizer" ] if platform.system().lower() != 'windows': from .transforms import UnicodeScriptTokenizer, WhitespaceTokenizer, CaseFold, NormalizeUTF8, \ - RegexReplace, RegexTokenizer, BasicTokenizer, BertTokenizer + RegexReplace, RegexTokenizer, BasicTokenizer, BertTokenizer, PythonTokenizer __all__.append(["UnicodeScriptTokenizer", "WhitespaceTokenizer", "CaseFold", "NormalizeUTF8", "RegexReplace", "RegexTokenizer", "BasicTokenizer", "BertTokenizer", "NormalizeForm"]) diff --git a/mindspore/dataset/text/transforms.py b/mindspore/dataset/text/transforms.py index ad4c12ad9..8e03a436b 100644 --- a/mindspore/dataset/text/transforms.py +++ b/mindspore/dataset/text/transforms.py @@ -18,13 +18,14 @@ c transforms for all text related operators import os import re import platform +import numpy as np import mindspore._c_dataengine as cde -from .utils import JiebaMode, NormalizeForm +from .utils import JiebaMode, NormalizeForm, to_str from .validators import check_lookup, check_jieba_add_dict, \ check_jieba_add_word, check_jieba_init, check_ngram, check_pair_truncate, \ - check_to_number + check_to_number, check_python_tokenizer from ..core.datatypes import mstype_to_detype @@ -406,3 +407,25 @@ class ToNumber(cde.ToNumberOp): data_type = mstype_to_detype(data_type) self.data_type = str(data_type) super().__init__(data_type) + + +class PythonTokenizer: + """ + Callable class to be used for user-defined string tokenizer. + Args: + tokenizer (Callable): Python function that takes a `str` and returns a list of `str` as tokens. + + Examples: + >>> def my_tokenizer(line): + >>> return line.split() + >>> data = data.map(operations=PythonTokenizer(my_tokenizer)) + """ + + @check_python_tokenizer + def __init__(self, tokenizer): + self.tokenizer = np.vectorize(lambda x: np.array(tokenizer(x), dtype='U'), signature='()->(n)') + + def __call__(self, in_array): + in_array = to_str(in_array) + tokens = self.tokenizer(in_array) + return tokens diff --git a/mindspore/dataset/text/validators.py b/mindspore/dataset/text/validators.py index 74ff31dd7..72802a840 100644 --- a/mindspore/dataset/text/validators.py +++ b/mindspore/dataset/text/validators.py @@ -411,3 +411,25 @@ def check_to_number(method): return method(self, **kwargs) return new_method + + +def check_python_tokenizer(method): + """A wrapper that wraps a parameter check to the original function (PythonTokenizer).""" + + @wraps(method) + def new_method(self, *args, **kwargs): + tokenizer = (list(args) + [None])[0] + if "tokenizer" in kwargs: + tokenizer = kwargs.get("tokenizer") + + if tokenizer is None: + raise ValueError("tokenizer is a mandatory parameter.") + + if not callable(tokenizer): + raise TypeError("tokenizer is not a callable python function") + + kwargs["tokenizer"] = tokenizer + + return method(self, **kwargs) + + return new_method diff --git a/tests/ut/python/dataset/test_python_tokenizer.py b/tests/ut/python/dataset/test_python_tokenizer.py new file mode 100644 index 000000000..78db55321 --- /dev/null +++ b/tests/ut/python/dataset/test_python_tokenizer.py @@ -0,0 +1,52 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing PythonTokenizer op in DE +""" +import mindspore.dataset as ds +import mindspore.dataset.text as text +from mindspore import log as logger + +DATA_FILE = "../data/dataset/testTokenizerData/1.txt" + + +def test_whitespace_tokenizer_ch(): + """ + Test PythonTokenizer + """ + whitespace_strs = [["Welcome", "to", "Beijing!"], + ["北京欢迎您!"], + ["我喜欢English!"], + [""]] + + def my_tokenizer(line): + words = line.split() + if not words: + return [""] + return words + + dataset = ds.TextFileDataset(DATA_FILE, shuffle=False) + tokenizer = text.PythonTokenizer(my_tokenizer) + dataset = dataset.map(operations=tokenizer, num_parallel_workers=1) + tokens = [] + for i in dataset.create_dict_iterator(): + s = text.to_str(i['text']).tolist() + tokens.append(s) + logger.info("The out tokens is : {}".format(tokens)) + assert whitespace_strs == tokens + + +if __name__ == '__main__': + test_whitespace_tokenizer_ch() -- GitLab