未验证 提交 aae6d797 编写于 作者: C ceci3 提交者: GitHub

Cherry pick fix (#563)

* fix nas (#553)

* fix when lambda_distill is None in ofa (#550)

* fix

* fix when lambda is None

* fix when the task is mnli (#540)

* fix

* fix when mnli

* Fix ofa (#536)

* fix

* fix embedding
Co-authored-by: NBai Yifan <me@ethanbai.com>
Co-authored-by: NBai Yifan <me@ethanbai.com>
上级 7aad9a73
......@@ -324,7 +324,9 @@ class OFA(OFABase):
else:
loss = distill_fn(Sact, Tact.detach())
losses.append(loss)
return sum(losses) * self.distill_config.lambda_distill
if self.distill_config.lambda_distill != None:
return sum(losses) * self.distill_config.lambda_distill
return sum(losses)
### TODO: complete it
def search(self, eval_func, condition):
......
......@@ -66,21 +66,26 @@ def compute_neuron_head_importance(task_name,
for w in intermediate_weight:
neuron_importance.append(np.zeros(shape=[w.shape[1]], dtype='float32'))
for batch in data_loader:
input_ids, segment_ids, labels = batch
logits = model(input_ids, segment_ids, attention_mask=[None, head_mask])
loss = loss_fct(logits, labels)
loss.backward()
head_importance += paddle.abs(paddle.to_tensor(head_mask.gradient()))
for w1, b1, w2, current_importance in zip(
intermediate_weight, intermediate_bias, output_weight,
neuron_importance):
current_importance += np.abs(
(np.sum(w1.numpy() * w1.gradient(), axis=0) + b1.numpy() *
b1.gradient()))
current_importance += np.abs(
np.sum(w2.numpy() * w2.gradient(), axis=1))
if task_name.lower() != 'mnli':
data_loader = (data_loader, )
for data in data_loader:
for batch in data:
input_ids, segment_ids, labels = batch
logits = model(
input_ids, segment_ids, attention_mask=[None, head_mask])
loss = loss_fct(logits, labels)
loss.backward()
head_importance += paddle.abs(
paddle.to_tensor(head_mask.gradient()))
for w1, b1, w2, current_importance in zip(
intermediate_weight, intermediate_bias, output_weight,
neuron_importance):
current_importance += np.abs(
(np.sum(w1.numpy() * w1.gradient(), axis=0) + b1.numpy() *
b1.gradient()))
current_importance += np.abs(
np.sum(w2.numpy() * w2.gradient(), axis=1))
return head_importance, neuron_importance
......
......@@ -75,6 +75,7 @@ class RLNAS(object):
self.range_tables = self._search_space.range_table()
self.save_controller = save_controller
self.load_controller = load_controller
self._is_server = is_server
if key.upper() in ['DDPG']:
try:
......
......@@ -22,7 +22,6 @@ import time
import paddle.fluid as fluid
from ..common import SAController
from ..common import get_logger
from ..analysis import flops
from ..common import ControllerServer
from ..common import ControllerClient
......
......@@ -330,7 +330,6 @@ class TestOFACase2(TestOFA):
}
self.run_config = RunConfig(**default_run_config)
default_distill_config = {
'lambda_distill': 0.01,
'teacher_model': self.teacher_model,
'mapping_layers': ['models.3.fn'],
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册