diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index fe069b1dee8dea8a54121229e26c3f188faa3c0c..fca3a82bbf05f6e2285e1ffe2b526c688d0e1af6 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -250,6 +250,8 @@ class Engine(object): self.scaler = paddle.amp.GradScaler( init_loss_scaling=self.scale_loss, use_dynamic_loss_scaling=self.use_dynamic_loss_scaling) + if self.config['AMP']['use_pure_fp16'] is True: + self.model = paddle.amp.decorate(models=self.model, level='O2') self.max_iter = len(self.train_dataloader) - 1 if platform.system( ) == "Windows" else len(self.train_dataloader) diff --git a/ppcls/engine/train/train.py b/ppcls/engine/train/train.py index cbf868e4e6d1d118b417568625c493afea6cd23a..b7fa9d3a060bfe6134bb7f42d8bb9926d03b73bc 100644 --- a/ppcls/engine/train/train.py +++ b/ppcls/engine/train/train.py @@ -21,6 +21,7 @@ from ppcls.utils import profiler def train_epoch(engine, epoch_id, print_batch_step): tic = time.time() + v_current = [int(i) for i in paddle.__version__.split(".")] for iter_id, batch in enumerate(engine.train_dataloader): if iter_id >= engine.max_iter: break @@ -41,14 +42,15 @@ def train_epoch(engine, epoch_id, print_batch_step): # image input if engine.amp: - with paddle.amp.auto_cast(custom_black_list={ - "flatten_contiguous_range", "greater_than" - }): + amp_level = 'O1' + if engine.config['AMP']['use_pure_fp16'] is True: + amp_level = 'O2' + with paddle.amp.auto_cast(custom_black_list={"flatten_contiguous_range", "greater_than"}, level=amp_level): out = forward(engine, batch) + loss_dict = engine.train_loss_func(out, batch[1]) else: out = forward(engine, batch) - - loss_dict = engine.train_loss_func(out, batch[1]) + loss_dict = engine.train_loss_func(out, batch[1]) # step opt and lr if engine.amp: diff --git a/ppcls/optimizer/optimizer.py b/ppcls/optimizer/optimizer.py index f429755fcc4fc189871526bae11639b83e870d05..4422ea70d32a3ed1ce89c33ec806e2035aa25420 100644 --- a/ppcls/optimizer/optimizer.py +++ b/ppcls/optimizer/optimizer.py @@ -17,6 +17,7 @@ from __future__ import division from __future__ import print_function from paddle import optimizer as optim +import paddle from ppcls.utils import logger @@ -36,7 +37,7 @@ class Momentum(object): momentum, weight_decay=None, grad_clip=None, - multi_precision=False): + multi_precision=True): super().__init__() self.learning_rate = learning_rate self.momentum = momentum @@ -55,6 +56,15 @@ class Momentum(object): grad_clip=self.grad_clip, multi_precision=self.multi_precision, parameters=parameters) + if hasattr(opt, '_use_multi_tensor'): + opt = optim.Momentum( + learning_rate=self.learning_rate, + momentum=self.momentum, + weight_decay=self.weight_decay, + grad_clip=self.grad_clip, + multi_precision=self.multi_precision, + parameters=parameters, + use_multi_tensor=True) return opt