未验证 提交 009b0a9f 编写于 作者: Z zhouzj 提交者: GitHub

Add lr scheduler for auto_compress (#1127)

* Add lr scheduler.

* Add lr scheduler.

* add notes for lr scheduler.

* add notes for lr scheduler.
上级 78332fce
......@@ -116,7 +116,13 @@ TrainConfig:
optim_args:
weight_decay: 0.0005
```
- 学习率衰减策略:主要设置策略类名和策略参数,如下所示。目前在paddle中已经实现了多种衰减策略,请参考[lr文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/2.2/api/paddle/optimizer/lr/LRScheduler_cn.html),策略参数即类初始化参数。
```yaml
learning_rate:
type: PiecewiseDecay # 学习率衰减策略类名
boundaries: [4500] # 设置策略参数
values: [0.005, 0.0005] # 设置策略参数
```
## 其他参数配置
#### 1.自动蒸馏效果不理想,怎么自主选择蒸馏节点?
......
......@@ -494,7 +494,8 @@ class AutoCompression:
np_probs_float, = self._exe.run(train_program_info.program, \
feed=data, \
fetch_list=train_program_info.fetch_targets)
if not isinstance(train_program_info.learning_rate, float):
train_program_info.learning_rate.step()
if 'unstructure' in strategy:
self._pruner.step()
......
......@@ -30,6 +30,18 @@ __all__ = [
]
def _create_lr_scheduler(train_config):
if 'learning_rate' not in train_config:
raise RuntimeError(
'No `learning_rate` specified in the configuration file.')
if isinstance(train_config.get('learning_rate'), float):
return train_config.get('learning_rate')
params = train_config.get('learning_rate')
lr_type = params.pop('type')
return getattr(optimizer.lr, lr_type)(**params)
def _create_optimizer(train_config):
"""create optimizer"""
opt = getattr(optimizer, train_config.get('optimizer') or
......@@ -54,10 +66,11 @@ def _create_optimizer(train_config):
train_config['optim_args'] = {}
grad_clip = None
op = opt(learning_rate=train_config["learning_rate"],
lr = _create_lr_scheduler(train_config)
op = opt(learning_rate=lr,
grad_clip=grad_clip,
**train_config['optim_args'])
return op
return op, lr
def _parse_distill_loss(distill_node_pair,
......@@ -223,7 +236,7 @@ def build_distill_program(executor,
train_fetch_list = []
with paddle.static.program_guard(train_program, startup_program):
with paddle.utils.unique_name.guard('merge'):
optimizer = _create_optimizer(train_config)
optimizer, learning_rate = _create_optimizer(train_config)
if train_config.get('use_fleet'):
optimizer = fleet.distributed_optimizer(optimizer,
......@@ -277,7 +290,7 @@ def build_distill_program(executor,
train_program_info = ProgramInfo(startup_program, train_program,
feed_target_names, train_fetch_list,
optimizer)
optimizer, learning_rate)
test_program_info = ProgramInfo(startup_program, test_program,
feed_target_names, fetch_targets)
return train_program_info, test_program_info
......
......@@ -128,9 +128,11 @@ class ProgramInfo:
program,
feed_target_names,
fetch_targets,
optimizer=None):
optimizer=None,
learning_rate=None):
self.startup_program = startup_program
self.program = program
self.feed_target_names = feed_target_names
self.fetch_targets = fetch_targets
self.optimizer = optimizer
self.learning_rate = learning_rate
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册