diff --git a/ppcls/arch/__init__.py b/ppcls/arch/__init__.py index cf30b12c2c16910a0443d14b1d2b1419a0fbfb99..78405398de5b11ceb598f9f10345c10880606772 100644 --- a/ppcls/arch/__init__.py +++ b/ppcls/arch/__init__.py @@ -101,7 +101,8 @@ class RecModel(TheseusLayer): else: self.head = None - def forward(self, x, label=None): + def forward(self, batch): + x, label = batch[0], batch[1] out = dict() x = self.backbone(x) out["backbone"] = x @@ -153,7 +154,8 @@ class DistillationModel(nn.Layer): load_dygraph_pretrain( self.model_name_list[idx], path=pretrained) - def forward(self, x, label=None): + def forward(self, batch): + x, label = batch[0], batch[1] result_dict = dict() for idx, model_name in enumerate(self.model_name_list): if label is None: @@ -171,7 +173,8 @@ class AttentionModel(DistillationModel): **kargs): super().__init__(models, pretrained_list, freeze_params_list, **kargs) - def forward(self, x, label=None): + def forward(self, batch): + x, label = batch[0], batch[1] result_dict = dict() out = x for idx, model_name in enumerate(self.model_name_list): diff --git a/ppcls/engine/train/classification.py b/ppcls/engine/train/classification.py index aa4ec795591732429f602f53595f3df2bfaf0fba..6f84133b302cb1ea8136a25a37e9dd60936c1400 100644 --- a/ppcls/engine/train/classification.py +++ b/ppcls/engine/train/classification.py @@ -204,7 +204,7 @@ class ClassTrainer(object): self.global_step += 1 if self.is_rec: - out = self.model(batch[0], batch[1]) + out = self.model(batch) else: out = self.model(batch[0])