# Copyright 2019 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """ Fuction: Test mindconverter to convert user's PyTorch network script. Usage: pytest tests/st/func/mindconverter """ import difflib import os import sys import pytest from mindinsight.mindconverter.converter import main @pytest.mark.usefixtures('create_output_dir') class TestConverter: """Test Converter module.""" @classmethod def setup_class(cls): """Setup method.""" cls.script_dir = os.path.join(os.path.dirname(__file__), 'data') sys.path.insert(0, cls.script_dir) @classmethod def teardown_class(cls): """Teardown method.""" sys.path.remove(cls.script_dir) @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_cpu @pytest.mark.env_single def test_convert_lenet(self, output): """Test LeNet script of the PyTorch convert to MindSpore script""" script_filename = "lenet_script.py" expect_filename = "lenet_converted.py" files_config = { 'root_path': self.script_dir, 'in_files': [os.path.join(self.script_dir, script_filename)], 'outfile_dir': output, 'report_dir': output } main(files_config) assert os.path.isfile(os.path.join(output, script_filename)) with open(os.path.join(output, script_filename)) as converted_f: converted_source = converted_f.readlines() with open(os.path.join(self.script_dir, expect_filename)) as expect_f: expect_source = expect_f.readlines() diff = difflib.ndiff(converted_source, expect_source) diff_lines = 0 for line in diff: if line.startswith('+'): diff_lines += 1 converted_ratio = 100 - (diff_lines * 100) / (len(expect_source)) assert converted_ratio >= 80