未验证 提交 a65904d8 编写于 作者: W Wenyu 提交者: GitHub

add ema filter no_grad (#7691)

* add ema filter no_grad

* update file name and default value
上级 1e21400e
......@@ -187,12 +187,14 @@ class Trainer(object):
ema_decay_type = self.cfg.get('ema_decay_type', 'threshold')
cycle_epoch = self.cfg.get('cycle_epoch', -1)
ema_black_list = self.cfg.get('ema_black_list', None)
ema_filter_no_grad = self.cfg.get('ema_filter_no_grad', False)
self.ema = ModelEMA(
self.model,
decay=ema_decay,
ema_decay_type=ema_decay_type,
cycle_epoch=cycle_epoch,
ema_black_list=ema_black_list)
ema_black_list=ema_black_list,
ema_filter_no_grad=ema_filter_no_grad)
self._nranks = dist.get_world_size()
self._local_rank = dist.get_rank()
......@@ -1040,6 +1042,7 @@ class Trainer(object):
start = end
return results
def _get_save_image_name(self, output_dir, image_path):
"""
Get save image name from source image path.
......
......@@ -21,6 +21,8 @@ import paddle
import weakref
from copy import deepcopy
from .utils import get_bn_running_state_names
__all__ = ['ModelEMA', 'SimpleModelEMA']
......@@ -49,7 +51,8 @@ class ModelEMA(object):
decay=0.9998,
ema_decay_type='threshold',
cycle_epoch=-1,
ema_black_list=None):
ema_black_list=None,
ema_filter_no_grad=False):
self.step = 0
self.epoch = 0
self.decay = decay
......@@ -64,6 +67,12 @@ class ModelEMA(object):
else:
self.state_dict[k] = paddle.zeros_like(v)
bn_states_names = get_bn_running_state_names(model)
if ema_filter_no_grad:
for n, p in model.named_parameters():
if p.stop_gradient == True and n not in bn_states_names:
self.ema_black_list.append(n)
self._model_state = {
k: weakref.ref(p)
for k, p in model.state_dict().items()
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
from typing import List
def get_bn_running_state_names(model: nn.Layer) -> List[str]:
"""Get all bn state full names including running mean and variance
"""
names = []
for n, m in model.named_sublayers():
if isinstance(m, (nn.BatchNorm2D, nn.SyncBatchNorm)):
assert hasattr(m, '_mean'), f'assert {m} has _mean'
assert hasattr(m, '_variance'), f'assert {m} has _variance'
running_mean = f'{n}._mean'
running_var = f'{n}._variance'
names.extend([running_mean, running_var])
return names
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册