diff --git a/tests/common/test_op/matmul.py b/tests/common/test_op/matmul.py index 94afe11f61f197d47af32df61250e4c3fae0726f..0874d2aa4b814d93513930936b97d317e62280e6 100644 --- a/tests/common/test_op/matmul.py +++ b/tests/common/test_op/matmul.py @@ -21,6 +21,7 @@ from akg import backend as cce from akg.utils import kernel_exec as utils from akg.utils import custom_tiling as ct_util from akg.utils import validation_check as vc_util +from akg.ops.math import cast logging.basicConfig(level=logging.DEBUG) @@ -166,7 +167,7 @@ def matmul4D_compute(x, y, bias_value, out_dtype, left_format, right_format, out if adj_y: y_indices = indices[:(N - 4)] + (ko,) + indices[(N - 4):(N - 3)] + indices[(N - 1):] + (ki,) - return akg.lang.cce.mmad((x(*x_indices) * y(*y_indices)).astype(out_dtype), axis=[ko, ki]) + return akg.lang.cce.mmad((x(*x_indices) * y(*y_indices)).astype("float32"), axis=[ko, ki]) if left_format == "zZ": @@ -223,6 +224,8 @@ def matmul4D_compute(x, y, bias_value, out_dtype, left_format, right_format, out "bias": bias_name, }) + if out_dtype == "float16": + result_matmul = cast.cast(result_matmul, out_dtype) def matmul_reshape(shape, result_matmul, *indices): N = len(shape) diff --git a/tests/operators/cube/test_matmul4d_ad_001.py b/tests/operators/cube/test_matmul4d_ad_001.py deleted file mode 100644 index 4f0f6bacad269c8f438f14ff0538ece1a03d47f6..0000000000000000000000000000000000000000 --- a/tests/operators/cube/test_matmul4d_ad_001.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -matmul4d_ad -""" -import datetime -import os -import pytest -from base import TestBase -from nose.plugins.attrib import attr -from test_run.matmul4d_ad_run import matmul4d_ad_run - - -class TestCase(TestBase): - - def setup(self): - case_name = "test_akg_matmul_001" - case_path = os.getcwd() - self.params_init(case_name, case_path) - self.caseresult = True - self._log.info("============= {0} Setup case============".format(self.casename)) - self.testarg = [ - # caseflag, opfuncname, testRunArgs, dimArgs - # shape_x, shape_y, bias, bypass, adj_x, adj_y, dtype, out_dtype, kernel_name, attrs - ("matmul4d_ad_run_0", matmul4d_ad_run, ((64, 128), (128, 32), 0, False, False, - "float16", "float16", "matmul4d_ad_cce")), - ("matmul4d_ad_run_1", matmul4d_ad_run, ((64, 1024), (1024, 32), 0, False, False, - "float16", "float16", "matmul4d_ad_cce")), - ("matmul4d_ad_run_2", matmul4d_ad_run, ((1024, 64), (1024, 32), 0, True, False, - "float16", "float16", "matmul4d_ad_cce")), - ("matmul4d_ad_run_3", matmul4d_ad_run, ((64, 1024), (32, 1024), 0, False, True, - "float16", "float16", "matmul4d_ad_cce")), - ("matmul4d_ad_run_4", matmul4d_ad_run, ((1024, 64), (32, 1024), 0, True, True, - "float16", "float16", "matmul4d_ad_cce")), - ] - return - - @pytest.mark.rpc_mini - @pytest.mark.level0 - @pytest.mark.env_onecard - @pytest.mark.platform_x86_ascend_training - def test_run(self): - """ - run case.# - :return: - """ - self.common_run(self.testarg) - - def teardown(self): - """ - clean environment - :return: - """ - self._log.info("============= {0} Teardown============".format(self.casename)) - return - -if __name__ == "__main__": - a = TestCase() - a.setup() - a.test_run()