提交 7a0a2814 编写于 作者: Z Zeyu Chen

update ernie tiny demo

上级 1a2d80d5
import paddle.fluid as fluid import paddle.fluid as fluid
import paddlehub as hub import paddlehub as hub
# Step1
module = hub.Module(name="ernie") module = hub.Module(name="ernie")
inputs, outputs, program = module.context(trainable=True, max_seq_len=128) inputs, outputs, program = module.context(trainable=True, max_seq_len=128)
# Step2
reader = hub.reader.ClassifyReader( reader = hub.reader.ClassifyReader(
dataset=hub.dataset.ChnSentiCorp(), dataset=hub.dataset.ChnSentiCorp(),
vocab_path=module.get_vocab_path(), vocab_path=module.get_vocab_path(),
max_seq_len=128) max_seq_len=128)
# Step3
with fluid.program_guard(program): with fluid.program_guard(program):
label = fluid.layers.data(name="label", shape=[1], dtype='int64') label = fluid.layers.data(name="label", shape=[1], dtype='int64')
pooled_output = outputs["pooled_output"] pooled_output = outputs["pooled_output"]
feed_list = [
inputs["input_ids"].name, inputs["position_ids"].name,
inputs["segment_ids"].name, inputs["input_mask"].name, label.name
]
cls_task = hub.create_text_classification_task( cls_task = hub.create_text_classification_task(
pooled_output, label, num_classes=reader.get_num_labels()) feature=pooled_output, label=label, num_classes=reader.get_num_labels())
# Step4
strategy = hub.BERTFinetuneStrategy(
learning_rate=5e-5,
warmup_proportion=0.1,
warmup_strategy="linear_warmup_decay",
weight_decay=0.01)
strategy = hub.BERTFinetuneStrategy( config = hub.RunConfig(
weight_decay=0.01, use_cuda=True, num_epoch=3, batch_size=32, strategy=strategy)
learning_rate=5e-5,
warmup_strategy="linear_warmup_decay",
)
config = hub.RunConfig( feed_list = [
use_cuda=True, num_epoch=3, batch_size=32, strategy=strategy) inputs["input_ids"].name, inputs["position_ids"].name,
inputs["segment_ids"].name, inputs["input_mask"].name, label.name
]
hub.finetune_and_eval( hub.finetune_and_eval(
task=cls_task, data_reader=reader, feed_list=feed_list, config=config) task=cls_task, data_reader=reader, feed_list=feed_list, config=config)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册