From 392b75b1acac742b74e808353059d0281df26dcc Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Thu, 9 Mar 2023 17:37:14 +0000 Subject: [PATCH] revert for running --- ppcls/arch/backbone/legendary_models/mobilenet_v1.py | 1 - ppcls/engine/evaluation/classification.py | 2 +- ppcls/engine/train/classification.py | 12 +++++++++++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/ppcls/arch/backbone/legendary_models/mobilenet_v1.py b/ppcls/arch/backbone/legendary_models/mobilenet_v1.py index 499fad7f..0cdd0552 100644 --- a/ppcls/arch/backbone/legendary_models/mobilenet_v1.py +++ b/ppcls/arch/backbone/legendary_models/mobilenet_v1.py @@ -165,7 +165,6 @@ class MobileNet(TheseusLayer): return_stages=return_stages) @AMP_forward_decorator - @clas_forward_decorator def forward(self, x): x = self.conv(x) x = self.blocks(x) diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index a35ff1ea..2169f703 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -67,7 +67,7 @@ class ClassEval(object): if not self.config["Global"].get("use_multilabel", False): batch[1] = batch[1].reshape([-1, 1]).astype("int64") - out = self.model(batch) + out = self.model(batch[0]) # just for DistributedBatchSampler issue: repeat sampling current_samples = batch_size * paddle.distributed.get_world_size() diff --git a/ppcls/engine/train/classification.py b/ppcls/engine/train/classification.py index 735beacc..5ed04e2a 100644 --- a/ppcls/engine/train/classification.py +++ b/ppcls/engine/train/classification.py @@ -41,6 +41,12 @@ class ClassTrainer(object): # gradient accumulation self.update_freq = self.config["Global"].get("update_freq", 1) + if "Head" in self.config["Arch"] or self.config["Arch"].get("is_rec", + False): + self.is_rec = True + else: + self.is_rec = False + # TODO(gaotingquan): mv to build_model # build EMA model self.model_ema = self._build_ema_model() @@ -197,7 +203,11 @@ class ClassTrainer(object): batch[1] = batch[1].reshape([batch_size, -1]) self.global_step += 1 - out = self.model(batch) + if self.is_rec: + out = self.model(batch) + else: + out = self.model(batch[0]) + loss_dict = self.loss_func(out, batch[1]) # TODO(gaotingquan): mv update_freq to loss and optimizer loss = loss_dict["loss"] / self.update_freq -- GitLab