#!/usr/bin/env python3 # This file is part of MegBrain. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. import argparse import os import re import subprocess import tempfile from functools import partial from multiprocessing import Manager from tqdm.contrib.concurrent import process_map # change workspace to MegBrain root dir os.chdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) failed_files = Manager().list() def process_file(file, clang_format, write): source = open(file, "r").read() source = re.sub(r"MGB_DEFINE(?P(.|\n)*?)// +{", "class MGB_DEFINE\g{", source) result = subprocess.check_output( [ clang_format, "-style=file", "-verbose", "-assume-filename={}".format(file), # file, ], input=bytes(source.encode("utf-8")), ) result = result.decode("utf-8") result = re.sub(r"class MGB_DEFINE((.|\n)*?){", r"MGB_DEFINE\1// {", result) if write: with tempfile.NamedTemporaryFile( dir=os.path.dirname(file), delete=False ) as tmp_file: tmp_file.write(result.encode("utf-8")) os.rename(tmp_file.name, file) else: ret_code = subprocess.run( ["diff", "--color=always", file, "-"], input=bytes(result.encode("utf-8")), ).returncode # man diff: 0 for same, 1 for different, 2 if trouble. if ret_code == 2: raise RuntimeError("format process (without overwrite) failed") if ret_code != 0: print(file) global failed_files failed_files.append(file) def main(): parser = argparse.ArgumentParser( description="Format source files using clang-format, eg: `./tools/format.py src -w`. \ Require clang-format version == 12.0", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "path", nargs="+", help="file name or path based on MegBrain root dir." ) parser.add_argument( "-w", "--write", action="store_true", help="use formatted file to replace original file.", ) parser.add_argument( "--clang-format", default=os.getenv("CLANG_FORMAT", "clang-format"), help="clang-format executable name; it can also be " "modified via the CLANG_FORMAT environment var", ) args = parser.parse_args() format_type = [".cpp", ".c", ".h", ".cu", ".cuh", ".inl"] def getfiles(path): rst = [] for p in os.listdir(path): p = os.path.join(path, p) if os.path.isdir(p): rst += getfiles(p) elif ( os.path.isfile(p) and not os.path.islink(p) and os.path.splitext(p)[1] in format_type ): rst.append(p) return rst files = [] for path in args.path: if os.path.isdir(path): files += getfiles(path) elif os.path.isfile(path): files.append(path) else: raise ValueError("Invalid path {}".format(path)) # check version, we only support 12.0.1 now version = subprocess.check_output( [ args.clang_format, "--version", ], ) version = version.decode("utf-8") need_version = '12.0.1' if version.find(need_version) < 0: print('We only support {} now, please install {} version, find version: {}' .format(need_version, need_version, version)) raise RuntimeError('clang-format version not equal {}'.format(need_version)) process_map( partial(process_file, clang_format=args.clang_format, write=args.write,), files, chunksize=10, ) if failed_files: raise RuntimeError("above files are not properly formatted!") if __name__ == "__main__": main()