Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
akg
提交
6335daad
A
akg
项目概览
MindSpore
/
akg
通知
58
Star
7
Fork
7
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
A
akg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6335daad
编写于
7月 16, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 16, 2020
浏览文件
操作
浏览文件
下载
差异文件
!71 support composite and unified build/launch for gpu
Merge pull request !71 from Gaoxiong/master
上级
f60af9df
cabffbd5
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
221 addition
and
23 deletion
+221
-23
python/akg/__init__.py
python/akg/__init__.py
+7
-0
python/akg/composite/build_module.py
python/akg/composite/build_module.py
+18
-1
python/akg/ms/gpu/__init__.py
python/akg/ms/gpu/__init__.py
+2
-1
python/akg/ms/gpu/mul.py
python/akg/ms/gpu/mul.py
+2
-19
python/akg/utils/dump_cuda_meta.py
python/akg/utils/dump_cuda_meta.py
+100
-0
python/akg/utils/kernel_exec.py
python/akg/utils/kernel_exec.py
+21
-2
src/composite/composite.cc
src/composite/composite.cc
+37
-0
tests/operators/gpu/test_ms_mul.py
tests/operators/gpu/test_ms_mul.py
+34
-0
未找到文件。
python/akg/__init__.py
浏览文件 @
6335daad
...
...
@@ -64,6 +64,13 @@ class AKGMetaPathLoader:
sys
.
modules
[
fullname
]
=
self
.
__target_module
return
self
.
__target_module
def
schedule
(
sch
,
target
=
'cuda'
):
def
decorator
(
func
):
def
wrapper
(
*
args
,
**
kwargs
):
output
=
func
(
*
args
,
**
kwargs
)
return
{
'schedule'
:
sch
,
'target'
:
target
,
'output'
:
output
,
'op_name'
:
func
.
__name__
}
return
wrapper
return
decorator
sys
.
meta_path
.
insert
(
0
,
AKGMetaPathFinder
())
...
...
python/akg/composite/build_module.py
浏览文件 @
6335daad
...
...
@@ -19,7 +19,8 @@ import json
from
akg
import
tvm
from
akg.tvm
import
_api_internal
from
.repository
import
__all__
as
repository
import
topi
from
akg.utils
import
dump_cuda_meta
def
generate_trait
(
desc
):
""" generate trait of kernel description """
...
...
@@ -116,6 +117,9 @@ def _build_to_func(desc_s, desc_d, attr=None):
return
func
(
desc_s
,
attr
)
def
_build
(
desc_s
,
desc_d
,
attr
=
None
):
if
desc_d
[
'process'
]
==
'gpu'
:
func
=
tvm
.
get_global_func
(
"composite_with_json"
)
return
func
(
desc_s
,
attr
)
rst
=
_build_to_func
(
desc_s
,
desc_d
,
attr
)
return
_api_internal
.
_BuildToModule
(
rst
)
...
...
@@ -163,3 +167,16 @@ def get_tiling_space(kernel_desc, level=1, attr=None):
if
level
>=
2
:
spaces
[
'tuning_space'
]
=
ret
.
tiling_candidate
.
asnumpy
().
tolist
()
return
spaces
@
tvm
.
register_func
(
"akg_build_gpu_module"
)
def
build_cuda
(
outputs
,
args
,
sch_name
,
kernel_name
):
scheduler
=
{
"injective"
:
topi
.
cuda
.
schedule_injective
,
"reduce"
:
topi
.
cuda
.
schedule_reduce
,
}
with
tvm
.
target
.
cuda
()
as
cuda
:
s
=
scheduler
[
sch_name
](
outputs
)
with
tvm
.
build_config
(
dump_pass_ir
=
True
):
mod
=
tvm
.
build
(
s
,
args
,
cuda
,
name
=
kernel_name
)
dump_cuda_meta
.
dump
(
mod
,
kernel_name
,
s
,
list
(
args
))
return
mod
python/akg/ms/gpu/__init__.py
浏览文件 @
6335daad
...
...
@@ -27,4 +27,5 @@ from .squeeze import Squeeze, gpu_schedule_Squeeze
from
.squeeze_grad
import
SqueezeGrad
,
gpu_schedule_SqueezeGrad
from
.mean
import
SimpleMean
,
gpu_schedule_SimpleMean
from
.mean_grad
import
SimpleMeanGrad
,
gpu_schedule_SimpleMeanGrad
from
.mul
import
Mul
,
gpu_schedule_Mul
from
.mul
import
Mul
python/akg/ms/gpu/mul.py
浏览文件 @
6335daad
...
...
@@ -15,29 +15,12 @@
# limitations under the License.
"""mul"""
import
akg
import
akg.topi
as
topi
import
akg.tvm
as
tvm
from
akg.ops.math
import
mul
@
akg
.
schedule
(
topi
.
cuda
.
schedule_injective
)
def
Mul
(
x
,
y
):
"""mul."""
return
mul
.
mul
(
x
,
y
)
def
gpu_schedule_Mul
(
outs
):
"""
gpu schedule for mul.
Args:
outs (tvm.tensor.Tensor): outputs of compute.
Returns:
sch (schedule.Schedule): The created schedule.
"""
device
=
'cuda'
ctx
=
tvm
.
context
(
device
,
0
)
if
not
ctx
.
exist
:
raise
SystemError
(
"Skip because %s is not enabled"
%
device
)
with
tvm
.
target
.
create
(
device
):
sch
=
topi
.
cuda
.
schedule_broadcast
(
outs
)
return
sch
python/akg/utils/dump_cuda_meta.py
0 → 100644
浏览文件 @
6335daad
#!/usr/bin/env python3
# coding: utf-8
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""save gpu param"""
import
os
import
fcntl
import
hashlib
import
akg.tvm
def
get_dim
(
dim
,
axis
=
True
):
"""get dim info"""
dims_str
=
{
"grid_dim0"
:
"// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = "
,
"grid_dim1"
:
"// attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = "
,
"grid_dim2"
:
"// attr [iter_var(blockIdx.z, , blockIdx.z)] thread_extent = "
,
"block_dim0"
:
"// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = "
,
"block_dim1"
:
"// attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = "
,
"block_dim2"
:
"// attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = "
}
dim_to_axis
=
{
"grid_dim0"
:
'"blockIdx.x" : '
,
"grid_dim1"
:
'"blockIdx.y" : '
,
"grid_dim2"
:
'"blockIdx.z" : '
,
"block_dim0"
:
'"threadIdx.x" : '
,
"block_dim1"
:
'"threadIdx.y" : '
,
"block_dim2"
:
'"threadIdx.z" : '
}
if
axis
:
return
dim_to_axis
.
get
(
dim
)
return
dims_str
.
get
(
dim
)
def
parse_params
(
file
,
dim
,
ir
):
"""parse parameters"""
dim_str
=
get_dim
(
dim
,
axis
=
False
)
pos
=
ir
.
find
(
dim_str
)
if
pos
!=
-
1
:
index
=
pos
+
len
(
dim_str
)
param_temp
=
get_dim
(
dim
)
while
ir
[
index
].
isdigit
():
param_temp
+=
ir
[
index
]
index
+=
1
file
.
write
(
param_temp
+
",
\n
"
)
else
:
param_temp
=
get_dim
(
dim
)
+
'1'
file
.
write
(
param_temp
+
",
\n
"
)
def
save_gpu_params
(
s
,
args
,
kernel_info
):
"""save gpu parameters"""
ptx_code
=
kernel_info
[
0
]
file_name
=
kernel_info
[
1
]
kernel_name
=
kernel_info
[
2
]
ir
=
str
(
akg
.
tvm
.
lower
(
s
,
args
,
simple_mode
=
True
))
file_path
=
os
.
path
.
realpath
(
file_name
)
if
os
.
path
.
exists
(
file_path
):
os
.
remove
(
file_path
)
sha256
=
hashlib
.
sha256
()
sha256
.
update
(
ptx_code
.
encode
(
"utf-8"
))
hash_str
=
sha256
.
hexdigest
()
with
os
.
fdopen
(
os
.
open
(
file_path
,
os
.
O_WRONLY
|
os
.
O_CREAT
,
0o400
),
'w'
)
as
fo
:
fo
.
write
(
"{
\n
"
)
fo
.
write
(
'"kernelName" : '
+
'"'
+
kernel_name
+
"_kernel0"
+
'",
\n
'
)
parse_params
(
fo
,
"grid_dim0"
,
ir
)
parse_params
(
fo
,
"grid_dim1"
,
ir
)
parse_params
(
fo
,
"grid_dim2"
,
ir
)
parse_params
(
fo
,
"block_dim0"
,
ir
)
parse_params
(
fo
,
"block_dim1"
,
ir
)
parse_params
(
fo
,
"block_dim2"
,
ir
)
fo
.
write
(
'"sha256" : '
+
'"'
+
hash_str
+
'"
\n
'
)
fo
.
write
(
"}
\n
"
)
def
dump
(
mod
,
kernel_name
,
sch
,
args
):
meta_path
=
"./cuda_meta/"
cuda_path
=
os
.
path
.
realpath
(
meta_path
)
if
not
os
.
path
.
isdir
(
cuda_path
):
os
.
makedirs
(
cuda_path
)
ptx_file
=
os
.
path
.
realpath
(
meta_path
+
kernel_name
+
".ptx"
)
with
open
(
ptx_file
,
"at"
)
as
f
:
fcntl
.
flock
(
f
.
fileno
(),
fcntl
.
LOCK_EX
)
f
.
seek
(
0
,
2
)
if
f
.
tell
()
==
0
:
ptx_code
=
mod
.
imported_modules
[
0
].
get_source
(
'ptx'
)
f
.
write
(
ptx_code
)
param_path
=
os
.
path
.
realpath
(
meta_path
+
kernel_name
+
'.json'
)
save_gpu_params
(
sch
,
args
,
(
ptx_code
,
param_path
,
kernel_name
))
\ No newline at end of file
python/akg/utils/kernel_exec.py
浏览文件 @
6335daad
...
...
@@ -42,7 +42,7 @@ from akg.utils import format_transform as ft_util
from
akg.utils
import
custom_tiling
as
ct_util
from
akg.utils
import
validation_check
as
vc_util
from
akg.utils.dsl_create
import
TensorUtils
from
akg.utils
import
dump_cuda_meta
sh
=
logging
.
StreamHandler
(
sys
.
stdout
)
logging
.
getLogger
().
addHandler
(
sh
)
...
...
@@ -435,6 +435,12 @@ def mod_launch(mod, args, outputs=(-1,), tuning=False, device_id=0, expect=None)
"""
gc
.
collect
()
if
mod
.
imported_modules
[
0
].
type_key
==
'cuda'
:
ctx
=
akg
.
tvm
.
context
(
'cuda'
,
device_id
)
mod_args
=
[
akg
.
tvm
.
nd
.
array
(
a
,
ctx
)
for
a
in
args
]
mod
(
*
mod_args
)
out_list
=
[
mod_args
[
len
(
args
)
+
i
if
i
<
0
else
i
].
asnumpy
()
for
i
in
outputs
]
return
out_list
[
0
]
if
len
(
out_list
)
==
1
else
tuple
(
out_list
)
stat_info
=
{}
profiling_mode
=
get_profiling_mode
()
...
...
@@ -679,7 +685,7 @@ def op_build(op_func, input_shapes, input_types, op_attrs=None, kernel_name="",
attrs
[
'dim'
]
=
dim_info
compute_func
=
None
# func which is defined in dsl for doing compute_inline or other
sch_tmpl
=
None
if
isinstance
(
output
,
(
list
,
tuple
)):
from
inspect
import
isfunction
new_outputs
=
[]
...
...
@@ -696,6 +702,9 @@ def op_build(op_func, input_shapes, input_types, op_attrs=None, kernel_name="",
new_outputs
.
append
(
elem
)
output
=
new_outputs
elif
isinstance
(
output
,
dict
):
sch_tmpl
=
output
output
=
sch_tmpl
[
'output'
]
binds
=
None
if
not
attrs
else
attrs
.
pop
(
BINDS
,
None
)
op_var
=
[]
...
...
@@ -715,6 +724,16 @@ def op_build(op_func, input_shapes, input_types, op_attrs=None, kernel_name="",
if
TensorUtils
.
is_output_value
(
output
):
op_var
=
op_var
+
[
output
]
if
sch_tmpl
!=
None
:
assert
(
sch_tmpl
[
'target'
]
==
'cuda'
)
kernel_name
=
kernel_name
if
kernel_name
!=
""
else
sch_tmpl
[
'op_name'
]
with
akg
.
tvm
.
target
.
cuda
()
as
target
:
s
=
sch_tmpl
[
'schedule'
](
sch_tmpl
[
'output'
])
with
akg
.
tvm
.
build_config
(
dump_pass_ir
=
True
):
mod
=
akg
.
tvm
.
build
(
s
,
op_var
,
target
,
target_host
=
'stackvm'
,
name
=
kernel_name
)
dump_cuda_meta
.
dump
(
mod
,
kernel_name
,
s
,
op_var
)
return
mod
if
isinstance
(
output
,
(
list
,
tuple
)):
tmp
=
[]
for
x
in
list
(
output
):
...
...
src/composite/composite.cc
浏览文件 @
6335daad
...
...
@@ -459,7 +459,44 @@ NodeRef composite_with_json_to_func(const std::string &json_str, Map<std::string
return
build_rst
;
}
std
::
string
get_process
(
const
std
::
string
&
json_str
)
{
size_t
pos
=
json_str
.
find
(
"
\"
process
\"
"
);
if
(
pos
!=
std
::
string
::
npos
&&
json_str
.
find
(
"gpu"
,
pos
)
!=
std
::
string
::
npos
)
{
return
"gpu"
;
}
return
"aicore"
;
}
std
::
string
get_schedule
(
Array
<
Tensor
>
&
outputs
)
{
for
(
const
Tensor
&
t
:
outputs
)
{
if
(
t
->
op
->
tag
==
"comm_reduce"
||
t
->
op
->
tag
==
"comm_reduce_idx"
)
{
return
"reduce"
;
}
}
return
"injective"
;
}
Module
composite_with_json_gpu
(
const
std
::
string
&
json_str
,
Map
<
std
::
string
,
NodeRef
>
attrs
)
{
picojson
::
value
v
;
std
::
string
err
=
picojson
::
parse
(
v
,
json_str
);
if
(
!
err
.
empty
())
{
LOG
(
ERROR
)
<<
"json parse error, error message: "
<<
err
;
}
Array
<
Tensor
>
tensors
;
Array
<
NodeRef
>
args
;
Map
<
Tensor
,
Buffer
>
in_binds
;
std
::
string
kernel_name
;
extract_op_info
(
v
,
&
tensors
,
&
args
,
&
kernel_name
,
&
in_binds
);
const
auto
*
build_func
=
air
::
runtime
::
Registry
::
Get
(
"akg_build_gpu_module"
);
CHECK
(
build_func
!=
nullptr
);
std
::
string
sch
=
get_schedule
(
tensors
);
return
(
*
build_func
)(
tensors
,
args
,
sch
,
kernel_name
);
}
Module
composite_with_json
(
const
std
::
string
&
json_str
,
Map
<
std
::
string
,
NodeRef
>
attrs
)
{
if
(
get_process
(
json_str
)
==
"gpu"
)
{
return
composite_with_json_gpu
(
json_str
,
attrs
);
}
auto
build_rst
=
composite_with_json_to_func
(
json_str
,
attrs
);
return
BuildToModule
(
build_rst
);
}
...
...
tests/operators/gpu/test_ms_mul.py
0 → 100644
浏览文件 @
6335daad
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import
numpy
as
np
from
akg.ms.gpu
import
Mul
from
gen_random
import
random_gaussian
from
akg.utils
import
kernel_exec
as
utils
def
gen_data
(
shape
,
dtype
):
support_list
=
{
"float16"
:
np
.
float16
,
"float32"
:
np
.
float32
}
lhd
=
random_gaussian
(
shape
,
miu
=
1
,
sigma
=
0.1
).
astype
(
support_list
[
dtype
])
rhd
=
random_gaussian
(
shape
,
miu
=
1
,
sigma
=
0.1
).
astype
(
support_list
[
dtype
])
expect
=
np
.
multiply
(
lhd
,
rhd
)
output
=
np
.
full
(
shape
,
np
.
nan
,
dtype
)
return
lhd
,
rhd
,
output
,
expect
def
test_ms_mul
(
shape
,
dtype
):
mod
=
utils
.
op_build
(
Mul
,
(
shape
,
shape
),
(
dtype
,
dtype
))
lhd
,
rhd
,
output
,
expect
=
gen_data
(
shape
,
dtype
)
output
=
utils
.
mod_launch
(
mod
,
(
lhd
,
rhd
,
output
),
expect
=
expect
)
np
.
allclose
(
output
,
expect
,
rtol
=
5e-03
,
atol
=
1.e-8
)
if
__name__
==
'__main__'
:
test_ms_mul
((
1024
,
4096
),
'float32'
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录