diff --git a/demo/text-classification/text_classifier.py b/demo/text-classification/text_classifier.py index d21900cae76a2dabcd9218f7f25a2832c1342d49..b0e1042f64500d7dddc8c6db7efd51c226b7e31b 100644 --- a/demo/text-classification/text_classifier.py +++ b/demo/text-classification/text_classifier.py @@ -22,6 +22,7 @@ import paddlehub as hub # yapf: disable parser = argparse.ArgumentParser(__doc__) parser.add_argument("--num_epoch", type=int, default=3, help="Number of epoches for fine-tuning.") +parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for finetuning, input should be True or False") parser.add_argument("--dataset", type=str, default="chnsenticorp", help="Directory to model checkpoint", choices=["chnsenticorp", "nlpcc_dbqa", "lcqmc"]) parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.") parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay rate for L2 regularizer.") @@ -71,7 +72,7 @@ if __name__ == '__main__': inputs["segment_ids"].name, inputs["input_mask"].name, label.name ] # Define a classfication finetune task by PaddleHub's API - cls_task = hub.create_text_classification_task( + cls_task = hub.create_text_cls_task( feature=pooled_output, label=label, num_classes=dataset.num_labels) # Step4: Select finetune strategy, setup config and finetune