提交 2c4ee992 编写于 作者: M Megvii Engine Team

fix(dnn): short cutlass filename in windows

GitOrigin-RevId: 83a43fdf87b1fd3480a5d59584bf60e3cd5dd5af
上级 b17b56f3
......@@ -4,12 +4,12 @@ genrule(
name = "cutlass_kimpls",
outs = cutlass_gen_list,
cmd = """GEN=$(location //brain/megbrain/dnn/scripts/cutlass_generator:generator.py)
python3 $$GEN --operations gemm --type simt $(@D)
python3 $$GEN --operations gemv --type simt $(@D)
python3 $$GEN --operations deconv --type simt $(@D)
python3 $$GEN --operations conv2d --type simt $(@D)
python3 $$GEN --operations conv2d --type tensorop8816 $(@D)
python3 $$GEN --operations conv2d --type tensorop8832 $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type simt $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemv --type simt $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations deconv --type simt $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type simt $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type tensorop8816 $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type tensorop8832 $(@D)
""",
tools = ["//brain/megbrain/dnn/scripts/cutlass_generator:generator.py"],
visibility = ["//visibility:public"],
......
......@@ -531,9 +531,10 @@ void initialize_${configuration_name}(Manifest &manifest) {
###################################################################################################
class EmitConvSingleKernelWrapper():
def __init__(self, kernel_path, operation):
def __init__(self, kernel_path, operation, short_path=False):
self.kernel_path = kernel_path
self.operation = operation
self.short_path = short_path
if self.operation.conv_kind == ConvKind.Fprop:
self.instance_emitter = EmitConv2dInstance()
......@@ -582,7 +583,11 @@ void initialize_${operation_name}(Manifest &manifest) {
#
def __enter__(self):
self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name())
if self.short_path:
self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % GlobalCnt.cnt)
GlobalCnt.cnt += 1
else:
self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name())
self.kernel_file = open(self.kernel_path, "w")
self.kernel_file.write(self.header_template)
return self
......
......@@ -994,7 +994,8 @@ void initialize_${configuration_name}(Manifest &manifest) {
###################################################################################################
class EmitGemmSingleKernelWrapper:
def __init__(self, kernel_path, gemm_operation):
def __init__(self, kernel_path, gemm_operation, short_path=False):
self.short_path = short_path
self.kernel_path = kernel_path
self.operation = gemm_operation
......@@ -1070,10 +1071,11 @@ void initialize_${operation_name}(Manifest &manifest) {
###################################################################################################
class EmitGemvSingleKernelWrapper:
def __init__(self, kernel_path, gemm_operation, wrapper_path):
def __init__(self, kernel_path, gemm_operation, wrapper_path, short_path=False):
self.kernel_path = kernel_path
self.wrapper_path = wrapper_path
self.operation = gemm_operation
self.short_path = short_path
self.wrapper_template = """
template void megdnn::cuda::cutlass_wrapper::
......@@ -1107,7 +1109,11 @@ ${operation_instance}
"""
#
def __enter__(self):
self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name())
if self.short_path:
self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % GlobalCnt.cnt)
GlobalCnt.cnt += 1
else:
self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name())
self.kernel_file = open(self.kernel_path, "w")
self.kernel_file.write(SubstituteTemplate(self.header_template, {
'wrapper_path': self.wrapper_path,
......
......@@ -8,6 +8,7 @@ import enum
import os.path
import shutil
import argparse
import platform
from library import *
from manifest import *
......@@ -634,7 +635,7 @@ if __name__ == "__main__":
default='simt', help="kernel type of CUTLASS kernel generator")
gemv_wrapper_path = "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
short_path = (platform.system() == "Windows" or platform.system().find('NT') >= 0) and ('true'!= os.getenv("CUTLASS_WITH_LONG_PATH", default='False').lower())
args = parser.parse_args()
if args.operations == "gemm":
......@@ -648,15 +649,15 @@ if __name__ == "__main__":
if args.operations == "conv2d" or args.operations == "deconv":
for operation in operations:
with EmitConvSingleKernelWrapper(args.output, operation) as emitter:
with EmitConvSingleKernelWrapper(args.output, operation, short_path) as emitter:
emitter.emit()
elif args.operations == "gemm":
for operation in operations:
with EmitGemmSingleKernelWrapper(args.output, operation) as emitter:
with EmitGemmSingleKernelWrapper(args.output, operation, short_path) as emitter:
emitter.emit()
elif args.operations == "gemv":
for operation in operations:
with EmitGemvSingleKernelWrapper(args.output, operation, gemv_wrapper_path) as emitter:
with EmitGemvSingleKernelWrapper(args.output, operation, gemv_wrapper_path, short_path) as emitter:
emitter.emit()
if args.operations != "gemv":
......
......@@ -612,3 +612,6 @@ class TensorDescription:
self.complex_transform = complex_transform
###################################################################################################
class GlobalCnt:
cnt = 0
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册