未验证 提交 aae6d797 编写于 作者: C ceci3 提交者: GitHub

Cherry pick fix (#563)

* fix nas (#553)

* fix when lambda_distill is None in ofa (#550)

* fix

* fix when lambda is None

* fix when the task is mnli (#540)

* fix

* fix when mnli

* Fix ofa (#536)

* fix

* fix embedding
Co-authored-by: NBai Yifan <me@ethanbai.com>
Co-authored-by: NBai Yifan <me@ethanbai.com>
上级 7aad9a73
...@@ -324,7 +324,9 @@ class OFA(OFABase): ...@@ -324,7 +324,9 @@ class OFA(OFABase):
else: else:
loss = distill_fn(Sact, Tact.detach()) loss = distill_fn(Sact, Tact.detach())
losses.append(loss) losses.append(loss)
return sum(losses) * self.distill_config.lambda_distill if self.distill_config.lambda_distill != None:
return sum(losses) * self.distill_config.lambda_distill
return sum(losses)
### TODO: complete it ### TODO: complete it
def search(self, eval_func, condition): def search(self, eval_func, condition):
......
...@@ -66,21 +66,26 @@ def compute_neuron_head_importance(task_name, ...@@ -66,21 +66,26 @@ def compute_neuron_head_importance(task_name,
for w in intermediate_weight: for w in intermediate_weight:
neuron_importance.append(np.zeros(shape=[w.shape[1]], dtype='float32')) neuron_importance.append(np.zeros(shape=[w.shape[1]], dtype='float32'))
for batch in data_loader: if task_name.lower() != 'mnli':
input_ids, segment_ids, labels = batch data_loader = (data_loader, )
logits = model(input_ids, segment_ids, attention_mask=[None, head_mask]) for data in data_loader:
loss = loss_fct(logits, labels) for batch in data:
loss.backward() input_ids, segment_ids, labels = batch
head_importance += paddle.abs(paddle.to_tensor(head_mask.gradient())) logits = model(
input_ids, segment_ids, attention_mask=[None, head_mask])
for w1, b1, w2, current_importance in zip( loss = loss_fct(logits, labels)
intermediate_weight, intermediate_bias, output_weight, loss.backward()
neuron_importance): head_importance += paddle.abs(
current_importance += np.abs( paddle.to_tensor(head_mask.gradient()))
(np.sum(w1.numpy() * w1.gradient(), axis=0) + b1.numpy() *
b1.gradient())) for w1, b1, w2, current_importance in zip(
current_importance += np.abs( intermediate_weight, intermediate_bias, output_weight,
np.sum(w2.numpy() * w2.gradient(), axis=1)) neuron_importance):
current_importance += np.abs(
(np.sum(w1.numpy() * w1.gradient(), axis=0) + b1.numpy() *
b1.gradient()))
current_importance += np.abs(
np.sum(w2.numpy() * w2.gradient(), axis=1))
return head_importance, neuron_importance return head_importance, neuron_importance
......
...@@ -75,6 +75,7 @@ class RLNAS(object): ...@@ -75,6 +75,7 @@ class RLNAS(object):
self.range_tables = self._search_space.range_table() self.range_tables = self._search_space.range_table()
self.save_controller = save_controller self.save_controller = save_controller
self.load_controller = load_controller self.load_controller = load_controller
self._is_server = is_server
if key.upper() in ['DDPG']: if key.upper() in ['DDPG']:
try: try:
......
...@@ -22,7 +22,6 @@ import time ...@@ -22,7 +22,6 @@ import time
import paddle.fluid as fluid import paddle.fluid as fluid
from ..common import SAController from ..common import SAController
from ..common import get_logger from ..common import get_logger
from ..analysis import flops
from ..common import ControllerServer from ..common import ControllerServer
from ..common import ControllerClient from ..common import ControllerClient
......
...@@ -330,7 +330,6 @@ class TestOFACase2(TestOFA): ...@@ -330,7 +330,6 @@ class TestOFACase2(TestOFA):
} }
self.run_config = RunConfig(**default_run_config) self.run_config = RunConfig(**default_run_config)
default_distill_config = { default_distill_config = {
'lambda_distill': 0.01,
'teacher_model': self.teacher_model, 'teacher_model': self.teacher_model,
'mapping_layers': ['models.3.fn'], 'mapping_layers': ['models.3.fn'],
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册