提交 b296c718 编写于 作者: Z Zeyu Chen

add choices for scripts

上级 da2db0f3
...@@ -21,12 +21,13 @@ import paddlehub as hub ...@@ -21,12 +21,13 @@ import paddlehub as hub
# yapf: disable # yapf: disable
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--num_epoch", type=int, default=3, help="Number of epoches for fine-tuning.") parser.add_argument("--num_epoch", type=int, default=3, help="Number of epoches for fine-tuning.")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="Whether use GPU for finetuning, input should be True or False")
parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.") parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.")
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay rate for L2 regularizer.") parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay rate for L2 regularizer.")
parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint") parser.add_argument("--warmup_proportion", type=float, default=0.0, help="Warmup proportion params for warmup strategy")
parser.add_argument("--max_seq_len", type=int, default=512, help="Number of words of the longest seqence.") parser.add_argument("--max_seq_len", type=int, default=512, help="Number of words of the longest seqence.")
parser.add_argument("--batch_size", type=int, default=32, help="Total examples' number in batch for training.") parser.add_argument("--batch_size", type=int, default=32, help="Total examples' number in batch for training.")
parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint")
args = parser.parse_args() args = parser.parse_args()
# yapf: enable. # yapf: enable.
...@@ -76,7 +77,7 @@ if __name__ == '__main__': ...@@ -76,7 +77,7 @@ if __name__ == '__main__':
# Setup runing config for PaddleHub Finetune API # Setup runing config for PaddleHub Finetune API
config = hub.RunConfig( config = hub.RunConfig(
use_cuda=True, use_cuda=args.use_gpu,
num_epoch=args.num_epoch, num_epoch=args.num_epoch,
batch_size=args.batch_size, batch_size=args.batch_size,
checkpoint_dir=args.checkpoint_dir, checkpoint_dir=args.checkpoint_dir,
......
...@@ -149,8 +149,10 @@ python cls_predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128 ...@@ -149,8 +149,10 @@ python cls_predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128
``` ```
其中CKPT_DIR为Finetune API保存最佳模型的路径, max_seq_len是ERNIE模型的最大序列长度,*请与训练时配置的参数保持一致* 其中CKPT_DIR为Finetune API保存最佳模型的路径, max_seq_len是ERNIE模型的最大序列长度,*请与训练时配置的参数保持一致*
参数配置正确后,请执行脚本`sh run_predict.sh`,即可看到以下文本分类预测结果。如需了解更多预测步骤,请参考`cls_predict.py` 参数配置正确后,请执行脚本`sh run_predict.sh`,即可看到以下文本分类预测结果, 以及最终准确率。
如需了解更多预测步骤,请参考`cls_predict.py`
``` ```
text=键盘缝隙大进灰,装系统自己不会装,屏幕有点窄玩游戏人物有点变形 label=0 predict=0 text=键盘缝隙大进灰,装系统自己不会装,屏幕有点窄玩游戏人物有点变形 label=0 predict=0
accuracy = 0.954267
``` ```
...@@ -15,5 +15,5 @@ python -u text_classifier.py \ ...@@ -15,5 +15,5 @@ python -u text_classifier.py \
--checkpoint_dir=${CKPT_DIR} \ --checkpoint_dir=${CKPT_DIR} \
--learning_rate=5e-5 \ --learning_rate=5e-5 \
--weight_decay=0.01 \ --weight_decay=0.01 \
--max_seq_len=128 --max_seq_len=128 \
--num_epoch=3 \ --num_epoch=3
...@@ -22,8 +22,7 @@ import paddlehub as hub ...@@ -22,8 +22,7 @@ import paddlehub as hub
# yapf: disable # yapf: disable
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--num_epoch", type=int, default=3, help="Number of epoches for fine-tuning.") parser.add_argument("--num_epoch", type=int, default=3, help="Number of epoches for fine-tuning.")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for finetuning, input should be True or False") parser.add_argument("--dataset", type=str, default="chnsenticorp", help="Directory to model checkpoint", choices=["chnsenticorp", "nlpcc_dbqa", "lcqmc"])
parser.add_argument("--dataset", type=str, default="senticorp", help="Directory to model checkpoint")
parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.") parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.")
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay rate for L2 regularizer.") parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay rate for L2 regularizer.")
parser.add_argument("--warmup_proportion", type=float, default=0.0, help="Warmup proportion params for warmup strategy") parser.add_argument("--warmup_proportion", type=float, default=0.0, help="Warmup proportion params for warmup strategy")
......
...@@ -25,7 +25,7 @@ from paddlehub.common.downloader import default_downloader ...@@ -25,7 +25,7 @@ from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
DATA_URL = "https://paddlehub-dataset.bj.bcebos.com/chnsenticorp.tar.gz" _DATA_URL = "https://paddlehub-dataset.bj.bcebos.com/chnsenticorp.tar.gz"
class ChnSentiCorp(HubDataset): class ChnSentiCorp(HubDataset):
...@@ -38,7 +38,7 @@ class ChnSentiCorp(HubDataset): ...@@ -38,7 +38,7 @@ class ChnSentiCorp(HubDataset):
self.dataset_dir = os.path.join(DATA_HOME, "chnsenticorp") self.dataset_dir = os.path.join(DATA_HOME, "chnsenticorp")
if not os.path.exists(self.dataset_dir): if not os.path.exists(self.dataset_dir):
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress( ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress(
url=DATA_URL, save_path=DATA_HOME, print_progress=True) url=_DATA_URL, save_path=DATA_HOME, print_progress=True)
else: else:
logger.info("Dataset {} already cached.".format(self.dataset_dir)) logger.info("Dataset {} already cached.".format(self.dataset_dir))
......
...@@ -25,7 +25,7 @@ from paddlehub.common.downloader import default_downloader ...@@ -25,7 +25,7 @@ from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
DATA_URL = "https://paddlehub-dataset.bj.bcebos.com/lcqmc.tar.gz" _DATA_URL = "https://paddlehub-dataset.bj.bcebos.com/lcqmc.tar.gz"
class LCQMC(HubDataset): class LCQMC(HubDataset):
...@@ -33,7 +33,7 @@ class LCQMC(HubDataset): ...@@ -33,7 +33,7 @@ class LCQMC(HubDataset):
self.dataset_dir = os.path.join(DATA_HOME, "lcqmc") self.dataset_dir = os.path.join(DATA_HOME, "lcqmc")
if not os.path.exists(self.dataset_dir): if not os.path.exists(self.dataset_dir):
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress( ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress(
url=DATA_URL, save_path=DATA_HOME, print_progress=True) url=_DATA_URL, save_path=DATA_HOME, print_progress=True)
else: else:
logger.info("Dataset {} already cached.".format(self.dataset_dir)) logger.info("Dataset {} already cached.".format(self.dataset_dir))
......
...@@ -26,7 +26,7 @@ from paddlehub.common.downloader import default_downloader ...@@ -26,7 +26,7 @@ from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
DATA_URL = "https://paddlehub-dataset.bj.bcebos.com/msra_ner.tar.gz" _DATA_URL = "https://paddlehub-dataset.bj.bcebos.com/msra_ner.tar.gz"
class MSRA_NER(HubDataset): class MSRA_NER(HubDataset):
...@@ -41,20 +41,14 @@ class MSRA_NER(HubDataset): ...@@ -41,20 +41,14 @@ class MSRA_NER(HubDataset):
self.dataset_dir = os.path.join(DATA_HOME, "msra_ner") self.dataset_dir = os.path.join(DATA_HOME, "msra_ner")
if not os.path.exists(self.dataset_dir): if not os.path.exists(self.dataset_dir):
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress( ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress(
url=DATA_URL, save_path=DATA_HOME, print_progress=True) url=_DATA_URL, save_path=DATA_HOME, print_progress=True)
else: else:
logger.info("Dataset {} already cached.".format(self.dataset_dir)) logger.info("Dataset {} already cached.".format(self.dataset_dir))
self._load_label_map()
self._load_train_examples() self._load_train_examples()
self._load_test_examples() self._load_test_examples()
self._load_dev_examples() self._load_dev_examples()
def _load_label_map(self):
self.label_map_file = os.path.join(self.dataset_dir, "label_map.json")
with open(self.label_map_file) as fi:
self.label_map = json.load(fi)
def _load_train_examples(self): def _load_train_examples(self):
train_file = os.path.join(self.dataset_dir, "train.tsv") train_file = os.path.join(self.dataset_dir, "train.tsv")
self.train_examples = self._read_tsv(train_file) self.train_examples = self._read_tsv(train_file)
......
...@@ -25,7 +25,7 @@ from paddlehub.common.downloader import default_downloader ...@@ -25,7 +25,7 @@ from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
DATA_URL = "https://paddlehub-dataset.bj.bcebos.com/nlpcc-dbqa.tar.gz" _DATA_URL = "https://paddlehub-dataset.bj.bcebos.com/nlpcc-dbqa.tar.gz"
class NLPCC_DBQA(HubDataset): class NLPCC_DBQA(HubDataset):
...@@ -39,7 +39,7 @@ class NLPCC_DBQA(HubDataset): ...@@ -39,7 +39,7 @@ class NLPCC_DBQA(HubDataset):
self.dataset_dir = os.path.join(DATA_HOME, "nlpcc-dbqa") self.dataset_dir = os.path.join(DATA_HOME, "nlpcc-dbqa")
if not os.path.exists(self.dataset_dir): if not os.path.exists(self.dataset_dir):
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress( ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress(
url=DATA_URL, save_path=DATA_HOME, print_progress=True) url=_DATA_URL, save_path=DATA_HOME, print_progress=True)
else: else:
logger.info("Dataset {} already cached.".format(self.dataset_dir)) logger.info("Dataset {} already cached.".format(self.dataset_dir))
......
...@@ -76,10 +76,6 @@ class BaseReader(object): ...@@ -76,10 +76,6 @@ class BaseReader(object):
"""Gets a collection of `InputExample`s for prediction.""" """Gets a collection of `InputExample`s for prediction."""
return self.dataset.get_test_examples() return self.dataset.get_test_examples()
def get_labels(self):
"""Gets the list of labels for this data set."""
return self.dataset.get_labels()
def get_train_progress(self): def get_train_progress(self):
"""Gets progress for training phase.""" """Gets progress for training phase."""
return self.current_example, self.current_epoch return self.current_example, self.current_epoch
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册