Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
akg
提交
944380e9
A
akg
项目概览
MindSpore
/
akg
通知
58
Star
7
Fork
7
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
A
akg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
944380e9
编写于
6月 20, 2020
作者:
W
wangzhuo325
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
use float32 for mad to improve precision
上级
9ed88ba9
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
4 addition
and
73 deletion
+4
-73
tests/common/test_op/matmul.py
tests/common/test_op/matmul.py
+4
-1
tests/operators/cube/test_matmul4d_ad_001.py
tests/operators/cube/test_matmul4d_ad_001.py
+0
-72
未找到文件。
tests/common/test_op/matmul.py
浏览文件 @
944380e9
...
...
@@ -21,6 +21,7 @@ from akg import backend as cce
from
akg.utils
import
kernel_exec
as
utils
from
akg.utils
import
custom_tiling
as
ct_util
from
akg.utils
import
validation_check
as
vc_util
from
akg.ops.math
import
cast
logging
.
basicConfig
(
level
=
logging
.
DEBUG
)
...
...
@@ -166,7 +167,7 @@ def matmul4D_compute(x, y, bias_value, out_dtype, left_format, right_format, out
if
adj_y
:
y_indices
=
indices
[:(
N
-
4
)]
+
(
ko
,)
+
indices
[(
N
-
4
):(
N
-
3
)]
+
indices
[(
N
-
1
):]
+
(
ki
,)
return
akg
.
lang
.
cce
.
mmad
((
x
(
*
x_indices
)
*
y
(
*
y_indices
)).
astype
(
out_dtype
),
axis
=
[
ko
,
ki
])
return
akg
.
lang
.
cce
.
mmad
((
x
(
*
x_indices
)
*
y
(
*
y_indices
)).
astype
(
"float32"
),
axis
=
[
ko
,
ki
])
if
left_format
==
"zZ"
:
...
...
@@ -223,6 +224,8 @@ def matmul4D_compute(x, y, bias_value, out_dtype, left_format, right_format, out
"bias"
:
bias_name
,
})
if
out_dtype
==
"float16"
:
result_matmul
=
cast
.
cast
(
result_matmul
,
out_dtype
)
def
matmul_reshape
(
shape
,
result_matmul
,
*
indices
):
N
=
len
(
shape
)
...
...
tests/operators/cube/test_matmul4d_ad_001.py
已删除
100644 → 0
浏览文件 @
9ed88ba9
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
matmul4d_ad
"""
import
datetime
import
os
import
pytest
from
base
import
TestBase
from
nose.plugins.attrib
import
attr
from
test_run.matmul4d_ad_run
import
matmul4d_ad_run
class
TestCase
(
TestBase
):
def
setup
(
self
):
case_name
=
"test_akg_matmul_001"
case_path
=
os
.
getcwd
()
self
.
params_init
(
case_name
,
case_path
)
self
.
caseresult
=
True
self
.
_log
.
info
(
"============= {0} Setup case============"
.
format
(
self
.
casename
))
self
.
testarg
=
[
# caseflag, opfuncname, testRunArgs, dimArgs
# shape_x, shape_y, bias, bypass, adj_x, adj_y, dtype, out_dtype, kernel_name, attrs
(
"matmul4d_ad_run_0"
,
matmul4d_ad_run
,
((
64
,
128
),
(
128
,
32
),
0
,
False
,
False
,
"float16"
,
"float16"
,
"matmul4d_ad_cce"
)),
(
"matmul4d_ad_run_1"
,
matmul4d_ad_run
,
((
64
,
1024
),
(
1024
,
32
),
0
,
False
,
False
,
"float16"
,
"float16"
,
"matmul4d_ad_cce"
)),
(
"matmul4d_ad_run_2"
,
matmul4d_ad_run
,
((
1024
,
64
),
(
1024
,
32
),
0
,
True
,
False
,
"float16"
,
"float16"
,
"matmul4d_ad_cce"
)),
(
"matmul4d_ad_run_3"
,
matmul4d_ad_run
,
((
64
,
1024
),
(
32
,
1024
),
0
,
False
,
True
,
"float16"
,
"float16"
,
"matmul4d_ad_cce"
)),
(
"matmul4d_ad_run_4"
,
matmul4d_ad_run
,
((
1024
,
64
),
(
32
,
1024
),
0
,
True
,
True
,
"float16"
,
"float16"
,
"matmul4d_ad_cce"
)),
]
return
@
pytest
.
mark
.
rpc_mini
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
platform_x86_ascend_training
def
test_run
(
self
):
"""
run case.#
:return:
"""
self
.
common_run
(
self
.
testarg
)
def
teardown
(
self
):
"""
clean environment
:return:
"""
self
.
_log
.
info
(
"============= {0} Teardown============"
.
format
(
self
.
casename
))
return
if
__name__
==
"__main__"
:
a
=
TestCase
()
a
.
setup
()
a
.
test_run
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录