提交 2773c85b 编写于 作者: Z Zeyu Chen

update config and demo

上级 1510df97
import paddle.fluid as fluid import paddle.fluid as fluid
import paddlehub as hub import paddlehub as hub
# Step1: Select pre-trained model
module = hub.Module(name="ernie") module = hub.Module(name="ernie")
inputs, outputs, program = module.context(trainable=True, max_seq_len=128) inputs, outputs, program = module.context(trainable=True)
reader = hub.reader.ClassifyReader(hub.dataset.ChnSentiCorp(),
# Step2: Prepare Dataset and DataReader module.get_vocab_path())
dataset = hub.dataset.ChnSentiCorp() task = hub.create_text_cls_task(feature=outputs["pooled_output"], num_classes=2)
reader = hub.reader.ClassifyReader( strategy = hub.AdamWeightDecayStrategy(learning_rate=5e-5)
dataset=dataset, vocab_path=module.get_vocab_path(), max_seq_len=128) config = hub.RunConfig(use_cuda=True, num_epoch=3, strategy=strategy)
# Step3: Construct transfer learning task
with fluid.program_guard(program):
label = fluid.layers.data(name="label", shape=[1], dtype='int64')
pooled_output = outputs["pooled_output"]
cls_task = hub.create_text_cls_task(
feature=pooled_output, label=label, num_classes=dataset.num_labels)
# Step4: Setup config then start finetune
strategy = hub.AdamWeightDecayStrategy(learning_rate=5e-5, weight_decay=0.01)
config = hub.RunConfig(
use_cuda=True,
checkpoint_dir="./ckpt",
num_epoch=3,
batch_size=32,
strategy=strategy)
feed_list = [ feed_list = [
inputs["input_ids"].name, inputs["position_ids"].name, inputs["input_ids"].name, inputs["position_ids"].name,
inputs["segment_ids"].name, inputs["input_mask"].name, label.name inputs["segment_ids"].name, inputs["input_mask"].name,
task.variable("label").name
] ]
hub.finetune_and_eval(task, reader, feed_list, config)
hub.finetune_and_eval(
task=cls_task, data_reader=reader, feed_list=feed_list, config=config)
...@@ -30,10 +30,10 @@ class RunConfig(object): ...@@ -30,10 +30,10 @@ class RunConfig(object):
log_interval=10, log_interval=10,
eval_interval=100, eval_interval=100,
save_ckpt_interval=None, save_ckpt_interval=None,
use_cuda=False, use_cuda=True,
checkpoint_dir=None, checkpoint_dir=None,
num_epoch=10, num_epoch=1,
batch_size=None, batch_size=32,
enable_memory_optim=True, enable_memory_optim=True,
strategy=None): strategy=None):
""" Construct finetune Config """ """ Construct finetune Config """
......
...@@ -469,7 +469,7 @@ class Module(object): ...@@ -469,7 +469,7 @@ class Module(object):
def context(self, def context(self,
sign_name=None, sign_name=None,
for_test=False, for_test=False,
trainable=False, trainable=True,
regularizer=None, regularizer=None,
max_seq_len=128, max_seq_len=128,
learning_rate=1e-3): learning_rate=1e-3):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册