diff --git a/model_zoo/Transformer/create_data.py b/model_zoo/Transformer/create_data.py index af941623cbcf4a6e245c926374af684ef8b28dfe..fc81d3c23213d7d11f344704543a71e4d569a706 100644 --- a/model_zoo/Transformer/create_data.py +++ b/model_zoo/Transformer/create_data.py @@ -37,13 +37,13 @@ class SampleInstance(): def __str__(self): s = "" s += "source sos tokens: %s\n" % (" ".join( - [tokenization.printable_text(x) for x in self.source_sos_tokens])) + [tokenization.convert_to_printable(x) for x in self.source_sos_tokens])) s += "source eos tokens: %s\n" % (" ".join( - [tokenization.printable_text(x) for x in self.source_eos_tokens])) + [tokenization.convert_to_printable(x) for x in self.source_eos_tokens])) s += "target sos tokens: %s\n" % (" ".join( - [tokenization.printable_text(x) for x in self.target_sos_tokens])) + [tokenization.convert_to_printable(x) for x in self.target_sos_tokens])) s += "target eos tokens: %s\n" % (" ".join( - [tokenization.printable_text(x) for x in self.target_eos_tokens])) + [tokenization.convert_to_printable(x) for x in self.target_eos_tokens])) s += "\n" return s @@ -185,9 +185,9 @@ def main(): if total_written <= 20: logging.info("*** Example ***") logging.info("source tokens: %s", " ".join( - [tokenization.printable_text(x) for x in instance.source_eos_tokens])) + [tokenization.convert_to_printable(x) for x in instance.source_eos_tokens])) logging.info("target tokens: %s", " ".join( - [tokenization.printable_text(x) for x in instance.target_sos_tokens])) + [tokenization.convert_to_printable(x) for x in instance.target_sos_tokens])) for feature_name in features.keys(): feature = features[feature_name] diff --git a/model_zoo/Transformer/scripts/replace-quote.perl b/model_zoo/Transformer/scripts/replace-quote.perl index 95f9abcc9127ffd0fd661f0c2109244f596d5eb6..5e6e715dafa6449f9fc1e2213f0556976f385dc6 100644 --- a/model_zoo/Transformer/scripts/replace-quote.perl +++ b/model_zoo/Transformer/scripts/replace-quote.perl @@ -1,4 +1,19 @@ #!/usr/bin/env perl +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ use warnings; use strict; diff --git a/model_zoo/Transformer/src/tokenization.py b/model_zoo/Transformer/src/tokenization.py index fd0fc9795566334d9289a861a54b0c0944954bcb..6c4f4fec206af78e56e62263e5da220136cb91c9 100644 --- a/model_zoo/Transformer/src/tokenization.py +++ b/model_zoo/Transformer/src/tokenization.py @@ -1,193 +1,158 @@ -# coding=utf-8 -# Copyright 2018 The Google AI Language Team Authors. +# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -############################################################################### -# Modified by Huawei Technologies Co., Ltd, May, 2020, with following changes: -# - Remove some unused classes and functions -# - Modify load_vocab, convert_to_unicode, printable_text function -# - Modify BasicTokenizer class -# - Add WhiteSpaceTokenizer class -############################################################################### - +# ============================================================================ """Tokenization utilities.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - +import sys import collections import unicodedata -import six -def convert_to_unicode(text): - """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" - if six.PY3: +def convert_to_printable(text): + """ + Converts `text` to a printable coding format. + """ + if sys.version_info[0] == 3: if isinstance(text, str): return text if isinstance(text, bytes): return text.decode("utf-8", "ignore") - raise ValueError("Unsupported string type: %s" % (type(text))) - if six.PY2: + raise ValueError("Only support type `str` or `bytes`, while text type is `%s`" % (type(text))) + if sys.version_info[0] == 2: if isinstance(text, str): - return text.decode("utf-8", "ignore") - if isinstance(text, unicode): return text - raise ValueError("Unsupported string type: %s" % (type(text))) - raise ValueError("Not running on Python2 or Python 3?") - + if isinstance(text, unicode): + return text.encode("utf-8") + raise ValueError("Only support type `str` or `unicode`, while text type is `%s`" % (type(text))) + raise ValueError("Only supported when running on Python2 or Python3.") -def printable_text(text): - """Returns text encoded in a way suitable for print or `logging`.""" - # These functions want `str` for both Python2 and Python3, but in one case - # it's a Unicode string and in the other it's a byte string. - if six.PY3: +def convert_to_unicode(text): + """ + Converts `text` to Unicode format. + """ + if sys.version_info[0] == 3: if isinstance(text, str): return text if isinstance(text, bytes): return text.decode("utf-8", "ignore") - raise ValueError("Unsupported string type: %s" % (type(text))) - if six.PY2: + raise ValueError("Only support type `str` or `bytes`, while text type is `%s`" % (type(text))) + if sys.version_info[0] == 2: if isinstance(text, str): - return text + return text.decode("utf-8", "ignore") if isinstance(text, unicode): - return text.encode("utf-8") - raise ValueError("Unsupported string type: %s" % (type(text))) - raise ValueError("Not running on Python2 or Python 3?") + return text + raise ValueError("Only support type `str` or `unicode`, while text type is `%s`" % (type(text))) + raise ValueError("Only supported when running on Python2 or Python3.") -def load_vocab(vocab_file): - """Loads a vocabulary file into a dictionary.""" - vocab = collections.OrderedDict() +def load_vocab_file(vocab_file): + """ + Loads a vocabulary file and turns into a {token:id} dictionary. + """ + vocab_dict = collections.OrderedDict() index = 0 - with open(vocab_file, "r") as reader: + with open(vocab_file, "r") as vocab: while True: - token = convert_to_unicode(reader.readline()) + token = convert_to_unicode(vocab.readline()) if not token: break token = token.strip() - vocab[token] = index + vocab_dict[token] = index index += 1 - return vocab + return vocab_dict -def convert_by_vocab(vocab, items): - """Converts a sequence of [tokens|ids] using the vocab.""" +def convert_by_vocab_dict(vocab_dict, items): + """ + Converts a sequence of [tokens|ids] according to the vocab dict. + """ output = [] for item in items: - if item in vocab: - output.append(vocab[item]) + if item in vocab_dict: + output.append(vocab_dict[item]) else: - output.append(vocab[""]) + output.append(vocab_dict[""]) return output -def convert_tokens_to_ids(vocab, tokens): - return convert_by_vocab(vocab, tokens) - - -def convert_ids_to_tokens(inv_vocab, ids): - return convert_by_vocab(inv_vocab, ids) - - -def whitespace_tokenize(text): - """Runs basic whitespace cleaning and splitting on a piece of text.""" - text = text.strip() - if not text: - return [] - tokens = text.split() - return tokens - - class WhiteSpaceTokenizer(): - """Runs end-to-end tokenziation.""" + """ + Whitespace tokenizer. + """ def __init__(self, vocab_file): - self.vocab = load_vocab(vocab_file) - self.inv_vocab = {v: k for k, v in self.vocab.items()} - self.basic_tokenizer = BasicTokenizer() - - def tokenize(self, text): - return self.basic_tokenizer.tokenize(text) - - def convert_tokens_to_ids(self, tokens): - return convert_by_vocab(self.vocab, tokens) - - def convert_ids_to_tokens(self, ids): - return convert_by_vocab(self.inv_vocab, ids) - - -class BasicTokenizer(): - """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" - - def __init__(self): - """Constructs a BasicTokenizer.""" + self.vocab_dict = load_vocab_file(vocab_file) + self.inv_vocab_dict = {index: token for token, index in self.vocab_dict.items()} + + def _is_whitespace_char(self, char): + """ + Checks if it is a whitespace character(regard "\t", "\n", "\r" as whitespace here). + """ + if char in (" ", "\t", "\n", "\r"): + return True + uni = unicodedata.category(char) + if uni == "Zs": + return True + return False - def tokenize(self, text): - """Tokenizes a piece of text.""" - text = convert_to_unicode(text) - text = self._clean_text(text) - return whitespace_tokenize(text) + def _is_control_char(self, char): + """ + Checks if it is a control character. + """ + if char in ("\t", "\n", "\r"): + return False + uni = unicodedata.category(char) + if uni in ("Cc", "Cf"): + return True + return False def _clean_text(self, text): - """Performs invalid character removal and whitespace cleanup on text.""" + """ + Remove invalid characters and cleanup whitespace. + """ output = [] for char in text: cp = ord(char) - if cp == 0 or cp == 0xfffd or _is_control(char): + if cp == 0 or cp == 0xfffd or self._is_control_char(char): continue - if _is_whitespace(char): + if self._is_whitespace_char(char): output.append(" ") else: output.append(char) return "".join(output) + def _whitespace_tokenize(self, text): + """ + Clean whitespace and split text into tokens. + """ + text = text.strip() + if not text: + tokens = [] + else: + tokens = text.split() + return tokens -def _is_whitespace(char): - """Checks whether `chars` is a whitespace character.""" - # \t, \n, and \r are technically contorl characters but we treat them - # as whitespace since they are generally considered as such. - if char in (" ", "\t", "\n", "\r"): - return True - cat = unicodedata.category(char) - if cat == "Zs": - return True - return False + def tokenize(self, text): + """ + Tokenizes text. + """ + text = convert_to_unicode(text) + text = self._clean_text(text) + tokens = self._whitespace_tokenize(text) + return tokens + def convert_tokens_to_ids(self, tokens): + return convert_by_vocab_dict(self.vocab_dict, tokens) -def _is_control(char): - """Checks whether `chars` is a control character.""" - # These are technically control characters but we count them as whitespace - # characters. - if char in ("\t", "\n", "\r"): - return False - cat = unicodedata.category(char) - if cat in ("Cc", "Cf"): - return True - return False - - -def _is_punctuation(char): - """Checks whether `chars` is a punctuation character.""" - cp = ord(char) - # We treat all non-letter/number ASCII as punctuation. - # Characters such as "^", "$", and "`" are not in the Unicode - # Punctuation class but we treat them as punctuation anyways, for - # consistency. - if ((33 <= cp <= 47) or (58 <= cp <= 64) or (91 <= cp <= 96) or (123 <= cp <= 126)): - return True - cat = unicodedata.category(char) - if cat.startswith("P"): - return True - return False + def convert_ids_to_tokens(self, ids): + return convert_by_vocab_dict(self.inv_vocab_dict, ids)