diff --git a/tools/train.py b/tools/train.py index 85c98eaddfe69c08e0e29921edcb1d26539b871f..cd6dd8c06ba5e38c491b1b28ea88c7276f4e70ed 100755 --- a/tools/train.py +++ b/tools/train.py @@ -93,8 +93,7 @@ def main(config, device, logger, vdl_writer, seed): 'DistillationSARLoss'][ 'ignore_index'] = char_num + 1 out_channels_list['SARLabelDecode'] = char_num + 2 - elif list(config['Loss']['loss_config_list'][-1].keys())[ - 0] == 'DistillationNRTRLoss': + elif any('DistillationNRTRLoss' in d for d in config['Loss']['loss_config_list']): out_channels_list['NRTRLabelDecode'] = char_num + 3 config['Architecture']['Models'][key]['Head'][