提交 876732fb 编写于 作者: T Tingquan Gao

Revert "debug"

This reverts commit ac27cb19.
上级 97a8bb7f
......@@ -101,7 +101,8 @@ class RecModel(TheseusLayer):
else:
self.head = None
def forward(self, x, label=None):
def forward(self, batch):
x, label = batch[0], batch[1]
out = dict()
x = self.backbone(x)
out["backbone"] = x
......@@ -153,7 +154,8 @@ class DistillationModel(nn.Layer):
load_dygraph_pretrain(
self.model_name_list[idx], path=pretrained)
def forward(self, x, label=None):
def forward(self, batch):
x, label = batch[0], batch[1]
result_dict = dict()
for idx, model_name in enumerate(self.model_name_list):
if label is None:
......@@ -171,7 +173,8 @@ class AttentionModel(DistillationModel):
**kargs):
super().__init__(models, pretrained_list, freeze_params_list, **kargs)
def forward(self, x, label=None):
def forward(self, batch):
x, label = batch[0], batch[1]
result_dict = dict()
out = x
for idx, model_name in enumerate(self.model_name_list):
......
......@@ -204,7 +204,7 @@ class ClassTrainer(object):
self.global_step += 1
if self.is_rec:
out = self.model(batch[0], batch[1])
out = self.model(batch)
else:
out = self.model(batch[0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册