diff --git a/ppcls/arch/backbone/legendary_models/mobilenet_v1.py b/ppcls/arch/backbone/legendary_models/mobilenet_v1.py index 499fad7f5fd836e8a1820b48344c66484abf339b..0cdd05529bd452aa236e210d96002c95bbea6be5 100644 --- a/ppcls/arch/backbone/legendary_models/mobilenet_v1.py +++ b/ppcls/arch/backbone/legendary_models/mobilenet_v1.py @@ -165,7 +165,6 @@ class MobileNet(TheseusLayer): return_stages=return_stages) @AMP_forward_decorator - @clas_forward_decorator def forward(self, x): x = self.conv(x) x = self.blocks(x) diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index a35ff1ea63d3f4a7c78497af3788f5ac4e9a1722..2169f703f17173f3a6368ee5ddec693104f9d715 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -67,7 +67,7 @@ class ClassEval(object): if not self.config["Global"].get("use_multilabel", False): batch[1] = batch[1].reshape([-1, 1]).astype("int64") - out = self.model(batch) + out = self.model(batch[0]) # just for DistributedBatchSampler issue: repeat sampling current_samples = batch_size * paddle.distributed.get_world_size() diff --git a/ppcls/engine/train/classification.py b/ppcls/engine/train/classification.py index 735beaccc4d8c307d3ce831f55f3d96bd5d03626..5ed04e2a231a154e7c662ebb94b263a8080417aa 100644 --- a/ppcls/engine/train/classification.py +++ b/ppcls/engine/train/classification.py @@ -41,6 +41,12 @@ class ClassTrainer(object): # gradient accumulation self.update_freq = self.config["Global"].get("update_freq", 1) + if "Head" in self.config["Arch"] or self.config["Arch"].get("is_rec", + False): + self.is_rec = True + else: + self.is_rec = False + # TODO(gaotingquan): mv to build_model # build EMA model self.model_ema = self._build_ema_model() @@ -197,7 +203,11 @@ class ClassTrainer(object): batch[1] = batch[1].reshape([batch_size, -1]) self.global_step += 1 - out = self.model(batch) + if self.is_rec: + out = self.model(batch) + else: + out = self.model(batch[0]) + loss_dict = self.loss_func(out, batch[1]) # TODO(gaotingquan): mv update_freq to loss and optimizer loss = loss_dict["loss"] / self.update_freq