未验证 提交 befeaeb5 编写于 作者: S shangliang Xu 提交者: GitHub

[dev] add white and black list for amp train (#6576)

上级 3e4d5697
......@@ -69,6 +69,8 @@ class Trainer(object):
self.is_loaded_weights = False
self.use_amp = self.cfg.get('amp', False)
self.amp_level = self.cfg.get('amp_level', 'O1')
self.custom_white_list = self.cfg.get('custom_white_list', None)
self.custom_black_list = self.cfg.get('custom_black_list', None)
# build data loader
capital_mode = self.mode.capitalize()
......@@ -155,8 +157,10 @@ class Trainer(object):
self.pruner = create('UnstructuredPruner')(self.model,
steps_per_epoch)
if self.use_amp and self.amp_level == 'O2':
self.model = paddle.amp.decorate(
models=self.model, level=self.amp_level)
self.model, self.optimizer = paddle.amp.decorate(
models=self.model,
optimizers=self.optimizer,
level=self.amp_level)
self.use_ema = ('use_ema' in cfg and cfg['use_ema'])
if self.use_ema:
ema_decay = self.cfg.get('ema_decay', 0.9998)
......@@ -456,7 +460,9 @@ class Trainer(object):
DataParallel) and use_fused_allreduce_gradients:
with model.no_sync():
with paddle.amp.auto_cast(
enable=self.cfg.use_gpus,
enable=self.cfg.use_gpu,
custom_white_list=self.custom_white_list,
custom_black_list=self.custom_black_list,
level=self.amp_level):
# model forward
outputs = model(data)
......@@ -468,7 +474,10 @@ class Trainer(object):
list(model.parameters()), None)
else:
with paddle.amp.auto_cast(
enable=self.cfg.use_gpu, level=self.amp_level):
enable=self.cfg.use_gpu,
custom_white_list=self.custom_white_list,
custom_black_list=self.custom_black_list,
level=self.amp_level):
# model forward
outputs = model(data)
loss = outputs['loss']
......@@ -477,7 +486,6 @@ class Trainer(object):
scaled_loss.backward()
# in dygraph mode, optimizer.minimize is equal to optimizer.step
scaler.minimize(self.optimizer, scaled_loss)
else:
if isinstance(
model, paddle.
......@@ -575,7 +583,10 @@ class Trainer(object):
# forward
if self.use_amp:
with paddle.amp.auto_cast(
enable=self.cfg.use_gpu, level=self.amp_level):
enable=self.cfg.use_gpu,
custom_white_list=self.custom_white_list,
custom_black_list=self.custom_black_list,
level=self.amp_level):
outs = self.model(data)
else:
outs = self.model(data)
......
......@@ -66,7 +66,10 @@ class ModelEMA(object):
def resume(self, state_dict, step=0):
for k, v in state_dict.items():
if k in self.state_dict:
if self.state_dict[k].dtype == v.dtype:
self.state_dict[k] = v
else:
self.state_dict[k] = v.astype(self.state_dict[k].dtype)
self.step = step
def update(self, model=None):
......
......@@ -84,9 +84,14 @@ def load_weight(model, weight, optimizer=None, ema=None):
model_weight = {}
incorrect_keys = 0
for key in model_dict.keys():
for key, value in model_dict.items():
if key in param_state_dict.keys():
if isinstance(param_state_dict[key], np.ndarray):
param_state_dict[key] = paddle.to_tensor(param_state_dict[key])
if value.dtype == param_state_dict[key].dtype:
model_weight[key] = param_state_dict[key]
else:
model_weight[key] = param_state_dict[key].astype(value.dtype)
else:
logger.info('Unmatched key: {}'.format(key))
incorrect_keys += 1
......@@ -209,6 +214,12 @@ def load_pretrain_weight(model, pretrain_weight):
param_state_dict = paddle.load(weights_path)
param_state_dict = match_state_dict(model_dict, param_state_dict)
for k, v in param_state_dict.items():
if isinstance(v, np.ndarray):
v = paddle.to_tensor(v)
if model_dict[k].dtype != v.dtype:
param_state_dict[k] = v.astype(model_dict[k].dtype)
model.set_dict(param_state_dict)
logger.info('Finish loading model weights: {}'.format(weights_path))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册