diff --git a/paddleslim/nas/ofa/utils/nlp_utils.py b/paddleslim/nas/ofa/utils/nlp_utils.py index 598b1a4e94eed740d228158937200e515cd6186d..1a61878a6cf97d27ede0e83f61621962f40fba25 100644 --- a/paddleslim/nas/ofa/utils/nlp_utils.py +++ b/paddleslim/nas/ofa/utils/nlp_utils.py @@ -66,21 +66,26 @@ def compute_neuron_head_importance(task_name, for w in intermediate_weight: neuron_importance.append(np.zeros(shape=[w.shape[1]], dtype='float32')) - for batch in data_loader: - input_ids, segment_ids, labels = batch - logits = model(input_ids, segment_ids, attention_mask=[None, head_mask]) - loss = loss_fct(logits, labels) - loss.backward() - head_importance += paddle.abs(paddle.to_tensor(head_mask.gradient())) - - for w1, b1, w2, current_importance in zip( - intermediate_weight, intermediate_bias, output_weight, - neuron_importance): - current_importance += np.abs( - (np.sum(w1.numpy() * w1.gradient(), axis=0) + b1.numpy() * - b1.gradient())) - current_importance += np.abs( - np.sum(w2.numpy() * w2.gradient(), axis=1)) + if task_name.lower() != 'mnli': + data_loader = (data_loader, ) + for data in data_loader: + for batch in data: + input_ids, segment_ids, labels = batch + logits = model( + input_ids, segment_ids, attention_mask=[None, head_mask]) + loss = loss_fct(logits, labels) + loss.backward() + head_importance += paddle.abs( + paddle.to_tensor(head_mask.gradient())) + + for w1, b1, w2, current_importance in zip( + intermediate_weight, intermediate_bias, output_weight, + neuron_importance): + current_importance += np.abs( + (np.sum(w1.numpy() * w1.gradient(), axis=0) + b1.numpy() * + b1.gradient())) + current_importance += np.abs( + np.sum(w2.numpy() * w2.gradient(), axis=1)) return head_importance, neuron_importance