diff --git a/fluid/neural_machine_translation/transformer/util.py b/fluid/neural_machine_translation/transformer/util.py index 77dc3868d894d485b59b67d5d75f5a08a53dbb3c..190abf92f4f48bfc943bd99bf61a222cc6c9d2f0 100644 --- a/fluid/neural_machine_translation/transformer/util.py +++ b/fluid/neural_machine_translation/transformer/util.py @@ -17,35 +17,6 @@ _ALPHANUMERIC_CHAR_SET = set( unicodedata.category(six.unichr(i)).startswith("N"))) -def tokens_to_ustr(tokens): - """ - Convert a list of tokens to a unicode string. - """ - token_is_alnum = [t[0] in _ALPHANUMERIC_CHAR_SET for t in tokens] - ret = [] - for i, token in enumerate(tokens): - if i > 0 and token_is_alnum[i - 1] and token_is_alnum[i]: - ret.append(u" ") - ret.append(token) - return "".join(ret) - - -def subtoken_ids_to_tokens(subtoken_ids, vocabs): - """ - Convert a list of subtoken(wordpiece) ids to a list of tokens. - """ - concatenated = "".join( - [vocabs.get(subtoken_id, u"") for subtoken_id in subtoken_ids]) - split = concatenated.split("_") - ret = [] - for t in split: - if t: - unescaped = unescape_token(t + "_") - if unescaped: - ret.append(unescaped) - return ret - - def unescape_token(escaped_token): """ Inverse of encoding escaping. @@ -65,9 +36,33 @@ def unescape_token(escaped_token): return _UNESCAPE_REGEX.sub(match, trimmed) -def subword_ids_to_str(ids, vocabs): +def subtoken_ids_to_str(subtoken_ids, vocabs): """ Convert a list of subtoken(word piece) ids to a native string. Refer to SubwordTextEncoder in Tensor2Tensor. """ - return tokens_to_ustr(subtoken_ids_to_tokens(ids, vocabs)).decode("utf-8") + subtokens = [vocabs.get(subtoken_id, u"") for subtoken_id in subtoken_ids] + + # Convert a list of subtokens to a list of tokens. + concatenated = "".join([ + t if isinstance(t, unicode) else t.decode("utf-8") for t in subtokens + ]) + split = concatenated.split("_") + tokens = [] + for t in split: + if t: + unescaped = unescape_token(t + "_") + if unescaped: + tokens.append(unescaped) + + # Convert a list of tokens to a unicode string (by inserting spaces bewteen + # word tokens). + token_is_alnum = [t[0] in _ALPHANUMERIC_CHAR_SET for t in tokens] + ret = [] + for i, token in enumerate(tokens): + if i > 0 and token_is_alnum[i - 1] and token_is_alnum[i]: + ret.append(u" ") + ret.append(token) + seq = "".join(ret) + + return seq.encode("utf-8")