提交 2c4ee992 编写于 作者: M Megvii Engine Team

fix(dnn): short cutlass filename in windows

GitOrigin-RevId: 83a43fdf87b1fd3480a5d59584bf60e3cd5dd5af
上级 b17b56f3
...@@ -4,12 +4,12 @@ genrule( ...@@ -4,12 +4,12 @@ genrule(
name = "cutlass_kimpls", name = "cutlass_kimpls",
outs = cutlass_gen_list, outs = cutlass_gen_list,
cmd = """GEN=$(location //brain/megbrain/dnn/scripts/cutlass_generator:generator.py) cmd = """GEN=$(location //brain/megbrain/dnn/scripts/cutlass_generator:generator.py)
python3 $$GEN --operations gemm --type simt $(@D) CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type simt $(@D)
python3 $$GEN --operations gemv --type simt $(@D) CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemv --type simt $(@D)
python3 $$GEN --operations deconv --type simt $(@D) CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations deconv --type simt $(@D)
python3 $$GEN --operations conv2d --type simt $(@D) CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type simt $(@D)
python3 $$GEN --operations conv2d --type tensorop8816 $(@D) CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type tensorop8816 $(@D)
python3 $$GEN --operations conv2d --type tensorop8832 $(@D) CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type tensorop8832 $(@D)
""", """,
tools = ["//brain/megbrain/dnn/scripts/cutlass_generator:generator.py"], tools = ["//brain/megbrain/dnn/scripts/cutlass_generator:generator.py"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
......
...@@ -531,9 +531,10 @@ void initialize_${configuration_name}(Manifest &manifest) { ...@@ -531,9 +531,10 @@ void initialize_${configuration_name}(Manifest &manifest) {
################################################################################################### ###################################################################################################
class EmitConvSingleKernelWrapper(): class EmitConvSingleKernelWrapper():
def __init__(self, kernel_path, operation): def __init__(self, kernel_path, operation, short_path=False):
self.kernel_path = kernel_path self.kernel_path = kernel_path
self.operation = operation self.operation = operation
self.short_path = short_path
if self.operation.conv_kind == ConvKind.Fprop: if self.operation.conv_kind == ConvKind.Fprop:
self.instance_emitter = EmitConv2dInstance() self.instance_emitter = EmitConv2dInstance()
...@@ -582,6 +583,10 @@ void initialize_${operation_name}(Manifest &manifest) { ...@@ -582,6 +583,10 @@ void initialize_${operation_name}(Manifest &manifest) {
# #
def __enter__(self): def __enter__(self):
if self.short_path:
self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % GlobalCnt.cnt)
GlobalCnt.cnt += 1
else:
self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name()) self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name())
self.kernel_file = open(self.kernel_path, "w") self.kernel_file = open(self.kernel_path, "w")
self.kernel_file.write(self.header_template) self.kernel_file.write(self.header_template)
......
...@@ -994,7 +994,8 @@ void initialize_${configuration_name}(Manifest &manifest) { ...@@ -994,7 +994,8 @@ void initialize_${configuration_name}(Manifest &manifest) {
################################################################################################### ###################################################################################################
class EmitGemmSingleKernelWrapper: class EmitGemmSingleKernelWrapper:
def __init__(self, kernel_path, gemm_operation): def __init__(self, kernel_path, gemm_operation, short_path=False):
self.short_path = short_path
self.kernel_path = kernel_path self.kernel_path = kernel_path
self.operation = gemm_operation self.operation = gemm_operation
...@@ -1070,10 +1071,11 @@ void initialize_${operation_name}(Manifest &manifest) { ...@@ -1070,10 +1071,11 @@ void initialize_${operation_name}(Manifest &manifest) {
################################################################################################### ###################################################################################################
class EmitGemvSingleKernelWrapper: class EmitGemvSingleKernelWrapper:
def __init__(self, kernel_path, gemm_operation, wrapper_path): def __init__(self, kernel_path, gemm_operation, wrapper_path, short_path=False):
self.kernel_path = kernel_path self.kernel_path = kernel_path
self.wrapper_path = wrapper_path self.wrapper_path = wrapper_path
self.operation = gemm_operation self.operation = gemm_operation
self.short_path = short_path
self.wrapper_template = """ self.wrapper_template = """
template void megdnn::cuda::cutlass_wrapper:: template void megdnn::cuda::cutlass_wrapper::
...@@ -1107,6 +1109,10 @@ ${operation_instance} ...@@ -1107,6 +1109,10 @@ ${operation_instance}
""" """
# #
def __enter__(self): def __enter__(self):
if self.short_path:
self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % GlobalCnt.cnt)
GlobalCnt.cnt += 1
else:
self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name()) self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name())
self.kernel_file = open(self.kernel_path, "w") self.kernel_file = open(self.kernel_path, "w")
self.kernel_file.write(SubstituteTemplate(self.header_template, { self.kernel_file.write(SubstituteTemplate(self.header_template, {
......
...@@ -8,6 +8,7 @@ import enum ...@@ -8,6 +8,7 @@ import enum
import os.path import os.path
import shutil import shutil
import argparse import argparse
import platform
from library import * from library import *
from manifest import * from manifest import *
...@@ -634,7 +635,7 @@ if __name__ == "__main__": ...@@ -634,7 +635,7 @@ if __name__ == "__main__":
default='simt', help="kernel type of CUTLASS kernel generator") default='simt', help="kernel type of CUTLASS kernel generator")
gemv_wrapper_path = "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl" gemv_wrapper_path = "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
short_path = (platform.system() == "Windows" or platform.system().find('NT') >= 0) and ('true'!= os.getenv("CUTLASS_WITH_LONG_PATH", default='False').lower())
args = parser.parse_args() args = parser.parse_args()
if args.operations == "gemm": if args.operations == "gemm":
...@@ -648,15 +649,15 @@ if __name__ == "__main__": ...@@ -648,15 +649,15 @@ if __name__ == "__main__":
if args.operations == "conv2d" or args.operations == "deconv": if args.operations == "conv2d" or args.operations == "deconv":
for operation in operations: for operation in operations:
with EmitConvSingleKernelWrapper(args.output, operation) as emitter: with EmitConvSingleKernelWrapper(args.output, operation, short_path) as emitter:
emitter.emit() emitter.emit()
elif args.operations == "gemm": elif args.operations == "gemm":
for operation in operations: for operation in operations:
with EmitGemmSingleKernelWrapper(args.output, operation) as emitter: with EmitGemmSingleKernelWrapper(args.output, operation, short_path) as emitter:
emitter.emit() emitter.emit()
elif args.operations == "gemv": elif args.operations == "gemv":
for operation in operations: for operation in operations:
with EmitGemvSingleKernelWrapper(args.output, operation, gemv_wrapper_path) as emitter: with EmitGemvSingleKernelWrapper(args.output, operation, gemv_wrapper_path, short_path) as emitter:
emitter.emit() emitter.emit()
if args.operations != "gemv": if args.operations != "gemv":
......
...@@ -612,3 +612,6 @@ class TensorDescription: ...@@ -612,3 +612,6 @@ class TensorDescription:
self.complex_transform = complex_transform self.complex_transform = complex_transform
################################################################################################### ###################################################################################################
class GlobalCnt:
cnt = 0
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册