From f37cb543b104feebf1695cb506dd4e39a0c7476b Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Wed, 29 Mar 2023 06:27:29 +0000 Subject: [PATCH] rm op black list in amp the op flatten_contiguous_range and greater_than has supported amp mode since paddle 2.4 --- ppcls/engine/engine.py | 6 +----- ppcls/engine/evaluation/classification.py | 12 ++---------- ppcls/engine/evaluation/retrieval.py | 6 +----- ppcls/engine/train/train.py | 6 +----- ppcls/engine/train/train_fixmatch.py | 6 +----- ppcls/engine/train/train_metabin.py | 6 +----- 6 files changed, 7 insertions(+), 35 deletions(-) diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index e79d3bd1..38d5b710 100755 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -504,11 +504,7 @@ class Engine(object): batch_tensor = paddle.to_tensor(batch_data) if self.amp and self.amp_eval: - with paddle.amp.auto_cast( - custom_black_list={ - "flatten_contiguous_range", "greater_than" - }, - level=self.amp_level): + with paddle.amp.auto_cast(level=self.amp_level): out = self.model(batch_tensor) else: out = self.model(batch_tensor) diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index 637b54f8..b9f3b870 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -56,11 +56,7 @@ def classification_eval(engine, epoch_id=0): # image input if engine.amp and engine.amp_eval: - with paddle.amp.auto_cast( - custom_black_list={ - "flatten_contiguous_range", "greater_than" - }, - level=engine.amp_level): + with paddle.amp.auto_cast(level=engine.amp_level): out = engine.model(batch[0]) else: out = engine.model(batch[0]) @@ -114,11 +110,7 @@ def classification_eval(engine, epoch_id=0): # calc loss if engine.eval_loss_func is not None: if engine.amp and engine.amp_eval: - with paddle.amp.auto_cast( - custom_black_list={ - "flatten_contiguous_range", "greater_than" - }, - level=engine.amp_level): + with paddle.amp.auto_cast(level=engine.amp_level): loss_dict = engine.eval_loss_func(preds, labels) else: loss_dict = engine.eval_loss_func(preds, labels) diff --git a/ppcls/engine/evaluation/retrieval.py b/ppcls/engine/evaluation/retrieval.py index 875a01c3..6644d57d 100644 --- a/ppcls/engine/evaluation/retrieval.py +++ b/ppcls/engine/evaluation/retrieval.py @@ -137,11 +137,7 @@ def compute_feature(engine, name="gallery"): has_camera = True batch[2] = batch[2].reshape([-1, 1]).astype("int64") if engine.amp and engine.amp_eval: - with paddle.amp.auto_cast( - custom_black_list={ - "flatten_contiguous_range", "greater_than" - }, - level=engine.amp_level): + with paddle.amp.auto_cast(level=engine.amp_level): out = engine.model(batch[0]) else: out = engine.model(batch[0]) diff --git a/ppcls/engine/train/train.py b/ppcls/engine/train/train.py index f5ec7a88..3c8035fc 100644 --- a/ppcls/engine/train/train.py +++ b/ppcls/engine/train/train.py @@ -50,11 +50,7 @@ def train_epoch(engine, epoch_id, print_batch_step): # image input if engine.amp: amp_level = engine.config["AMP"].get("level", "O1").upper() - with paddle.amp.auto_cast( - custom_black_list={ - "flatten_contiguous_range", "greater_than" - }, - level=amp_level): + with paddle.amp.auto_cast(level=amp_level): out = forward(engine, batch) loss_dict = engine.train_loss_func(out, batch[1]) else: diff --git a/ppcls/engine/train/train_fixmatch.py b/ppcls/engine/train/train_fixmatch.py index 20e38f9b..93839c89 100644 --- a/ppcls/engine/train/train_fixmatch.py +++ b/ppcls/engine/train/train_fixmatch.py @@ -64,11 +64,7 @@ def train_epoch_fixmatch(engine, epoch_id, print_batch_step): # image input if engine.amp: amp_level = engine.config['AMP'].get("level", "O1").upper() - with paddle.amp.auto_cast( - custom_black_list={ - "flatten_contiguous_range", "greater_than" - }, - level=amp_level): + with paddle.amp.auto_cast(level=amp_level): loss_dict, logits_label = get_loss( engine, inputs, batch_size_label, temperture, threshold, targets_x) diff --git a/ppcls/engine/train/train_metabin.py b/ppcls/engine/train/train_metabin.py index 25186f25..eed4c4d9 100644 --- a/ppcls/engine/train/train_metabin.py +++ b/ppcls/engine/train/train_metabin.py @@ -191,11 +191,7 @@ def forward(engine, batch, loss_func): batch_info = {"label": batch[1], "domain": batch[2]} if engine.amp: amp_level = engine.config["AMP"].get("level", "O1").upper() - with paddle.amp.auto_cast( - custom_black_list={ - "flatten_contiguous_range", "greater_than" - }, - level=amp_level): + with paddle.amp.auto_cast(level=amp_level): out = engine.model(batch[0], batch[1]) loss_dict = loss_func(out, batch_info) else: -- GitLab