提交 16145775 编写于 作者: Z Zeyu Chen

finish bert and ernie text classification task

上级 b45479ee
......@@ -21,7 +21,6 @@ import os
import time
import argparse
import numpy as np
import multiprocessing
import paddle
import paddle.fluid as fluid
......@@ -33,24 +32,14 @@ from paddle_hub.finetune.config import FinetuneConfig
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
model_g = ArgumentGroup(parser, "model", "model configuration and paths.")
model_g.add_arg("bert_config_path", str, None, "Path to the json file for bert model config.")
train_g = ArgumentGroup(parser, "training", "training options.")
train_g.add_arg("epoch", int, 3, "Number of epoches for fine-tuning.")
train_g.add_arg("learning_rate", float, 5e-5, "Learning rate used to train with warmup.")
train_g.add_arg("lr_scheduler", str, "linear_warmup_decay",
"scheduler of learning rate.", choices=['linear_warmup_decay', 'noam_decay'])
train_g.add_arg("lr_scheduler", str, "linear_warmup_decay", "scheduler of learning rate.", choices=['linear_warmup_decay', 'noam_decay'])
train_g.add_arg("weight_decay", float, 0.01, "Weight decay rate for L2 regularizer.")
train_g.add_arg("warmup_proportion", float, 0.1,
"Proportion of training steps to perform linear learning rate warmup for.")
train_g.add_arg("validation_steps", int, 1000, "The steps interval to evaluate model performance.")
train_g.add_arg("loss_scaling", float, 1.0,
"Loss scaling factor for mixed precision training, only valid when use_fp16 is enabled.")
log_g = ArgumentGroup(parser, "logging", "logging related.")
log_g.add_arg("skip_steps", int, 10, "The steps interval to print loss.")
log_g.add_arg("verbose", bool, False, "Whether to output verbose log.")
data_g = ArgumentGroup(parser, "data", "Data paths, vocab paths and data processing options")
data_g.add_arg("data_dir", str, None, "Path to training data.")
......@@ -60,12 +49,6 @@ data_g.add_arg("batch_size", int, 32, "Total examples' number in batch fo
data_g.add_arg("in_tokens", bool, False,
"If set, the batch size will be the maximum number of tokens in one batch. "
"Otherwise, it will be the maximum number of examples in one batch.")
data_g.add_arg("do_lower_case", bool, True,
"Whether to lower case the input text. Should be True for uncased models and False for cased models.")
data_g.add_arg("random_seed", int, 0, "Random seed.")
run_type_g = ArgumentGroup(parser, "run_type", "running type options.")
run_type_g.add_arg("use_cuda", bool, True, "If set, use GPU for training.")
args = parser.parse_args()
# yapf: enable.
......@@ -85,26 +68,22 @@ if __name__ == '__main__':
weight_decay=args.weight_decay,
finetune_strategy="bert_finetune",
with_memory_optimization=True,
in_tokens=True,
in_tokens=False,
optimizer=None,
warmup_proportion=args.warmup_proportion)
module = hub.Module(
module_dir="./hub_module/chinese_L-12_H-768_A-12.hub_module")
# loading paddlehub BERT
# module = hub.Module(
# module_dir="./hub_module/chinese_L-12_H-768_A-12.hub_module")
module = hub.Module(module_dir="./hub_module/ernie-stable.hub_module")
print("vocab_path = {}".format(module.get_vocab_path()))
processor = reader.ChnsenticorpProcessor(
data_dir=args.data_dir,
vocab_path=module.get_vocab_path(),
max_seq_len=args.max_seq_len,
do_lower_case=args.do_lower_case,
in_tokens=args.in_tokens,
random_seed=args.random_seed)
max_seq_len=args.max_seq_len)
num_labels = len(processor.get_labels())
# loading paddlehub BERT
# bert's input tensor, output tensor and forward graph
# If you want to fine-tune the pretrain model parameter, please set
# trainable to True
......
......@@ -27,8 +27,8 @@ class DataProcessor(object):
data_dir,
vocab_path,
max_seq_len,
do_lower_case,
in_tokens,
do_lower_case=True,
in_tokens=False,
random_seed=None):
self.data_dir = data_dir
self.max_seq_len = max_seq_len
......@@ -83,7 +83,7 @@ class DataProcessor(object):
voc_size=-1,
mask_id=-1,
return_input_mask=True,
return_max_len=False,
return_max_len=True,
return_num_token=False):
return prepare_batch_data(
batch_data,
......@@ -93,9 +93,9 @@ class DataProcessor(object):
cls_id=self.vocab["[CLS]"],
sep_id=self.vocab["[SEP]"],
mask_id=-1,
return_input_mask=True,
return_max_len=False,
return_num_token=False)
return_input_mask=return_input_mask,
return_max_len=True,
return_num_token=return_num_token)
@classmethod
def _read_tsv(cls, input_file, quotechar=None):
......@@ -188,7 +188,7 @@ class DataProcessor(object):
voc_size=-1,
mask_id=-1,
return_input_mask=True,
return_max_len=False,
return_max_len=True,
return_num_token=False)
yield batch_data
......
export CUDA_VISIBLE_DEVICES=2
BERT_BASE_PATH="chinese_L-12_H-768_A-12"
TASK_NAME='chnsenticorp'
DATA_PATH=chnsenticorp_data
DATA_PATH=./chnsenticorp_data
rm -rf $CKPT_PATH
python -u finetune_with_hub.py \
......@@ -10,7 +8,6 @@ python -u finetune_with_hub.py \
--batch_size 32 \
--in_tokens false \
--data_dir ${DATA_PATH} \
--vocab_path ${BERT_BASE_PATH}/vocab.txt \
--weight_decay 0.01 \
--warmup_proportion 0.0 \
--validation_steps 50 \
......
......@@ -78,6 +78,10 @@ def _finetune_model(task,
logger.info(
"Memory optimization done! Time elapsed %f sec" % time_used)
lower_mem, upper_mem, unit = fluid.contrib.memory_usage(
program=main_program, batch_size=batch_size)
logger.info("Theoretical memory usage in training: %.3f - %.3f %s" %
(lower_mem, upper_mem, unit)),
# initilize all parameters
exe.run(fluid.default_startup_program())
step = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册