未验证 提交 6b8b683e 编写于 作者: J Jiaqi Liu 提交者: GitHub

support dict batch in dynabert (#1064)

上级 bd4ecf05
......@@ -70,7 +70,11 @@ def compute_neuron_head_importance(task_name,
data_loader = (data_loader, )
for data in data_loader:
for batch in data:
input_ids, segment_ids, labels = batch
if isinstance(batch, dict):
input_ids, segment_ids, labels = batch['input_ids'], batch[
'token_type_ids'], batch['labels']
else:
input_ids, segment_ids, labels = batch
logits = model(
input_ids, segment_ids, attention_mask=[None, head_mask])
loss = loss_fct(logits, labels)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册