提交 2773c85b 编写于 作者: Z Zeyu Chen

update config and demo

上级 1510df97
import paddle.fluid as fluid
import paddlehub as hub
# Step1: Select pre-trained model
module = hub.Module(name="ernie")
inputs, outputs, program = module.context(trainable=True, max_seq_len=128)
# Step2: Prepare Dataset and DataReader
dataset = hub.dataset.ChnSentiCorp()
reader = hub.reader.ClassifyReader(
dataset=dataset, vocab_path=module.get_vocab_path(), max_seq_len=128)
# Step3: Construct transfer learning task
with fluid.program_guard(program):
label = fluid.layers.data(name="label", shape=[1], dtype='int64')
pooled_output = outputs["pooled_output"]
cls_task = hub.create_text_cls_task(
feature=pooled_output, label=label, num_classes=dataset.num_labels)
# Step4: Setup config then start finetune
strategy = hub.AdamWeightDecayStrategy(learning_rate=5e-5, weight_decay=0.01)
config = hub.RunConfig(
use_cuda=True,
checkpoint_dir="./ckpt",
num_epoch=3,
batch_size=32,
strategy=strategy)
inputs, outputs, program = module.context(trainable=True)
reader = hub.reader.ClassifyReader(hub.dataset.ChnSentiCorp(),
module.get_vocab_path())
task = hub.create_text_cls_task(feature=outputs["pooled_output"], num_classes=2)
strategy = hub.AdamWeightDecayStrategy(learning_rate=5e-5)
config = hub.RunConfig(use_cuda=True, num_epoch=3, strategy=strategy)
feed_list = [
inputs["input_ids"].name, inputs["position_ids"].name,
inputs["segment_ids"].name, inputs["input_mask"].name, label.name
inputs["segment_ids"].name, inputs["input_mask"].name,
task.variable("label").name
]
hub.finetune_and_eval(
task=cls_task, data_reader=reader, feed_list=feed_list, config=config)
hub.finetune_and_eval(task, reader, feed_list, config)
......@@ -30,10 +30,10 @@ class RunConfig(object):
log_interval=10,
eval_interval=100,
save_ckpt_interval=None,
use_cuda=False,
use_cuda=True,
checkpoint_dir=None,
num_epoch=10,
batch_size=None,
num_epoch=1,
batch_size=32,
enable_memory_optim=True,
strategy=None):
""" Construct finetune Config """
......
......@@ -469,7 +469,7 @@ class Module(object):
def context(self,
sign_name=None,
for_test=False,
trainable=False,
trainable=True,
regularizer=None,
max_seq_len=128,
learning_rate=1e-3):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册