Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
akg
提交
d237aa7d
A
akg
项目概览
MindSpore
/
akg
通知
52
Star
7
Fork
7
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
A
akg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
d237aa7d
编写于
8月 20, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 20, 2020
浏览文件
操作
浏览文件
下载
差异文件
!111 dump cuda_meta inside akg.build instead of after akg.build
Merge pull request !111 from looop5/dump_cuda_meta
上级
be2686db
075bd7e4
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
70 addition
and
79 deletion
+70
-79
python/akg/__init__.py
python/akg/__init__.py
+1
-0
python/akg/composite/build_module.py
python/akg/composite/build_module.py
+0
-2
python/akg/utils/dump_cuda_meta.py
python/akg/utils/dump_cuda_meta.py
+48
-75
python/akg/utils/kernel_exec.py
python/akg/utils/kernel_exec.py
+0
-2
third_party/incubator-tvm/src/codegen/opt/build_cuda_on.cc
third_party/incubator-tvm/src/codegen/opt/build_cuda_on.cc
+21
-0
未找到文件。
python/akg/__init__.py
浏览文件 @
d237aa7d
...
@@ -81,5 +81,6 @@ from .autodiff import get_variables
...
@@ -81,5 +81,6 @@ from .autodiff import get_variables
from
.autodiff
import
register_variables
from
.autodiff
import
register_variables
from
.lang.cce.te_compute.common
import
fargmax
,
fargmin
,
mad
from
.lang.cce.te_compute.common
import
fargmax
,
fargmin
,
mad
from
.
import
lang
from
.
import
lang
from
.utils.dump_cuda_meta
import
dump_cuda_meta
__all__
=
[
"differentiate"
]
__all__
=
[
"differentiate"
]
python/akg/composite/build_module.py
浏览文件 @
d237aa7d
...
@@ -22,7 +22,6 @@ from akg import tvm
...
@@ -22,7 +22,6 @@ from akg import tvm
from
akg.tvm
import
_api_internal
from
akg.tvm
import
_api_internal
from
.repository
import
__all__
as
repository
from
.repository
import
__all__
as
repository
import
topi
import
topi
from
akg.utils
import
dump_cuda_meta
def
generate_trait
(
desc
):
def
generate_trait
(
desc
):
""" generate trait of kernel description """
""" generate trait of kernel description """
...
@@ -181,5 +180,4 @@ def build_cuda(outputs, args, sch_name, kernel_name):
...
@@ -181,5 +180,4 @@ def build_cuda(outputs, args, sch_name, kernel_name):
dump_ir
=
os
.
getenv
(
'MS_AKG_DUMP_IR'
)
==
"on"
dump_ir
=
os
.
getenv
(
'MS_AKG_DUMP_IR'
)
==
"on"
with
tvm
.
build_config
(
dump_pass_ir
=
dump_ir
):
with
tvm
.
build_config
(
dump_pass_ir
=
dump_ir
):
mod
=
akg
.
build
(
s
,
list
(
args
),
"cuda"
,
name
=
kernel_name
)
mod
=
akg
.
build
(
s
,
list
(
args
),
"cuda"
,
name
=
kernel_name
)
dump_cuda_meta
.
dump
(
mod
,
kernel_name
,
s
,
list
(
args
))
return
mod
return
mod
python/akg/utils/dump_cuda_meta.py
浏览文件 @
d237aa7d
...
@@ -20,93 +20,66 @@ import fcntl
...
@@ -20,93 +20,66 @@ import fcntl
import
hashlib
import
hashlib
import
akg.tvm
import
akg.tvm
def
get_dim
(
dim
,
axis
=
True
):
"""get dim info"""
dims_str
=
{
"grid_dim0"
:
"// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = "
,
"grid_dim1"
:
"// attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = "
,
"grid_dim2"
:
"// attr [iter_var(blockIdx.z, , blockIdx.z)] thread_extent = "
,
"block_dim0"
:
"// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = "
,
"block_dim1"
:
"// attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = "
,
"block_dim2"
:
"// attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = "
}
dim_to_axis
=
{
"grid_dim0"
:
'"blockIdx.x" : '
,
"grid_dim1"
:
'"blockIdx.y" : '
,
"grid_dim2"
:
'"blockIdx.z" : '
,
"block_dim0"
:
'"threadIdx.x" : '
,
"block_dim1"
:
'"threadIdx.y" : '
,
"block_dim2"
:
'"threadIdx.z" : '
}
if
axis
:
return
dim_to_axis
.
get
(
dim
)
return
dims_str
.
get
(
dim
)
def
parse_params
(
file
,
dim
,
ir
):
"""parse parameters"""
dim_str
=
get_dim
(
dim
,
axis
=
False
)
pos
=
ir
.
find
(
dim_str
)
if
pos
!=
-
1
:
index
=
pos
+
len
(
dim_str
)
param_temp
=
get_dim
(
dim
)
while
ir
[
index
].
isdigit
():
param_temp
+=
ir
[
index
]
index
+=
1
file
.
write
(
param_temp
+
",
\n
"
)
else
:
param_temp
=
get_dim
(
dim
)
+
'1'
file
.
write
(
param_temp
+
",
\n
"
)
def
save_gpu_params
(
s
,
args
,
kernel_info
):
@
akg
.
tvm
.
register_func
"""save gpu parameters"""
def
dump_cuda_meta
(
code
,
ptx
,
thread_info
):
ptx_code
=
kernel_info
[
0
]
"""
file_name
=
kernel_info
[
1
]
Function for dumping cuda meta.
kernel_name
=
kernel_info
[
2
]
Args:
dump_ir
=
os
.
getenv
(
'MS_AKG_DUMP_IR'
)
==
"on"
code: gpu code.
if
dump_ir
:
ptx: ptx code.
schedule_path
=
os
.
path
.
realpath
(
kernel_name
)
thread_info: thread info, written to json file.
all_passes
=
os
.
listdir
(
schedule_path
)
"""
for
cur_pass
in
all_passes
:
# kernel name
if
cur_pass
.
startswith
(
"00_"
):
kernel_name
=
code
.
split
(
"_kernel"
)[
0
].
split
(
" "
)[
-
1
]
with
open
(
schedule_path
+
'/'
+
cur_pass
,
"r"
)
as
file
:
ir
=
file
.
read
()
break
else
:
ir
=
str
(
akg
.
tvm
.
lower
(
s
,
args
,
simple_mode
=
True
))
file_path
=
os
.
path
.
realpath
(
file_name
)
if
os
.
path
.
exists
(
file_path
):
os
.
remove
(
file_path
)
# sha256 of ptx
sha256
=
hashlib
.
sha256
()
sha256
=
hashlib
.
sha256
()
sha256
.
update
(
ptx
_code
.
encode
(
"utf-8"
))
sha256
.
update
(
ptx
.
encode
(
"utf-8"
))
hash_str
=
sha256
.
hexdigest
()
hash_str
=
sha256
.
hexdigest
()
with
os
.
fdopen
(
os
.
open
(
file_path
,
os
.
O_WRONLY
|
os
.
O_CREAT
,
0o400
),
'w'
)
as
fo
:
fo
.
write
(
"{
\n
"
)
fo
.
write
(
'"kernelName" : '
+
'"'
+
kernel_name
+
"_kernel0"
+
'",
\n
'
)
parse_params
(
fo
,
"grid_dim0"
,
ir
)
parse_params
(
fo
,
"grid_dim1"
,
ir
)
parse_params
(
fo
,
"grid_dim2"
,
ir
)
parse_params
(
fo
,
"block_dim0"
,
ir
)
parse_params
(
fo
,
"block_dim1"
,
ir
)
parse_params
(
fo
,
"block_dim2"
,
ir
)
fo
.
write
(
'"sha256" : '
+
'"'
+
hash_str
+
'"
\n
'
)
fo
.
write
(
"}
\n
"
)
def
dump
(
mod
,
kernel_name
,
sch
,
args
):
# thread info
thread_info_dict
=
{
"blockIdx.x"
:
"1"
,
"blockIdx.y"
:
"1"
,
"blockIdx.z"
:
"1"
,
"threadIdx.x"
:
"1"
,
"threadIdx.y"
:
"1"
,
"threadIdx.z"
:
"1"
}
for
thread_tag
in
thread_info_dict
.
keys
():
if
thread_tag
in
thread_info
:
if
isinstance
(
thread_info
[
thread_tag
],
int
):
thread_info_dict
[
thread_tag
]
=
str
(
thread_info
[
thread_tag
])
elif
isinstance
(
thread_info
[
thread_tag
],
akg
.
tvm
.
expr
.
IntImm
):
thread_info_dict
[
thread_tag
]
=
str
(
thread_info
[
thread_tag
].
value
)
meta_path
=
"./cuda_meta_"
+
str
(
os
.
getpid
())
+
"/"
meta_path
=
"./cuda_meta_"
+
str
(
os
.
getpid
())
+
"/"
cuda_path
=
os
.
path
.
realpath
(
meta_path
)
cuda_path
=
os
.
path
.
realpath
(
meta_path
)
if
not
os
.
path
.
isdir
(
cuda_path
):
if
not
os
.
path
.
isdir
(
cuda_path
):
os
.
makedirs
(
cuda_path
)
os
.
makedirs
(
cuda_path
)
# save ptx file to cuda meta
ptx_file
=
os
.
path
.
realpath
(
meta_path
+
kernel_name
+
".ptx"
)
ptx_file
=
os
.
path
.
realpath
(
meta_path
+
kernel_name
+
".ptx"
)
with
open
(
ptx_file
,
"at"
)
as
f
:
with
open
(
ptx_file
,
"at"
)
as
f
:
fcntl
.
flock
(
f
.
fileno
(),
fcntl
.
LOCK_EX
)
fcntl
.
flock
(
f
.
fileno
(),
fcntl
.
LOCK_EX
)
f
.
seek
(
0
,
2
)
f
.
seek
(
0
,
2
)
if
f
.
tell
()
==
0
:
if
f
.
tell
()
==
0
:
ptx_code
=
mod
.
imported_modules
[
0
].
get_source
(
'ptx'
)
f
.
write
(
ptx
)
f
.
write
(
ptx_code
)
param_path
=
os
.
path
.
realpath
(
meta_path
+
kernel_name
+
'.json'
)
# save json file to cuda meta
save_gpu_params
(
sch
,
args
,
(
ptx_code
,
param_path
,
kernel_name
))
json_file
=
os
.
path
.
realpath
(
meta_path
+
kernel_name
+
".json"
)
if
os
.
path
.
exists
(
json_file
):
os
.
remove
(
json_file
)
with
os
.
fdopen
(
os
.
open
(
json_file
,
os
.
O_WRONLY
|
os
.
O_CREAT
,
0o400
),
'w'
)
as
fo
:
fo
.
write
(
"{
\n
"
)
fo
.
write
(
'"kernelName" : '
+
'"'
+
kernel_name
+
"_kernel0"
+
'",
\n
'
)
fo
.
write
(
'"blockIdx.x" : '
+
thread_info_dict
[
"blockIdx.x"
]
+
',
\n
'
)
fo
.
write
(
'"blockIdx.y" : '
+
thread_info_dict
[
"blockIdx.y"
]
+
',
\n
'
)
fo
.
write
(
'"blockIdx.z" : '
+
thread_info_dict
[
"blockIdx.z"
]
+
',
\n
'
)
fo
.
write
(
'"threadIdx.x" : '
+
thread_info_dict
[
"threadIdx.x"
]
+
',
\n
'
)
fo
.
write
(
'"threadIdx.y" : '
+
thread_info_dict
[
"threadIdx.y"
]
+
',
\n
'
)
fo
.
write
(
'"threadIdx.z" : '
+
thread_info_dict
[
"threadIdx.z"
]
+
',
\n
'
)
fo
.
write
(
'"sha256" : '
+
'"'
+
hash_str
+
'"
\n
'
)
fo
.
write
(
"}
\n
"
)
python/akg/utils/kernel_exec.py
浏览文件 @
d237aa7d
...
@@ -42,7 +42,6 @@ from akg.utils import format_transform as ft_util
...
@@ -42,7 +42,6 @@ from akg.utils import format_transform as ft_util
from
akg.utils
import
custom_tiling
as
ct_util
from
akg.utils
import
custom_tiling
as
ct_util
from
akg.utils
import
validation_check
as
vc_util
from
akg.utils
import
validation_check
as
vc_util
from
akg.utils.dsl_create
import
TensorUtils
from
akg.utils.dsl_create
import
TensorUtils
from
akg.utils
import
dump_cuda_meta
sh
=
logging
.
StreamHandler
(
sys
.
stdout
)
sh
=
logging
.
StreamHandler
(
sys
.
stdout
)
logging
.
getLogger
().
addHandler
(
sh
)
logging
.
getLogger
().
addHandler
(
sh
)
...
@@ -746,7 +745,6 @@ def op_build(op_func, input_shapes, input_types, op_attrs=None, kernel_name="",
...
@@ -746,7 +745,6 @@ def op_build(op_func, input_shapes, input_types, op_attrs=None, kernel_name="",
with
akg
.
tvm
.
build_config
(
dump_pass_ir
=
dump_ir
):
with
akg
.
tvm
.
build_config
(
dump_pass_ir
=
dump_ir
):
mod
=
akg
.
build
(
s
,
op_var
,
"cuda"
,
shape_var
,
name
=
kernel_name
,
attrs
=
attrs
,
mod
=
akg
.
build
(
s
,
op_var
,
"cuda"
,
shape_var
,
name
=
kernel_name
,
attrs
=
attrs
,
polyhedral
=
polyhedral
,
binds
=
binds
)
polyhedral
=
polyhedral
,
binds
=
binds
)
dump_cuda_meta
.
dump
(
mod
,
kernel_name
,
s
,
op_var
)
if
dump_code
:
if
dump_code
:
source_code
=
mod
.
imported_modules
[
0
].
get_source
()
source_code
=
mod
.
imported_modules
[
0
].
get_source
()
create_code
(
kernel_name
,
"./"
,
source_code
,
"CUDA"
)
create_code
(
kernel_name
,
"./"
,
source_code
,
"CUDA"
)
...
...
third_party/incubator-tvm/src/codegen/opt/build_cuda_on.cc
浏览文件 @
d237aa7d
...
@@ -23,6 +23,12 @@
...
@@ -23,6 +23,12 @@
*
*
* \file build_cuda.cc
* \file build_cuda.cc
*/
*/
/*
* 2020.8.14 - Get thread info inside BuildCUDA function,
* enbale dump cuda meta.
*/
#if defined(__linux__)
#if defined(__linux__)
#include <sys/stat.h>
#include <sys/stat.h>
#endif
#endif
...
@@ -133,8 +139,18 @@ runtime::Module BuildCUDA(Array<LoweredFunc> funcs) {
...
@@ -133,8 +139,18 @@ runtime::Module BuildCUDA(Array<LoweredFunc> funcs) {
CodeGenCUDA
cg
;
CodeGenCUDA
cg
;
cg
.
Init
(
output_ssa
);
cg
.
Init
(
output_ssa
);
Map
<
std
::
string
,
Expr
>
thread_info
;
for
(
LoweredFunc
f
:
funcs
)
{
for
(
LoweredFunc
f
:
funcs
)
{
cg
.
AddFunction
(
f
);
cg
.
AddFunction
(
f
);
for
(
const
auto
&
axis
:
f
->
thread_axis
)
{
auto
thread_tag
=
axis
->
thread_tag
;
auto
node
=
axis
->
dom
.
get
();
if
(
node
!=
nullptr
)
{
CHECK
(
axis
->
dom
->
extent
.
as
<
IntImm
>
());
thread_info
.
Set
(
thread_tag
,
axis
->
dom
->
extent
);
}
}
}
}
std
::
string
code
=
cg
.
Finish
();
std
::
string
code
=
cg
.
Finish
();
...
@@ -151,6 +167,11 @@ runtime::Module BuildCUDA(Array<LoweredFunc> funcs) {
...
@@ -151,6 +167,11 @@ runtime::Module BuildCUDA(Array<LoweredFunc> funcs) {
}
else
{
}
else
{
ptx
=
NVRTCCompile
(
code
,
cg
.
need_include_path
());
ptx
=
NVRTCCompile
(
code
,
cg
.
need_include_path
());
}
}
if
(
const
auto
*
f
=
Registry
::
Get
(
"dump_cuda_meta"
))
{
(
*
f
)(
code
,
ptx
,
thread_info
);
}
return
CUDAModuleCreate
(
ptx
,
fmt
,
ExtractFuncInfo
(
funcs
),
code
);
return
CUDAModuleCreate
(
ptx
,
fmt
,
ExtractFuncInfo
(
funcs
),
code
);
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录