提交 edf5630f 编写于 作者: Z zhangxuefei

Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleHub into develop

export FLAGS_eager_delete_tensor_gb=0.0 export FLAGS_eager_delete_tensor_gb=0.0
export CUDA_VISIBLE_DEVICES=0
python -u img_classifier.py $@ python -u img_classifier.py $@
export FLAGS_eager_delete_tensor_gb=0.0 export FLAGS_eager_delete_tensor_gb=0.0
export CUDA_VISIBLE_DEVICES=0
python -u predict.py $@ python -u predict.py $@
export FLAGS_eager_delete_tensor_gb=0.0 export FLAGS_eager_delete_tensor_gb=0.0
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
CKPT_DIR="./ckpt_qa" CKPT_DIR="./ckpt_qa"
# Recommending hyper parameters for difference task # Recommending hyper parameters for difference task
# ChnSentiCorp: batch_size=24, weight_decay=0.01, num_epoch=3, max_seq_len=128, lr=5e-5 # ChnSentiCorp: batch_size=24, weight_decay=0.01, num_epoch=3, max_seq_len=128, lr=5e-5
......
...@@ -89,9 +89,7 @@ if __name__ == '__main__': ...@@ -89,9 +89,7 @@ if __name__ == '__main__':
# Setup runing config for PaddleHub Finetune API # Setup runing config for PaddleHub Finetune API
config = hub.RunConfig( config = hub.RunConfig(
log_interval=10,
eval_interval=300, eval_interval=300,
save_ckpt_interval=10000,
use_pyreader=args.use_pyreader, use_pyreader=args.use_pyreader,
use_data_parallel=args.use_data_parallel, use_data_parallel=args.use_data_parallel,
use_cuda=args.use_gpu, use_cuda=args.use_gpu,
......
export FLAGS_eager_delete_tensor_gb=0.0 export FLAGS_eager_delete_tensor_gb=0.0
export CUDA_VISIBLE_DEVICES=0
# Recommending hyper parameters for difference task # Recommending hyper parameters for difference task
# squad: batch_size=8, weight_decay=0, num_epoch=3, max_seq_len=512, lr=5e-5 # squad: batch_size=8, weight_decay=0, num_epoch=3, max_seq_len=512, lr=5e-5
......
export FLAGS_eager_delete_tensor_gb=0.0 export FLAGS_eager_delete_tensor_gb=0.0
export CUDA_VISIBLE_DEVICES=0
CKPT_DIR="./ckpt_cmrc2018" CKPT_DIR="./ckpt_cmrc2018"
dataset=cmrc2018 dataset=cmrc2018
......
export FLAGS_eager_delete_tensor_gb=0.0 export FLAGS_eager_delete_tensor_gb=0.0
# export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
# User can select chnsenticorp, nlpcc_dbqa, lcqmc and so on for different task # User can select chnsenticorp, nlpcc_dbqa, lcqmc and so on for different task
DATASET="STS-B" DATASET="STS-B"
......
...@@ -41,7 +41,7 @@ args = parser.parse_args() ...@@ -41,7 +41,7 @@ args = parser.parse_args()
if __name__ == '__main__': if __name__ == '__main__':
# loading Paddlehub ERNIE pretrained model # loading Paddlehub ERNIE pretrained model
module = hub.Module(name="ernie") module = hub.Module(name="ernie_tiny")
inputs, outputs, program = module.context(max_seq_len=args.max_seq_len) inputs, outputs, program = module.context(max_seq_len=args.max_seq_len)
# Sentence labeling dataset reader # Sentence labeling dataset reader
...@@ -49,7 +49,9 @@ if __name__ == '__main__': ...@@ -49,7 +49,9 @@ if __name__ == '__main__':
reader = hub.reader.SequenceLabelReader( reader = hub.reader.SequenceLabelReader(
dataset=dataset, dataset=dataset,
vocab_path=module.get_vocab_path(), vocab_path=module.get_vocab_path(),
max_seq_len=args.max_seq_len) max_seq_len=args.max_seq_len,
sp_model_path=module.get_spm_path(),
word_dict_path=module.get_word_dict_path())
inv_label_map = {val: key for key, val in reader.label_map.items()} inv_label_map = {val: key for key, val in reader.label_map.items()}
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
......
export FLAGS_eager_delete_tensor_gb=0.0 export FLAGS_eager_delete_tensor_gb=0.0
export CUDA_VISIBLE_DEVICES=0
CKPT_DIR="./ckpt_sequence_label" CKPT_DIR="./ckpt_sequence_label"
python -u sequence_label.py \ python -u sequence_label.py \
......
...@@ -71,9 +71,6 @@ if __name__ == '__main__': ...@@ -71,9 +71,6 @@ if __name__ == '__main__':
# Setup runing config for PaddleHub Finetune API # Setup runing config for PaddleHub Finetune API
config = hub.RunConfig( config = hub.RunConfig(
log_interval=10,
eval_interval=300,
save_ckpt_interval=10000,
use_data_parallel=args.use_data_parallel, use_data_parallel=args.use_data_parallel,
use_pyreader=args.use_pyreader, use_pyreader=args.use_pyreader,
use_cuda=args.use_gpu, use_cuda=args.use_gpu,
......
...@@ -45,15 +45,35 @@ if __name__ == '__main__': ...@@ -45,15 +45,35 @@ if __name__ == '__main__':
# Download dataset and use ClassifyReader to read dataset # Download dataset and use ClassifyReader to read dataset
if args.dataset.lower() == "chnsenticorp": if args.dataset.lower() == "chnsenticorp":
dataset = hub.dataset.ChnSentiCorp() dataset = hub.dataset.ChnSentiCorp()
module = hub.Module(name="ernie") module = hub.Module(name="ernie_tiny")
metrics_choices = ["acc"]
elif args.dataset.lower() == "tnews":
dataset = hub.dataset.TNews()
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc"] metrics_choices = ["acc"]
elif args.dataset.lower() == "nlpcc_dbqa": elif args.dataset.lower() == "nlpcc_dbqa":
dataset = hub.dataset.NLPCC_DBQA() dataset = hub.dataset.NLPCC_DBQA()
module = hub.Module(name="ernie") module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc"] metrics_choices = ["acc"]
elif args.dataset.lower() == "lcqmc": elif args.dataset.lower() == "lcqmc":
dataset = hub.dataset.LCQMC() dataset = hub.dataset.LCQMC()
module = hub.Module(name="ernie") module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc"]
elif args.dataset.lower() == 'inews':
dataset = hub.dataset.INews()
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc"]
elif args.dataset.lower() == 'bq':
dataset = hub.dataset.BQ()
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc"]
elif args.dataset.lower() == 'thucnews':
dataset = hub.dataset.THUCNEWS()
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc"]
elif args.dataset.lower() == 'iflytek':
dataset = hub.dataset.IFLYTEK()
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc"] metrics_choices = ["acc"]
elif args.dataset.lower() == "mrpc": elif args.dataset.lower() == "mrpc":
dataset = hub.dataset.GLUE("MRPC") dataset = hub.dataset.GLUE("MRPC")
...@@ -90,7 +110,7 @@ if __name__ == '__main__': ...@@ -90,7 +110,7 @@ if __name__ == '__main__':
metrics_choices = ["acc"] metrics_choices = ["acc"]
elif args.dataset.lower().startswith("xnli"): elif args.dataset.lower().startswith("xnli"):
dataset = hub.dataset.XNLI(language=args.dataset.lower()[-2:]) dataset = hub.dataset.XNLI(language=args.dataset.lower()[-2:])
module = hub.Module(name="bert_multi_cased_L-12_H-768_A-12") module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc"] metrics_choices = ["acc"]
else: else:
raise ValueError("%s dataset is not defined" % args.dataset) raise ValueError("%s dataset is not defined" % args.dataset)
......
export FLAGS_eager_delete_tensor_gb=0.0 export FLAGS_eager_delete_tensor_gb=0.0
export CUDA_VISIBLE_DEVICES=0
# User can select chnsenticorp, nlpcc_dbqa, lcqmc and so on for different task # User can select chnsenticorp, nlpcc_dbqa, lcqmc and so on for different task
DATASET="chnsenticorp" DATASET="chnsenticorp"
......
...@@ -17,4 +17,4 @@ python -u predict.py --checkpoint_dir=$CKPT_DIR \ ...@@ -17,4 +17,4 @@ python -u predict.py --checkpoint_dir=$CKPT_DIR \
--max_seq_len=128 \ --max_seq_len=128 \
--use_gpu=True \ --use_gpu=True \
--dataset=${DATASET} \ --dataset=${DATASET} \
--batch_size=150 \ --batch_size=32 \
...@@ -47,7 +47,7 @@ if __name__ == '__main__': ...@@ -47,7 +47,7 @@ if __name__ == '__main__':
elif args.dataset.lower() == "tnews": elif args.dataset.lower() == "tnews":
dataset = hub.dataset.TNews() dataset = hub.dataset.TNews()
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16") module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc", "f1"] metrics_choices = ["acc"]
elif args.dataset.lower() == "nlpcc_dbqa": elif args.dataset.lower() == "nlpcc_dbqa":
dataset = hub.dataset.NLPCC_DBQA() dataset = hub.dataset.NLPCC_DBQA()
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16") module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
...@@ -59,19 +59,19 @@ if __name__ == '__main__': ...@@ -59,19 +59,19 @@ if __name__ == '__main__':
elif args.dataset.lower() == 'inews': elif args.dataset.lower() == 'inews':
dataset = hub.dataset.INews() dataset = hub.dataset.INews()
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16") module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc", "f1"] metrics_choices = ["acc"]
elif args.dataset.lower() == 'bq': elif args.dataset.lower() == 'bq':
dataset = hub.dataset.BQ() dataset = hub.dataset.BQ()
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16") module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc", "f1"] metrics_choices = ["acc"]
elif args.dataset.lower() == 'thucnews': elif args.dataset.lower() == 'thucnews':
dataset = hub.dataset.THUCNEWS() dataset = hub.dataset.THUCNEWS()
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16") module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc", "f1"] metrics_choices = ["acc"]
elif args.dataset.lower() == 'iflytek': elif args.dataset.lower() == 'iflytek':
dataset = hub.dataset.IFLYTEK() dataset = hub.dataset.IFLYTEK()
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16") module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc", "f1"] metrics_choices = ["acc"]
elif args.dataset.lower() == "mrpc": elif args.dataset.lower() == "mrpc":
dataset = hub.dataset.GLUE("MRPC") dataset = hub.dataset.GLUE("MRPC")
module = hub.Module(name="ernie_v2_eng_base") module = hub.Module(name="ernie_v2_eng_base")
...@@ -97,7 +97,7 @@ if __name__ == '__main__': ...@@ -97,7 +97,7 @@ if __name__ == '__main__':
dataset = hub.dataset.GLUE("RTE") dataset = hub.dataset.GLUE("RTE")
module = hub.Module(name="ernie_v2_eng_base") module = hub.Module(name="ernie_v2_eng_base")
metrics_choices = ["acc"] metrics_choices = ["acc"]
elif args.dataset.lower() == "mnli" or args.dataset.lower() == "mnli": elif args.dataset.lower() == "mnli" or args.dataset.lower() == "mnli_m":
dataset = hub.dataset.GLUE("MNLI_m") dataset = hub.dataset.GLUE("MNLI_m")
module = hub.Module(name="ernie_v2_eng_base") module = hub.Module(name="ernie_v2_eng_base")
metrics_choices = ["acc"] metrics_choices = ["acc"]
......
...@@ -49,6 +49,7 @@ class ImageClassificationReader(object): ...@@ -49,6 +49,7 @@ class ImageClassificationReader(object):
self.data_augmentation = data_augmentation self.data_augmentation = data_augmentation
self.images_std = images_std self.images_std = images_std
self.images_mean = images_mean self.images_mean = images_mean
self.num_examples = {'train': -1, 'dev': -1, 'test': -1}
if self.images_mean is None: if self.images_mean is None:
try: try:
...@@ -80,12 +81,15 @@ class ImageClassificationReader(object): ...@@ -80,12 +81,15 @@ class ImageClassificationReader(object):
raise ValueError("The dataset is none and it's not allowed!") raise ValueError("The dataset is none and it's not allowed!")
if phase == "train": if phase == "train":
data = self.dataset.train_data(shuffle) data = self.dataset.train_data(shuffle)
self.num_examples['train'] = len(self.get_train_examples())
elif phase == "test": elif phase == "test":
shuffle = False shuffle = False
data = self.dataset.test_data(shuffle) data = self.dataset.test_data(shuffle)
self.num_examples['test'] = len(self.get_test_examples())
elif phase == "val" or phase == "dev": elif phase == "val" or phase == "dev":
shuffle = False shuffle = False
data = self.dataset.validate_data(shuffle) data = self.dataset.validate_data(shuffle)
self.num_examples['dev'] = len(self.get_dev_examples())
elif phase == "predict": elif phase == "predict":
data = data data = data
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册