提交 a7a825b4 编写于 作者: Q qingqing01 提交者: GitHub

Update parameter loadding (#3735)

上级 0b7f5d7f
...@@ -64,7 +64,7 @@ Advanced Features: ...@@ -64,7 +64,7 @@ Advanced Features:
## Get Started ## Get Started
- [Installation guide](docs/INSTALL.md) - [Installation guide](docs/INSTALL.md)
- [Quick Start on small dataset](docs/QUICK_STARTED.md) - [Quick start on small dataset](docs/QUICK_STARTED.md)
- [Guide to traing, evaluate and arguments description](docs/GETTING_STARTED.md) - [Guide to traing, evaluate and arguments description](docs/GETTING_STARTED.md)
- [Guide to preprocess pipeline and custom dataset](docs/DATA.md) - [Guide to preprocess pipeline and custom dataset](docs/DATA.md)
- [Introduction to the configuration workflow](docs/CONFIG.md) - [Introduction to the configuration workflow](docs/CONFIG.md)
......
...@@ -180,3 +180,7 @@ batch size可以达到每GPU 4 (Tesla V100 16GB)。 ...@@ -180,3 +180,7 @@ batch size可以达到每GPU 4 (Tesla V100 16GB)。
**Q:** 如何修改数据预处理? </br> **Q:** 如何修改数据预处理? </br>
**A:** 可在配置文件中设置 `sample_transform`。注意需要在配置文件中加入**完整预处理** **A:** 可在配置文件中设置 `sample_transform`。注意需要在配置文件中加入**完整预处理**
例如RCNN模型中`DecodeImage`, `NormalizeImage` and `Permute`。更多详细描述请参考[配置案例](config_example) 例如RCNN模型中`DecodeImage`, `NormalizeImage` and `Permute`。更多详细描述请参考[配置案例](config_example)
**Q:** affine_channel和batch norm是什么关系?
**A:** 在RCNN系列模型加载预训练模型初始化,有时候会固定住batch norm的参数, 使用预训练模型中的全局均值和方式,并且batch norm的scale和bias参数不更新,已发布的大多ResNet系列的RCNN模型采用这种方式。这种情况下可以在config中设置norm_type为bn或affine_channel, freeze_norm为true (默认为true),两种方式等价。affne_channel的计算方式为`scale * x + bias`。只不过设置affine_channel时,内部对batch norm的参数自动做了融合。如果训练使用的affine_channel,用保存的模型做初始化,训练其他任务时,即可使用affine_channel, 也可使用batch norm, 参数均可正确加载。
...@@ -177,7 +177,8 @@ def load_and_fusebn(exe, prog, path): ...@@ -177,7 +177,8 @@ def load_and_fusebn(exe, prog, path):
prog (fluid.Program): save weight from which Program object. prog (fluid.Program): save weight from which Program object.
path (string): the path to save model. path (string): the path to save model.
""" """
logger.info('Load model and fuse batch norm from {}...'.format(path)) logger.info('Load model and fuse batch norm if have from {}...'.format(
path))
if is_url(path): if is_url(path):
path = _get_weight_path(path) path = _get_weight_path(path)
...@@ -253,8 +254,11 @@ def load_and_fusebn(exe, prog, path): ...@@ -253,8 +254,11 @@ def load_and_fusebn(exe, prog, path):
[scale_name, bias_name, mean_name, variance_name]) [scale_name, bias_name, mean_name, variance_name])
if not bn_in_path: if not bn_in_path:
raise ValueError("There is no params of batch norm in model {}.".format( fluid.io.load_vars(exe, path, prog, vars=all_vars)
path)) logger.warning(
"There is no paramters of batch norm in model {}. "
"Skip to fuse batch norm. And load paramters done.".format(path))
return
# load running mean and running variance on cpu place into global scope. # load running mean and running variance on cpu place into global scope.
place = fluid.CPUPlace() place = fluid.CPUPlace()
......
...@@ -77,9 +77,6 @@ def main(): ...@@ -77,9 +77,6 @@ def main():
if 'log_iter' not in cfg: if 'log_iter' not in cfg:
cfg.log_iter = 20 cfg.log_iter = 20
ignore_params = cfg.finetune_exclude_pretrained_params \
if 'finetune_exclude_pretrained_params' in cfg else []
# check if set use_gpu=True in paddlepaddle cpu version # check if set use_gpu=True in paddlepaddle cpu version
check_gpu(cfg.use_gpu) check_gpu(cfg.use_gpu)
if not FLAGS.dist or trainer_id == 0: if not FLAGS.dist or trainer_id == 0:
...@@ -193,8 +190,11 @@ def main(): ...@@ -193,8 +190,11 @@ def main():
compiled_eval_prog = fluid.compiler.CompiledProgram(eval_prog) compiled_eval_prog = fluid.compiler.CompiledProgram(eval_prog)
fuse_bn = getattr(model.backbone, 'norm_type', None) == 'affine_channel' fuse_bn = getattr(model.backbone, 'norm_type', None) == 'affine_channel'
start_iter = 0
ignore_params = cfg.finetune_exclude_pretrained_params \
if 'finetune_exclude_pretrained_params' in cfg else []
start_iter = 0
if FLAGS.resume_checkpoint: if FLAGS.resume_checkpoint:
checkpoint.load_checkpoint(exe, train_prog, FLAGS.resume_checkpoint) checkpoint.load_checkpoint(exe, train_prog, FLAGS.resume_checkpoint)
start_iter = checkpoint.global_step() start_iter = checkpoint.global_step()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册