提交 4712acdc 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!74 enable gpu for API: compilewithjson

Merge pull request !74 from lingyunli63/support_gpu_ops
......@@ -20,7 +20,6 @@ from .equal import gpu_schedule_Equal
from .tile import Tile
from .tile import gpu_schedule_Tile
from .cast import Cast
from .cast import gpu_schedule_Cast
from .relu6 import ReLU6, gpu_schedule_ReLU6
from .relu6_grad import ReLU6Grad, gpu_schedule_ReLU6Grad
from .squeeze import Squeeze, gpu_schedule_Squeeze
......
......@@ -19,27 +19,9 @@ import logging
import akg.tvm
from akg.ops.math import cast
from akg.topi.generic import schedule_elemwise
import akg.topi as topi
@akg.schedule(topi.cuda.schedule_injective)
def Cast(x, dst_type):
"""cast."""
return cast.cast(x, dst_type)
def gpu_schedule_Cast(outs):
"""
gpu schedule for cast.
Args:
outs (tvm.tensor.Tensor): outputs of compute.
Returns:
sch (schedule.Schedule): The created schedule.
"""
device = 'cuda'
ctx = akg.tvm.context(device, 0)
if not ctx.exist:
logging.info("Skip because %s is not enabled", device)
return None
with akg.tvm.target.create(device):
sch = schedule_elemwise(outs)
return sch
......@@ -29,8 +29,8 @@ from akg.utils import validation_check as vc_util
from akg import composite
from akg.tvm import _api_internal
from . import cce
from . import op_build_to_func
from . import gpu
from . import op_build
@vc_util.check_input_type(str)
def compilewithjson_to_func(json_str):
......@@ -68,6 +68,17 @@ def compilewithjson_to_func(json_str):
if op_func is None:
if processor == 'cuda':
op_func = getattr(gpu, op_name, None)
input_shapes = []
input_types = []
for input_desc in kernel_info['input_desc']:
input_shapes.append(input_desc[0]['shape'])
input_types.append(input_desc[0]['data_type'])
op_attrs = []
if kernel_info['attr']:
for ext_arg in kernel_info['attr']:
op_attrs.append(ext_arg['value'])
mod = utils.op_build(op_func, input_shapes, input_types, op_attrs, kernel_info['op'])
return True
else:
op_func = getattr(cce, op_name, None)
......@@ -121,7 +132,7 @@ def compilewithjson_to_func(json_str):
output = [output]
tsr = tsr + [i for i in output if utils.TensorUtils.is_output_value(i)]
return op_build_to_func([op_name], output, tsr, schedule_func, processor, kernel_info['op'], attrs)
return op_build([op_name], output, tsr, schedule_func, processor, kernel_info['op'], attrs)
def compilewithjson(json_str):
tmp_rst = compilewithjson_to_func(json_str)
......
......@@ -33,7 +33,6 @@ BINDS = "binds"
MS_AKG_DUMP_IR = "MS_AKG_DUMP_IR"
MS_AKG_DUMP_CCE = "MS_AKG_DUMP_CCE"
MS_DAVINCI_KERNEL_PATH = "./kernel_meta/"
MS_CUDA_KERNEL_PATH = "./cuda_meta/"
@vc_util.check_input_type(list, (list, tuple), (list, tuple), (types.FunctionType, type(None)), str, str, dict)
......@@ -72,10 +71,11 @@ def op_build(opnames, computes, args, custom_schedule, device, kernel_name, attr
"""op_build"""
if device in ("aicore", "aicpu"):
tmp_rst = op_build_to_func(opnames, computes, args, custom_schedule, device, kernel_name, attrs)
return _api_internal._BuildToModule(tmp_rst)
return tmp_rst
if device == "cuda":
cuda_path = os.path.realpath(MS_CUDA_KERNEL_PATH)
kernel_meta_path = "./cuda_meta_" + str(os.getpid()) + "/"
cuda_path = os.path.realpath(kernel_meta_path)
if not os.path.isdir(cuda_path):
os.makedirs(cuda_path)
if not opnames:
......@@ -88,7 +88,7 @@ def op_build(opnames, computes, args, custom_schedule, device, kernel_name, attr
logging.error("no schedule func found %s", str(schedule_name))
return None
ptx_file = os.path.realpath(MS_CUDA_KERNEL_PATH + kernel_name + ".ptx")
ptx_file = os.path.realpath(kernel_meta_path + kernel_name + ".ptx")
if os.path.exists(ptx_file):
os.remove(ptx_file)
try:
......@@ -100,7 +100,7 @@ def op_build(opnames, computes, args, custom_schedule, device, kernel_name, attr
foo = akg.tvm.build(s, args, device, name=kernel_name)
ptx_code = foo.imported_modules[0].get_source("ptx")
file.write(ptx_code)
json_file = os.path.realpath(MS_CUDA_KERNEL_PATH + kernel_name + ".json")
json_file = os.path.realpath(kernel_meta_path + kernel_name + ".json")
kernel_info = (ptx_code, json_file, kernel_name)
gpu_utils.save_gpu_params(s, args, kernel_info)
os.chmod(ptx_file, 0o400)
......
......@@ -85,7 +85,7 @@ def save_gpu_params(s, args, kernel_info):
fo.write("}\n")
def dump(mod, kernel_name, sch, args):
meta_path = "./cuda_meta/"
meta_path = "./cuda_meta_/" + str(os.getpid()) + "/"
cuda_path = os.path.realpath(meta_path)
if not os.path.isdir(cuda_path):
os.makedirs(cuda_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册