设计缺陷,classify head的后处理不返回输出,仅支持输出到文件
Created by: 131250208
trainer里的predict是可以不传入output_dir并将results返回的,但是predict函数是调用分类头里的后处理函数处理返回results的,而分类头里后处理函数却没有相应的处理。不传入output_dir会抛出异常。希望该函数在不传入output_dir时返回results 源码:PALM/paddlepalm/head/cls.py
def epoch_postprocess(self, post_inputs, output_dir=None):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if not self._is_training:
if output_dir is None:
raise ValueError('argument output_dir not found in config. Please add it into config dict/file.')
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
for i in range(len(self._preds)):
label = int(np.argmax(np.array(self._preds[i])))
result = {'index': i, 'label': label, 'logits': self._preds[i], 'probs': self._probs[i]}
result = json.dumps(result)
writer.write(result+'\n')
print('Predictions saved at '+os.path.join(output_dir, 'predictions.json'))
其他head的后处理函数也有相同的问题,建议都修改一下,谢谢。