diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 6b43d0d0eb36154ab84d45cd9b99c5b7e7e7a4cb..45963f8ba0c4ee82b39998a4a710bf0bf8bc347c 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -478,9 +478,9 @@ def export(net, *inputs, file_name, file_format='AIR'): supported_formats = ['AIR', 'ONNX', 'MINDIR'] if file_format not in supported_formats: raise ValueError(f'Illegal file format {file_format}, it must be one of {supported_formats}') - # switch network mode to infer when it is training - is_training = net.training - if is_training: + # When dumping ONNX file, switch network mode to infer when it is training(NOTE: ONNX only designed for prediction) + is_dump_onnx_in_training = net.training and file_format == 'ONNX' + if is_dump_onnx_in_training: net.set_train(mode=False) # export model net.init_parameters_data() @@ -503,7 +503,7 @@ def export(net, *inputs, file_name, file_format='AIR'): os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) f.write(onnx_stream) # restore network training mode - if is_training: + if is_dump_onnx_in_training: net.set_train(mode=True)