diff --git a/python/akg/__init__.py b/python/akg/__init__.py index ba1cb13b54cef7683945a0420cb505b1eb3c77d7..a1a01629513df28ee61b81094268ea948190b4a0 100644 --- a/python/akg/__init__.py +++ b/python/akg/__init__.py @@ -64,6 +64,13 @@ class AKGMetaPathLoader: sys.modules[fullname] = self.__target_module return self.__target_module +def schedule(sch, target = 'cuda'): + def decorator(func): + def wrapper(*args, **kwargs): + output = func(*args, **kwargs) + return {'schedule' : sch, 'target' : target, 'output' : output, 'op_name' : func.__name__} + return wrapper + return decorator sys.meta_path.insert(0, AKGMetaPathFinder()) diff --git a/python/akg/composite/build_module.py b/python/akg/composite/build_module.py index 5464cbe0816c9fb763ea0085af7703c430db0f7d..369ef761b3a34176b0424c1ca09a2e837fddaf4e 100644 --- a/python/akg/composite/build_module.py +++ b/python/akg/composite/build_module.py @@ -19,7 +19,8 @@ import json from akg import tvm from akg.tvm import _api_internal from .repository import __all__ as repository - +import topi +from akg.utils import dump_cuda_meta def generate_trait(desc): """ generate trait of kernel description """ @@ -116,6 +117,9 @@ def _build_to_func(desc_s, desc_d, attr=None): return func(desc_s, attr) def _build(desc_s, desc_d, attr=None): + if desc_d['process'] == 'gpu': + func = tvm.get_global_func("composite_with_json") + return func(desc_s, attr) rst = _build_to_func(desc_s, desc_d, attr) return _api_internal._BuildToModule(rst) @@ -163,3 +167,16 @@ def get_tiling_space(kernel_desc, level=1, attr=None): if level >= 2: spaces['tuning_space'] = ret.tiling_candidate.asnumpy().tolist() return spaces + +@tvm.register_func("akg_build_gpu_module") +def build_cuda(outputs, args, sch_name, kernel_name): + scheduler = { + "injective" : topi.cuda.schedule_injective, + "reduce" : topi.cuda.schedule_reduce, + } + with tvm.target.cuda() as cuda: + s = scheduler[sch_name](outputs) + with tvm.build_config(dump_pass_ir = True): + mod = tvm.build(s, args, cuda, name = kernel_name) + dump_cuda_meta.dump(mod, kernel_name, s, list(args)) + return mod diff --git a/python/akg/ms/gpu/__init__.py b/python/akg/ms/gpu/__init__.py index a711858f68269c3f1048162af16c44b0c63e3067..1931881b7657e50384fc998b133ae9728d11f341 100644 --- a/python/akg/ms/gpu/__init__.py +++ b/python/akg/ms/gpu/__init__.py @@ -27,4 +27,5 @@ from .squeeze import Squeeze, gpu_schedule_Squeeze from .squeeze_grad import SqueezeGrad, gpu_schedule_SqueezeGrad from .mean import SimpleMean, gpu_schedule_SimpleMean from .mean_grad import SimpleMeanGrad, gpu_schedule_SimpleMeanGrad -from .mul import Mul, gpu_schedule_Mul + +from .mul import Mul diff --git a/python/akg/ms/gpu/mul.py b/python/akg/ms/gpu/mul.py index 12bb4e15d30c87801f6d573d65406d99d6bf1746..4dda614e766a660aed565f793c77774b2a989a8e 100644 --- a/python/akg/ms/gpu/mul.py +++ b/python/akg/ms/gpu/mul.py @@ -15,29 +15,12 @@ # limitations under the License. """mul""" +import akg import akg.topi as topi import akg.tvm as tvm from akg.ops.math import mul +@akg.schedule(topi.cuda.schedule_injective) def Mul(x, y): """mul.""" return mul.mul(x, y) - - -def gpu_schedule_Mul(outs): - """ - gpu schedule for mul. - - Args: - outs (tvm.tensor.Tensor): outputs of compute. - - Returns: - sch (schedule.Schedule): The created schedule. - """ - device = 'cuda' - ctx = tvm.context(device, 0) - if not ctx.exist: - raise SystemError("Skip because %s is not enabled" % device) - with tvm.target.create(device): - sch = topi.cuda.schedule_broadcast(outs) - return sch diff --git a/python/akg/utils/dump_cuda_meta.py b/python/akg/utils/dump_cuda_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..5bdc590a7baac087f7b989118c3b256c4f741f97 --- /dev/null +++ b/python/akg/utils/dump_cuda_meta.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +# coding: utf-8 +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""save gpu param""" +import os +import fcntl +import hashlib +import akg.tvm + +def get_dim(dim, axis=True): + """get dim info""" + dims_str = { + "grid_dim0": "// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = ", + "grid_dim1": "// attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = ", + "grid_dim2": "// attr [iter_var(blockIdx.z, , blockIdx.z)] thread_extent = ", + "block_dim0": "// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = ", + "block_dim1": "// attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = ", + "block_dim2": "// attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = " + } + dim_to_axis = { + "grid_dim0": '"blockIdx.x" : ', + "grid_dim1": '"blockIdx.y" : ', + "grid_dim2": '"blockIdx.z" : ', + "block_dim0": '"threadIdx.x" : ', + "block_dim1": '"threadIdx.y" : ', + "block_dim2": '"threadIdx.z" : ' + } + if axis: + return dim_to_axis.get(dim) + return dims_str.get(dim) + + +def parse_params(file, dim, ir): + """parse parameters""" + dim_str = get_dim(dim, axis=False) + pos = ir.find(dim_str) + if pos != -1: + index = pos + len(dim_str) + param_temp = get_dim(dim) + + while ir[index].isdigit(): + param_temp += ir[index] + index += 1 + file.write(param_temp + ",\n") + else: + param_temp = get_dim(dim) + '1' + file.write(param_temp + ",\n") + +def save_gpu_params(s, args, kernel_info): + """save gpu parameters""" + ptx_code = kernel_info[0] + file_name = kernel_info[1] + kernel_name = kernel_info[2] + ir = str(akg.tvm.lower(s, args, simple_mode=True)) + file_path = os.path.realpath(file_name) + if os.path.exists(file_path): + os.remove(file_path) + + sha256 = hashlib.sha256() + sha256.update(ptx_code.encode("utf-8")) + hash_str = sha256.hexdigest() + with os.fdopen(os.open(file_path, os.O_WRONLY | os.O_CREAT, 0o400), 'w') as fo: + fo.write("{\n") + fo.write('"kernelName" : ' + '"' + kernel_name + "_kernel0" + '",\n') + parse_params(fo, "grid_dim0", ir) + parse_params(fo, "grid_dim1", ir) + parse_params(fo, "grid_dim2", ir) + parse_params(fo, "block_dim0", ir) + parse_params(fo, "block_dim1", ir) + parse_params(fo, "block_dim2", ir) + fo.write('"sha256" : ' + '"' + hash_str + '"\n') + fo.write("}\n") + +def dump(mod, kernel_name, sch, args): + meta_path = "./cuda_meta/" + cuda_path = os.path.realpath(meta_path) + if not os.path.isdir(cuda_path): + os.makedirs(cuda_path) + ptx_file = os.path.realpath(meta_path + kernel_name + ".ptx") + with open(ptx_file, "at") as f: + fcntl.flock(f.fileno(), fcntl.LOCK_EX) + f.seek(0, 2) + if f.tell() == 0: + ptx_code = mod.imported_modules[0].get_source('ptx') + f.write(ptx_code) + param_path = os.path.realpath(meta_path + kernel_name + '.json') + save_gpu_params(sch, args, (ptx_code, param_path, kernel_name)) \ No newline at end of file diff --git a/python/akg/utils/kernel_exec.py b/python/akg/utils/kernel_exec.py index d900a0d83b066e56b73924c780893b8de0bcc213..861107b74e71a964c446541c18a1cc2e83a87378 100644 --- a/python/akg/utils/kernel_exec.py +++ b/python/akg/utils/kernel_exec.py @@ -42,7 +42,7 @@ from akg.utils import format_transform as ft_util from akg.utils import custom_tiling as ct_util from akg.utils import validation_check as vc_util from akg.utils.dsl_create import TensorUtils - +from akg.utils import dump_cuda_meta sh = logging.StreamHandler(sys.stdout) logging.getLogger().addHandler(sh) @@ -435,6 +435,12 @@ def mod_launch(mod, args, outputs=(-1,), tuning=False, device_id=0, expect=None) """ gc.collect() + if mod.imported_modules[0].type_key == 'cuda': + ctx = akg.tvm.context('cuda', device_id) + mod_args = [akg.tvm.nd.array(a, ctx) for a in args] + mod(*mod_args) + out_list = [mod_args[len(args) + i if i < 0 else i].asnumpy() for i in outputs] + return out_list[0] if len(out_list) == 1 else tuple(out_list) stat_info = {} profiling_mode = get_profiling_mode() @@ -679,7 +685,7 @@ def op_build(op_func, input_shapes, input_types, op_attrs=None, kernel_name="", attrs['dim'] = dim_info compute_func = None # func which is defined in dsl for doing compute_inline or other - + sch_tmpl = None if isinstance(output, (list, tuple)): from inspect import isfunction new_outputs = [] @@ -696,6 +702,9 @@ def op_build(op_func, input_shapes, input_types, op_attrs=None, kernel_name="", new_outputs.append(elem) output = new_outputs + elif isinstance(output, dict): + sch_tmpl = output + output = sch_tmpl['output'] binds = None if not attrs else attrs.pop(BINDS, None) op_var = [] @@ -715,6 +724,16 @@ def op_build(op_func, input_shapes, input_types, op_attrs=None, kernel_name="", if TensorUtils.is_output_value(output): op_var = op_var + [output] + if sch_tmpl != None: + assert(sch_tmpl['target'] == 'cuda') + kernel_name = kernel_name if kernel_name != "" else sch_tmpl['op_name'] + with akg.tvm.target.cuda() as target: + s = sch_tmpl['schedule'](sch_tmpl['output']) + with akg.tvm.build_config(dump_pass_ir = True): + mod = akg.tvm.build(s, op_var, target, target_host = 'stackvm', name = kernel_name) + dump_cuda_meta.dump(mod, kernel_name, s, op_var) + return mod + if isinstance(output, (list, tuple)): tmp = [] for x in list(output): diff --git a/src/composite/composite.cc b/src/composite/composite.cc index d3e6e9813b0db5e3e289072a706cc0a27be183e3..8be07afe9c56d66c9173901e563e496e0a69f6e8 100644 --- a/src/composite/composite.cc +++ b/src/composite/composite.cc @@ -459,7 +459,44 @@ NodeRef composite_with_json_to_func(const std::string &json_str, Map &outputs) { + for (const Tensor &t : outputs) { + if (t->op->tag == "comm_reduce" || t->op->tag == "comm_reduce_idx") { + return "reduce"; + } + } + return "injective"; +} + +Module composite_with_json_gpu(const std::string &json_str, Map attrs) { + picojson::value v; + std::string err = picojson::parse(v, json_str); + if (!err.empty()) { + LOG(ERROR) << "json parse error, error message: " << err; + } + Array tensors; + Array args; + Map in_binds; + std::string kernel_name; + extract_op_info(v, &tensors, &args, &kernel_name, &in_binds); + const auto* build_func = air::runtime::Registry::Get("akg_build_gpu_module"); + CHECK(build_func != nullptr); + std::string sch = get_schedule(tensors); + return (*build_func)(tensors, args, sch, kernel_name); +} + Module composite_with_json(const std::string &json_str, Map attrs) { + if (get_process(json_str) == "gpu") { + return composite_with_json_gpu(json_str, attrs); + } auto build_rst = composite_with_json_to_func(json_str, attrs); return BuildToModule(build_rst); } diff --git a/tests/operators/gpu/test_ms_mul.py b/tests/operators/gpu/test_ms_mul.py new file mode 100644 index 0000000000000000000000000000000000000000..eb58fc862142feacb7387e220e2260b3be695d9e --- /dev/null +++ b/tests/operators/gpu/test_ms_mul.py @@ -0,0 +1,34 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +import numpy as np +from akg.ms.gpu import Mul +from gen_random import random_gaussian +from akg.utils import kernel_exec as utils + +def gen_data(shape, dtype): + support_list = {"float16": np.float16, "float32": np.float32} + lhd = random_gaussian(shape, miu=1, sigma=0.1).astype(support_list[dtype]) + rhd = random_gaussian(shape, miu=1, sigma=0.1).astype(support_list[dtype]) + expect = np.multiply(lhd, rhd) + output = np.full(shape, np.nan, dtype) + return lhd, rhd, output, expect + +def test_ms_mul(shape, dtype): + mod = utils.op_build(Mul, (shape, shape), (dtype, dtype)) + lhd, rhd, output, expect = gen_data(shape, dtype) + output = utils.mod_launch(mod, (lhd, rhd, output), expect = expect) + np.allclose(output, expect, rtol=5e-03, atol=1.e-8) + +if __name__ == '__main__': + test_ms_mul((1024, 4096), 'float32')