提交 1166014b 编写于 作者: Z Zeyu Chen

migrate BERTFinetune to AdamWeightDecay

上级 1916d77b
...@@ -95,7 +95,7 @@ with fluid.program_guard(program): # NOTE: 必须使用fluid.program_guard接口 ...@@ -95,7 +95,7 @@ with fluid.program_guard(program): # NOTE: 必须使用fluid.program_guard接口
### Step4:选择优化策略并开始Finetune ### Step4:选择优化策略并开始Finetune
```python ```python
strategy = hub.BERTFinetuneStrategy( strategy = hub.AdamWeightDecayStrategy(
weight_decay=0.01, weight_decay=0.01,
learning_rate=5e-5, learning_rate=5e-5,
warmup_strategy="linear_warmup_decay", warmup_strategy="linear_warmup_decay",
......
...@@ -21,7 +21,7 @@ with fluid.program_guard(program): ...@@ -21,7 +21,7 @@ with fluid.program_guard(program):
feature=pooled_output, label=label, num_classes=reader.get_num_labels()) feature=pooled_output, label=label, num_classes=reader.get_num_labels())
# Step4 # Step4
strategy = hub.BERTFinetuneStrategy( strategy = hub.AdamWeightDecayStrategy(
learning_rate=5e-5, learning_rate=5e-5,
warmup_proportion=0.1, warmup_proportion=0.1,
warmup_strategy="linear_warmup_decay", warmup_strategy="linear_warmup_decay",
......
...@@ -62,7 +62,7 @@ if __name__ == '__main__': ...@@ -62,7 +62,7 @@ if __name__ == '__main__':
pooled_output, label, num_classes=num_labels) pooled_output, label, num_classes=num_labels)
# Step4: Select finetune strategy, setup config and finetune # Step4: Select finetune strategy, setup config and finetune
strategy = hub.BERTFinetuneStrategy( strategy = hub.AdamWeightDecayStrategy(
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
learning_rate=args.learning_rate, learning_rate=args.learning_rate,
warmup_strategy="linear_warmup_decay", warmup_strategy="linear_warmup_decay",
......
...@@ -62,7 +62,7 @@ if __name__ == '__main__': ...@@ -62,7 +62,7 @@ if __name__ == '__main__':
pooled_output, label, num_classes=num_labels) pooled_output, label, num_classes=num_labels)
# Step4: Select finetune strategy, setup config and finetune # Step4: Select finetune strategy, setup config and finetune
strategy = hub.BERTFinetuneStrategy( strategy = hub.AdamWeightDecayStrategy(
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
learning_rate=args.learning_rate, learning_rate=args.learning_rate,
warmup_strategy="linear_warmup_decay", warmup_strategy="linear_warmup_decay",
......
...@@ -61,7 +61,7 @@ if __name__ == '__main__': ...@@ -61,7 +61,7 @@ if __name__ == '__main__':
pooled_output, label, num_classes=reader.get_num_labels()) pooled_output, label, num_classes=reader.get_num_labels())
# Step4: Select finetune strategy, setup config and finetune # Step4: Select finetune strategy, setup config and finetune
strategy = hub.BERTFinetuneStrategy( strategy = hub.AdamWeightDecayStrategy(
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
learning_rate=args.learning_rate, learning_rate=args.learning_rate,
warmup_strategy="linear_warmup_decay", warmup_strategy="linear_warmup_decay",
......
...@@ -69,7 +69,7 @@ if __name__ == '__main__': ...@@ -69,7 +69,7 @@ if __name__ == '__main__':
num_classes=num_labels) num_classes=num_labels)
# Select a finetune strategy # Select a finetune strategy
strategy = hub.BERTFinetuneStrategy( strategy = hub.AdamWeightDecayStrategy(
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
learning_rate=args.learning_rate, learning_rate=args.learning_rate,
warmup_strategy="linear_warmup_decay", warmup_strategy="linear_warmup_decay",
......
...@@ -39,5 +39,5 @@ from .finetune.task import create_text_classification_task ...@@ -39,5 +39,5 @@ from .finetune.task import create_text_classification_task
from .finetune.task import create_img_classification_task from .finetune.task import create_img_classification_task
from .finetune.finetune import finetune_and_eval from .finetune.finetune import finetune_and_eval
from .finetune.config import RunConfig from .finetune.config import RunConfig
from .finetune.strategy import BERTFinetuneStrategy from .finetune.strategy import AdamWeightDecayStrategy
from .finetune.strategy import DefaultStrategy from .finetune.strategy import DefaultStrategy
...@@ -25,7 +25,7 @@ import numpy as np ...@@ -25,7 +25,7 @@ import numpy as np
from visualdl import LogWriter from visualdl import LogWriter
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
from paddlehub.finetune.strategy import BERTFinetuneStrategy, DefaultStrategy from paddlehub.finetune.strategy import AdamWeightDecayStrategy, DefaultStrategy
from paddlehub.finetune.checkpoint import load_checkpoint, save_checkpoint from paddlehub.finetune.checkpoint import load_checkpoint, save_checkpoint
from paddlehub.finetune.evaluate import evaluate_cls_task, evaluate_seq_labeling_task from paddlehub.finetune.evaluate import evaluate_cls_task, evaluate_seq_labeling_task
import paddlehub as hub import paddlehub as hub
...@@ -74,7 +74,7 @@ def _finetune_seq_label_task(task, ...@@ -74,7 +74,7 @@ def _finetune_seq_label_task(task,
data_feeder = fluid.DataFeeder(feed_list=feed_list, place=place) data_feeder = fluid.DataFeeder(feed_list=feed_list, place=place)
# Select strategy # Select strategy
if isinstance(config.strategy, hub.BERTFinetuneStrategy): if isinstance(config.strategy, hub.AdamWeightDecayStrategy):
scheduled_lr = config.strategy.execute(loss, main_program, scheduled_lr = config.strategy.execute(loss, main_program,
data_reader, config) data_reader, config)
elif isinstance(config.strategy, hub.DefaultStrategy): elif isinstance(config.strategy, hub.DefaultStrategy):
...@@ -173,7 +173,7 @@ def _finetune_cls_task(task, data_reader, feed_list, config=None, ...@@ -173,7 +173,7 @@ def _finetune_cls_task(task, data_reader, feed_list, config=None,
data_feeder = fluid.DataFeeder(feed_list=feed_list, place=place) data_feeder = fluid.DataFeeder(feed_list=feed_list, place=place)
# select strategy # select strategy
if isinstance(config.strategy, hub.BERTFinetuneStrategy): if isinstance(config.strategy, hub.AdamWeightDecayStrategy):
scheduled_lr = config.strategy.execute(loss, main_program, scheduled_lr = config.strategy.execute(loss, main_program,
data_reader, config) data_reader, config)
elif isinstance(config.strategy, hub.DefaultStrategy): elif isinstance(config.strategy, hub.DefaultStrategy):
......
...@@ -61,7 +61,7 @@ class DefaultStrategy(object): ...@@ -61,7 +61,7 @@ class DefaultStrategy(object):
return "DefaultStrategy" return "DefaultStrategy"
class BERTFinetuneStrategy(DefaultStrategy): class AdamWeightDecayStrategy(DefaultStrategy):
def __init__(self, def __init__(self,
learning_rate=1e-4, learning_rate=1e-4,
warmup_strategy="linear_warmup_decay", warmup_strategy="linear_warmup_decay",
...@@ -114,7 +114,7 @@ class BERTFinetuneStrategy(DefaultStrategy): ...@@ -114,7 +114,7 @@ class BERTFinetuneStrategy(DefaultStrategy):
# TODO complete __str__() # TODO complete __str__()
def __str__(self): def __str__(self):
return "BERTFintuneStrategy" return "AdamWeightDecayStrategy"
class DefaultFinetuneStrategy(DefaultStrategy): class DefaultFinetuneStrategy(DefaultStrategy):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册