提交 af0f4890 编写于 作者: S Steffy-zxf

update demo for preset net

上级 c77d0dee
...@@ -34,6 +34,7 @@ parser.add_argument("--batch_size", type=int, default=1, help="Total examp ...@@ -34,6 +34,7 @@ parser.add_argument("--batch_size", type=int, default=1, help="Total examp
parser.add_argument("--max_seq_len", type=int, default=512, help="Number of words of the longest seqence.") parser.add_argument("--max_seq_len", type=int, default=512, help="Number of words of the longest seqence.")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for finetuning, input should be True or False") parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for finetuning, input should be True or False")
parser.add_argument("--use_data_parallel", type=ast.literal_eval, default=False, help="Whether use data parallel.") parser.add_argument("--use_data_parallel", type=ast.literal_eval, default=False, help="Whether use data parallel.")
parser.add_argument("--network", type=str, default='bilstm', help="Preset network which was connected after Transformer model, such as ERNIE, BERT ,RoBERTa and ELECTRA.")
args = parser.parse_args() args = parser.parse_args()
# yapf: enable. # yapf: enable.
...@@ -59,7 +60,7 @@ if __name__ == '__main__': ...@@ -59,7 +60,7 @@ if __name__ == '__main__':
# Construct transfer learning network # Construct transfer learning network
# Use "pooled_output" for classification tasks on an entire sentence. # Use "pooled_output" for classification tasks on an entire sentence.
# Use "sequence_output" for token-level output. # Use "sequence_output" for token-level output.
pooled_output = outputs["pooled_output"] pooled_output = outputs["sequence_output"]
# Setup feed list for data feeder # Setup feed list for data feeder
# Must feed all the tensor of module need # Must feed all the tensor of module need
...@@ -79,10 +80,15 @@ if __name__ == '__main__': ...@@ -79,10 +80,15 @@ if __name__ == '__main__':
strategy=hub.AdamWeightDecayStrategy()) strategy=hub.AdamWeightDecayStrategy())
# Define a classfication finetune task by PaddleHub's API # Define a classfication finetune task by PaddleHub's API
# network choice: bilstm, bow, cnn, dpcnn, gru, lstm
# If you wanna add network after ERNIE/BERT/RoBERTa/ELECTRA module,
# you must use the outputs["sequence_output"] as the feature of TextClassifierTask
# rather than outputs["pooled_output"]
cls_task = hub.TextClassifierTask( cls_task = hub.TextClassifierTask(
data_reader=reader, data_reader=reader,
feature=pooled_output, feature=pooled_output,
feed_list=feed_list, feed_list=feed_list,
network=args.network,
num_classes=dataset.num_labels, num_classes=dataset.num_labels,
config=config) config=config)
......
export FLAGS_eager_delete_tensor_gb=0.0 export FLAGS_eager_delete_tensor_gb=0.0
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0,1,2,3
CKPT_DIR="./ckpt_chnsenticorp" CKPT_DIR="./ckpt_chnsenticorp"
...@@ -12,7 +12,8 @@ python -u text_classifier.py \ ...@@ -12,7 +12,8 @@ python -u text_classifier.py \
--max_seq_len=128 \ --max_seq_len=128 \
--warmup_proportion=0.1 \ --warmup_proportion=0.1 \
--num_epoch=3 \ --num_epoch=3 \
--use_data_parallel=True --use_data_parallel=True \
--network=bilstm
# The sugguested hyper parameters for difference task # The sugguested hyper parameters for difference task
# for ChineseGLUE: # for ChineseGLUE:
......
...@@ -7,3 +7,4 @@ python -u predict.py --checkpoint_dir=$CKPT_DIR \ ...@@ -7,3 +7,4 @@ python -u predict.py --checkpoint_dir=$CKPT_DIR \
--max_seq_len=128 \ --max_seq_len=128 \
--use_gpu=True \ --use_gpu=True \
--batch_size=24 \ --batch_size=24 \
--network=bilstm
...@@ -85,11 +85,15 @@ if __name__ == '__main__': ...@@ -85,11 +85,15 @@ if __name__ == '__main__':
strategy=strategy) strategy=strategy)
# Define a classfication finetune task by PaddleHub's API # Define a classfication finetune task by PaddleHub's API
# network choice: bilstm, bow, cnn, dpcnn, gru, lstm
# If you wanna add network after ERNIE/BERT/RoBERTa/ELECTRA module,
# you must use the outputs["sequence_output"] as the feature of TextClassifierTask
# rather than outputs["pooled_output"]
cls_task = hub.TextClassifierTask( cls_task = hub.TextClassifierTask(
data_reader=reader, data_reader=reader,
feature=pooled_output, feature=pooled_output,
feed_list=feed_list, feed_list=feed_list,
network='dpcnn', network=args.network,
num_classes=dataset.num_labels, num_classes=dataset.num_labels,
config=config, config=config,
metrics_choices=metrics_choices) metrics_choices=metrics_choices)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册