未验证 提交 befeaeb5 编写于 作者: S shangliang Xu 提交者: GitHub

[dev] add white and black list for amp train (#6576)

上级 3e4d5697
...@@ -69,6 +69,8 @@ class Trainer(object): ...@@ -69,6 +69,8 @@ class Trainer(object):
self.is_loaded_weights = False self.is_loaded_weights = False
self.use_amp = self.cfg.get('amp', False) self.use_amp = self.cfg.get('amp', False)
self.amp_level = self.cfg.get('amp_level', 'O1') self.amp_level = self.cfg.get('amp_level', 'O1')
self.custom_white_list = self.cfg.get('custom_white_list', None)
self.custom_black_list = self.cfg.get('custom_black_list', None)
# build data loader # build data loader
capital_mode = self.mode.capitalize() capital_mode = self.mode.capitalize()
...@@ -155,8 +157,10 @@ class Trainer(object): ...@@ -155,8 +157,10 @@ class Trainer(object):
self.pruner = create('UnstructuredPruner')(self.model, self.pruner = create('UnstructuredPruner')(self.model,
steps_per_epoch) steps_per_epoch)
if self.use_amp and self.amp_level == 'O2': if self.use_amp and self.amp_level == 'O2':
self.model = paddle.amp.decorate( self.model, self.optimizer = paddle.amp.decorate(
models=self.model, level=self.amp_level) models=self.model,
optimizers=self.optimizer,
level=self.amp_level)
self.use_ema = ('use_ema' in cfg and cfg['use_ema']) self.use_ema = ('use_ema' in cfg and cfg['use_ema'])
if self.use_ema: if self.use_ema:
ema_decay = self.cfg.get('ema_decay', 0.9998) ema_decay = self.cfg.get('ema_decay', 0.9998)
...@@ -456,7 +460,9 @@ class Trainer(object): ...@@ -456,7 +460,9 @@ class Trainer(object):
DataParallel) and use_fused_allreduce_gradients: DataParallel) and use_fused_allreduce_gradients:
with model.no_sync(): with model.no_sync():
with paddle.amp.auto_cast( with paddle.amp.auto_cast(
enable=self.cfg.use_gpus, enable=self.cfg.use_gpu,
custom_white_list=self.custom_white_list,
custom_black_list=self.custom_black_list,
level=self.amp_level): level=self.amp_level):
# model forward # model forward
outputs = model(data) outputs = model(data)
...@@ -468,7 +474,10 @@ class Trainer(object): ...@@ -468,7 +474,10 @@ class Trainer(object):
list(model.parameters()), None) list(model.parameters()), None)
else: else:
with paddle.amp.auto_cast( with paddle.amp.auto_cast(
enable=self.cfg.use_gpu, level=self.amp_level): enable=self.cfg.use_gpu,
custom_white_list=self.custom_white_list,
custom_black_list=self.custom_black_list,
level=self.amp_level):
# model forward # model forward
outputs = model(data) outputs = model(data)
loss = outputs['loss'] loss = outputs['loss']
...@@ -477,7 +486,6 @@ class Trainer(object): ...@@ -477,7 +486,6 @@ class Trainer(object):
scaled_loss.backward() scaled_loss.backward()
# in dygraph mode, optimizer.minimize is equal to optimizer.step # in dygraph mode, optimizer.minimize is equal to optimizer.step
scaler.minimize(self.optimizer, scaled_loss) scaler.minimize(self.optimizer, scaled_loss)
else: else:
if isinstance( if isinstance(
model, paddle. model, paddle.
...@@ -575,7 +583,10 @@ class Trainer(object): ...@@ -575,7 +583,10 @@ class Trainer(object):
# forward # forward
if self.use_amp: if self.use_amp:
with paddle.amp.auto_cast( with paddle.amp.auto_cast(
enable=self.cfg.use_gpu, level=self.amp_level): enable=self.cfg.use_gpu,
custom_white_list=self.custom_white_list,
custom_black_list=self.custom_black_list,
level=self.amp_level):
outs = self.model(data) outs = self.model(data)
else: else:
outs = self.model(data) outs = self.model(data)
......
...@@ -66,7 +66,10 @@ class ModelEMA(object): ...@@ -66,7 +66,10 @@ class ModelEMA(object):
def resume(self, state_dict, step=0): def resume(self, state_dict, step=0):
for k, v in state_dict.items(): for k, v in state_dict.items():
if k in self.state_dict: if k in self.state_dict:
if self.state_dict[k].dtype == v.dtype:
self.state_dict[k] = v self.state_dict[k] = v
else:
self.state_dict[k] = v.astype(self.state_dict[k].dtype)
self.step = step self.step = step
def update(self, model=None): def update(self, model=None):
......
...@@ -84,9 +84,14 @@ def load_weight(model, weight, optimizer=None, ema=None): ...@@ -84,9 +84,14 @@ def load_weight(model, weight, optimizer=None, ema=None):
model_weight = {} model_weight = {}
incorrect_keys = 0 incorrect_keys = 0
for key in model_dict.keys(): for key, value in model_dict.items():
if key in param_state_dict.keys(): if key in param_state_dict.keys():
if isinstance(param_state_dict[key], np.ndarray):
param_state_dict[key] = paddle.to_tensor(param_state_dict[key])
if value.dtype == param_state_dict[key].dtype:
model_weight[key] = param_state_dict[key] model_weight[key] = param_state_dict[key]
else:
model_weight[key] = param_state_dict[key].astype(value.dtype)
else: else:
logger.info('Unmatched key: {}'.format(key)) logger.info('Unmatched key: {}'.format(key))
incorrect_keys += 1 incorrect_keys += 1
...@@ -209,6 +214,12 @@ def load_pretrain_weight(model, pretrain_weight): ...@@ -209,6 +214,12 @@ def load_pretrain_weight(model, pretrain_weight):
param_state_dict = paddle.load(weights_path) param_state_dict = paddle.load(weights_path)
param_state_dict = match_state_dict(model_dict, param_state_dict) param_state_dict = match_state_dict(model_dict, param_state_dict)
for k, v in param_state_dict.items():
if isinstance(v, np.ndarray):
v = paddle.to_tensor(v)
if model_dict[k].dtype != v.dtype:
param_state_dict[k] = v.astype(model_dict[k].dtype)
model.set_dict(param_state_dict) model.set_dict(param_state_dict)
logger.info('Finish loading model weights: {}'.format(weights_path)) logger.info('Finish loading model weights: {}'.format(weights_path))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册