提交 1166014b 编写于 作者: Z Zeyu Chen

migrate BERTFinetune to AdamWeightDecay

上级 1916d77b
......@@ -95,7 +95,7 @@ with fluid.program_guard(program): # NOTE: 必须使用fluid.program_guard接口
### Step4:选择优化策略并开始Finetune
```python
strategy = hub.BERTFinetuneStrategy(
strategy = hub.AdamWeightDecayStrategy(
weight_decay=0.01,
learning_rate=5e-5,
warmup_strategy="linear_warmup_decay",
......
......@@ -21,7 +21,7 @@ with fluid.program_guard(program):
feature=pooled_output, label=label, num_classes=reader.get_num_labels())
# Step4
strategy = hub.BERTFinetuneStrategy(
strategy = hub.AdamWeightDecayStrategy(
learning_rate=5e-5,
warmup_proportion=0.1,
warmup_strategy="linear_warmup_decay",
......
......@@ -62,7 +62,7 @@ if __name__ == '__main__':
pooled_output, label, num_classes=num_labels)
# Step4: Select finetune strategy, setup config and finetune
strategy = hub.BERTFinetuneStrategy(
strategy = hub.AdamWeightDecayStrategy(
weight_decay=args.weight_decay,
learning_rate=args.learning_rate,
warmup_strategy="linear_warmup_decay",
......
......@@ -62,7 +62,7 @@ if __name__ == '__main__':
pooled_output, label, num_classes=num_labels)
# Step4: Select finetune strategy, setup config and finetune
strategy = hub.BERTFinetuneStrategy(
strategy = hub.AdamWeightDecayStrategy(
weight_decay=args.weight_decay,
learning_rate=args.learning_rate,
warmup_strategy="linear_warmup_decay",
......
......@@ -61,7 +61,7 @@ if __name__ == '__main__':
pooled_output, label, num_classes=reader.get_num_labels())
# Step4: Select finetune strategy, setup config and finetune
strategy = hub.BERTFinetuneStrategy(
strategy = hub.AdamWeightDecayStrategy(
weight_decay=args.weight_decay,
learning_rate=args.learning_rate,
warmup_strategy="linear_warmup_decay",
......
......@@ -69,7 +69,7 @@ if __name__ == '__main__':
num_classes=num_labels)
# Select a finetune strategy
strategy = hub.BERTFinetuneStrategy(
strategy = hub.AdamWeightDecayStrategy(
weight_decay=args.weight_decay,
learning_rate=args.learning_rate,
warmup_strategy="linear_warmup_decay",
......
......@@ -39,5 +39,5 @@ from .finetune.task import create_text_classification_task
from .finetune.task import create_img_classification_task
from .finetune.finetune import finetune_and_eval
from .finetune.config import RunConfig
from .finetune.strategy import BERTFinetuneStrategy
from .finetune.strategy import AdamWeightDecayStrategy
from .finetune.strategy import DefaultStrategy
......@@ -25,7 +25,7 @@ import numpy as np
from visualdl import LogWriter
from paddlehub.common.logger import logger
from paddlehub.finetune.strategy import BERTFinetuneStrategy, DefaultStrategy
from paddlehub.finetune.strategy import AdamWeightDecayStrategy, DefaultStrategy
from paddlehub.finetune.checkpoint import load_checkpoint, save_checkpoint
from paddlehub.finetune.evaluate import evaluate_cls_task, evaluate_seq_labeling_task
import paddlehub as hub
......@@ -74,7 +74,7 @@ def _finetune_seq_label_task(task,
data_feeder = fluid.DataFeeder(feed_list=feed_list, place=place)
# Select strategy
if isinstance(config.strategy, hub.BERTFinetuneStrategy):
if isinstance(config.strategy, hub.AdamWeightDecayStrategy):
scheduled_lr = config.strategy.execute(loss, main_program,
data_reader, config)
elif isinstance(config.strategy, hub.DefaultStrategy):
......@@ -173,7 +173,7 @@ def _finetune_cls_task(task, data_reader, feed_list, config=None,
data_feeder = fluid.DataFeeder(feed_list=feed_list, place=place)
# select strategy
if isinstance(config.strategy, hub.BERTFinetuneStrategy):
if isinstance(config.strategy, hub.AdamWeightDecayStrategy):
scheduled_lr = config.strategy.execute(loss, main_program,
data_reader, config)
elif isinstance(config.strategy, hub.DefaultStrategy):
......
......@@ -61,7 +61,7 @@ class DefaultStrategy(object):
return "DefaultStrategy"
class BERTFinetuneStrategy(DefaultStrategy):
class AdamWeightDecayStrategy(DefaultStrategy):
def __init__(self,
learning_rate=1e-4,
warmup_strategy="linear_warmup_decay",
......@@ -114,7 +114,7 @@ class BERTFinetuneStrategy(DefaultStrategy):
# TODO complete __str__()
def __str__(self):
return "BERTFintuneStrategy"
return "AdamWeightDecayStrategy"
class DefaultFinetuneStrategy(DefaultStrategy):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册