提交 dbb29416 编写于 作者: Z Zeyu Chen

test bert modules

上级 19ec79e8
...@@ -33,6 +33,7 @@ args = parser.parse_args() ...@@ -33,6 +33,7 @@ args = parser.parse_args()
if __name__ == '__main__': if __name__ == '__main__':
# Step1: load Paddlehub ERNIE pretrained model # Step1: load Paddlehub ERNIE pretrained model
module = hub.Module(name="ernie") module = hub.Module(name="ernie")
# module = hub.Module(name="bert_multi_cased_L-12_H-768_A-12")
inputs, outputs, program = module.context( inputs, outputs, program = module.context(
trainable=True, max_seq_len=args.max_seq_len) trainable=True, max_seq_len=args.max_seq_len)
...@@ -59,9 +60,7 @@ if __name__ == '__main__': ...@@ -59,9 +60,7 @@ if __name__ == '__main__':
] ]
# Define a classfication finetune task by PaddleHub's API # Define a classfication finetune task by PaddleHub's API
cls_task = hub.create_text_classification_task( cls_task = hub.create_text_classification_task(
feature=pooled_output, feature=pooled_output, label=label, num_classes=dataset.num_labels)
label=label,
num_classes=dataset.num_labels())
# Step4: Select finetune strategy, setup config and finetune # Step4: Select finetune strategy, setup config and finetune
strategy = hub.AdamWeightDecayStrategy( strategy = hub.AdamWeightDecayStrategy(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册