提交 ac27cb19 编写于 作者: G gaotingquan 提交者: Wei Shengyu

debug

上级 afc9b4c6
......@@ -101,8 +101,7 @@ class RecModel(TheseusLayer):
else:
self.head = None
def forward(self, batch):
x, label = batch[0], batch[1]
def forward(self, x, label=None):
out = dict()
x = self.backbone(x)
out["backbone"] = x
......@@ -154,8 +153,7 @@ class DistillationModel(nn.Layer):
load_dygraph_pretrain(
self.model_name_list[idx], path=pretrained)
def forward(self, batch):
x, label = batch[0], batch[1]
def forward(self, x, label=None):
result_dict = dict()
for idx, model_name in enumerate(self.model_name_list):
if label is None:
......@@ -173,8 +171,7 @@ class AttentionModel(DistillationModel):
**kargs):
super().__init__(models, pretrained_list, freeze_params_list, **kargs)
def forward(self, batch):
x, label = batch[0], batch[1]
def forward(self, x, label=None):
result_dict = dict()
out = x
for idx, model_name in enumerate(self.model_name_list):
......
......@@ -204,7 +204,7 @@ class ClassTrainer(object):
self.global_step += 1
if self.is_rec:
out = self.model(batch)
out = self.model(batch[0], batch[1])
else:
out = self.model(batch[0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册