提交 dbb29416 编写于 作者: Z Zeyu Chen

test bert modules

上级 19ec79e8
......@@ -33,6 +33,7 @@ args = parser.parse_args()
if __name__ == '__main__':
# Step1: load Paddlehub ERNIE pretrained model
module = hub.Module(name="ernie")
# module = hub.Module(name="bert_multi_cased_L-12_H-768_A-12")
inputs, outputs, program = module.context(
trainable=True, max_seq_len=args.max_seq_len)
......@@ -59,9 +60,7 @@ if __name__ == '__main__':
]
# Define a classfication finetune task by PaddleHub's API
cls_task = hub.create_text_classification_task(
feature=pooled_output,
label=label,
num_classes=dataset.num_labels())
feature=pooled_output, label=label, num_classes=dataset.num_labels)
# Step4: Select finetune strategy, setup config and finetune
strategy = hub.AdamWeightDecayStrategy(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册