提交 b296c718 编写于 作者: Z Zeyu Chen

add choices for scripts

上级 da2db0f3
......@@ -21,12 +21,13 @@ import paddlehub as hub
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--num_epoch", type=int, default=3, help="Number of epoches for fine-tuning.")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="Whether use GPU for finetuning, input should be True or False")
parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.")
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay rate for L2 regularizer.")
parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint")
parser.add_argument("--warmup_proportion", type=float, default=0.0, help="Warmup proportion params for warmup strategy")
parser.add_argument("--max_seq_len", type=int, default=512, help="Number of words of the longest seqence.")
parser.add_argument("--batch_size", type=int, default=32, help="Total examples' number in batch for training.")
parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint")
args = parser.parse_args()
# yapf: enable.
......@@ -76,7 +77,7 @@ if __name__ == '__main__':
# Setup runing config for PaddleHub Finetune API
config = hub.RunConfig(
use_cuda=True,
use_cuda=args.use_gpu,
num_epoch=args.num_epoch,
batch_size=args.batch_size,
checkpoint_dir=args.checkpoint_dir,
......
......@@ -149,8 +149,10 @@ python cls_predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128
```
其中CKPT_DIR为Finetune API保存最佳模型的路径, max_seq_len是ERNIE模型的最大序列长度,*请与训练时配置的参数保持一致*
参数配置正确后,请执行脚本`sh run_predict.sh`,即可看到以下文本分类预测结果。如需了解更多预测步骤,请参考`cls_predict.py`
参数配置正确后,请执行脚本`sh run_predict.sh`,即可看到以下文本分类预测结果, 以及最终准确率。
如需了解更多预测步骤,请参考`cls_predict.py`
```
text=键盘缝隙大进灰,装系统自己不会装,屏幕有点窄玩游戏人物有点变形 label=0 predict=0
accuracy = 0.954267
```
......@@ -15,5 +15,5 @@ python -u text_classifier.py \
--checkpoint_dir=${CKPT_DIR} \
--learning_rate=5e-5 \
--weight_decay=0.01 \
--max_seq_len=128
--num_epoch=3 \
--max_seq_len=128 \
--num_epoch=3
......@@ -22,8 +22,7 @@ import paddlehub as hub
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--num_epoch", type=int, default=3, help="Number of epoches for fine-tuning.")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for finetuning, input should be True or False")
parser.add_argument("--dataset", type=str, default="senticorp", help="Directory to model checkpoint")
parser.add_argument("--dataset", type=str, default="chnsenticorp", help="Directory to model checkpoint", choices=["chnsenticorp", "nlpcc_dbqa", "lcqmc"])
parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.")
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay rate for L2 regularizer.")
parser.add_argument("--warmup_proportion", type=float, default=0.0, help="Warmup proportion params for warmup strategy")
......
......@@ -25,7 +25,7 @@ from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger
DATA_URL = "https://paddlehub-dataset.bj.bcebos.com/chnsenticorp.tar.gz"
_DATA_URL = "https://paddlehub-dataset.bj.bcebos.com/chnsenticorp.tar.gz"
class ChnSentiCorp(HubDataset):
......@@ -38,7 +38,7 @@ class ChnSentiCorp(HubDataset):
self.dataset_dir = os.path.join(DATA_HOME, "chnsenticorp")
if not os.path.exists(self.dataset_dir):
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress(
url=DATA_URL, save_path=DATA_HOME, print_progress=True)
url=_DATA_URL, save_path=DATA_HOME, print_progress=True)
else:
logger.info("Dataset {} already cached.".format(self.dataset_dir))
......
......@@ -25,7 +25,7 @@ from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger
DATA_URL = "https://paddlehub-dataset.bj.bcebos.com/lcqmc.tar.gz"
_DATA_URL = "https://paddlehub-dataset.bj.bcebos.com/lcqmc.tar.gz"
class LCQMC(HubDataset):
......@@ -33,7 +33,7 @@ class LCQMC(HubDataset):
self.dataset_dir = os.path.join(DATA_HOME, "lcqmc")
if not os.path.exists(self.dataset_dir):
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress(
url=DATA_URL, save_path=DATA_HOME, print_progress=True)
url=_DATA_URL, save_path=DATA_HOME, print_progress=True)
else:
logger.info("Dataset {} already cached.".format(self.dataset_dir))
......
......@@ -26,7 +26,7 @@ from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger
DATA_URL = "https://paddlehub-dataset.bj.bcebos.com/msra_ner.tar.gz"
_DATA_URL = "https://paddlehub-dataset.bj.bcebos.com/msra_ner.tar.gz"
class MSRA_NER(HubDataset):
......@@ -41,20 +41,14 @@ class MSRA_NER(HubDataset):
self.dataset_dir = os.path.join(DATA_HOME, "msra_ner")
if not os.path.exists(self.dataset_dir):
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress(
url=DATA_URL, save_path=DATA_HOME, print_progress=True)
url=_DATA_URL, save_path=DATA_HOME, print_progress=True)
else:
logger.info("Dataset {} already cached.".format(self.dataset_dir))
self._load_label_map()
self._load_train_examples()
self._load_test_examples()
self._load_dev_examples()
def _load_label_map(self):
self.label_map_file = os.path.join(self.dataset_dir, "label_map.json")
with open(self.label_map_file) as fi:
self.label_map = json.load(fi)
def _load_train_examples(self):
train_file = os.path.join(self.dataset_dir, "train.tsv")
self.train_examples = self._read_tsv(train_file)
......
......@@ -25,7 +25,7 @@ from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger
DATA_URL = "https://paddlehub-dataset.bj.bcebos.com/nlpcc-dbqa.tar.gz"
_DATA_URL = "https://paddlehub-dataset.bj.bcebos.com/nlpcc-dbqa.tar.gz"
class NLPCC_DBQA(HubDataset):
......@@ -39,7 +39,7 @@ class NLPCC_DBQA(HubDataset):
self.dataset_dir = os.path.join(DATA_HOME, "nlpcc-dbqa")
if not os.path.exists(self.dataset_dir):
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress(
url=DATA_URL, save_path=DATA_HOME, print_progress=True)
url=_DATA_URL, save_path=DATA_HOME, print_progress=True)
else:
logger.info("Dataset {} already cached.".format(self.dataset_dir))
......
......@@ -76,10 +76,6 @@ class BaseReader(object):
"""Gets a collection of `InputExample`s for prediction."""
return self.dataset.get_test_examples()
def get_labels(self):
"""Gets the list of labels for this data set."""
return self.dataset.get_labels()
def get_train_progress(self):
"""Gets progress for training phase."""
return self.current_example, self.current_epoch
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册