提交 0bdbbd73 编写于 作者: W wuzewu

Fix the compatibility problem of text generation module

上级 733b6c21
......@@ -24,6 +24,7 @@ from paddlehub.utils import log, parser, utils
from paddlehub.utils.paddlex import download, ResourceNotFoundError
from paddlehub.server.server_source import ServerConnectionError
from paddlehub.module import Module
from paddlehub.text.bert_tokenizer import BertTokenizer
# In order to maintain the compatibility of the old version, we put the relevant
# compatible code in the paddlehub.compat package, and mapped some modules referenced
......@@ -33,6 +34,9 @@ from paddlehub.compat.module.processor import BaseProcessor
from paddlehub.compat.module.nlp_module import NLPPredictionModule, TransformerModule
from paddlehub.compat.type import DataType
from paddlehub.compat import task
from paddlehub.compat.datasets import couplet
from paddlehub.compat.task.config import RunConfig
from paddlehub.compat.task.text_generation_task import TextGenerationTask
sys.modules['paddlehub.io.parser'] = parser
sys.modules['paddlehub.common.logger'] = log
......@@ -41,3 +45,5 @@ sys.modules['paddlehub.common.utils'] = utils
sys.modules['paddlehub.reader'] = task
common = EasyDict(paddle_helper=paddle_utils)
dataset = EasyDict(Couplet=couplet.Couplet)
AdamWeightDecayStrategy = lambda: 0
# coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from paddlehub.utils.log import logger
class InputExample(object):
'''
Input data structure of BERT/ERNIE, can satisfy single sequence task like
text classification, sequence lableing; Sequence pair task like dialog
task.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
'''
def __init__(self, guid, text_a, text_b=None, label=None):
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label
def __str__(self):
if self.text_b is None:
return 'text={}\tlabel={}'.format(self.text_a, self.label)
else:
return 'text_a={}\ttext_b={},label={}'.format(self.text_a, self.text_b, self.label)
class BaseDataset(object):
def __init__(self,
base_path,
train_file=None,
dev_file=None,
test_file=None,
predict_file=None,
label_file=None,
label_list=None,
train_file_with_header=False,
dev_file_with_header=False,
test_file_with_header=False,
predict_file_with_header=False):
if not (train_file or dev_file or test_file):
raise ValueError('At least one file should be assigned')
self.base_path = base_path
self.train_file = train_file
self.dev_file = dev_file
self.test_file = test_file
self.predict_file = predict_file
self.label_file = label_file
self.label_list = label_list
self.train_examples = []
self.dev_examples = []
self.test_examples = []
self.predict_examples = []
self.if_file_with_header = {
'train': train_file_with_header,
'dev': dev_file_with_header,
'test': test_file_with_header,
'predict': predict_file_with_header
}
if train_file:
self._load_train_examples()
if dev_file:
self._load_dev_examples()
if test_file:
self._load_test_examples()
if predict_file:
self._load_predict_examples()
if self.label_file:
if not self.label_list:
self.label_list = self._load_label_data()
else:
logger.warning('As label_list has been assigned, label_file is noneffective')
if self.label_list:
self.label_index = dict(zip(self.label_list, range(len(self.label_list))))
def get_train_examples(self):
return self.train_examples
def get_dev_examples(self):
return self.dev_examples
def get_test_examples(self):
return self.test_examples
def get_val_examples(self):
return self.get_dev_examples()
def get_predict_examples(self):
return self.predict_examples
def get_examples(self, phase):
if phase == 'train':
return self.get_train_examples()
elif phase == 'dev':
return self.get_dev_examples()
elif phase == 'test':
return self.get_test_examples()
elif phase == 'val':
return self.get_val_examples()
elif phase == 'predict':
return self.get_predict_examples()
else:
raise ValueError('Invalid phase: %s' % phase)
def get_labels(self):
return self.label_list
@property
def num_labels(self):
return len(self.label_list)
# To be compatible with ImageClassificationDataset
def label_dict(self):
return {index: key for index, key in enumerate(self.label_list)}
def _load_train_examples(self):
self.train_path = os.path.join(self.base_path, self.train_file)
self.train_examples = self._read_file(self.train_path, phase='train')
def _load_dev_examples(self):
self.dev_path = os.path.join(self.base_path, self.dev_file)
self.dev_examples = self._read_file(self.dev_path, phase='dev')
def _load_test_examples(self):
self.test_path = os.path.join(self.base_path, self.test_file)
self.test_examples = self._read_file(self.test_path, phase='test')
def _load_predict_examples(self):
self.predict_path = os.path.join(self.base_path, self.predict_file)
self.predict_examples = self._read_file(self.predict_path, phase='predict')
def _read_file(self, path, phase=None):
raise NotImplementedError
def _load_label_data(self):
with open(os.path.join(self.base_path, self.label_file), 'r', encoding='utf8') as file:
return file.read().strip().split('\n')
def __str__(self):
return 'Dataset: %s with %i train examples, %i dev examples and %i test examples' % (
self.__class__.__name__, len(self.train_examples), len(self.dev_examples), len(self.test_examples))
# coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import codecs
import csv
import os
import paddlehub.env as hubenv
from paddlehub.compat.datasets.base_dataset import InputExample
from paddlehub.compat.datasets.nlp_dataset import GenerationDataset
from paddlehub.utils.download import download_data
@download_data('https://bj.bcebos.com/paddlehub-dataset/couplet.tar.gz')
class Couplet(GenerationDataset):
'''An open source couplet dataset, see https://github.com/v-zich/couplet-clean-dataset for details.'''
def __init__(self, tokenizer=None, max_seq_len=None):
dataset_dir = os.path.join(hubenv.DATA_HOME, 'couplet')
with open(os.path.join(dataset_dir, 'vocab.txt'), encoding='utf8') as vocab_file:
label_list = [line.strip() for line in vocab_file.readlines()]
super(Couplet, self).__init__(
base_path=dataset_dir,
train_file='train.tsv',
dev_file='dev.tsv',
test_file='test.tsv',
label_list=label_list,
tokenizer=tokenizer,
max_seq_len=max_seq_len)
def _read_file(self, input_file, phase=None):
'''Reads a tab separated value file.'''
with codecs.open(input_file, 'r', encoding='UTF-8') as f:
reader = csv.reader(f, delimiter='\t', quotechar=None)
examples = []
seq_id = 0
for line in reader:
example = InputExample(guid=seq_id, label=line[1], text_a=line[0])
seq_id += 1
examples.append(example)
return examples
# coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import csv
import collections
import numpy as np
from tqdm import tqdm
from paddlehub.compat.datasets.base_dataset import InputExample, BaseDataset
from paddlehub.utils.log import logger
from paddlehub.text.tokenizer import CustomTokenizer
from paddlehub.text.bert_tokenizer import BertTokenizer
class BaseNLPDataset(BaseDataset):
def __init__(self,
base_path,
train_file=None,
dev_file=None,
test_file=None,
predict_file=None,
label_file=None,
label_list=None,
train_file_with_header=False,
dev_file_with_header=False,
test_file_with_header=False,
predict_file_with_header=False,
tokenizer=None,
max_seq_len=128):
super(BaseNLPDataset, self).__init__(
base_path=base_path,
train_file=train_file,
dev_file=dev_file,
test_file=test_file,
predict_file=predict_file,
label_file=label_file,
label_list=label_list,
train_file_with_header=train_file_with_header,
dev_file_with_header=dev_file_with_header,
test_file_with_header=test_file_with_header,
predict_file_with_header=predict_file_with_header)
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
self._train_records = None
self._dev_records = None
self._test_records = None
self._predict_records = None
@property
def train_records(self):
if not self._train_records:
examples = self.train_examples
if not self.tokenizer or not examples:
return []
logger.info('Processing the train set...')
self._train_records = self._convert_examples_to_records(examples, phase='train')
return self._train_records
@property
def dev_records(self):
if not self._dev_records:
examples = self.dev_examples
if not self.tokenizer or not examples:
return []
logger.info('Processing the dev set...')
self._dev_records = self._convert_examples_to_records(examples, phase='dev')
return self._dev_records
@property
def test_records(self):
if not self._test_records:
examples = self.test_examples
if not self.tokenizer or not examples:
return []
logger.info('Processing the test set...')
self._test_records = self._convert_examples_to_records(examples, phase='test')
return self._test_records
@property
def predict_records(self):
if not self._predict_records:
examples = self.predict_examples
if not self.tokenizer or not examples:
return []
logger.info('Processing the predict set...')
self._predict_records = self._convert_examples_to_records(examples, phase='predict')
return self._predict_records
def _read_file(self, input_file, phase=None):
'''Reads a tab separated value file.'''
has_warned = False
with io.open(input_file, 'r', encoding='UTF-8') as file:
reader = csv.reader(file, delimiter='\t', quotechar=None)
examples = []
for (i, line) in enumerate(reader):
if i == 0:
ncol = len(line)
if self.if_file_with_header[phase]:
continue
if phase != 'predict':
if ncol == 1:
raise Exception(
'the %s file: %s only has one column but it is not a predict file' % (phase, input_file))
elif ncol == 2:
example = InputExample(guid=i, text_a=line[0], label=line[1])
elif ncol == 3:
example = InputExample(guid=i, text_a=line[0], text_b=line[1], label=line[2])
else:
raise Exception('the %s file: %s has too many columns (should <=3)' % (phase, input_file))
else:
if ncol == 1:
example = InputExample(guid=i, text_a=line[0])
elif ncol == 2:
if not has_warned:
logger.warning(
'the predict file: %s has 2 columns, as it is a predict file, the second one will be regarded as text_b'
% (input_file))
has_warned = True
example = InputExample(guid=i, text_a=line[0], text_b=line[1])
else:
raise Exception('the predict file: %s has too many columns (should <=2)' % (input_file))
examples.append(example)
return examples
def _convert_examples_to_records(self, examples, phase):
'''
Returns a list[dict] including all the input information what the model need.
Args:
examples (list): the data example, returned by _read_file.
phase (str): the processing phase, can be 'train' 'dev' 'test' or 'predict'.
Returns:
a list with all the examples record.
'''
records = []
with tqdm(total=len(examples)) as process_bar:
for example in examples:
record = self.tokenizer.encode(
text=example.text_a, text_pair=example.text_b, max_seq_len=self.max_seq_len)
# CustomTokenizer will tokenize the text firstly and then lookup words in the vocab
# When all words are not found in the vocab, the text will be dropped.
if not record:
logger.info('The text %s has been dropped as it has no words in the vocab after tokenization.' %
example.text_a)
continue
if example.label:
record['label'] = self.label_list.index(example.label) if self.label_list else float(example.label)
records.append(record)
process_bar.update(1)
return records
def get_train_records(self, shuffle=False):
return self.get_records('train', shuffle=shuffle)
def get_dev_records(self, shuffle=False):
return self.get_records('dev', shuffle=shuffle)
def get_test_records(self, shuffle=False):
return self.get_records('test', shuffle=shuffle)
def get_val_records(self, shuffle=False):
return self.get_records('val', shuffle=shuffle)
def get_predict_records(self, shuffle=False):
return self.get_records('predict', shuffle=shuffle)
def get_records(self, phase, shuffle=False):
if phase == 'train':
records = self.train_records
elif phase == 'dev':
records = self.dev_records
elif phase == 'test':
records = self.test_records
elif phase == 'val':
records = self.dev_records
elif phase == 'predict':
records = self.predict_records
else:
raise ValueError('Invalid phase: %s' % phase)
if shuffle:
np.random.shuffle(records)
return records
def get_feed_list(self, phase):
records = self.get_records(phase)
if records:
feed_list = list(records[0].keys())
else:
feed_list = []
return feed_list
def batch_records_generator(self, phase, batch_size, shuffle=True, pad_to_batch_max_seq_len=False):
''' generate a batch of records, usually used in dynamic graph mode.
Args:
phase (str): the dataset phase, can be 'train', 'dev', 'val', 'test' or 'predict'.
batch_size (int): the data batch size
shuffle (bool): if set to True, will shuffle the dataset.
pad_to_batch_max_seq_len (bool): if set to True, will dynamically pad to the max sequence length of the batch data.
Only recommended to set to True when the model has used RNN.
'''
records = self.get_records(phase, shuffle=shuffle)
batch_records = []
batch_lens = []
for record in records:
batch_records.append(record)
if pad_to_batch_max_seq_len:
# This may reduce the processing speed
tokens_wo_pad = [
token for token in self.tokenizer.decode(record, only_convert_to_tokens=True)
if token != self.tokenizer.pad_token
]
batch_lens.append(len(tokens_wo_pad))
if len(batch_records) == batch_size:
if pad_to_batch_max_seq_len:
# This may reduce the processing speed.
batch_max_seq_len = max(batch_lens)
for record in batch_records:
for key, value in record.items():
if isinstance(value, list):
# This may not be universal
record[key] = value[:batch_max_seq_len]
rev_batch_records = {key: [record[key] for record in batch_records] for key in batch_records[0]}
yield rev_batch_records
batch_records = []
batch_lens = []
if batch_records:
if pad_to_batch_max_seq_len:
# This may reduce the processing speed.
batch_max_seq_len = max(batch_lens)
for record in batch_records:
for key in record.keys():
if isinstance(record[key], list):
record[key] = record[key][:batch_max_seq_len]
rev_batch_records = {key: [record[key] for record in batch_records] for key in batch_records[0]}
yield rev_batch_records
class GenerationDataset(BaseNLPDataset):
def __init__(self,
base_path,
train_file=None,
dev_file=None,
test_file=None,
predict_file=None,
label_file=None,
label_list=None,
train_file_with_header=False,
dev_file_with_header=False,
test_file_with_header=False,
predict_file_with_header=False,
tokenizer=None,
max_seq_len=128,
split_char='\002',
start_token='<s>',
end_token='</s>',
unk_token='<unk>'):
self.split_char = split_char
self.start_token = start_token
self.end_token = end_token
self.unk_token = unk_token
super(GenerationDataset, self).__init__(
base_path=base_path,
train_file=train_file,
dev_file=dev_file,
test_file=test_file,
predict_file=predict_file,
label_file=label_file,
label_list=label_list,
train_file_with_header=train_file_with_header,
dev_file_with_header=dev_file_with_header,
test_file_with_header=test_file_with_header,
predict_file_with_header=predict_file_with_header,
tokenizer=tokenizer,
max_seq_len=max_seq_len)
def _convert_examples_to_records(self, examples, phase):
'''
Returns a list[dict] including all the input information what the model need.
Args:
examples (list): the data example, returned by _read_file.
phase (str): the processing phase, can be 'train' 'dev' 'test' or 'predict'.
Returns:
a list with all the examples record.
'''
records = []
with tqdm(total=len(examples)) as process_bar:
for example in examples:
record = self.tokenizer.encode(
text=example.text_a.split(self.split_char),
text_pair=example.text_b.split(self.split_char) if example.text_b else None,
max_seq_len=self.max_seq_len)
if example.label:
expand_label = [self.start_token] + example.label.split(
self.split_char)[:self.max_seq_len - 2] + [self.end_token]
expand_label_id = [
self.label_index.get(label, self.label_index[self.unk_token]) for label in expand_label
]
record['label'] = expand_label_id[1:] + [self.label_index[self.end_token]
] * (self.max_seq_len - len(expand_label) + 1)
record['dec_input'] = expand_label_id[:-1] + [self.label_index[self.end_token]
] * (self.max_seq_len - len(expand_label) + 1)
records.append(record)
process_bar.update(1)
return records
......@@ -200,7 +200,7 @@ def set_op_attr(program: paddle.static.Program, is_test: bool = False):
@contextlib.contextmanager
def static_mode_guard():
''''''
'''enter static graph mode with `with` statement.'''
premode = 'static' if not paddle.in_dynamic_mode() else 'dynamic'
if premode == 'dynamic':
......@@ -213,7 +213,7 @@ def static_mode_guard():
def run_in_static_mode(func):
''''''
'''Decorate a function to run in static graph mode.'''
def runner(*args, **kwargs):
with static_mode_guard():
......
......@@ -27,6 +27,7 @@ from paddlehub.compat import paddle_utils
from paddlehub.compat.task.config import RunConfig
from paddlehub.compat.task.hook import TaskHooks
from paddlehub.compat.task.task_utils import RunEnv, RunState
from paddlehub.compat.task.checkpoint import load_checkpoint
from paddlehub.utils.log import logger
from paddlehub.utils.utils import generate_tempdir
......@@ -134,6 +135,26 @@ class BaseTask(object):
def exit_phase(self):
self._phases = self._phases[:-1]
def init_if_necessary(self):
if not self.is_checkpoint_loaded:
if not self.load_checkpoint():
self.exe.run(self._base_startup_program)
self.is_checkpoint_loaded = True
self.is_best_model_loaded = False
def init_if_load_best_model(self):
if not self.is_best_model_loaded:
best_model_path = os.path.join(self.config.checkpoint_dir, "best_model")
logger.info("Load the best model from %s" % best_model_path)
if os.path.exists(best_model_path):
self.load_parameters(best_model_path)
self.is_checkpoint_loaded = False
self.is_best_model_loaded = True
else:
self.init_if_necessary()
else:
logger.info("The best model has been loaded")
def _build_env(self):
'''Building the program and strategy for specific running phase.'''
if self.env.is_inititalized:
......@@ -277,12 +298,25 @@ class BaseTask(object):
@property
def generator(self) -> Generator:
def data_generator(records):
def wrapper():
for record in records:
values = []
for feed_name in self.feed_list:
values.append(record[feed_name])
yield values
return wrapper
if self.is_predict_phase:
data = self._predict_data
records = self._predict_data
else:
data = None
self.env.generator = self._base_data_reader.data_generator(
batch_size=self.config.batch_size, phase=self.phase, data=data, return_list=True)
if self.is_train_phase:
shuffle = True
else:
shuffle = False
records = self.dataset.get_records(phase=self.phase, shuffle=shuffle)
self.env.generator = data_generator(records)
return self.env.generator
......@@ -325,20 +359,15 @@ class BaseTask(object):
@property
def feed_list(self) -> List[str]:
if self._compatible_mode:
feed_list = [varname for varname in self._base_feed_list]
if self.is_train_phase or self.is_test_phase:
feed_list += [label.name for label in self.labels]
else:
if not self.env.is_inititalized:
self._build_env()
if not self.env.is_inititalized:
self._build_env()
if self._predict_data:
feed_list = list(self._predict_data[0].keys())
else:
feed_list = self.dataset.get_feed_list(self.phase)
if self._predict_data:
feed_list = list(self._predict_data[0].keys())
else:
feed_list = self.dataset.get_feed_list(self.phase)
feed_list = [feed_name for feed_name in feed_list if feed_name in self.main_program.global_block().vars]
feed_list = [feed_name for feed_name in feed_list if feed_name in self.main_program.global_block().vars]
return feed_list
@property
......@@ -544,6 +573,23 @@ class BaseTask(object):
# The first key will be used as main metrics to update the best model
raise NotImplementedError
def load_checkpoint(self):
is_load_successful, self.env.current_epoch, self.env.current_step, self.best_score = load_checkpoint(
self.config.checkpoint_dir, self.exe, main_program=self.main_program)
# Revise max_train_steps when incremental training
if is_load_successful:
self.max_train_steps = self.env.current_step + self.max_train_steps / self.config.num_epoch * (
self.config.num_epoch - self.env.current_epoch + 1)
return is_load_successful
def load_parameters(self, dirname):
def if_exist(var):
path = os.path.join(dirname, var.name)
return os.path.exists(path)
paddle.static.load(executor=self.exe, model_path=dirname, program=self.main_program)
def save_inference_model(self, dirname: str, model_filename: str = None, params_filename: str = None):
with self.phase_guard('predict'):
paddle.static.save_inference_model(
......@@ -688,7 +734,8 @@ class BaseTask(object):
self,
data: List[Any] = None,
label_list: List[Any] = None,
return_result: bool = False,
load_best_model: bool = True,
return_result: bool = True,
accelerate_mode: bool = True,
) -> List[RunState]:
'''
......@@ -710,6 +757,9 @@ class BaseTask(object):
self._label_list = label_list
self._predict_start_event()
if load_best_model:
self.init_if_load_best_model()
if not self.accelerate_mode:
run_states = self._run()
else:
......@@ -719,7 +769,7 @@ class BaseTask(object):
self._predict_end_event(run_states)
self._predict_data = None
if return_result or not self._compatible_mode:
if return_result:
return self._postprocessing(run_states)
return run_states
......@@ -746,18 +796,14 @@ class BaseTask(object):
RunState: the running result of specific phase
'''
with paddle.static.program_guard(self.main_program, self.startup_program):
if self.config.use_pyreader:
data_loader = paddle.io.DataLoader.from_generator(
feed_list=self.feed_var_list, capacity=64, use_double_buffer=True, iterable=True)
if self._compatible_mode:
data_reader = data_loader.set_batch_generator(self.generator, places=self.places)
else:
if self.is_predict_phase:
data_reader = data_loader.set_sample_generator(
self.generator, places=self.places, batch_size=self.config.batch_size, drop_last=False)
else:
data_reader = data_loader.set_sample_generator(
self.generator, places=self.places, batch_size=self.config.batch_size, drop_last=True)
data_loader = paddle.io.DataLoader.from_generator(
feed_list=self.feed_var_list, capacity=64, use_double_buffer=True, iterable=True)
if self.is_predict_phase:
data_reader = data_loader.set_sample_generator(
self.generator, places=self.places, batch_size=self.config.batch_size, drop_last=False)
else:
data_reader = data_loader.set_sample_generator(
self.generator, places=self.places, batch_size=self.config.batch_size, drop_last=True)
global_run_states = []
period_run_states = []
......
// Copyright 2019 The Paddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
syntax = "proto3";
option optimize_for = LITE_RUNTIME;
package paddlehub.task.checkpoint;
message CheckPoint {
int64 current_epoch = 1;
int64 global_step = 2;
string latest_model_dir = 3;
double best_score = 4;
}
\ No newline at end of file
# coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Tuple
import paddle
from paddlehub.compat.task import checkpoint_pb2
from paddlehub.utils.log import logger
CKPT_FILE_NAME = 'ckpt.meta'
def load_checkpoint(checkpoint_dir: str, exe: paddle.static.Executor,
main_program: paddle.static.Program) -> Tuple[bool, int, int, float]:
ckpt_meta_path = os.path.join(checkpoint_dir, CKPT_FILE_NAME)
ckpt = checkpoint_pb2.CheckPoint()
logger.info('Try loading checkpoint from {}'.format(ckpt_meta_path))
if os.path.exists(ckpt_meta_path):
with open(ckpt_meta_path, 'rb') as f:
ckpt.ParseFromString(f.read())
current_epoch = 1
global_step = 0
best_score = -999
if ckpt.latest_model_dir:
paddle.static.load(executor=exe, model_path=ckpt.latest_model_dir, program=main_program)
# Compatible with older versions without best_score in checkpoint_pb2
try:
best_score = ckpt.best_score
except:
best_score = -999
logger.info('PaddleHub model checkpoint loaded. current_epoch={}, '
'global_step={}, best_score={:.5f}'.format(ckpt.current_epoch, ckpt.global_step, best_score))
return True, ckpt.current_epoch, ckpt.global_step, best_score
logger.info('PaddleHub model checkpoint not found, start from scratch...')
return False, current_epoch, global_step, best_score
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: checkpoint.proto
import sys
_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode('latin1'))
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='checkpoint.proto',
package='paddlehub.task.checkpoint',
syntax='proto3',
serialized_pb=_b(
'\n\x10\x63heckpoint.proto\x12\x19paddlehub.task.checkpoint\"f\n\nCheckPoint\x12\x15\n\rcurrent_epoch\x18\x01 \x01(\x03\x12\x13\n\x0bglobal_step\x18\x02 \x01(\x03\x12\x18\n\x10latest_model_dir\x18\x03 \x01(\t\x12\x12\n\nbest_score\x18\x04 \x01(\x01\x42\x02H\x03\x62\x06proto3'
))
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
_CHECKPOINT = _descriptor.Descriptor(
name='CheckPoint',
full_name='paddlehub.task.checkpoint.CheckPoint',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='current_epoch',
full_name='paddlehub.task.checkpoint.CheckPoint.current_epoch',
index=0,
number=1,
type=3,
cpp_type=2,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='global_step',
full_name='paddlehub.task.checkpoint.CheckPoint.global_step',
index=1,
number=2,
type=3,
cpp_type=2,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='latest_model_dir',
full_name='paddlehub.task.checkpoint.CheckPoint.latest_model_dir',
index=2,
number=3,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='best_score',
full_name='paddlehub.task.checkpoint.CheckPoint.best_score',
index=3,
number=4,
type=1,
cpp_type=5,
label=1,
has_default_value=False,
default_value=float(0),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=47,
serialized_end=149,
)
DESCRIPTOR.message_types_by_name['CheckPoint'] = _CHECKPOINT
CheckPoint = _reflection.GeneratedProtocolMessageType(
'CheckPoint',
(_message.Message, ),
dict(
DESCRIPTOR=_CHECKPOINT,
__module__='checkpoint_pb2'
# @@protoc_insertion_point(class_scope:paddlehub.task.checkpoint.CheckPoint)
))
_sym_db.RegisterMessage(CheckPoint)
DESCRIPTOR.has_options = True
DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('H\003'))
# @@protoc_insertion_point(module_scope)
......@@ -14,6 +14,7 @@
# limitations under the License.
import time
from typing import Callable
class RunConfig(object):
......@@ -27,7 +28,8 @@ class RunConfig(object):
use_cuda: bool = True,
checkpoint_dir: str = None,
num_epoch: int = 1,
batch_size: int = 32):
batch_size: int = 32,
strategy: Callable = None):
''' Construct finetune Config '''
self.log_interval = log_interval
self.eval_interval = eval_interval
......
# coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import math
from typing import List
def _get_ngrams(segment: str, max_order: int):
"""
Extracts all n-grams upto a given maximum order from an input segment.
Args:
segment: text segment from which n-grams will be extracted.
max_order: maximum length in tokens of the n-grams returned by this
methods.
Returns:
The Counter containing all n-grams upto max_order in segment
with a count of how many times each n-gram occurred.
"""
ngram_counts = collections.Counter()
for order in range(1, max_order + 1):
for i in range(0, len(segment) - order + 1):
ngram = tuple(segment[i:i + order])
ngram_counts[ngram] += 1
return ngram_counts
def compute_bleu(reference_corpus: List, translation_corpus: List, max_order: int = 4, smooth: bool = False):
'''
Computes BLEU score of translated segments against one or more references.
Args:
reference_corpus: list of lists of references for each translation. Each
reference should be tokenized into a list of tokens.
translation_corpus: list of translations to score. Each translation
should be tokenized into a list of tokens.
max_order: Maximum n-gram order to use when computing BLEU score.
smooth: Whether or not to apply Lin et al. 2004 smoothing.
Returns:
3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram
precisions and brevity penalty.
'''
matches_by_order = [0] * max_order
possible_matches_by_order = [0] * max_order
reference_length = 0
translation_length = 0
for (reference, translation) in zip(reference_corpus, translation_corpus):
reference_length += len(reference)
translation_length += len(translation)
merged_ref_ngram_counts = collections.Counter()
merged_ref_ngram_counts |= _get_ngrams(reference, max_order)
translation_ngram_counts = _get_ngrams(translation, max_order)
overlap = translation_ngram_counts & merged_ref_ngram_counts
for ngram in overlap:
matches_by_order[len(ngram) - 1] += overlap[ngram]
for order in range(1, max_order + 1):
possible_matches = len(translation) - order + 1
if possible_matches > 0:
possible_matches_by_order[order - 1] += possible_matches
precisions = [0] * max_order
for i in range(0, max_order):
if smooth:
precisions[i] = ((matches_by_order[i] + 1.) / (possible_matches_by_order[i] + 1.))
else:
if possible_matches_by_order[i] > 0:
precisions[i] = (float(matches_by_order[i]) / possible_matches_by_order[i])
else:
precisions[i] = 0.0
if min(precisions) > 0:
p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions)
geo_mean = math.exp(p_log_sum)
else:
geo_mean = 0
ratio = float(translation_length) / reference_length
if ratio > 1.0:
bp = 1.
elif ratio > 0.0:
bp = math.exp(1 - 1. / ratio)
else:
bp = 0
bleu = geo_mean * bp
return (bleu, precisions, bp, ratio, translation_length, reference_length)
# coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from collections import OrderedDict
import numpy as np
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
from paddle.fluid.layers import RNNCell, LSTMCell, rnn, BeamSearchDecoder, dynamic_decode
from paddlehub.compat.task.metrics import compute_bleu
from paddlehub.compat.task.base_task import BaseTask
class AttentionDecoderCell(RNNCell):
def __init__(self, num_layers, hidden_size, dropout_prob=0., init_scale=0.1):
self.num_layers = num_layers
self.hidden_size = hidden_size
self.dropout_prob = dropout_prob
self.lstm_cells = []
self.init_scale = init_scale
param_attr = ParamAttr(initializer=fluid.initializer.UniformInitializer(low=-init_scale, high=init_scale))
bias_attr = ParamAttr(initializer=fluid.initializer.Constant(0.0))
for i in range(num_layers):
self.lstm_cells.append(LSTMCell(hidden_size, param_attr, bias_attr))
def attention(self, query, enc_output, mask=None):
query = fluid.layers.unsqueeze(query, [1])
memory = fluid.layers.fc(
enc_output,
self.hidden_size,
num_flatten_dims=2,
param_attr=ParamAttr(
name='dec_memory_w',
initializer=fluid.initializer.UniformInitializer(low=-self.init_scale, high=self.init_scale)))
attn = fluid.layers.matmul(query, memory, transpose_y=True)
if mask:
attn = fluid.layers.transpose(attn, [1, 0, 2])
attn = fluid.layers.elementwise_add(attn, mask * 1000000000, -1)
attn = fluid.layers.transpose(attn, [1, 0, 2])
weight = fluid.layers.softmax(attn)
weight_memory = fluid.layers.matmul(weight, memory)
return weight_memory
def call(self, step_input, states, enc_output, enc_padding_mask=None):
lstm_states, input_feed = states
new_lstm_states = []
step_input = fluid.layers.concat([step_input, input_feed], 1)
for i in range(self.num_layers):
out, new_lstm_state = self.lstm_cells[i](step_input, lstm_states[i])
step_input = fluid.layers.dropout(
out, self.dropout_prob, dropout_implementation='upscale_in_train') if self.dropout_prob > 0 else out
new_lstm_states.append(new_lstm_state)
dec_att = self.attention(step_input, enc_output, enc_padding_mask)
dec_att = fluid.layers.squeeze(dec_att, [1])
concat_att_out = fluid.layers.concat([dec_att, step_input], 1)
out = fluid.layers.fc(
concat_att_out,
self.hidden_size,
param_attr=ParamAttr(
name='dec_out_w',
initializer=fluid.initializer.UniformInitializer(low=-self.init_scale, high=self.init_scale)))
return out, [new_lstm_states, out]
class TextGenerationTask(BaseTask):
'''
TextGenerationTask use rnn as decoder and beam search technology when predict.
Args:
feature(Variable): The sentence-level feature, shape as [-1, emb_size].
token_feature(Variable): The token-level feature, shape as [-1, seq_len, emb_size].
max_seq_len(int): the decoder max sequence length.
num_classes(int): total labels of the task.
dataset(GenerationDataset): the dataset containing training set, development set and so on. If you want to finetune the model, you should set it.
Otherwise, if you just want to use the model to predict, you can omit it. Default None
num_layers(int): the decoder rnn layers number. Default 1
hidden_size(int): the decoder rnn hidden size. Default 128
dropout(float): the decoder dropout rate. Default 0.
beam_size(int): the beam search size during predict phase. Default 10.
beam_max_step_num(int): the beam search max step number. Default 30.
start_token(str): the beam search start token. Default '<s>'
end_token(str): the beam search end token. Default '</s>'
startup_program(Program): the customized startup_program, default None
config(RunConfig): the config for the task, default None
metrics_choices(list): metrics used to the task, default ['bleu']
'''
def __init__(
self,
feature,
token_feature,
max_seq_len,
num_classes,
dataset=None,
num_layers=1,
hidden_size=512,
dropout=0.,
beam_size=10,
beam_max_step_num=30,
start_token='<s>',
end_token='</s>',
startup_program=None,
config=None,
metrics_choices='default',
):
if metrics_choices == 'default':
metrics_choices = ['bleu']
main_program = feature.block.program
super(TextGenerationTask, self).__init__(
dataset=dataset,
main_program=main_program,
startup_program=startup_program,
config=config,
metrics_choices=metrics_choices)
self.num_layers = num_layers
self.hidden_size = hidden_size
self.dropout = dropout
self.token_feature = token_feature
self.feature = feature
self.max_seq_len = max_seq_len
self.num_classes = num_classes
self.beam_size = beam_size
self.beam_max_step_num = beam_max_step_num
self.start_token = start_token
self.end_token = end_token
def _add_label(self):
label = fluid.layers.data(name='label', shape=[self.max_seq_len, 1], dtype='int64')
return [label]
def _build_net(self):
self.seq_len = fluid.layers.data(name='seq_len', shape=[1], dtype='int64', lod_level=0)
self.seq_len_used = fluid.layers.squeeze(self.seq_len, axes=[1])
src_mask = fluid.layers.sequence_mask(self.seq_len_used, maxlen=self.max_seq_len, dtype='float32')
enc_padding_mask = (src_mask - 1.0)
# Define decoder and initialize it.
dec_cell = AttentionDecoderCell(self.num_layers, self.hidden_size, self.dropout)
dec_init_hidden = fluid.layers.fc(
input=self.feature,
size=self.hidden_size,
num_flatten_dims=1,
param_attr=fluid.ParamAttr(
name='dec_init_hidden_w', initializer=fluid.initializer.TruncatedNormal(scale=0.02)),
bias_attr=fluid.ParamAttr(name='dec_init_hidden_b', initializer=fluid.initializer.Constant(0.)))
dec_initial_states = [
[[dec_init_hidden,
dec_cell.get_initial_states(batch_ref=self.feature, shape=[self.hidden_size])]] * self.num_layers,
dec_cell.get_initial_states(batch_ref=self.feature, shape=[self.hidden_size])
]
tar_vocab_size = len(self._label_list)
tar_embeder = lambda x: fluid.embedding(
input=x,
size=[tar_vocab_size, self.hidden_size],
dtype='float32',
is_sparse=False,
param_attr=fluid.ParamAttr(
name='target_embedding', initializer=fluid.initializer.UniformInitializer(low=-0.1, high=0.1)))
start_token_id = self._label_list.index(self.start_token)
end_token_id = self._label_list.index(self.end_token)
if not self.is_predict_phase:
self.dec_input = fluid.layers.data(name='dec_input', shape=[self.max_seq_len], dtype='int64')
tar_emb = tar_embeder(self.dec_input)
dec_output, _ = rnn(
cell=dec_cell,
inputs=tar_emb,
initial_states=dec_initial_states,
sequence_length=None,
enc_output=self.token_feature,
enc_padding_mask=enc_padding_mask)
self.logits = fluid.layers.fc(
dec_output,
size=tar_vocab_size,
num_flatten_dims=len(dec_output.shape) - 1,
param_attr=fluid.ParamAttr(
name='output_w', initializer=fluid.initializer.UniformInitializer(low=-0.1, high=0.1)))
self.ret_infers = fluid.layers.reshape(x=fluid.layers.argmax(self.logits, axis=2), shape=[-1, 1])
logits = self.logits
logits = fluid.layers.softmax(logits)
return [logits]
else:
output_layer = lambda x: fluid.layers.fc(
x, size=tar_vocab_size, num_flatten_dims=len(x.shape) - 1, param_attr=fluid.ParamAttr(name='output_w'))
beam_search_decoder = BeamSearchDecoder(
dec_cell,
start_token_id,
end_token_id,
self.beam_size,
embedding_fn=tar_embeder,
output_fn=output_layer)
enc_output = beam_search_decoder.tile_beam_merge_with_batch(self.token_feature, self.beam_size)
enc_padding_mask = beam_search_decoder.tile_beam_merge_with_batch(enc_padding_mask, self.beam_size)
self.ret_infers, _ = dynamic_decode(
beam_search_decoder,
inits=dec_initial_states,
max_step_num=self.beam_max_step_num,
enc_output=enc_output,
enc_padding_mask=enc_padding_mask)
return self.ret_infers
def _postprocessing(self, run_states):
results = []
for batch_states in run_states:
batch_results = batch_states.run_results
batch_infers = batch_results[0].astype(np.int32)
seq_lens = batch_results[1].reshape([-1]).astype(np.int32).tolist()
for i, sample_infers in enumerate(batch_infers):
beam_result = []
for beam_infer in sample_infers.T:
seq_result = [self._label_list[infer] for infer in beam_infer.tolist()[:seq_lens[i] - 2]]
beam_result.append(seq_result)
results.append(beam_result)
return results
def _add_metrics(self):
self.ret_labels = fluid.layers.reshape(x=self.labels[0], shape=[-1, 1])
return [self.ret_labels, self.ret_infers, self.seq_len_used]
def _add_loss(self):
loss = fluid.layers.cross_entropy(input=self.outputs[0], label=self.labels[0], soft_label=False)
loss = fluid.layers.unsqueeze(loss, axes=[2])
max_tar_seq_len = fluid.layers.shape(self.dec_input)[1]
tar_sequence_length = fluid.layers.elementwise_sub(self.seq_len_used, fluid.layers.ones_like(self.seq_len_used))
tar_mask = fluid.layers.sequence_mask(tar_sequence_length, maxlen=max_tar_seq_len, dtype='float32')
loss = loss * tar_mask
loss = fluid.layers.reduce_mean(loss, dim=[0])
loss = fluid.layers.reduce_sum(loss)
return loss
@property
def fetch_list(self):
if self.is_train_phase or self.is_test_phase:
return [metric.name for metric in self.metrics] + [self.loss.name]
elif self.is_predict_phase:
return [self.ret_infers.name] + [self.seq_len_used.name]
return [output.name for output in self.outputs]
def _calculate_metrics(self, run_states):
loss_sum = 0
run_step = run_examples = 0
labels = []
results = []
for run_state in run_states:
loss_sum += np.mean(run_state.run_results[-1])
np_labels = run_state.run_results[0]
np_infers = run_state.run_results[1]
np_lens = run_state.run_results[2]
batch_size = len(np_lens)
max_len = len(np_labels) // batch_size
for i in range(batch_size):
label = [
self.dataset.label_list[int(id)] for id in np_labels[i * max_len:i * max_len + np_lens[i] - 2]
] # -2 for CLS and SEP
result = [
self.dataset.label_list[int(id)] for id in np_infers[i * max_len:i * max_len + np_lens[i] - 2]
]
labels.append(label)
results.append(result)
run_examples += run_state.run_examples
run_step += run_state.run_step
run_time_used = time.time() - run_states[0].run_time_begin
run_speed = run_step / run_time_used
avg_loss = loss_sum / run_examples
# The first key will be used as main metrics to update the best model
scores = OrderedDict()
for metric in self.metrics_choices:
if metric == 'bleu':
scores['bleu'] = compute_bleu(labels, results, max_order=1)[0]
else:
raise ValueError('Not Support Metric: \'%s\'' % metric)
return scores, avg_loss, run_speed
......@@ -335,18 +335,14 @@ class LocalModuleManager(object):
tempdir = self._get_normalized_name(tempdir)
shutil.copytree(directory, tempdir)
directory = tempdir
hub_module_cls = HubModule.load(directory)
# Uninstall local module
if os.path.exists(self._get_normalized_path(hub_module_cls.name)):
self.uninstall(hub_module_cls.name)
if os.path.exists(self._get_normalized_path(module_info.name)):
self.uninstall(module_info.name)
shutil.copytree(directory, self._get_normalized_path(hub_module_cls.name))
shutil.copytree(directory, self._get_normalized_path(module_info.name))
# Reload the Module object to avoid path errors
hub_module_cls = HubModule.load(self._get_normalized_path(hub_module_cls.name))
self._local_modules[hub_module_cls.name] = hub_module_cls
hub_module_cls = HubModule.load(self._get_normalized_path(module_info.name))
self._local_modules[module_info.name] = hub_module_cls
# Install python package requirements
self._install_module_requirements(hub_module_cls)
......
# coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class DataFormatError(Exception):
def __init__(self, *args):
self.args = args
此差异已折叠。
# -*- coding:utf-8 -*-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import inspect
import os
from typing import Callable, List, Union
import paddlehub as hub
from paddlehub.utils.log import logger
from paddlehub.text.bert_tokenizer import BasicTokenizer
from paddlehub.text.utils import load_vocab, whitespace_tokenize
class CustomTokenizer(object):
'''
Customtokenizer which will tokenize the input text as words or phases and convert the words (str) to an index (int) using the vocab.
If you would like tokens, please use `hub.BertTokenizer`.
'''
def __init__(self,
vocab_file: str,
do_lower_case: bool = True,
pad_token: str = '[PAD]',
tokenize_chinese_chars: bool = True,
cut_function: Callable = None):
''' Constructs a CustomTokenizer.
Args:
vocab_file (:obj:`string`): File containing the vocabulary.
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether to lower case the input if the input is in English
pad_token (:obj:`string`, `optional`, defaults to '[PAD]'): The token used for padding, for example when batching sequences of different lengths.
tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether to tokenize Chinese characters.
cut_function(:obj:`function`): It is a function that aims to segment a chinese text and get the word segmentation result (list).
'''
if not os.path.isfile(vocab_file):
raise ValueError('Can\'t find a vocabulary file at path \'{}\'.'.format(vocab_file))
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
self.pad_token = pad_token
self.pad_token_id = self.convert_tokens_to_ids(self.pad_token)
self.tokenize_chinese_chars = tokenize_chinese_chars
self.basic_tokenizer = BasicTokenizer(
do_lower_case=do_lower_case, tokenize_chinese_chars=tokenize_chinese_chars)
self.cut_function = cut_function
if not self.cut_function:
lac = hub.Module(name='lac')
self.cut_function = lac.cut
elif inspect.isfunction(self.cut_function):
self.cut_function = cut_function
else:
raise RuntimeError('The cut_function (%s) is not a true function.')
@property
def vocab_size(self):
return len(self.vocab)
def get_vocab(self):
return dict(self.vocab)
def _convert_token_to_id(self, token: str):
''' Converts a token (str) in an id using the vocab. '''
return self.vocab.get(token, None)
def _convert_id_to_token(self, index: int):
'''Converts an index (integer) in a token (str) using the vocab.'''
return self.ids_to_tokens.get(index, None)
def convert_tokens_to_string(self, tokens: List[str]):
''' Converts a sequence of tokens (string) in a single string. '''
if self.tokenize_chinese_chars:
out_string = ''.join(tokens).strip()
else:
out_string = ' '.join(tokens).strip()
return out_string
def convert_ids_to_tokens(self, ids: Union[int, List[int]], skip_pad_token: bool):
''' Converts a single index or a sequence of indices (integers) in a token '
(resp.) a sequence of tokens (str), using the vocabulary and added tokens.
Args:
ids(:obj:`int` or :obj:`List[int]`): list of tokenized input ids.
skip_special_token: Don't decode special tokens (self.all_special_tokens). Default: False
'''
if isinstance(ids, int):
return self._convert_id_to_token(ids)
tokens = []
for index in ids:
index = int(index)
if skip_pad_token and index == self.pad_token_id:
continue
tokens.append(self._convert_id_to_token(index))
return tokens
def convert_tokens_to_ids(self, tokens: List[str]):
''' Converts a token string (or a sequence of tokens) in a single integer id
(or a sequence of ids), using the vocabulary.
'''
if tokens is None:
return None
if isinstance(tokens, str):
return self._convert_token_to_id(tokens)
ids = []
for token in tokens:
wid = self._convert_token_to_id(token)
if wid is not None:
ids.append(wid)
return ids
def tokenize(self, text: str):
'''
Converts a string in a sequence of tokens (string), using the tokenizer.
Text in chinese will be splitted in words using the Word Segmentor (Baidu_LAC) defaultly.
If cut_function is set, it will be splitted in words using cut_function.
Args:
text (`string`): The sequence to be encoded.
Returns:
split_tokens (`list`): split
'''
if self.tokenize_chinese_chars:
splitted_tokens = self.cut_function(text=text)
else:
splitted_tokens = self.basic_tokenizer.tokenize(text=text)
return splitted_tokens
def encode(self,
text: str,
text_pair: Union[str, List[str], List[int]] = None,
max_seq_len: int = None,
pad_to_max_seq_len: bool = True,
truncation_strategy: str = 'longest_first',
return_length: bool = True,
return_overflowing_tokens: bool = False):
'''
Returns a dictionary containing the encoded sequence or sequence pair and additional information:
the mask for sequence classification and the overflowing elements if a ``max_seq_len`` is specified.
Args:
text (:obj:`str`, :obj:`List[str]` or :obj:`List[int]`):
The first sequence to be encoded. This can be a string, a list of strings (tokenized string using
the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
method)
text_pair (:obj:`str`, :obj:`List[str]` or :obj:`List[int]`, `optional`, defaults to :obj:`None`):
It's nonsense, just for compatible.
max_seq_len (:obj:`int`, `optional`, defaults to :int:`None`):
If set to a number, will limit the total sequence returned so that it has a maximum length.
If there are overflowing tokens, those will be added to the returned dictionary
pad_to_max_seq_len (:obj:`bool`, `optional`, defaults to :obj:`True`):
If set to True, the returned sequences will be padded according to the model's padding side and
padding index, up to their max length. If no max length is specified, the padding is done up to the
model's max length.
truncation_strategy (:obj:`str`, `optional`, defaults to `longest_first`):
String selected in the following options:
- 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_seq_len
starting from the longest one at each token (when there is a pair of input sequences)
- 'only_first': Only truncate the first sequence
- 'only_second': Only truncate the second sequence
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_seq_len)
return_lengths (:obj:`bool`, `optional`, defaults to :obj:`True`):
If set the resulting dictionary will include the length of each encoded inputs
return_overflowing_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Set to True to return overflowing token information (default False).
Return:
A Dictionary of shape::
{
text: list[int],
seq_len: int if return_length is True (default)
overflowing_tokens: list[int] if a ``max_seq_len`` is specified and return_overflowing_tokens is True
}
With the fields:
- ``text``: list of token ids to be fed to a model
- ``length``: the input_ids length
- ``overflowing_tokens``: list of overflowing tokens if a max length is specified.
'''
def get_input_ids(text: str):
if isinstance(text, str):
tokens = self.tokenize(text)
ids = self.convert_tokens_to_ids(tokens)
return ids
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
return self.convert_tokens_to_ids(text)
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
return text
else:
raise ValueError(
'Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers.')
ids = get_input_ids(text)
len_ids = len(ids)
encoded_inputs = {}
# When all words are not found in the vocab, it will return {}.
if not len_ids:
return encoded_inputs
# Truncation: Handle max sequence length
if max_seq_len and len_ids > max_seq_len:
ids, pair_ids, overflowing_tokens = self.truncate_sequences(
ids, num_tokens_to_remove=len_ids - max_seq_len, truncation_strategy=truncation_strategy)
if return_overflowing_tokens:
encoded_inputs['overflowing_tokens'] = overflowing_tokens
encoded_inputs['num_truncated_tokens'] = len_ids - max_seq_len
## Check length and Pad
if pad_to_max_seq_len and len(ids) < max_seq_len:
encoded_inputs['text'] = ids + [self.pad_token_id] * (max_seq_len - len(ids))
else:
encoded_inputs['text'] = ids
if return_length:
encoded_inputs['seq_len'] = len(ids)
return encoded_inputs
def truncate_sequences(self,
ids: List[int],
pair_ids: List[int] = None,
num_tokens_to_remove: int = 0,
truncation_strategy: str = 'longest_first',
stride: int = 0):
''' Truncates a sequence pair in place to the maximum length.
Args:
ids: list of tokenized input ids. Can be obtained from a string by chaining the
`tokenize` and `convert_tokens_to_ids` methods.
pair_ids: Optional second list of input ids. Can be obtained from a string by chaining the
`tokenize` and `convert_tokens_to_ids` methods.
num_tokens_to_remove (:obj:`int`, `optional`, defaults to ``0``):
number of tokens to remove using the truncation strategy
truncation_strategy: string selected in the following options:
- 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_seq_len
starting from the longest one at each token (when there is a pair of input sequences).
Overflowing tokens only contains overflow from the first sequence.
- 'only_first': Only truncate the first sequence. raise an error if the first sequence is shorter or equal to than num_tokens_to_remove.
- 'only_second': Only truncate the second sequence
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_seq_len)
stride (:obj:`int`, `optional`, defaults to ``0``):
If set to a number along with max_seq_len, the overflowing tokens returned will contain some tokens
from the main sequence returned. The value of this argument defines the number of additional tokens.
'''
if num_tokens_to_remove <= 0:
return ids, pair_ids, []
if truncation_strategy == 'longest_first':
overflowing_tokens = []
for _ in range(num_tokens_to_remove):
if pair_ids is None or len(ids) > len(pair_ids):
overflowing_tokens = [ids[-1]] + overflowing_tokens
ids = ids[:-1]
else:
pair_ids = pair_ids[:-1]
window_len = min(len(ids), stride)
if window_len > 0:
overflowing_tokens = ids[-window_len:] + overflowing_tokens
elif truncation_strategy == 'only_first':
assert len(ids) > num_tokens_to_remove
window_len = min(len(ids), stride + num_tokens_to_remove)
overflowing_tokens = ids[-window_len:]
ids = ids[:-num_tokens_to_remove]
elif truncation_strategy == 'only_second':
assert pair_ids is not None and len(pair_ids) > num_tokens_to_remove
window_len = min(len(pair_ids), stride + num_tokens_to_remove)
overflowing_tokens = pair_ids[-window_len:]
pair_ids = pair_ids[:-num_tokens_to_remove]
elif truncation_strategy == 'do_not_truncate':
raise ValueError('Input sequence are too long for max_seq_len. Please select a truncation strategy.')
else:
raise ValueError(
'Truncation_strategy should be selected in [\'longest_first\', \'only_first\', \'only_second\', \'do_not_truncate\']'
)
return (ids, pair_ids, overflowing_tokens)
def decode(self,
token_ids: List[int],
only_convert_to_tokens: bool = True,
skip_pad_token: bool = False,
clean_up_tokenization_spaces: bool = True):
'''
Converts a sequence of ids (integer) to a string if only_convert_to_tokens is False or a list a sequence of tokens (str)
when only_convert_to_tokens is True.
Args:
token_ids: list of tokenized input ids or dict with a key called 'text', can be obtained by using the `encode` methods.
only_convert_to_tokens: if set to True, will only return a list a sequence of tokens (str). `paddlehub.dataset.base_nlp_dataset` will use this optional argument.
skip_pad_token: if set to True, will replace pad tokens.
skip_special_tokens: if set to True, will replace special tokens.
clean_up_tokenization_spaces: if set to True, will clean up the tokenization spaces.
'''
if isinstance(token_ids, dict):
token_ids = token_ids['text']
tokens = self.convert_ids_to_tokens(token_ids, skip_pad_token=skip_pad_token)
if only_convert_to_tokens:
return tokens
if tokens and self.tokenize_chinese_chars:
text = ''.join(self.convert_tokens_to_string(tokens))
else:
text = ' '.join(self.convert_tokens_to_string(tokens))
if not self.tokenize_chinese_chars and clean_up_tokenization_spaces:
clean_text = self.clean_up_tokenization(text)
return clean_text
else:
return text
def clean_up_tokenization(self, out_string: str) -> str:
'''
Clean up a list of simple English tokenization artifacts like spaces before punctuations and abreviated forms.
'''
out_string = (out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',').replace(
' \' ', '\'').replace(' n\'t', 'n\'t').replace(' \'m', '\'m').replace(' \'s', '\'s').replace(
' \'ve', '\'ve').replace(' \'re', '\'re'))
return out_string
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unicodedata
from collections import OrderedDict
def load_vocab(vocab_file: str):
'''Loads a vocabulary file into a dictionary.'''
vocab = {}
with open(vocab_file, 'r', encoding='utf-8') as reader:
tokens = reader.readlines()
for index, token in enumerate(tokens):
token = token.rstrip('\n').split('\t')[0]
vocab[token] = index
return vocab
def whitespace_tokenize(text: str):
'''Runs basic whitespace cleaning and splitting on a piece of text.'''
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
def is_whitespace(char: str):
'''Checks whether `chars` is a whitespace character.'''
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == ' ' or char == '\t' or char == '\n' or char == '\r':
return True
cat = unicodedata.category(char)
if cat == 'Zs':
return True
return False
def is_control(char: str):
'''Checks whether `chars` is a control character.'''
# These are technically control characters but we count them as whitespace
# characters.
if char == '\t' or char == '\n' or char == '\r':
return False
cat = unicodedata.category(char)
if cat.startswith('C'):
return True
return False
def is_punctuation(char: str):
'''Checks whether `chars` is a punctuation character.'''
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as '^', '$', and '`' are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
return True
cat = unicodedata.category(char)
if cat.startswith('P'):
return True
return False
def is_chinese_char(char: str):
'''Checks whether CP is the codepoint of a CJK character.'''
# This defines a 'chinese character' as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
cp = ord(char)
if ((cp >= 0x4E00 and cp <= 0x9FFF) or (cp >= 0x3400 and cp <= 0x4DBF) #
or (cp >= 0x20000 and cp <= 0x2A6DF) #
or (cp >= 0x2A700 and cp <= 0x2B73F) #
or (cp >= 0x2B740 and cp <= 0x2B81F) #
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
or (cp >= 0xF900 and cp <= 0xFAFF) or (cp >= 0x2F800 and cp <= 0x2FA1F) #
): #
return True
return False
......@@ -227,6 +227,7 @@ def download_with_progress(url: str, path: str = None) -> Generator[str, int, in
def load_py_module(python_path: str, py_module_name: str) -> types.ModuleType:
'''
Load the specified python module.
Args:
python_path(str) : The directory where the python module is located
py_module_name(str) : Module name to be loaded
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册