提交 a4bc9da8 编写于 作者: 文幕地方's avatar 文幕地方

fix bug

上级 06194524
......@@ -83,7 +83,7 @@ Train:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
......@@ -122,7 +122,7 @@ Eval:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
......
......@@ -57,14 +57,16 @@ Loss:
mode: "l2"
model_name_pairs:
- ["Student", "Teacher"]
key: hidden_states_5
key: hidden_states
index: 5
name: "loss_5"
- DistillationVQADistanceLoss:
weight: 0.5
mode: "l2"
model_name_pairs:
- ["Student", "Teacher"]
key: hidden_states_8
key: hidden_states
index: 8
name: "loss_8"
......@@ -126,7 +128,7 @@ Train:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
......@@ -166,7 +168,7 @@ Eval:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
......
......@@ -417,11 +417,13 @@ class DistillationVQADistanceLoss(DistanceLoss):
mode="l2",
model_name_pairs=[],
key=None,
index=None,
name="loss_distance",
**kargs):
super().__init__(mode=mode, **kargs)
assert isinstance(model_name_pairs, list)
self.key = key
self.index = index
self.model_name_pairs = model_name_pairs
self.name = name + "_l2"
......@@ -434,6 +436,9 @@ class DistillationVQADistanceLoss(DistanceLoss):
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
if self.index is not None:
out1 = out1[:, self.index, :, :]
out2 = out2[:, self.index, :, :]
if attention_mask is not None:
max_len = attention_mask.shape[-1]
out1 = out1[:, :max_len]
......
......@@ -32,7 +32,7 @@ class VQAReTokenLayoutLMPostProcess(object):
return self._infer(pred_relations, *args, **kwargs)
def _metric(self, pred_relations, label):
return pred_relations, label[6], label[5]
return pred_relations, label[-1], label[-2]
def _infer(self, pred_relations, *args, **kwargs):
ser_results = kwargs['ser_results']
......
......@@ -64,7 +64,10 @@ class SerRePredictor(object):
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = dict(loss=outputs[0], pred_relations=outputs[1])
preds = dict(
loss=outputs[1],
pred_relations=outputs[2],
hidden_states=outputs[0], )
post_result = self.postprocess_op(
preds,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册