未验证 提交 8476d724 编写于 作者: Z Zeyu Chen 提交者: GitHub

Update simple_demo.py

上级 0ad41389
#coding:utf-8
import paddle.fluid as fluid import paddle.fluid as fluid
import paddlehub as hub import paddlehub as hub
# Load ERNIE pretrained model
module = hub.Module(name="ernie") module = hub.Module(name="ernie")
inputs, outputs, program = module.context(trainable=True, max_seq_len=128) inputs, outputs, program = module.context(trainable=True, max_seq_len=128)
# Create ClassifyReader
reader = hub.reader.ClassifyReader( reader = hub.reader.ClassifyReader(
hub.dataset.ChnSentiCorp(), module.get_vocab_path(), max_seq_len=128) hub.dataset.ChnSentiCorp(), module.get_vocab_path(), max_seq_len=128)
# Create Text Classification Task
task = hub.create_text_cls_task(feature=outputs["pooled_output"], num_classes=2) task = hub.create_text_cls_task(feature=outputs["pooled_output"], num_classes=2)
# Configure Fine-tune strategy
strategy = hub.AdamWeightDecayStrategy(learning_rate=5e-5) strategy = hub.AdamWeightDecayStrategy(learning_rate=5e-5)
# Setting runing config
config = hub.RunConfig( config = hub.RunConfig(
use_cuda=True, num_epoch=3, batch_size=32, strategy=strategy) use_cuda=True, num_epoch=3, batch_size=32, strategy=strategy)
feed_list = [ feed_list = [
inputs["input_ids"].name, inputs["position_ids"].name, inputs["input_ids"].name, inputs["position_ids"].name,
inputs["segment_ids"].name, inputs["input_mask"].name, inputs["segment_ids"].name, inputs["input_mask"].name,
task.variable("label").name task.variable("label").name
] ]
# Start fine-tuning
hub.finetune_and_eval(task, reader, feed_list, config) hub.finetune_and_eval(task, reader, feed_list, config)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册