未验证 提交 69baaf96 编写于 作者: Z Zeyu Chen 提交者: GitHub

Update README.md

上级 e1e93667
......@@ -23,8 +23,7 @@
--weight_decay:
--checkpoint_dir: 模型保存路径,PaddleHub会自动保存验证集上表现最好的模型
--num_epoch: Finetune迭代的轮数
--max_seq_len: ERNIE模型使用的最大序列长度,最大不能超过512,
若出现显存不足错误,请调低这一参数
--max_seq_len: ERNIE模型使用的最大序列长度,最大不能超过512, 若出现显存不足错误,请调低这一参数
```
## 代码步骤
......@@ -34,9 +33,8 @@
### Step1: 加载预训练模型
```python
module = hub.Module(name="ernie")
inputs, outputs, program = module.context(
trainable=True, max_seq_len=128)
module = hub.Module(name="ernie")
inputs, outputs, program = module.context(trainable=True, max_seq_len=128)
```
其中最大序列长度`max_seq_len`是可以调整的参数,建议值128,根据任务文本长度不同可以调整该值,但最大不超过512。
......@@ -54,6 +52,51 @@ BERT-Base, Chinese | bert_chinese_L-12_H-768_A-12
```python
# 即可无缝切换BERT中文模型
module = hub.Module(name="bert_chinese_L-12_H-768_A-12")
# 更换name参数即可无缝切换BERT中文模型
module = hub.Module(name="bert_chinese_L-12_H-768_A-12")
```
### Step2: 准备数据集并使用ClassifyReader读取数据
```python
with fluid.program_guard(program):
label = fluid.layers.data(name="label", shape=[1], dtype='int64')
pooled_output = outputs["pooled_output"]
feed_list = [
inputs["input_ids"].name, inputs["position_ids"].name,
inputs["segment_ids"].name, inputs["input_mask"].name, label.name
]
cls_task = hub.create_text_classification_task(
pooled_output, label, num_classes=reader.get_num_labels())
```
### Step3: 构建网络并创建分类迁移任务
```python
with fluid.program_guard(program):
label = fluid.layers.data(name="label", shape=[1], dtype='int64')
pooled_output = outputs["pooled_output"]
feed_list = [
inputs["input_ids"].name, inputs["position_ids"].name,
inputs["segment_ids"].name, inputs["input_mask"].name, label.name
]
cls_task = hub.create_text_classification_task(
pooled_output, label, num_classes=reader.get_num_labels())
```
### Step4:选择优化策略并开始Finetune
```python
strategy = hub.BERTFinetuneStrategy(
weight_decay=0.01,
learning_rate=5e-5,
warmup_strategy="linear_warmup_decay",
)
config = hub.RunConfig(use_cuda=True, num_epoch=3, batch_size=32, strategy=strategy)
hub.finetune_and_eval(task=cls_task, data_reader=reader, feed_list=feed_list, config=config)
```
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册