未验证 提交 060beb78 编写于 作者: C ceci3 提交者: GitHub

fix model mode in ofa (#803)

上级 a6f42259
...@@ -246,8 +246,6 @@ class OFA(OFABase): ...@@ -246,8 +246,6 @@ class OFA(OFABase):
self._add_teacher = True self._add_teacher = True
self._prepare_distill() self._prepare_distill()
self.model.train()
def _prepare_distill(self): def _prepare_distill(self):
if self.distill_config.teacher_model == None: if self.distill_config.teacher_model == None:
_logger.error( _logger.error(
...@@ -259,8 +257,9 @@ class OFA(OFABase): ...@@ -259,8 +257,9 @@ class OFA(OFABase):
# load teacher parameter # load teacher parameter
if self.distill_config.teacher_model_path != None: if self.distill_config.teacher_model_path != None:
param_state_dict, _ = paddle.load_dygraph( param_state_dict = self.distill_config.teacher_model_path if isinstance(
self.distill_config.teacher_model_path) self.distill_config.teacher_model_path,
dict) else paddle.load(self.distill_config.teacher_model_path)
self.distill_config.teacher_model.set_dict(param_state_dict) self.distill_config.teacher_model.set_dict(param_state_dict)
self.ofa_teacher_model = OFABase(self.distill_config.teacher_model) self.ofa_teacher_model = OFABase(self.distill_config.teacher_model)
......
...@@ -373,6 +373,7 @@ class TestOFACase2(TestOFA): ...@@ -373,6 +373,7 @@ class TestOFACase2(TestOFA):
self.data = paddle.to_tensor(data_np) self.data = paddle.to_tensor(data_np)
def init_config(self): def init_config(self):
teacher_model_state_dict = self.teacher_model.state_dict()
default_run_config = { default_run_config = {
'train_batch_size': 1, 'train_batch_size': 1,
'n_epochs': [[2, 5]], 'n_epochs': [[2, 5]],
...@@ -384,6 +385,7 @@ class TestOFACase2(TestOFA): ...@@ -384,6 +385,7 @@ class TestOFACase2(TestOFA):
default_distill_config = { default_distill_config = {
'teacher_model': self.teacher_model, 'teacher_model': self.teacher_model,
'mapping_layers': ['models.3.fn'], 'mapping_layers': ['models.3.fn'],
'teacher_model_path': teacher_model_state_dict
} }
self.distill_config = DistillConfig(**default_distill_config) self.distill_config = DistillConfig(**default_distill_config)
self.elastic_order = None self.elastic_order = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册