From befeaeb5424fcadaa70a2ff646a6a4b9c2ebf848 Mon Sep 17 00:00:00 2001 From: shangliang Xu Date: Thu, 4 Aug 2022 17:31:10 +0800 Subject: [PATCH] [dev] add white and black list for amp train (#6576) --- ppdet/engine/trainer.py | 23 +++++++++++++++++------ ppdet/optimizer/ema.py | 5 ++++- ppdet/utils/checkpoint.py | 15 +++++++++++++-- 3 files changed, 34 insertions(+), 9 deletions(-) diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 4983a1eb0..c253b40aa 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -69,6 +69,8 @@ class Trainer(object): self.is_loaded_weights = False self.use_amp = self.cfg.get('amp', False) self.amp_level = self.cfg.get('amp_level', 'O1') + self.custom_white_list = self.cfg.get('custom_white_list', None) + self.custom_black_list = self.cfg.get('custom_black_list', None) # build data loader capital_mode = self.mode.capitalize() @@ -155,8 +157,10 @@ class Trainer(object): self.pruner = create('UnstructuredPruner')(self.model, steps_per_epoch) if self.use_amp and self.amp_level == 'O2': - self.model = paddle.amp.decorate( - models=self.model, level=self.amp_level) + self.model, self.optimizer = paddle.amp.decorate( + models=self.model, + optimizers=self.optimizer, + level=self.amp_level) self.use_ema = ('use_ema' in cfg and cfg['use_ema']) if self.use_ema: ema_decay = self.cfg.get('ema_decay', 0.9998) @@ -456,7 +460,9 @@ class Trainer(object): DataParallel) and use_fused_allreduce_gradients: with model.no_sync(): with paddle.amp.auto_cast( - enable=self.cfg.use_gpus, + enable=self.cfg.use_gpu, + custom_white_list=self.custom_white_list, + custom_black_list=self.custom_black_list, level=self.amp_level): # model forward outputs = model(data) @@ -468,7 +474,10 @@ class Trainer(object): list(model.parameters()), None) else: with paddle.amp.auto_cast( - enable=self.cfg.use_gpu, level=self.amp_level): + enable=self.cfg.use_gpu, + custom_white_list=self.custom_white_list, + custom_black_list=self.custom_black_list, + level=self.amp_level): # model forward outputs = model(data) loss = outputs['loss'] @@ -477,7 +486,6 @@ class Trainer(object): scaled_loss.backward() # in dygraph mode, optimizer.minimize is equal to optimizer.step scaler.minimize(self.optimizer, scaled_loss) - else: if isinstance( model, paddle. @@ -575,7 +583,10 @@ class Trainer(object): # forward if self.use_amp: with paddle.amp.auto_cast( - enable=self.cfg.use_gpu, level=self.amp_level): + enable=self.cfg.use_gpu, + custom_white_list=self.custom_white_list, + custom_black_list=self.custom_black_list, + level=self.amp_level): outs = self.model(data) else: outs = self.model(data) diff --git a/ppdet/optimizer/ema.py b/ppdet/optimizer/ema.py index e06a01ba0..bd8cb825c 100644 --- a/ppdet/optimizer/ema.py +++ b/ppdet/optimizer/ema.py @@ -66,7 +66,10 @@ class ModelEMA(object): def resume(self, state_dict, step=0): for k, v in state_dict.items(): if k in self.state_dict: - self.state_dict[k] = v + if self.state_dict[k].dtype == v.dtype: + self.state_dict[k] = v + else: + self.state_dict[k] = v.astype(self.state_dict[k].dtype) self.step = step def update(self, model=None): diff --git a/ppdet/utils/checkpoint.py b/ppdet/utils/checkpoint.py index e4325de8b..add087c89 100644 --- a/ppdet/utils/checkpoint.py +++ b/ppdet/utils/checkpoint.py @@ -84,9 +84,14 @@ def load_weight(model, weight, optimizer=None, ema=None): model_weight = {} incorrect_keys = 0 - for key in model_dict.keys(): + for key, value in model_dict.items(): if key in param_state_dict.keys(): - model_weight[key] = param_state_dict[key] + if isinstance(param_state_dict[key], np.ndarray): + param_state_dict[key] = paddle.to_tensor(param_state_dict[key]) + if value.dtype == param_state_dict[key].dtype: + model_weight[key] = param_state_dict[key] + else: + model_weight[key] = param_state_dict[key].astype(value.dtype) else: logger.info('Unmatched key: {}'.format(key)) incorrect_keys += 1 @@ -209,6 +214,12 @@ def load_pretrain_weight(model, pretrain_weight): param_state_dict = paddle.load(weights_path) param_state_dict = match_state_dict(model_dict, param_state_dict) + for k, v in param_state_dict.items(): + if isinstance(v, np.ndarray): + v = paddle.to_tensor(v) + if model_dict[k].dtype != v.dtype: + param_state_dict[k] = v.astype(model_dict[k].dtype) + model.set_dict(param_state_dict) logger.info('Finish loading model weights: {}'.format(weights_path)) -- GitLab