提交 b2f94aa8 编写于 作者: X xyzhou-puck

update hapi, hapi.text and hapi.text.bert

上级 601db3ab
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT fine-tuning in Paddle Dygraph Mode."""
import os
import io
import time
import argparse
import paddle.fluid as fluid
from paddle.fluid.dygraph import to_variable
from hapi.text.text import PrePostProcessLayer
from hapi.text.bert.bert import BertConfig
from cls import ClsModelLayer
from hapi.text.bert.optimization import Optimizer
from hapi.text.bert.utils.args import ArgumentGroup, print_arguments, check_cuda
from hapi.model import set_device, Model, SoftmaxWithCrossEntropy, Input
from hapi.metrics import Accuracy
from hapi.text.bert.dataloader import SingleSentenceDataLoader, BertInputExample
import hapi.text.tokenizer.tokenization as tokenization
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
model_g = ArgumentGroup(parser, "model", "model configuration and paths.")
model_g.add_arg("bert_config_path", str, "./config/bert_config.json", "Path to the json file for bert model config.")
model_g.add_arg("init_checkpoint", str, None, "Init checkpoint to resume training from.")
model_g.add_arg("init_pretraining_params", str, None,
"Init pre-training params which preforms fine-tuning from. If the "
"arg 'init_checkpoint' has been set, this argument wouldn't be valid.")
model_g.add_arg("checkpoints", str, "checkpoints", "Path to save checkpoints.")
train_g = ArgumentGroup(parser, "training", "training options.")
train_g.add_arg("epoch", int, 100, "Number of epoches for training.")
train_g.add_arg("learning_rate", float, 0.0001, "Learning rate used to train with warmup.")
train_g.add_arg("lr_scheduler", str, "linear_warmup_decay",
"scheduler of learning rate.", choices=['linear_warmup_decay', 'noam_decay'])
train_g.add_arg("weight_decay", float, 0.01, "Weight decay rate for L2 regularizer.")
train_g.add_arg("warmup_proportion", float, 0.1, "Proportion of training steps to perform linear learning rate warmup for.")
train_g.add_arg("save_steps", int, 10000, "The steps interval to save checkpoints.")
train_g.add_arg("validation_steps", int, 1000, "The steps interval to evaluate model performance.")
train_g.add_arg("loss_scaling", float, 1.0,
"Loss scaling factor for mixed precision training, only valid when use_fp16 is enabled.")
log_g = ArgumentGroup(parser, "logging", "logging related.")
log_g.add_arg("skip_steps", int, 10, "The steps interval to print loss.")
data_g = ArgumentGroup(parser, "data", "Data paths, vocab paths and data processing options")
data_g.add_arg("data_dir", str, None, "Path to training data.")
data_g.add_arg("vocab_path", str, None, "Vocabulary path.")
data_g.add_arg("max_seq_len", int, 512, "Tokens' number of the longest seqence allowed.")
data_g.add_arg("batch_size", int, 32,
"The total number of examples in one batch for training, see also --in_tokens.")
data_g.add_arg("in_tokens", bool, False,
"If set, the batch size will be the maximum number of tokens in one batch. "
"Otherwise, it will be the maximum number of examples in one batch.")
data_g.add_arg("do_lower_case", bool, True,
"Whether to lower case the input text. Should be True for uncased models and False for cased models.")
data_g.add_arg("random_seed", int, 5512, "Random seed.")
run_type_g = ArgumentGroup(parser, "run_type", "running type options.")
run_type_g.add_arg("use_cuda", bool, True, "If set, use GPU for training.")
run_type_g.add_arg("shuffle", bool, True, "")
run_type_g.add_arg("task_name", str, None,
"The name of task to perform fine-tuning, should be in {'xnli', 'mnli', 'cola', 'mrpc'}.")
run_type_g.add_arg("do_train", bool, True, "Whether to perform training.")
run_type_g.add_arg("do_test", bool, False, "Whether to perform evaluation on test data set.")
run_type_g.add_arg("use_data_parallel", bool, False, "The flag indicating whether to shuffle instances in each pass.")
run_type_g.add_arg("enable_ce", bool, False, help="The flag indicating whether to run the task for continuous evaluation.")
args = parser.parse_args()
def create_data(batch):
"""
convert data to variable
"""
src_ids = to_variable(batch[0], "src_ids")
position_ids = to_variable(batch[1], "position_ids")
sentence_ids = to_variable(batch[2], "sentence_ids")
input_mask = to_variable(batch[3], "input_mask")
labels = to_variable(batch[4], "labels")
labels.stop_gradient = True
return src_ids, position_ids, sentence_ids, input_mask, labels
def train(args):
device = set_device("gpu" if args.use_cuda else "cpu")
fluid.enable_dygraph(device)
bert_config = BertConfig(args.bert_config_path)
bert_config.print_config()
if not (args.do_train or args.do_test):
raise ValueError("For args `do_train`, `do_test`, at "
"least one of them must be True.")
trainer_count = fluid.dygraph.parallel.Env().nranks
tokenizer = tokenization.FullTokenizer(
vocab_file=args.vocab_path, do_lower_case=args.do_lower_case)
def mnli_line_processor(line_id, line):
if line_id == "0":
return None
uid = tokenization.convert_to_unicode(line[0])
text_a = tokenization.convert_to_unicode(line[8])
text_b = tokenization.convert_to_unicode(line[9])
label = tokenization.convert_to_unicode(line[-1])
if label not in ["contradiction", "entailment", "neutral"]:
label = "contradiction"
return BertInputExample(uid=uid, text_a=text_a, text_b=text_b, label=label)
bert_dataloader = SingleSentenceDataLoader("./data/glue_data/MNLI/train.tsv", tokenizer, ["contradiction", "entailment", "neutral"],
max_seq_length=64, batch_size=32, line_processor=mnli_line_processor)
num_train_examples = len(bert_dataloader.dataset)
max_train_steps = args.epoch * num_train_examples // args.batch_size // trainer_count
warmup_steps = int(max_train_steps * args.warmup_proportion)
print("Trainer count: %d" % trainer_count)
print("Num train examples: %d" % num_train_examples)
print("Max train steps: %d" % max_train_steps)
print("Num warmup steps: %d" % warmup_steps)
if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
inputs = [Input([None, None], 'int64', name='src_ids'),
Input([None, None], 'int64', name='pos_ids'),
Input([None, None], 'int64', name='sent_ids'),
Input([None, None], 'float32', name='input_mask')]
labels = [Input([None, 1], 'int64', name='label')]
cls_model = ClsModelLayer(
args,
bert_config,
3,
is_training=True,
return_pooled_out=True)
optimizer = Optimizer(
warmup_steps=warmup_steps,
num_train_steps=max_train_steps,
learning_rate=args.learning_rate,
model_cls=cls_model,
weight_decay=args.weight_decay,
scheduler=args.lr_scheduler,
loss_scaling=args.loss_scaling,
parameter_list=cls_model.parameters())
cls_model.prepare(
optimizer,
SoftmaxWithCrossEntropy(),
Accuracy(topk=(1, 2)),
inputs,
labels,
device=device)
cls_model.bert_layer.init_parameters(args.init_pretraining_params, verbose=True)
cls_model.fit(train_data=bert_dataloader.dataloader, epochs=args.epoch)
return cls_model
if __name__ == '__main__':
print_arguments(args)
check_cuda(args.use_cuda)
if args.do_train:
cls_model = train(args)
bert_config_path: "./config/bert_config.json"
init_checkpoint: None
init_pretraining_params: None
checkpoints: "./saved_model"
epoch: 3
learning_rate: 0.0001
lr_scheduler: "linear_warmup_decay"
weight_decay: 0.01
warmup_proportion: 0.1
save_steps: 100000
validation_steps: 100000
loss_scaling: 1.0
skip_steps: 100
data_dir: None
vocab_path: None
max_seq_len: 512
batch_size: 32
in_tokens: False
do_lower_case: True
random_seed: 5512
use_cuda: True
shuffle: True
do_train: True
do_test: True
use_data_parallel: False
verbose: False
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT fine-tuning in Paddle Dygraph Mode."""
import paddle.fluid as fluid
from hapi.metrics import Accuracy
from hapi.configure import Config
from hapi.model import set_device, Model, SoftmaxWithCrossEntropy, Input
from cls import ClsModelLayer
import hapi.text.tokenizer.tokenization as tokenization
from hapi.text.bert import Optimizer, BertConfig, BertDataLoader, BertInputExample
def train():
config = Config(yaml_file="./bert.yaml")
config.build()
config.Print()
device = set_device("gpu" if config.use_cuda else "cpu")
fluid.enable_dygraph(device)
bert_config = BertConfig(config.bert_config_path)
bert_config.print_config()
trainer_count = fluid.dygraph.parallel.Env().nranks
tokenizer = tokenization.FullTokenizer(
vocab_file=config.vocab_path, do_lower_case=config.do_lower_case)
def mnli_line_processor(line_id, line):
if line_id == "0":
return None
uid = tokenization.convert_to_unicode(line[0])
text_a = tokenization.convert_to_unicode(line[8])
text_b = tokenization.convert_to_unicode(line[9])
label = tokenization.convert_to_unicode(line[-1])
if label not in ["contradiction", "entailment", "neutral"]:
label = "contradiction"
return BertInputExample(
uid=uid, text_a=text_a, text_b=text_b, label=label)
bert_dataloader = BertDataLoader(
"./data/glue_data/MNLI/train.tsv",
tokenizer, ["contradiction", "entailment", "neutral"],
max_seq_length=64,
batch_size=32,
line_processor=mnli_line_processor)
num_train_examples = len(bert_dataloader.dataset)
max_train_steps = config.epoch * num_train_examples // config.batch_size // trainer_count
warmup_steps = int(max_train_steps * config.warmup_proportion)
print("Trainer count: %d" % trainer_count)
print("Num train examples: %d" % num_train_examples)
print("Max train steps: %d" % max_train_steps)
print("Num warmup steps: %d" % warmup_steps)
inputs = [
Input(
[None, None], 'int64', name='src_ids'), Input(
[None, None], 'int64', name='pos_ids'), Input(
[None, None], 'int64', name='sent_ids'), Input(
[None, None], 'float32', name='input_mask')
]
labels = [Input([None, 1], 'int64', name='label')]
cls_model = ClsModelLayer(
config,
bert_config,
len(["contradiction", "entailment", "neutral"]),
is_training=True,
return_pooled_out=True)
optimizer = Optimizer(
warmup_steps=warmup_steps,
num_train_steps=max_train_steps,
learning_rate=config.learning_rate,
model_cls=cls_model,
weight_decay=config.weight_decay,
scheduler=config.lr_scheduler,
loss_scaling=config.loss_scaling,
parameter_list=cls_model.parameters())
cls_model.prepare(
optimizer,
SoftmaxWithCrossEntropy(),
Accuracy(topk=(1, 2)),
inputs,
labels,
device=device)
cls_model.bert_layer.init_parameters(
config.init_pretraining_params, verbose=config.verbose)
cls_model.fit(train_data=bert_dataloader.dataloader, epochs=config.epoch)
return cls_model
if __name__ == '__main__':
cls_model = train()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,10 +13,6 @@
# limitations under the License.
"dygraph transformer layers"
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import json
import numpy as np
......@@ -25,7 +21,7 @@ import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import Linear, Layer
from hapi.text.bert.bert import BertEncoder
from hapi.text.bert import BertEncoder
from hapi.model import Model
......@@ -63,11 +59,6 @@ class ClsModelLayer(Model):
"""
forward
"""
#src_ids = data_ids[0]
#position_ids = data_ids[1]
#sentence_ids = data_ids[2]
#input_mask = data_ids[3]
#labels = data_ids[4]
enc_output, next_sent_feat = self.bert_layer(src_ids, position_ids,
sentence_ids, input_mask)
......@@ -80,19 +71,3 @@ class ClsModelLayer(Model):
logits = self.cls_fc(cls_feats)
return logits
"""
logits = self.cls_fc(cls_feats)
ce_loss, probs = fluid.layers.softmax_with_cross_entropy(
logits=logits, label=labels, return_softmax=True)
loss = fluid.layers.mean(x=ce_loss)
if self.use_fp16 and self.loss_scaling > 1.0:
loss *= self.loss_scaling
num_seqs = fluid.layers.create_tensor(dtype='int64')
accuracy = fluid.layers.accuracy(
input=probs, label=labels, total=num_seqs)
"""
return loss, accuracy
......@@ -8,7 +8,6 @@ export CUDA_VISIBLE_DEVICES=0
# start fine-tuning
python3.7 bert_classifier.py\
--task_name ${TASK_NAME} \
--use_cuda true \
--do_train true \
--do_test true \
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from hapi.configure import Config as Config
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import argparse
import json
import yaml
import six
import logging
logging_only_message = "%(message)s"
logging_details = "%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s"
class JsonConfig(object):
"""
A high-level api for handling json configure file.
"""
def __init__(self, config_path):
self._config_dict = self._parse(config_path)
def _parse(self, config_path):
try:
with open(config_path) as json_file:
config_dict = json.load(json_file)
except:
raise IOError("Error in parsing bert model config file '%s'" %
config_path)
else:
return config_dict
def __getitem__(self, key):
return self._config_dict[key]
def print_config(self):
for arg, value in sorted(six.iteritems(self._config_dict)):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
class ArgumentGroup(object):
def __init__(self, parser, title, des):
self._group = parser.add_argument_group(title=title, description=des)
def add_arg(self, name, type, default, help, **kwargs):
type = str2bool if type == bool else type
self._group.add_argument(
"--" + name,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
class ArgConfig(object):
"""
A high-level api for handling argument configs.
"""
def __init__(self):
parser = argparse.ArgumentParser()
custom_g = ArgumentGroup(parser, "customize", "customized options.")
self.custom_g = custom_g
self.parser = parser
def add_arg(self, name, dtype, default, descrip):
self.custom_g.add_arg(name, dtype, default, descrip)
def build_conf(self):
return self.parser.parse_args()
def str2bool(v):
# because argparse does not support to parse "true, False" as python
# boolean directly
return v.lower() in ("true", "t", "1")
def print_arguments(args, log=None):
if not log:
print('----------- Configuration Arguments -----------')
for arg, value in sorted(six.iteritems(vars(args))):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
else:
log.info('----------- Configuration Arguments -----------')
for arg, value in sorted(six.iteritems(vars(args))):
log.info('%s: %s' % (arg, value))
log.info('------------------------------------------------')
class Config(object):
"""
A high-level API for managing configuration files in PaddlePaddle.
Can jointly work with command-line-arugment, json files and yaml files.
"""
def __init__(self, json_file="", yaml_file="", fuse_args=True):
"""
Init funciton for PDConfig.
json_file: the path to the json configure file.
yaml_file: the path to the yaml configure file.
fuse_args: if fuse the json/yaml configs with argparse.
"""
assert isinstance(json_file, str)
assert isinstance(yaml_file, str)
if json_file != "" and yaml_file != "":
raise Warning(
"json_file and yaml_file can not co-exist for now. please only use one configure file type."
)
return
self.args = None
self.arg_config = {}
self.json_config = {}
self.yaml_config = {}
parser = argparse.ArgumentParser()
self.default_g = ArgumentGroup(parser, "default", "default options.")
self.yaml_g = ArgumentGroup(parser, "yaml", "options from yaml.")
self.json_g = ArgumentGroup(parser, "json", "options from json.")
self.com_g = ArgumentGroup(parser, "custom", "customized options.")
self.parser = parser
if json_file != "":
self.load_json(json_file, fuse_args=fuse_args)
if yaml_file:
self.load_yaml(yaml_file, fuse_args=fuse_args)
def load_json(self, file_path, fuse_args=True):
if not os.path.exists(file_path):
raise Warning("the json file %s does not exist." % file_path)
return
with open(file_path, "r") as fin:
self.json_config = json.loads(fin.read())
fin.close()
if fuse_args:
for name in self.json_config:
if isinstance(self.json_config[name], list):
self.json_g.add_arg(
name,
type(self.json_config[name][0]),
self.json_config[name],
"This is from %s" % file_path,
nargs=len(self.json_config[name]))
continue
if not isinstance(self.json_config[name], int) \
and not isinstance(self.json_config[name], float) \
and not isinstance(self.json_config[name], str) \
and not isinstance(self.json_config[name], bool):
continue
self.json_g.add_arg(name,
type(self.json_config[name]),
self.json_config[name],
"This is from %s" % file_path)
def load_yaml(self, file_path, fuse_args=True):
if not os.path.exists(file_path):
raise Warning("the yaml file %s does not exist." % file_path)
return
with open(file_path, "r") as fin:
self.yaml_config = yaml.load(fin, Loader=yaml.SafeLoader)
fin.close()
if fuse_args:
for name in self.yaml_config:
if isinstance(self.yaml_config[name], list):
self.yaml_g.add_arg(
name,
type(self.yaml_config[name][0]),
self.yaml_config[name],
"This is from %s" % file_path,
nargs=len(self.yaml_config[name]))
continue
if not isinstance(self.yaml_config[name], int) \
and not isinstance(self.yaml_config[name], float) \
and not isinstance(self.yaml_config[name], str) \
and not isinstance(self.yaml_config[name], bool):
continue
self.yaml_g.add_arg(name,
type(self.yaml_config[name]),
self.yaml_config[name],
"This is from %s" % file_path)
def build(self):
self.args = self.parser.parse_args()
self.arg_config = vars(self.args)
def __add__(self, new_arg):
assert isinstance(new_arg, list) or isinstance(new_arg, tuple)
assert len(new_arg) >= 3
assert self.args is None
name = new_arg[0]
dtype = new_arg[1]
dvalue = new_arg[2]
desc = new_arg[3] if len(
new_arg) == 4 else "Description is not provided."
self.com_g.add_arg(name, dtype, dvalue, desc)
return self
def __getattr__(self, name):
if name in self.arg_config:
return self.arg_config[name]
if name in self.json_config:
return self.json_config[name]
if name in self.yaml_config:
return self.yaml_config[name]
raise Warning("The argument %s is not defined." % name)
def Print(self):
print("-" * 70)
for name in self.arg_config:
print("%s:\t\t\t\t%s" % (str(name), str(self.arg_config[name])))
for name in self.json_config:
if name not in self.arg_config:
print("%s:\t\t\t\t%s" %
(str(name), str(self.json_config[name])))
for name in self.yaml_config:
if name not in self.arg_config:
print("%s:\t\t\t\t%s" %
(str(name), str(self.yaml_config[name])))
print("-" * 70)
if __name__ == "__main__":
"""
pd_config = PDConfig(json_file = "./test/bert_config.json")
pd_config.build()
print(pd_config.do_train)
print(pd_config.hidden_size)
pd_config = PDConfig(yaml_file = "./test/bert_config.yaml")
pd_config.build()
print(pd_config.do_train)
print(pd_config.hidden_size)
"""
config = Config(yaml_file="./bert.yaml")
config += ("my_age", int, 18, "I am forever 18.")
config.build()
print(config.data_dir)
print(config.my_age)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from hapi.text.text import RNNCell as RNNCell
from hapi.text.text import BasicLSTMCell as BasicLSTMCell
from hapi.text.text import BasicGRUCell as BasicGRUCell
from hapi.text.text import RNN as RNN
from hapi.text.text import DynamicDecode as DynamicDecode
from hapi.text.text import BeamSearchDecoder as BeamSearchDecoder
from hapi.text.text import MultiHeadAttention as MultiHeadAttention
from hapi.text.text import FFN as FFN
from hapi.text.text import TransformerEncoderLayer as TransformerEncoderLayer
from hapi.text.text import TransformerDecoderLayer as TransformerDecoderLayer
from hapi.text.text import TransformerEncoder as TransformerEncoder
from hapi.text.text import TransformerDecoder as TransformerDecoder
from hapi.text.text import TransformerBeamSearchDecoder as TransformerBeamSearchDecoder
from hapi.text.text import DynamicGRU as DynamicGRU
from hapi.text.text import BiGRU as BiGRU
from hapi.text.text import Linear_chain_crf as Linear_chain_crf
from hapi.text.text import Crf_decoding as Crf_decoding
from hapi.text.text import SequenceTagging as SequenceTagging
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from hapi.text.bert.bert import BertConfig as BertConfig
from hapi.text.bert.optimization import Optimizer as Optimizer
from hapi.text.bert.dataloader import BertDataLoader as BertDataLoader
from hapi.text.bert.dataloader import BertInputExample as BertInputExample
from hapi.text.tokenizer import tokenization as tokenization
from hapi.text.bert.bert import BertEncoder as BertEncoder
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -32,7 +32,7 @@ import hapi.text.tokenizer.tokenization as tokenization
__all__ = [
'BertInputExample', 'BertInputFeatures', 'SingleSentenceDataset',
'SentencePairDataset'
'SentencePairDataset', 'BertDataLoader'
]
......@@ -289,7 +289,7 @@ def _prepare_train_batch(insts,
return_num_token=return_num_token)
class SingleSentenceDataLoader(object):
class BertDataLoader(object):
def __init__(self,
input_file,
tokenizer,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# -*- coding: UTF-8 -*-
################################################################################
#
# Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved.
# Copyright (c) 2020 Baidu.com, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
......@@ -29,10 +29,7 @@ setuptools.setup(
author="PaddlePaddle",
author_email="zhouxiangyang@baidu.com",
description="A Paddle High-level API that supports both static and dynamic execution modes (still under development)",
# long_description=long_description,
# long_description_content_type="text/markdown",
url="https://github.com/PaddlePaddle/hapi",
# packages=setuptools.find_packages(),
packages=[
'hapi', 'hapi.text', 'hapi.text.tokenizer', 'hapi.text.bert',
'hapi.text.bert.utils'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册