parser.add_argument("--checkpoint_dir",type=str,default=None,help="Directory to model checkpoint")
parser.add_argument("--checkpoint_dir",type=str,default=None,help="Directory to model checkpoint")
parser.add_argument("--batch_size",type=int,default=1,help="Total examples' number in batch for training.")
parser.add_argument("--batch_size",type=int,default=1,help="Total examples' number in batch for training.")
parser.add_argument("--max_seq_len",type=int,default=512,help="Number of words of the longest seqence.")
parser.add_argument("--max_seq_len",type=int,default=512,help="Number of words of the longest seqence.")
parser.add_argument("--use_gpu",type=ast.literal_eval,default=False,help="Whether use GPU for finetuning, input should be True or False")
parser.add_argument("--use_gpu",type=ast.literal_eval,default=False,help="Whether use GPU for fine-tuning, input should be True or False")
parser.add_argument("--use_data_parallel",type=ast.literal_eval,default=False,help="Whether use data parallel.")
parser.add_argument("--use_data_parallel",type=ast.literal_eval,default=False,help="Whether use data parallel.")
parser.add_argument("--network",type=str,default='bilstm',help="Pre-defined network which was connected after Transformer model, such as ERNIE, BERT ,RoBERTa and ELECTRA.")
parser.add_argument("--network",type=str,default='bilstm',help="Pre-defined network which was connected after Transformer model, such as ERNIE, BERT ,RoBERTa and ELECTRA.")
args=parser.parse_args()
args=parser.parse_args()
...
@@ -71,7 +71,7 @@ if __name__ == '__main__':
...
@@ -71,7 +71,7 @@ if __name__ == '__main__':
inputs["input_mask"].name,
inputs["input_mask"].name,
]
]
# Setup runing config for PaddleHub Finetune API
# Setup RunConfig for PaddleHub Fine-tune API
config=hub.RunConfig(
config=hub.RunConfig(
use_data_parallel=args.use_data_parallel,
use_data_parallel=args.use_data_parallel,
use_cuda=args.use_gpu,
use_cuda=args.use_gpu,
...
@@ -79,7 +79,7 @@ if __name__ == '__main__':
...
@@ -79,7 +79,7 @@ if __name__ == '__main__':
checkpoint_dir=args.checkpoint_dir,
checkpoint_dir=args.checkpoint_dir,
strategy=hub.AdamWeightDecayStrategy())
strategy=hub.AdamWeightDecayStrategy())
# Define a classfication finetune task by PaddleHub's API
# Define a classfication fine-tune task by PaddleHub's API