gen_list.py 1.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
from generator import (
    GenerateGemmOperations,
    GenerateGemvOperations,
    GenerateConv2dOperations,
    GenerateDeconvOperations,
)


class GenArg:
    def __init__(self, gen_op, gen_type):
        self.operations = gen_op
        self.type = gen_type


def write_op_list(f, gen_op, gen_type):
    if gen_op == "gemm":
        operations = GenerateGemmOperations(GenArg(gen_op, gen_type))
    elif gen_op == "gemv":
        operations = GenerateGemvOperations(GenArg(gen_op, gen_type))
    elif gen_op == "conv2d":
        operations = GenerateConv2dOperations(GenArg(gen_op, gen_type))
    elif gen_op == "deconv":
        operations = GenerateDeconvOperations(GenArg(gen_op, gen_type))
    for op in operations:
        f.write('    "%s.cu",\n' % op.procedural_name())
26 27
    if gen_op != "gemv":
        f.write('    "all_%s_%s_operations.cu",\n' % (gen_op, gen_type))
28 29 30 31 32 33 34


if __name__ == "__main__":
    with open("list.bzl", "w") as f:
        f.write("# Generated by dnn/scripts/cutlass_generator/gen_list.py\n\n")
        f.write("cutlass_gen_list = [\n")
        write_op_list(f, "gemm", "simt")
35 36
        write_op_list(f, "gemm", "tensorop1688")
        write_op_list(f, "gemm", "tensorop884")
37 38
        write_op_list(f, "gemv", "simt")
        write_op_list(f, "deconv", "simt")
39
        write_op_list(f, "deconv", "tensorop8816")
40 41 42 43
        write_op_list(f, "conv2d", "simt")
        write_op_list(f, "conv2d", "tensorop8816")
        write_op_list(f, "conv2d", "tensorop8832")
        f.write("]")