未验证 提交 63f5787b 编写于 作者: X xlg-go 提交者: GitHub

ch_PP-OCRv4_rec_distill.yml, fix KeyError: 'NRTRLabelDecode' (#10761)

上级 e3cd3433
......@@ -93,8 +93,7 @@ def main(config, device, logger, vdl_writer, seed):
'DistillationSARLoss'][
'ignore_index'] = char_num + 1
out_channels_list['SARLabelDecode'] = char_num + 2
elif list(config['Loss']['loss_config_list'][-1].keys())[
0] == 'DistillationNRTRLoss':
elif any('DistillationNRTRLoss' in d for d in config['Loss']['loss_config_list']):
out_channels_list['NRTRLabelDecode'] = char_num + 3
config['Architecture']['Models'][key]['Head'][
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册