提交 392b75b1 编写于 作者: G gaotingquan 提交者: Wei Shengyu

revert for running

上级 9beb154b
......@@ -165,7 +165,6 @@ class MobileNet(TheseusLayer):
return_stages=return_stages)
@AMP_forward_decorator
@clas_forward_decorator
def forward(self, x):
x = self.conv(x)
x = self.blocks(x)
......
......@@ -67,7 +67,7 @@ class ClassEval(object):
if not self.config["Global"].get("use_multilabel", False):
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
out = self.model(batch)
out = self.model(batch[0])
# just for DistributedBatchSampler issue: repeat sampling
current_samples = batch_size * paddle.distributed.get_world_size()
......
......@@ -41,6 +41,12 @@ class ClassTrainer(object):
# gradient accumulation
self.update_freq = self.config["Global"].get("update_freq", 1)
if "Head" in self.config["Arch"] or self.config["Arch"].get("is_rec",
False):
self.is_rec = True
else:
self.is_rec = False
# TODO(gaotingquan): mv to build_model
# build EMA model
self.model_ema = self._build_ema_model()
......@@ -197,7 +203,11 @@ class ClassTrainer(object):
batch[1] = batch[1].reshape([batch_size, -1])
self.global_step += 1
if self.is_rec:
out = self.model(batch)
else:
out = self.model(batch[0])
loss_dict = self.loss_func(out, batch[1])
# TODO(gaotingquan): mv update_freq to loss and optimizer
loss = loss_dict["loss"] / self.update_freq
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册