未验证 提交 060beb78 编写于 作者: C ceci3 提交者: GitHub

fix model mode in ofa (#803)

上级 a6f42259
......@@ -246,8 +246,6 @@ class OFA(OFABase):
self._add_teacher = True
self._prepare_distill()
self.model.train()
def _prepare_distill(self):
if self.distill_config.teacher_model == None:
_logger.error(
......@@ -259,8 +257,9 @@ class OFA(OFABase):
# load teacher parameter
if self.distill_config.teacher_model_path != None:
param_state_dict, _ = paddle.load_dygraph(
self.distill_config.teacher_model_path)
param_state_dict = self.distill_config.teacher_model_path if isinstance(
self.distill_config.teacher_model_path,
dict) else paddle.load(self.distill_config.teacher_model_path)
self.distill_config.teacher_model.set_dict(param_state_dict)
self.ofa_teacher_model = OFABase(self.distill_config.teacher_model)
......
......@@ -373,6 +373,7 @@ class TestOFACase2(TestOFA):
self.data = paddle.to_tensor(data_np)
def init_config(self):
teacher_model_state_dict = self.teacher_model.state_dict()
default_run_config = {
'train_batch_size': 1,
'n_epochs': [[2, 5]],
......@@ -384,6 +385,7 @@ class TestOFACase2(TestOFA):
default_distill_config = {
'teacher_model': self.teacher_model,
'mapping_layers': ['models.3.fn'],
'teacher_model_path': teacher_model_state_dict
}
self.distill_config = DistillConfig(**default_distill_config)
self.elastic_order = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册