提交 7b96067a 编写于 作者: Z Zeyu Chen

add Strategy

上级 dc0b2847
...@@ -43,6 +43,7 @@ args = parser.parse_args() ...@@ -43,6 +43,7 @@ args = parser.parse_args()
# yapf: enable. # yapf: enable.
if __name__ == '__main__': if __name__ == '__main__':
strategy = hub.BERTFinetuneStrategy(weight_decay=args.weight_decay)
config = hub.FinetuneConfig( config = hub.FinetuneConfig(
log_interval=10, log_interval=10,
eval_interval=100, eval_interval=100,
...@@ -51,9 +52,7 @@ if __name__ == '__main__': ...@@ -51,9 +52,7 @@ if __name__ == '__main__':
learning_rate=args.learning_rate, learning_rate=args.learning_rate,
num_epoch=args.num_epoch, num_epoch=args.num_epoch,
batch_size=args.batch_size, batch_size=args.batch_size,
max_seq_len=args.max_seq_len, strategy=strategy)
weight_decay=args.weight_decay,
finetune_strategy="bert_finetune")
# loading Paddlehub BERT # loading Paddlehub BERT
module = hub.Module(module_dir=args.hub_module_dir) module = hub.Module(module_dir=args.hub_module_dir)
......
...@@ -15,6 +15,7 @@ from . import module ...@@ -15,6 +15,7 @@ from . import module
from . import common from . import common
from . import io from . import io
from . import dataset from . import dataset
from . import finetune
from .common.dir import USER_HOME from .common.dir import USER_HOME
from .common.dir import HUB_HOME from .common.dir import HUB_HOME
...@@ -35,6 +36,8 @@ from .finetune.network import append_mlp_classifier ...@@ -35,6 +36,8 @@ from .finetune.network import append_mlp_classifier
from .finetune.finetune import finetune_and_eval from .finetune.finetune import finetune_and_eval
from .finetune.config import FinetuneConfig from .finetune.config import FinetuneConfig
from .finetune.task import Task from .finetune.task import Task
from .finetune.strategy import BERTFinetuneStrategy
from .finetune.strategy import DefaultStrategy
from .reader import BERTTokenizeReader from .reader import BERTTokenizeReader
from .reader.cv_reader import ImageClassificationReader from .reader.cv_reader import ImageClassificationReader
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections from .strategy import DefaultStrategy
class FinetuneConfig(object): class FinetuneConfig(object):
...@@ -30,8 +30,8 @@ class FinetuneConfig(object): ...@@ -30,8 +30,8 @@ class FinetuneConfig(object):
max_seq_len=128, max_seq_len=128,
weight_decay=None, weight_decay=None,
warmup_proportion=0.0, warmup_proportion=0.0,
finetune_strategy=None,
enable_memory_optim=True, enable_memory_optim=True,
strategy=None,
optimizer="adam"): optimizer="adam"):
""" Construct finetune Config """ """ Construct finetune Config """
self._log_interval = log_interval self._log_interval = log_interval
...@@ -43,9 +43,10 @@ class FinetuneConfig(object): ...@@ -43,9 +43,10 @@ class FinetuneConfig(object):
self._num_epoch = num_epoch self._num_epoch = num_epoch
self._batch_size = batch_size self._batch_size = batch_size
self._max_seq_len = max_seq_len self._max_seq_len = max_seq_len
self._weight_decay = weight_decay if strategy is None:
self._warmup_proportion = warmup_proportion self._strategy = DefaultStrategy()
self._finetune_strategy = finetune_strategy else:
self._strategy = strategy
self._enable_memory_optim = enable_memory_optim self._enable_memory_optim = enable_memory_optim
self._optimizer = optimizer self._optimizer = optimizer
...@@ -94,8 +95,8 @@ class FinetuneConfig(object): ...@@ -94,8 +95,8 @@ class FinetuneConfig(object):
return self._warmup_proportion return self._warmup_proportion
@property @property
def finetune_strategy(self): def strategy(self):
return self._finetune_strategy return self._strategy
@property @property
def enable_memory_optim(self): def enable_memory_optim(self):
......
...@@ -18,13 +18,16 @@ from __future__ import print_function ...@@ -18,13 +18,16 @@ from __future__ import print_function
import os import os
import time import time
import multiprocessing
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle_hub as hub
from visualdl import LogWriter from visualdl import LogWriter
from paddle_hub.common.logger import logger from paddle_hub.common.logger import logger
from paddle_hub.finetune.optimization import bert_finetune from paddle_hub.finetune.optimization import bert_finetune
from paddle_hub.finetune.strategy import BERTFinetuneStrategy, DefaultStrategy
from paddle_hub.finetune.checkpoint import load_checkpoint, save_checkpoint from paddle_hub.finetune.checkpoint import load_checkpoint, save_checkpoint
...@@ -76,12 +79,12 @@ def _finetune_model(task, data_reader, feed_list, config=None, do_eval=False): ...@@ -76,12 +79,12 @@ def _finetune_model(task, data_reader, feed_list, config=None, do_eval=False):
exe = fluid.Executor(place=place) exe = fluid.Executor(place=place)
data_feeder = fluid.DataFeeder(feed_list=feed_list, place=place) data_feeder = fluid.DataFeeder(feed_list=feed_list, place=place)
if config.finetune_strategy == "bert_finetune": # select strategy
scheduled_lr = bert_finetune(task, main_program, data_reader, if isinstance(config.strategy, hub.BERTFinetuneStrategy):
config, dev_count) scheduled_lr = config.strategy.execute(loss, main_program,
elif config.optimizer == "adam": data_reader, config)
optimizer = fluid.optimizer.Adam(learning_rate=config.learning_rate) elif isinstance(config.optimizer, hub.DefaultStrategy):
optimizer.minimize(loss) config.strategy.execute(loss)
#TODO: add more finetune strategy #TODO: add more finetune strategy
_do_memory_optimization(task, config) _do_memory_optimization(task, config)
......
...@@ -19,12 +19,9 @@ from __future__ import print_function ...@@ -19,12 +19,9 @@ from __future__ import print_function
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
"""
Finetune optimization strategy
"""
def bert_finetune(task, train_program, data_processor, config, dev_count): def bert_finetune(task, main_program, data_processor, config, dev_count):
# calculate wamrup step # calculate wamrup step
num_train_examples = data_processor.get_num_examples(phase='train') num_train_examples = data_processor.get_num_examples(phase='train')
max_train_steps = config.num_epoch * num_train_examples // config.batch_size // dev_count max_train_steps = config.num_epoch * num_train_examples // config.batch_size // dev_count
...@@ -32,20 +29,19 @@ def bert_finetune(task, train_program, data_processor, config, dev_count): ...@@ -32,20 +29,19 @@ def bert_finetune(task, train_program, data_processor, config, dev_count):
loss = task.variable("loss") loss = task.variable("loss")
scheduled_lr = adam_weight_decay_optimizer_with_linear_warmup( scheduled_lr = adam_weight_decay_optimizer_with_linear_warmup(
loss, warmup_steps, max_train_steps, config.learning_rate, loss, warmup_steps, max_train_steps, config.learning_rate, main_program,
train_program, config.weight_decay) config.weight_decay)
return scheduled_lr return scheduled_lr
def adam_weight_decay_optimizer_with_noam_decay( def adam_weight_decay_optimization(loss,
loss, warmup_steps,
warmup_steps, num_train_steps,
num_train_steps, learning_rate,
learning_rate, main_program,
train_program, weight_decay,
weight_decay, scheduler='linear_warmup_decay'):
scheduler='linear_warmup_decay'):
if warmup_steps > 0: if warmup_steps > 0:
if scheduler == 'noam_decay': if scheduler == 'noam_decay':
scheduled_lr = fluid.layers.learning_rate_scheduler\ scheduled_lr = fluid.layers.learning_rate_scheduler\
...@@ -77,64 +73,7 @@ def adam_weight_decay_optimizer_with_noam_decay( ...@@ -77,64 +73,7 @@ def adam_weight_decay_optimizer_with_noam_decay(
param_list = dict() param_list = dict()
for param in train_program.global_block().all_parameters(): for param in main_program.global_block().all_parameters():
param_list[param.name] = param * 1.0
param_list[param.name].stop_gradient = True
_, param_grads = optimizer.minimize(loss)
if weight_decay > 0:
for param, grad in param_grads:
if exclude_from_weight_decay(param.name):
continue
with param.block.program._optimized_guard(
[param, grad]), fluid.framework.name_scope("weight_decay"):
updated_param = param - param_list[
param.name] * weight_decay * scheduled_lr
fluid.layers.assign(output=param, input=updated_param)
return scheduled_lr
def adam_weight_decay_optimizer_with_linear_warmup(loss,
warmup_steps,
num_train_steps,
learning_rate,
train_program,
weight_decay,
scheduler='noam_decay'):
if warmup_steps > 0:
if scheduler == 'noam_decay':
scheduled_lr = fluid.layers.learning_rate_scheduler\
.noam_decay(1/(warmup_steps *(learning_rate ** 2)),
warmup_steps)
elif scheduler == 'linear_warmup_decay':
scheduled_lr = linear_warmup_decay(learning_rate, warmup_steps,
num_train_steps)
else:
raise ValueError("Unkown learning rate scheduler, should be "
"'noam_decay' or 'linear_warmup_decay'")
optimizer = fluid.optimizer.Adam(learning_rate=scheduled_lr)
else:
optimizer = fluid.optimizer.Adam(learning_rate=learning_rate)
scheduled_lr = learning_rate
clip_norm_thres = 1.0
fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=clip_norm_thres))
def exclude_from_weight_decay(name):
if name.find("layer_norm") > -1:
return True
bias_suffix = ["_bias", "_b", ".b_0"]
for suffix in bias_suffix:
if name.endswith(suffix):
return True
return False
param_list = dict()
for param in train_program.global_block().all_parameters():
param_list[param.name] = param * 1.0 param_list[param.name] = param * 1.0
param_list[param.name].stop_gradient = True param_list[param.name].stop_gradient = True
......
...@@ -191,7 +191,7 @@ class Module(object): ...@@ -191,7 +191,7 @@ class Module(object):
def _init_with_module_file(self, module_dir): def _init_with_module_file(self, module_dir):
checker = ModuleChecker(module_dir) checker = ModuleChecker(module_dir)
if not checker.check(): if not checker.check():
logger.error("module check fail") logger.error("Module init failed on {}".format(module_dir))
exit(1) exit(1)
self.helper = ModuleHelper(module_dir) self.helper = ModuleHelper(module_dir)
...@@ -205,7 +205,7 @@ class Module(object): ...@@ -205,7 +205,7 @@ class Module(object):
self._load_assets() self._load_assets()
self._recover_from_desc() self._recover_from_desc()
self._generate_sign_attr() self._generate_sign_attr()
self._recovery_parameter(self.program) self._restore_parameter(self.program)
self._recover_variable_info(self.program) self._recover_variable_info(self.program)
def _init_with_signature(self, signatures): def _init_with_signature(self, signatures):
...@@ -228,7 +228,7 @@ class Module(object): ...@@ -228,7 +228,7 @@ class Module(object):
self.default_signature = sign self.default_signature = sign
self.signatures[sign.name] = sign self.signatures[sign.name] = sign
def _recovery_parameter(self, program): def _restore_parameter(self, program):
global_block = program.global_block() global_block = program.global_block()
param_attrs = self.desc.extra_info.map.data['param_attrs'] param_attrs = self.desc.extra_info.map.data['param_attrs']
for key, param_attr in param_attrs.map.data.items(): for key, param_attr in param_attrs.map.data.items():
...@@ -477,7 +477,7 @@ class Module(object): ...@@ -477,7 +477,7 @@ class Module(object):
if regularizer != "Default": if regularizer != "Default":
paddle_helper.set_parameter_regularizer(program, regularizer) paddle_helper.set_parameter_regularizer(program, regularizer)
self._recovery_parameter(program) self._restore_parameter(program)
self._recover_variable_info(program) self._recover_variable_info(program)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册