提交 876732fb 编写于 作者: T Tingquan Gao

Revert "debug"

This reverts commit ac27cb19.
上级 97a8bb7f
...@@ -101,7 +101,8 @@ class RecModel(TheseusLayer): ...@@ -101,7 +101,8 @@ class RecModel(TheseusLayer):
else: else:
self.head = None self.head = None
def forward(self, x, label=None): def forward(self, batch):
x, label = batch[0], batch[1]
out = dict() out = dict()
x = self.backbone(x) x = self.backbone(x)
out["backbone"] = x out["backbone"] = x
...@@ -153,7 +154,8 @@ class DistillationModel(nn.Layer): ...@@ -153,7 +154,8 @@ class DistillationModel(nn.Layer):
load_dygraph_pretrain( load_dygraph_pretrain(
self.model_name_list[idx], path=pretrained) self.model_name_list[idx], path=pretrained)
def forward(self, x, label=None): def forward(self, batch):
x, label = batch[0], batch[1]
result_dict = dict() result_dict = dict()
for idx, model_name in enumerate(self.model_name_list): for idx, model_name in enumerate(self.model_name_list):
if label is None: if label is None:
...@@ -171,7 +173,8 @@ class AttentionModel(DistillationModel): ...@@ -171,7 +173,8 @@ class AttentionModel(DistillationModel):
**kargs): **kargs):
super().__init__(models, pretrained_list, freeze_params_list, **kargs) super().__init__(models, pretrained_list, freeze_params_list, **kargs)
def forward(self, x, label=None): def forward(self, batch):
x, label = batch[0], batch[1]
result_dict = dict() result_dict = dict()
out = x out = x
for idx, model_name in enumerate(self.model_name_list): for idx, model_name in enumerate(self.model_name_list):
......
...@@ -204,7 +204,7 @@ class ClassTrainer(object): ...@@ -204,7 +204,7 @@ class ClassTrainer(object):
self.global_step += 1 self.global_step += 1
if self.is_rec: if self.is_rec:
out = self.model(batch[0], batch[1]) out = self.model(batch)
else: else:
out = self.model(batch[0]) out = self.model(batch[0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册