提交 cabffbd5 编写于 作者: G Gaoxiong

support composite and op_build/mod_launch for gpu

上级 1f184177
......@@ -64,6 +64,13 @@ class AKGMetaPathLoader:
sys.modules[fullname] = self.__target_module
return self.__target_module
def schedule(sch, target = 'cuda'):
def decorator(func):
def wrapper(*args, **kwargs):
output = func(*args, **kwargs)
return {'schedule' : sch, 'target' : target, 'output' : output, 'op_name' : func.__name__}
return wrapper
return decorator
sys.meta_path.insert(0, AKGMetaPathFinder())
......
......@@ -19,7 +19,8 @@ import json
from akg import tvm
from akg.tvm import _api_internal
from .repository import __all__ as repository
import topi
from akg.utils import dump_cuda_meta
def generate_trait(desc):
""" generate trait of kernel description """
......@@ -116,6 +117,9 @@ def _build_to_func(desc_s, desc_d, attr=None):
return func(desc_s, attr)
def _build(desc_s, desc_d, attr=None):
if desc_d['process'] == 'gpu':
func = tvm.get_global_func("composite_with_json")
return func(desc_s, attr)
rst = _build_to_func(desc_s, desc_d, attr)
return _api_internal._BuildToModule(rst)
......@@ -163,3 +167,16 @@ def get_tiling_space(kernel_desc, level=1, attr=None):
if level >= 2:
spaces['tuning_space'] = ret.tiling_candidate.asnumpy().tolist()
return spaces
@tvm.register_func("akg_build_gpu_module")
def build_cuda(outputs, args, sch_name, kernel_name):
scheduler = {
"injective" : topi.cuda.schedule_injective,
"reduce" : topi.cuda.schedule_reduce,
}
with tvm.target.cuda() as cuda:
s = scheduler[sch_name](outputs)
with tvm.build_config(dump_pass_ir = True):
mod = tvm.build(s, args, cuda, name = kernel_name)
dump_cuda_meta.dump(mod, kernel_name, s, list(args))
return mod
......@@ -27,4 +27,5 @@ from .squeeze import Squeeze, gpu_schedule_Squeeze
from .squeeze_grad import SqueezeGrad, gpu_schedule_SqueezeGrad
from .mean import SimpleMean, gpu_schedule_SimpleMean
from .mean_grad import SimpleMeanGrad, gpu_schedule_SimpleMeanGrad
from .mul import Mul, gpu_schedule_Mul
from .mul import Mul
......@@ -15,29 +15,12 @@
# limitations under the License.
"""mul"""
import akg
import akg.topi as topi
import akg.tvm as tvm
from akg.ops.math import mul
@akg.schedule(topi.cuda.schedule_injective)
def Mul(x, y):
"""mul."""
return mul.mul(x, y)
def gpu_schedule_Mul(outs):
"""
gpu schedule for mul.
Args:
outs (tvm.tensor.Tensor): outputs of compute.
Returns:
sch (schedule.Schedule): The created schedule.
"""
device = 'cuda'
ctx = tvm.context(device, 0)
if not ctx.exist:
raise SystemError("Skip because %s is not enabled" % device)
with tvm.target.create(device):
sch = topi.cuda.schedule_broadcast(outs)
return sch
#!/usr/bin/env python3
# coding: utf-8
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""save gpu param"""
import os
import fcntl
import hashlib
import akg.tvm
def get_dim(dim, axis=True):
"""get dim info"""
dims_str = {
"grid_dim0": "// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = ",
"grid_dim1": "// attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = ",
"grid_dim2": "// attr [iter_var(blockIdx.z, , blockIdx.z)] thread_extent = ",
"block_dim0": "// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = ",
"block_dim1": "// attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = ",
"block_dim2": "// attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = "
}
dim_to_axis = {
"grid_dim0": '"blockIdx.x" : ',
"grid_dim1": '"blockIdx.y" : ',
"grid_dim2": '"blockIdx.z" : ',
"block_dim0": '"threadIdx.x" : ',
"block_dim1": '"threadIdx.y" : ',
"block_dim2": '"threadIdx.z" : '
}
if axis:
return dim_to_axis.get(dim)
return dims_str.get(dim)
def parse_params(file, dim, ir):
"""parse parameters"""
dim_str = get_dim(dim, axis=False)
pos = ir.find(dim_str)
if pos != -1:
index = pos + len(dim_str)
param_temp = get_dim(dim)
while ir[index].isdigit():
param_temp += ir[index]
index += 1
file.write(param_temp + ",\n")
else:
param_temp = get_dim(dim) + '1'
file.write(param_temp + ",\n")
def save_gpu_params(s, args, kernel_info):
"""save gpu parameters"""
ptx_code = kernel_info[0]
file_name = kernel_info[1]
kernel_name = kernel_info[2]
ir = str(akg.tvm.lower(s, args, simple_mode=True))
file_path = os.path.realpath(file_name)
if os.path.exists(file_path):
os.remove(file_path)
sha256 = hashlib.sha256()
sha256.update(ptx_code.encode("utf-8"))
hash_str = sha256.hexdigest()
with os.fdopen(os.open(file_path, os.O_WRONLY | os.O_CREAT, 0o400), 'w') as fo:
fo.write("{\n")
fo.write('"kernelName" : ' + '"' + kernel_name + "_kernel0" + '",\n')
parse_params(fo, "grid_dim0", ir)
parse_params(fo, "grid_dim1", ir)
parse_params(fo, "grid_dim2", ir)
parse_params(fo, "block_dim0", ir)
parse_params(fo, "block_dim1", ir)
parse_params(fo, "block_dim2", ir)
fo.write('"sha256" : ' + '"' + hash_str + '"\n')
fo.write("}\n")
def dump(mod, kernel_name, sch, args):
meta_path = "./cuda_meta/"
cuda_path = os.path.realpath(meta_path)
if not os.path.isdir(cuda_path):
os.makedirs(cuda_path)
ptx_file = os.path.realpath(meta_path + kernel_name + ".ptx")
with open(ptx_file, "at") as f:
fcntl.flock(f.fileno(), fcntl.LOCK_EX)
f.seek(0, 2)
if f.tell() == 0:
ptx_code = mod.imported_modules[0].get_source('ptx')
f.write(ptx_code)
param_path = os.path.realpath(meta_path + kernel_name + '.json')
save_gpu_params(sch, args, (ptx_code, param_path, kernel_name))
\ No newline at end of file
......@@ -42,7 +42,7 @@ from akg.utils import format_transform as ft_util
from akg.utils import custom_tiling as ct_util
from akg.utils import validation_check as vc_util
from akg.utils.dsl_create import TensorUtils
from akg.utils import dump_cuda_meta
sh = logging.StreamHandler(sys.stdout)
logging.getLogger().addHandler(sh)
......@@ -435,6 +435,12 @@ def mod_launch(mod, args, outputs=(-1,), tuning=False, device_id=0, expect=None)
"""
gc.collect()
if mod.imported_modules[0].type_key == 'cuda':
ctx = akg.tvm.context('cuda', device_id)
mod_args = [akg.tvm.nd.array(a, ctx) for a in args]
mod(*mod_args)
out_list = [mod_args[len(args) + i if i < 0 else i].asnumpy() for i in outputs]
return out_list[0] if len(out_list) == 1 else tuple(out_list)
stat_info = {}
profiling_mode = get_profiling_mode()
......@@ -679,7 +685,7 @@ def op_build(op_func, input_shapes, input_types, op_attrs=None, kernel_name="",
attrs['dim'] = dim_info
compute_func = None # func which is defined in dsl for doing compute_inline or other
sch_tmpl = None
if isinstance(output, (list, tuple)):
from inspect import isfunction
new_outputs = []
......@@ -696,6 +702,9 @@ def op_build(op_func, input_shapes, input_types, op_attrs=None, kernel_name="",
new_outputs.append(elem)
output = new_outputs
elif isinstance(output, dict):
sch_tmpl = output
output = sch_tmpl['output']
binds = None if not attrs else attrs.pop(BINDS, None)
op_var = []
......@@ -715,6 +724,16 @@ def op_build(op_func, input_shapes, input_types, op_attrs=None, kernel_name="",
if TensorUtils.is_output_value(output):
op_var = op_var + [output]
if sch_tmpl != None:
assert(sch_tmpl['target'] == 'cuda')
kernel_name = kernel_name if kernel_name != "" else sch_tmpl['op_name']
with akg.tvm.target.cuda() as target:
s = sch_tmpl['schedule'](sch_tmpl['output'])
with akg.tvm.build_config(dump_pass_ir = True):
mod = akg.tvm.build(s, op_var, target, target_host = 'stackvm', name = kernel_name)
dump_cuda_meta.dump(mod, kernel_name, s, op_var)
return mod
if isinstance(output, (list, tuple)):
tmp = []
for x in list(output):
......
......@@ -459,7 +459,44 @@ NodeRef composite_with_json_to_func(const std::string &json_str, Map<std::string
return build_rst;
}
std::string get_process(const std::string &json_str) {
size_t pos = json_str.find("\"process\"");
if (pos != std::string::npos && json_str.find("gpu", pos) != std::string::npos) {
return "gpu";
}
return "aicore";
}
std::string get_schedule(Array<Tensor> &outputs) {
for (const Tensor &t : outputs) {
if (t->op->tag == "comm_reduce" || t->op->tag == "comm_reduce_idx") {
return "reduce";
}
}
return "injective";
}
Module composite_with_json_gpu(const std::string &json_str, Map<std::string, NodeRef> attrs) {
picojson::value v;
std::string err = picojson::parse(v, json_str);
if (!err.empty()) {
LOG(ERROR) << "json parse error, error message: " << err;
}
Array<Tensor> tensors;
Array<NodeRef> args;
Map<Tensor, Buffer> in_binds;
std::string kernel_name;
extract_op_info(v, &tensors, &args, &kernel_name, &in_binds);
const auto* build_func = air::runtime::Registry::Get("akg_build_gpu_module");
CHECK(build_func != nullptr);
std::string sch = get_schedule(tensors);
return (*build_func)(tensors, args, sch, kernel_name);
}
Module composite_with_json(const std::string &json_str, Map<std::string, NodeRef> attrs) {
if (get_process(json_str) == "gpu") {
return composite_with_json_gpu(json_str, attrs);
}
auto build_rst = composite_with_json_to_func(json_str, attrs);
return BuildToModule(build_rst);
}
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import numpy as np
from akg.ms.gpu import Mul
from gen_random import random_gaussian
from akg.utils import kernel_exec as utils
def gen_data(shape, dtype):
support_list = {"float16": np.float16, "float32": np.float32}
lhd = random_gaussian(shape, miu=1, sigma=0.1).astype(support_list[dtype])
rhd = random_gaussian(shape, miu=1, sigma=0.1).astype(support_list[dtype])
expect = np.multiply(lhd, rhd)
output = np.full(shape, np.nan, dtype)
return lhd, rhd, output, expect
def test_ms_mul(shape, dtype):
mod = utils.op_build(Mul, (shape, shape), (dtype, dtype))
lhd, rhd, output, expect = gen_data(shape, dtype)
output = utils.mod_launch(mod, (lhd, rhd, output), expect = expect)
np.allclose(output, expect, rtol=5e-03, atol=1.e-8)
if __name__ == '__main__':
test_ms_mul((1024, 4096), 'float32')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册