diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index e79d3bd11dba88e092e7ca094fc091f153f5974a..38d5b710c75de6cb63a6ed4effd26c5f4b762adc 100755 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -504,11 +504,7 @@ class Engine(object): batch_tensor = paddle.to_tensor(batch_data) if self.amp and self.amp_eval: - with paddle.amp.auto_cast( - custom_black_list={ - "flatten_contiguous_range", "greater_than" - }, - level=self.amp_level): + with paddle.amp.auto_cast(level=self.amp_level): out = self.model(batch_tensor) else: out = self.model(batch_tensor) diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index 637b54f8cb7844d3dcb7e4d73231b35123b9c2bc..b9f3b87087dbed525786a520e25f820b29ac1986 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -56,11 +56,7 @@ def classification_eval(engine, epoch_id=0): # image input if engine.amp and engine.amp_eval: - with paddle.amp.auto_cast( - custom_black_list={ - "flatten_contiguous_range", "greater_than" - }, - level=engine.amp_level): + with paddle.amp.auto_cast(level=engine.amp_level): out = engine.model(batch[0]) else: out = engine.model(batch[0]) @@ -114,11 +110,7 @@ def classification_eval(engine, epoch_id=0): # calc loss if engine.eval_loss_func is not None: if engine.amp and engine.amp_eval: - with paddle.amp.auto_cast( - custom_black_list={ - "flatten_contiguous_range", "greater_than" - }, - level=engine.amp_level): + with paddle.amp.auto_cast(level=engine.amp_level): loss_dict = engine.eval_loss_func(preds, labels) else: loss_dict = engine.eval_loss_func(preds, labels) diff --git a/ppcls/engine/evaluation/retrieval.py b/ppcls/engine/evaluation/retrieval.py index 875a01c3bffb1b9fe81e621ccdea6ca7162e0e05..6644d57d666dd1028aee844b565c9a4129dd22de 100644 --- a/ppcls/engine/evaluation/retrieval.py +++ b/ppcls/engine/evaluation/retrieval.py @@ -137,11 +137,7 @@ def compute_feature(engine, name="gallery"): has_camera = True batch[2] = batch[2].reshape([-1, 1]).astype("int64") if engine.amp and engine.amp_eval: - with paddle.amp.auto_cast( - custom_black_list={ - "flatten_contiguous_range", "greater_than" - }, - level=engine.amp_level): + with paddle.amp.auto_cast(level=engine.amp_level): out = engine.model(batch[0]) else: out = engine.model(batch[0]) diff --git a/ppcls/engine/train/train.py b/ppcls/engine/train/train.py index f5ec7a88df3c7c68a86c8d5f84e2d80ed653773a..3c8035fc11b951d8e087cf52c664d55c3210fcec 100644 --- a/ppcls/engine/train/train.py +++ b/ppcls/engine/train/train.py @@ -50,11 +50,7 @@ def train_epoch(engine, epoch_id, print_batch_step): # image input if engine.amp: amp_level = engine.config["AMP"].get("level", "O1").upper() - with paddle.amp.auto_cast( - custom_black_list={ - "flatten_contiguous_range", "greater_than" - }, - level=amp_level): + with paddle.amp.auto_cast(level=amp_level): out = forward(engine, batch) loss_dict = engine.train_loss_func(out, batch[1]) else: diff --git a/ppcls/engine/train/train_fixmatch.py b/ppcls/engine/train/train_fixmatch.py index 20e38f9bb99860c7526c5701204fe368a586652d..93839c89b53e472867f6ead24e4646bd6c18c318 100644 --- a/ppcls/engine/train/train_fixmatch.py +++ b/ppcls/engine/train/train_fixmatch.py @@ -64,11 +64,7 @@ def train_epoch_fixmatch(engine, epoch_id, print_batch_step): # image input if engine.amp: amp_level = engine.config['AMP'].get("level", "O1").upper() - with paddle.amp.auto_cast( - custom_black_list={ - "flatten_contiguous_range", "greater_than" - }, - level=amp_level): + with paddle.amp.auto_cast(level=amp_level): loss_dict, logits_label = get_loss( engine, inputs, batch_size_label, temperture, threshold, targets_x) diff --git a/ppcls/engine/train/train_metabin.py b/ppcls/engine/train/train_metabin.py index 25186f25a70e911425f8dfa4c43d4d28ae06b90b..eed4c4d9eb113f28afa6774864002055e42665c7 100644 --- a/ppcls/engine/train/train_metabin.py +++ b/ppcls/engine/train/train_metabin.py @@ -191,11 +191,7 @@ def forward(engine, batch, loss_func): batch_info = {"label": batch[1], "domain": batch[2]} if engine.amp: amp_level = engine.config["AMP"].get("level", "O1").upper() - with paddle.amp.auto_cast( - custom_black_list={ - "flatten_contiguous_range", "greater_than" - }, - level=amp_level): + with paddle.amp.auto_cast(level=amp_level): out = engine.model(batch[0], batch[1]) loss_dict = loss_func(out, batch_info) else: