提交 944380e9 编写于 作者: W wangzhuo325

use float32 for mad to improve precision

上级 9ed88ba9
......@@ -21,6 +21,7 @@ from akg import backend as cce
from akg.utils import kernel_exec as utils
from akg.utils import custom_tiling as ct_util
from akg.utils import validation_check as vc_util
from akg.ops.math import cast
logging.basicConfig(level=logging.DEBUG)
......@@ -166,7 +167,7 @@ def matmul4D_compute(x, y, bias_value, out_dtype, left_format, right_format, out
if adj_y:
y_indices = indices[:(N - 4)] + (ko,) + indices[(N - 4):(N - 3)] + indices[(N - 1):] + (ki,)
return akg.lang.cce.mmad((x(*x_indices) * y(*y_indices)).astype(out_dtype), axis=[ko, ki])
return akg.lang.cce.mmad((x(*x_indices) * y(*y_indices)).astype("float32"), axis=[ko, ki])
if left_format == "zZ":
......@@ -223,6 +224,8 @@ def matmul4D_compute(x, y, bias_value, out_dtype, left_format, right_format, out
"bias": bias_name,
})
if out_dtype == "float16":
result_matmul = cast.cast(result_matmul, out_dtype)
def matmul_reshape(shape, result_matmul, *indices):
N = len(shape)
......
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
matmul4d_ad
"""
import datetime
import os
import pytest
from base import TestBase
from nose.plugins.attrib import attr
from test_run.matmul4d_ad_run import matmul4d_ad_run
class TestCase(TestBase):
def setup(self):
case_name = "test_akg_matmul_001"
case_path = os.getcwd()
self.params_init(case_name, case_path)
self.caseresult = True
self._log.info("============= {0} Setup case============".format(self.casename))
self.testarg = [
# caseflag, opfuncname, testRunArgs, dimArgs
# shape_x, shape_y, bias, bypass, adj_x, adj_y, dtype, out_dtype, kernel_name, attrs
("matmul4d_ad_run_0", matmul4d_ad_run, ((64, 128), (128, 32), 0, False, False,
"float16", "float16", "matmul4d_ad_cce")),
("matmul4d_ad_run_1", matmul4d_ad_run, ((64, 1024), (1024, 32), 0, False, False,
"float16", "float16", "matmul4d_ad_cce")),
("matmul4d_ad_run_2", matmul4d_ad_run, ((1024, 64), (1024, 32), 0, True, False,
"float16", "float16", "matmul4d_ad_cce")),
("matmul4d_ad_run_3", matmul4d_ad_run, ((64, 1024), (32, 1024), 0, False, True,
"float16", "float16", "matmul4d_ad_cce")),
("matmul4d_ad_run_4", matmul4d_ad_run, ((1024, 64), (32, 1024), 0, True, True,
"float16", "float16", "matmul4d_ad_cce")),
]
return
@pytest.mark.rpc_mini
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_x86_ascend_training
def test_run(self):
"""
run case.#
:return:
"""
self.common_run(self.testarg)
def teardown(self):
"""
clean environment
:return:
"""
self._log.info("============= {0} Teardown============".format(self.casename))
return
if __name__ == "__main__":
a = TestCase()
a.setup()
a.test_run()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册