提交 d237aa7d 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!111 dump cuda_meta inside akg.build instead of after akg.build

Merge pull request !111 from looop5/dump_cuda_meta
......@@ -81,5 +81,6 @@ from .autodiff import get_variables
from .autodiff import register_variables
from .lang.cce.te_compute.common import fargmax, fargmin, mad
from . import lang
from .utils.dump_cuda_meta import dump_cuda_meta
__all__ = ["differentiate"]
......@@ -22,7 +22,6 @@ from akg import tvm
from akg.tvm import _api_internal
from .repository import __all__ as repository
import topi
from akg.utils import dump_cuda_meta
def generate_trait(desc):
""" generate trait of kernel description """
......@@ -181,5 +180,4 @@ def build_cuda(outputs, args, sch_name, kernel_name):
dump_ir = os.getenv('MS_AKG_DUMP_IR') == "on"
with tvm.build_config(dump_pass_ir = dump_ir):
mod = akg.build(s, list(args), "cuda", name = kernel_name)
dump_cuda_meta.dump(mod, kernel_name, s, list(args))
return mod
......@@ -20,93 +20,66 @@ import fcntl
import hashlib
import akg.tvm
def get_dim(dim, axis=True):
"""get dim info"""
dims_str = {
"grid_dim0": "// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = ",
"grid_dim1": "// attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = ",
"grid_dim2": "// attr [iter_var(blockIdx.z, , blockIdx.z)] thread_extent = ",
"block_dim0": "// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = ",
"block_dim1": "// attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = ",
"block_dim2": "// attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = "
}
dim_to_axis = {
"grid_dim0": '"blockIdx.x" : ',
"grid_dim1": '"blockIdx.y" : ',
"grid_dim2": '"blockIdx.z" : ',
"block_dim0": '"threadIdx.x" : ',
"block_dim1": '"threadIdx.y" : ',
"block_dim2": '"threadIdx.z" : '
}
if axis:
return dim_to_axis.get(dim)
return dims_str.get(dim)
def parse_params(file, dim, ir):
"""parse parameters"""
dim_str = get_dim(dim, axis=False)
pos = ir.find(dim_str)
if pos != -1:
index = pos + len(dim_str)
param_temp = get_dim(dim)
while ir[index].isdigit():
param_temp += ir[index]
index += 1
file.write(param_temp + ",\n")
else:
param_temp = get_dim(dim) + '1'
file.write(param_temp + ",\n")
def save_gpu_params(s, args, kernel_info):
"""save gpu parameters"""
ptx_code = kernel_info[0]
file_name = kernel_info[1]
kernel_name = kernel_info[2]
@akg.tvm.register_func
def dump_cuda_meta(code, ptx, thread_info):
"""
Function for dumping cuda meta.
Args:
code: gpu code.
ptx: ptx code.
thread_info: thread info, written to json file.
"""
# kernel name
kernel_name = code.split("_kernel")[0].split(" ")[-1]
dump_ir = os.getenv('MS_AKG_DUMP_IR') == "on"
if dump_ir:
schedule_path = os.path.realpath(kernel_name)
all_passes = os.listdir(schedule_path)
for cur_pass in all_passes:
if cur_pass.startswith("00_"):
with open(schedule_path + '/' + cur_pass, "r") as file:
ir = file.read()
break
else:
ir = str(akg.tvm.lower(s, args, simple_mode=True))
file_path = os.path.realpath(file_name)
if os.path.exists(file_path):
os.remove(file_path)
# sha256 of ptx
sha256 = hashlib.sha256()
sha256.update(ptx_code.encode("utf-8"))
sha256.update(ptx.encode("utf-8"))
hash_str = sha256.hexdigest()
with os.fdopen(os.open(file_path, os.O_WRONLY | os.O_CREAT, 0o400), 'w') as fo:
fo.write("{\n")
fo.write('"kernelName" : ' + '"' + kernel_name + "_kernel0" + '",\n')
parse_params(fo, "grid_dim0", ir)
parse_params(fo, "grid_dim1", ir)
parse_params(fo, "grid_dim2", ir)
parse_params(fo, "block_dim0", ir)
parse_params(fo, "block_dim1", ir)
parse_params(fo, "block_dim2", ir)
fo.write('"sha256" : ' + '"' + hash_str + '"\n')
fo.write("}\n")
def dump(mod, kernel_name, sch, args):
# thread info
thread_info_dict = {
"blockIdx.x": "1",
"blockIdx.y": "1",
"blockIdx.z": "1",
"threadIdx.x": "1",
"threadIdx.y": "1",
"threadIdx.z": "1"
}
for thread_tag in thread_info_dict.keys():
if thread_tag in thread_info:
if isinstance(thread_info[thread_tag], int):
thread_info_dict[thread_tag] = str(thread_info[thread_tag])
elif isinstance(thread_info[thread_tag], akg.tvm.expr.IntImm):
thread_info_dict[thread_tag] = str(thread_info[thread_tag].value)
meta_path = "./cuda_meta_" + str(os.getpid()) + "/"
cuda_path = os.path.realpath(meta_path)
if not os.path.isdir(cuda_path):
os.makedirs(cuda_path)
# save ptx file to cuda meta
ptx_file = os.path.realpath(meta_path + kernel_name + ".ptx")
with open(ptx_file, "at") as f:
fcntl.flock(f.fileno(), fcntl.LOCK_EX)
f.seek(0, 2)
if f.tell() == 0:
ptx_code = mod.imported_modules[0].get_source('ptx')
f.write(ptx_code)
param_path = os.path.realpath(meta_path + kernel_name + '.json')
save_gpu_params(sch, args, (ptx_code, param_path, kernel_name))
f.write(ptx)
# save json file to cuda meta
json_file = os.path.realpath(meta_path + kernel_name + ".json")
if os.path.exists(json_file):
os.remove(json_file)
with os.fdopen(os.open(json_file, os.O_WRONLY | os.O_CREAT, 0o400), 'w') as fo:
fo.write("{\n")
fo.write('"kernelName" : ' + '"' + kernel_name + "_kernel0" + '",\n')
fo.write('"blockIdx.x" : ' + thread_info_dict["blockIdx.x"] + ',\n')
fo.write('"blockIdx.y" : ' + thread_info_dict["blockIdx.y"] + ',\n')
fo.write('"blockIdx.z" : ' + thread_info_dict["blockIdx.z"] + ',\n')
fo.write('"threadIdx.x" : ' + thread_info_dict["threadIdx.x"] + ',\n')
fo.write('"threadIdx.y" : ' + thread_info_dict["threadIdx.y"] + ',\n')
fo.write('"threadIdx.z" : ' + thread_info_dict["threadIdx.z"] + ',\n')
fo.write('"sha256" : ' + '"' + hash_str + '"\n')
fo.write("}\n")
......@@ -42,7 +42,6 @@ from akg.utils import format_transform as ft_util
from akg.utils import custom_tiling as ct_util
from akg.utils import validation_check as vc_util
from akg.utils.dsl_create import TensorUtils
from akg.utils import dump_cuda_meta
sh = logging.StreamHandler(sys.stdout)
logging.getLogger().addHandler(sh)
......@@ -746,7 +745,6 @@ def op_build(op_func, input_shapes, input_types, op_attrs=None, kernel_name="",
with akg.tvm.build_config(dump_pass_ir=dump_ir):
mod = akg.build(s, op_var, "cuda", shape_var, name=kernel_name, attrs=attrs,
polyhedral=polyhedral, binds=binds)
dump_cuda_meta.dump(mod, kernel_name, s, op_var)
if dump_code:
source_code = mod.imported_modules[0].get_source()
create_code(kernel_name, "./", source_code, "CUDA")
......
......@@ -23,6 +23,12 @@
*
* \file build_cuda.cc
*/
/*
* 2020.8.14 - Get thread info inside BuildCUDA function,
* enbale dump cuda meta.
*/
#if defined(__linux__)
#include <sys/stat.h>
#endif
......@@ -133,8 +139,18 @@ runtime::Module BuildCUDA(Array<LoweredFunc> funcs) {
CodeGenCUDA cg;
cg.Init(output_ssa);
Map<std::string, Expr> thread_info;
for (LoweredFunc f : funcs) {
cg.AddFunction(f);
for (const auto &axis : f->thread_axis) {
auto thread_tag = axis->thread_tag;
auto node = axis->dom.get();
if (node != nullptr) {
CHECK(axis->dom->extent.as<IntImm>());
thread_info.Set(thread_tag, axis->dom->extent);
}
}
}
std::string code = cg.Finish();
......@@ -151,6 +167,11 @@ runtime::Module BuildCUDA(Array<LoweredFunc> funcs) {
} else {
ptx = NVRTCCompile(code, cg.need_include_path());
}
if (const auto* f = Registry::Get("dump_cuda_meta")) {
(*f)(code, ptx, thread_info);
}
return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(funcs), code);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册