diff --git a/paddleslim/nas/ofa/utils/nlp_utils.py b/paddleslim/nas/ofa/utils/nlp_utils.py index aeca95d0d5bdde4b7648f758233b44461b49c0f7..effce0890405650e3794634c0b35eb22eac6f3d4 100644 --- a/paddleslim/nas/ofa/utils/nlp_utils.py +++ b/paddleslim/nas/ofa/utils/nlp_utils.py @@ -70,7 +70,11 @@ def compute_neuron_head_importance(task_name, data_loader = (data_loader, ) for data in data_loader: for batch in data: - input_ids, segment_ids, labels = batch + if isinstance(batch, dict): + input_ids, segment_ids, labels = batch['input_ids'], batch[ + 'token_type_ids'], batch['labels'] + else: + input_ids, segment_ids, labels = batch logits = model( input_ids, segment_ids, attention_mask=[None, head_mask]) loss = loss_fct(logits, labels)