未验证 提交 18f80c7e 编写于 作者: C ceci3 提交者: GitHub

fix when lambda_distill is None in ofa (#550)

* fix

* fix when lambda is None
上级 e21933ff
...@@ -324,7 +324,9 @@ class OFA(OFABase): ...@@ -324,7 +324,9 @@ class OFA(OFABase):
else: else:
loss = distill_fn(Sact, Tact.detach()) loss = distill_fn(Sact, Tact.detach())
losses.append(loss) losses.append(loss)
return sum(losses) * self.distill_config.lambda_distill if self.distill_config.lambda_distill != None:
return sum(losses) * self.distill_config.lambda_distill
return sum(losses)
### TODO: complete it ### TODO: complete it
def search(self, eval_func, condition): def search(self, eval_func, condition):
......
...@@ -330,7 +330,6 @@ class TestOFACase2(TestOFA): ...@@ -330,7 +330,6 @@ class TestOFACase2(TestOFA):
} }
self.run_config = RunConfig(**default_run_config) self.run_config = RunConfig(**default_run_config)
default_distill_config = { default_distill_config = {
'lambda_distill': 0.01,
'teacher_model': self.teacher_model, 'teacher_model': self.teacher_model,
'mapping_layers': ['models.3.fn'], 'mapping_layers': ['models.3.fn'],
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册