From 876732fbcc439b9a45b2b648506dca5d7e69ad0e Mon Sep 17 00:00:00 2001 From: Tingquan Gao <35441050@qq.com> Date: Tue, 14 Mar 2023 16:16:40 +0800 Subject: [PATCH] Revert "debug" This reverts commit ac27cb1917fcf9e7cb9ce2e817cec096369facab. --- ppcls/arch/__init__.py | 9 ++++++--- ppcls/engine/train/classification.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/ppcls/arch/__init__.py b/ppcls/arch/__init__.py index cf30b12c..78405398 100644 --- a/ppcls/arch/__init__.py +++ b/ppcls/arch/__init__.py @@ -101,7 +101,8 @@ class RecModel(TheseusLayer): else: self.head = None - def forward(self, x, label=None): + def forward(self, batch): + x, label = batch[0], batch[1] out = dict() x = self.backbone(x) out["backbone"] = x @@ -153,7 +154,8 @@ class DistillationModel(nn.Layer): load_dygraph_pretrain( self.model_name_list[idx], path=pretrained) - def forward(self, x, label=None): + def forward(self, batch): + x, label = batch[0], batch[1] result_dict = dict() for idx, model_name in enumerate(self.model_name_list): if label is None: @@ -171,7 +173,8 @@ class AttentionModel(DistillationModel): **kargs): super().__init__(models, pretrained_list, freeze_params_list, **kargs) - def forward(self, x, label=None): + def forward(self, batch): + x, label = batch[0], batch[1] result_dict = dict() out = x for idx, model_name in enumerate(self.model_name_list): diff --git a/ppcls/engine/train/classification.py b/ppcls/engine/train/classification.py index aa4ec795..6f84133b 100644 --- a/ppcls/engine/train/classification.py +++ b/ppcls/engine/train/classification.py @@ -204,7 +204,7 @@ class ClassTrainer(object): self.global_step += 1 if self.is_rec: - out = self.model(batch[0], batch[1]) + out = self.model(batch) else: out = self.model(batch[0]) -- GitLab