From 16145775d955347e0ea4fc04cfa02eff10fbc742 Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Fri, 29 Mar 2019 03:48:24 +0800 Subject: [PATCH] finish bert and ernie text classification task --- demo/bert-cls/finetune_with_hub.py | 35 ++++++--------------------- demo/bert-cls/reader/cls.py | 14 +++++------ demo/bert-cls/run_fintune_with_hub.sh | 5 +--- paddle_hub/finetune/finetune.py | 4 +++ 4 files changed, 19 insertions(+), 39 deletions(-) diff --git a/demo/bert-cls/finetune_with_hub.py b/demo/bert-cls/finetune_with_hub.py index aa9541c2..2b1c8648 100644 --- a/demo/bert-cls/finetune_with_hub.py +++ b/demo/bert-cls/finetune_with_hub.py @@ -21,7 +21,6 @@ import os import time import argparse import numpy as np -import multiprocessing import paddle import paddle.fluid as fluid @@ -33,24 +32,14 @@ from paddle_hub.finetune.config import FinetuneConfig # yapf: disable parser = argparse.ArgumentParser(__doc__) -model_g = ArgumentGroup(parser, "model", "model configuration and paths.") -model_g.add_arg("bert_config_path", str, None, "Path to the json file for bert model config.") train_g = ArgumentGroup(parser, "training", "training options.") train_g.add_arg("epoch", int, 3, "Number of epoches for fine-tuning.") train_g.add_arg("learning_rate", float, 5e-5, "Learning rate used to train with warmup.") -train_g.add_arg("lr_scheduler", str, "linear_warmup_decay", - "scheduler of learning rate.", choices=['linear_warmup_decay', 'noam_decay']) +train_g.add_arg("lr_scheduler", str, "linear_warmup_decay", "scheduler of learning rate.", choices=['linear_warmup_decay', 'noam_decay']) train_g.add_arg("weight_decay", float, 0.01, "Weight decay rate for L2 regularizer.") train_g.add_arg("warmup_proportion", float, 0.1, "Proportion of training steps to perform linear learning rate warmup for.") -train_g.add_arg("validation_steps", int, 1000, "The steps interval to evaluate model performance.") -train_g.add_arg("loss_scaling", float, 1.0, - "Loss scaling factor for mixed precision training, only valid when use_fp16 is enabled.") - -log_g = ArgumentGroup(parser, "logging", "logging related.") -log_g.add_arg("skip_steps", int, 10, "The steps interval to print loss.") -log_g.add_arg("verbose", bool, False, "Whether to output verbose log.") data_g = ArgumentGroup(parser, "data", "Data paths, vocab paths and data processing options") data_g.add_arg("data_dir", str, None, "Path to training data.") @@ -60,12 +49,6 @@ data_g.add_arg("batch_size", int, 32, "Total examples' number in batch fo data_g.add_arg("in_tokens", bool, False, "If set, the batch size will be the maximum number of tokens in one batch. " "Otherwise, it will be the maximum number of examples in one batch.") -data_g.add_arg("do_lower_case", bool, True, - "Whether to lower case the input text. Should be True for uncased models and False for cased models.") -data_g.add_arg("random_seed", int, 0, "Random seed.") - -run_type_g = ArgumentGroup(parser, "run_type", "running type options.") -run_type_g.add_arg("use_cuda", bool, True, "If set, use GPU for training.") args = parser.parse_args() # yapf: enable. @@ -85,26 +68,22 @@ if __name__ == '__main__': weight_decay=args.weight_decay, finetune_strategy="bert_finetune", with_memory_optimization=True, - in_tokens=True, + in_tokens=False, optimizer=None, warmup_proportion=args.warmup_proportion) - module = hub.Module( - module_dir="./hub_module/chinese_L-12_H-768_A-12.hub_module") + # loading paddlehub BERT + # module = hub.Module( + # module_dir="./hub_module/chinese_L-12_H-768_A-12.hub_module") + module = hub.Module(module_dir="./hub_module/ernie-stable.hub_module") - print("vocab_path = {}".format(module.get_vocab_path())) processor = reader.ChnsenticorpProcessor( data_dir=args.data_dir, vocab_path=module.get_vocab_path(), - max_seq_len=args.max_seq_len, - do_lower_case=args.do_lower_case, - in_tokens=args.in_tokens, - random_seed=args.random_seed) + max_seq_len=args.max_seq_len) num_labels = len(processor.get_labels()) - # loading paddlehub BERT - # bert's input tensor, output tensor and forward graph # If you want to fine-tune the pretrain model parameter, please set # trainable to True diff --git a/demo/bert-cls/reader/cls.py b/demo/bert-cls/reader/cls.py index f004405c..d09d4ba9 100644 --- a/demo/bert-cls/reader/cls.py +++ b/demo/bert-cls/reader/cls.py @@ -27,8 +27,8 @@ class DataProcessor(object): data_dir, vocab_path, max_seq_len, - do_lower_case, - in_tokens, + do_lower_case=True, + in_tokens=False, random_seed=None): self.data_dir = data_dir self.max_seq_len = max_seq_len @@ -83,7 +83,7 @@ class DataProcessor(object): voc_size=-1, mask_id=-1, return_input_mask=True, - return_max_len=False, + return_max_len=True, return_num_token=False): return prepare_batch_data( batch_data, @@ -93,9 +93,9 @@ class DataProcessor(object): cls_id=self.vocab["[CLS]"], sep_id=self.vocab["[SEP]"], mask_id=-1, - return_input_mask=True, - return_max_len=False, - return_num_token=False) + return_input_mask=return_input_mask, + return_max_len=True, + return_num_token=return_num_token) @classmethod def _read_tsv(cls, input_file, quotechar=None): @@ -188,7 +188,7 @@ class DataProcessor(object): voc_size=-1, mask_id=-1, return_input_mask=True, - return_max_len=False, + return_max_len=True, return_num_token=False) yield batch_data diff --git a/demo/bert-cls/run_fintune_with_hub.sh b/demo/bert-cls/run_fintune_with_hub.sh index c1b82def..d71c6a38 100644 --- a/demo/bert-cls/run_fintune_with_hub.sh +++ b/demo/bert-cls/run_fintune_with_hub.sh @@ -1,8 +1,6 @@ export CUDA_VISIBLE_DEVICES=2 -BERT_BASE_PATH="chinese_L-12_H-768_A-12" -TASK_NAME='chnsenticorp' -DATA_PATH=chnsenticorp_data +DATA_PATH=./chnsenticorp_data rm -rf $CKPT_PATH python -u finetune_with_hub.py \ @@ -10,7 +8,6 @@ python -u finetune_with_hub.py \ --batch_size 32 \ --in_tokens false \ --data_dir ${DATA_PATH} \ - --vocab_path ${BERT_BASE_PATH}/vocab.txt \ --weight_decay 0.01 \ --warmup_proportion 0.0 \ --validation_steps 50 \ diff --git a/paddle_hub/finetune/finetune.py b/paddle_hub/finetune/finetune.py index 8d54ebd7..44f68f44 100644 --- a/paddle_hub/finetune/finetune.py +++ b/paddle_hub/finetune/finetune.py @@ -78,6 +78,10 @@ def _finetune_model(task, logger.info( "Memory optimization done! Time elapsed %f sec" % time_used) + lower_mem, upper_mem, unit = fluid.contrib.memory_usage( + program=main_program, batch_size=batch_size) + logger.info("Theoretical memory usage in training: %.3f - %.3f %s" % + (lower_mem, upper_mem, unit)), # initilize all parameters exe.run(fluid.default_startup_program()) step = 0 -- GitLab