test_matmul4d_ad_001.py 2.7 KB
Newer Older
C
ckey_Dou 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
matmul4d_ad
"""
import datetime
import os
import pytest
from base import TestBase
from nose.plugins.attrib import attr
from test_run.matmul4d_ad_run import matmul4d_ad_run


class TestCase(TestBase):

    def setup(self):
        case_name = "test_akg_matmul_001"
        case_path = os.getcwd()
        self.params_init(case_name, case_path)
        self.caseresult = True
        self._log.info("============= {0} Setup case============".format(self.casename))
        self.testarg = [
            # caseflag, opfuncname, testRunArgs, dimArgs
            # shape_x, shape_y, bias, bypass, adj_x, adj_y, dtype, out_dtype, kernel_name, attrs
            ("matmul4d_ad_run_0", matmul4d_ad_run, ((64, 128), (128, 32), 0,   False, False,
                                                    "float16", "float16", "matmul4d_ad_cce")),
            ("matmul4d_ad_run_1", matmul4d_ad_run, ((64, 1024), (1024, 32), 0,   False, False,
                                                    "float16", "float16", "matmul4d_ad_cce")),
            ("matmul4d_ad_run_2", matmul4d_ad_run, ((1024, 64), (1024, 32), 0,  True, False,
                                                    "float16", "float16", "matmul4d_ad_cce")),
            ("matmul4d_ad_run_3", matmul4d_ad_run, ((64, 1024), (32, 1024), 0,   False, True,
                                                    "float16", "float16", "matmul4d_ad_cce")),
            ("matmul4d_ad_run_4", matmul4d_ad_run, ((1024, 64), (32, 1024), 0, True, True,
                                                    "float16", "float16", "matmul4d_ad_cce")),
        ]
        return

    @pytest.mark.rpc_mini
    @pytest.mark.level0
    @pytest.mark.env_onecard
    @pytest.mark.platform_x86_ascend_training
    def test_run(self):
        """
        run case.#
        :return:
        """
        self.common_run(self.testarg)

    def teardown(self):
        """
        clean environment
        :return:
        """
        self._log.info("============= {0} Teardown============".format(self.casename))
        return

if __name__ == "__main__":
    a = TestCase()
    a.setup()
    a.test_run()