未验证 提交 1c4b18c0 编写于 作者: Z Zhong Hui 提交者: GitHub

fix package requirement of regex and sentencepiece (#5261)

上级 4dba6afa
...@@ -208,7 +208,7 @@ class GPT2Dataset(paddle.io.Dataset): ...@@ -208,7 +208,7 @@ class GPT2Dataset(paddle.io.Dataset):
# -INF mask value as default # -INF mask value as default
attention_mask = (attention_mask - 1.0) * 1e9 attention_mask = (attention_mask - 1.0) * 1e9
# Bool mask of attention # Bool mask of attention
# attention_mask = attention_mask.astype("float32") attention_mask = attention_mask.astype("float32")
return [tokens, loss_mask, attention_mask, position_ids, labels] return [tokens, loss_mask, attention_mask, position_ids, labels]
def _get_single_sample_from_idx(self, doc_index_f, doc_index_l, offset_f, def _get_single_sample_from_idx(self, doc_index_f, doc_index_l, offset_f,
......
...@@ -13,16 +13,17 @@ ...@@ -13,16 +13,17 @@
# limitations under the License. # limitations under the License.
import collections import collections
import math
import numpy as np import numpy as np
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.tensor as tensor
import paddle.nn.functional as F import paddle.nn.functional as F
import math import paddle.tensor as tensor
from paddle.fluid import layers from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list
from .. import PretrainedModel, register_base_model from .. import PretrainedModel, register_base_model
from paddle.nn.layer.transformer import _convert_param_attr_to_list
__all__ = [ __all__ = [
'GPT2Model', 'GPT2Model',
......
...@@ -13,14 +13,13 @@ ...@@ -13,14 +13,13 @@
# limitations under the License. # limitations under the License.
import os import os
import regex as re from functools import lru_cache
import unicodedata from collections import namedtuple
import json import json
import sentencepiece
import jieba import jieba
from paddle.utils import try_import
from functools import lru_cache
from collections import namedtuple
from .. import PretrainedTokenizer from .. import PretrainedTokenizer
from ..tokenizer_utils import convert_to_unicode, whitespace_tokenize,\ from ..tokenizer_utils import convert_to_unicode, whitespace_tokenize,\
_is_whitespace, _is_control, _is_punctuation _is_whitespace, _is_control, _is_punctuation
...@@ -122,7 +121,8 @@ class GPT2ChineseTokenizer(PretrainedTokenizer): ...@@ -122,7 +121,8 @@ class GPT2ChineseTokenizer(PretrainedTokenizer):
self.max_len = max_len if max_len is not None else int(1e12) self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file)) self.encoder = json.load(open(vocab_file))
self.decoder = {v: k for k, v in self.encoder.items()} self.decoder = {v: k for k, v in self.encoder.items()}
self.sp = sentencepiece.SentencePieceProcessor(model_file=model_file) mod = try_import("sentencepiece")
self.sp = mod.SentencePieceProcessor(model_file=model_file)
self.translator = str.maketrans(" \n", "\u2582\u2583") self.translator = str.maketrans(" \n", "\u2582\u2583")
def tokenize(self, text): def tokenize(self, text):
...@@ -220,7 +220,7 @@ class GPT2Tokenizer(PretrainedTokenizer): ...@@ -220,7 +220,7 @@ class GPT2Tokenizer(PretrainedTokenizer):
bpe_merges = [tuple(merge.split()) for merge in bpe_data] bpe_merges = [tuple(merge.split()) for merge in bpe_data]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {} self.cache = {}
re = try_import("regex")
self.pat = re.compile( self.pat = re.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
) )
...@@ -295,6 +295,7 @@ class GPT2Tokenizer(PretrainedTokenizer): ...@@ -295,6 +295,7 @@ class GPT2Tokenizer(PretrainedTokenizer):
def tokenize(self, text): def tokenize(self, text):
""" Tokenize a string. """ """ Tokenize a string. """
bpe_tokens = [] bpe_tokens = []
re = try_import("regex")
for token in re.findall(self.pat, text): for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend( bpe_tokens.extend(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册