diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 4983a1eb0394f3bbd6aacf7936ed6aca1a7d4f54..c253b40aa58a32f2b8cbe89d902480a3a537e1e9 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -69,6 +69,8 @@ class Trainer(object): self.is_loaded_weights = False self.use_amp = self.cfg.get('amp', False) self.amp_level = self.cfg.get('amp_level', 'O1') + self.custom_white_list = self.cfg.get('custom_white_list', None) + self.custom_black_list = self.cfg.get('custom_black_list', None) # build data loader capital_mode = self.mode.capitalize() @@ -155,8 +157,10 @@ class Trainer(object): self.pruner = create('UnstructuredPruner')(self.model, steps_per_epoch) if self.use_amp and self.amp_level == 'O2': - self.model = paddle.amp.decorate( - models=self.model, level=self.amp_level) + self.model, self.optimizer = paddle.amp.decorate( + models=self.model, + optimizers=self.optimizer, + level=self.amp_level) self.use_ema = ('use_ema' in cfg and cfg['use_ema']) if self.use_ema: ema_decay = self.cfg.get('ema_decay', 0.9998) @@ -456,7 +460,9 @@ class Trainer(object): DataParallel) and use_fused_allreduce_gradients: with model.no_sync(): with paddle.amp.auto_cast( - enable=self.cfg.use_gpus, + enable=self.cfg.use_gpu, + custom_white_list=self.custom_white_list, + custom_black_list=self.custom_black_list, level=self.amp_level): # model forward outputs = model(data) @@ -468,7 +474,10 @@ class Trainer(object): list(model.parameters()), None) else: with paddle.amp.auto_cast( - enable=self.cfg.use_gpu, level=self.amp_level): + enable=self.cfg.use_gpu, + custom_white_list=self.custom_white_list, + custom_black_list=self.custom_black_list, + level=self.amp_level): # model forward outputs = model(data) loss = outputs['loss'] @@ -477,7 +486,6 @@ class Trainer(object): scaled_loss.backward() # in dygraph mode, optimizer.minimize is equal to optimizer.step scaler.minimize(self.optimizer, scaled_loss) - else: if isinstance( model, paddle. @@ -575,7 +583,10 @@ class Trainer(object): # forward if self.use_amp: with paddle.amp.auto_cast( - enable=self.cfg.use_gpu, level=self.amp_level): + enable=self.cfg.use_gpu, + custom_white_list=self.custom_white_list, + custom_black_list=self.custom_black_list, + level=self.amp_level): outs = self.model(data) else: outs = self.model(data) diff --git a/ppdet/optimizer/ema.py b/ppdet/optimizer/ema.py index e06a01ba0235647e79d6d8ffe3be009034b17e0d..bd8cb825ca0ecd33ca174acea7adb7ad37ba6185 100644 --- a/ppdet/optimizer/ema.py +++ b/ppdet/optimizer/ema.py @@ -66,7 +66,10 @@ class ModelEMA(object): def resume(self, state_dict, step=0): for k, v in state_dict.items(): if k in self.state_dict: - self.state_dict[k] = v + if self.state_dict[k].dtype == v.dtype: + self.state_dict[k] = v + else: + self.state_dict[k] = v.astype(self.state_dict[k].dtype) self.step = step def update(self, model=None): diff --git a/ppdet/utils/checkpoint.py b/ppdet/utils/checkpoint.py index e4325de8bb3988495fe90b3ab078805718408cbc..add087c890d4fbe82ecaec5635c19fc2c2090059 100644 --- a/ppdet/utils/checkpoint.py +++ b/ppdet/utils/checkpoint.py @@ -84,9 +84,14 @@ def load_weight(model, weight, optimizer=None, ema=None): model_weight = {} incorrect_keys = 0 - for key in model_dict.keys(): + for key, value in model_dict.items(): if key in param_state_dict.keys(): - model_weight[key] = param_state_dict[key] + if isinstance(param_state_dict[key], np.ndarray): + param_state_dict[key] = paddle.to_tensor(param_state_dict[key]) + if value.dtype == param_state_dict[key].dtype: + model_weight[key] = param_state_dict[key] + else: + model_weight[key] = param_state_dict[key].astype(value.dtype) else: logger.info('Unmatched key: {}'.format(key)) incorrect_keys += 1 @@ -209,6 +214,12 @@ def load_pretrain_weight(model, pretrain_weight): param_state_dict = paddle.load(weights_path) param_state_dict = match_state_dict(model_dict, param_state_dict) + for k, v in param_state_dict.items(): + if isinstance(v, np.ndarray): + v = paddle.to_tensor(v) + if model_dict[k].dtype != v.dtype: + param_state_dict[k] = v.astype(model_dict[k].dtype) + model.set_dict(param_state_dict) logger.info('Finish loading model weights: {}'.format(weights_path))