From 6b8b683e3b1926128406d0470c21806518cd918e Mon Sep 17 00:00:00 2001 From: Jiaqi Liu <709153940@qq.com> Date: Fri, 22 Apr 2022 15:28:24 +0800 Subject: [PATCH] support dict batch in dynabert (#1064) --- paddleslim/nas/ofa/utils/nlp_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/paddleslim/nas/ofa/utils/nlp_utils.py b/paddleslim/nas/ofa/utils/nlp_utils.py index aeca95d0..effce089 100644 --- a/paddleslim/nas/ofa/utils/nlp_utils.py +++ b/paddleslim/nas/ofa/utils/nlp_utils.py @@ -70,7 +70,11 @@ def compute_neuron_head_importance(task_name, data_loader = (data_loader, ) for data in data_loader: for batch in data: - input_ids, segment_ids, labels = batch + if isinstance(batch, dict): + input_ids, segment_ids, labels = batch['input_ids'], batch[ + 'token_type_ids'], batch['labels'] + else: + input_ids, segment_ids, labels = batch logits = model( input_ids, segment_ids, attention_mask=[None, head_mask]) loss = loss_fct(logits, labels) -- GitLab