diff --git a/tools/gen_header_for_bin_reduce.py b/tools/gen_header_for_bin_reduce.py index d084f87e0b89cefca083ef3cf6e1ffa92751a01d..39fdf3005c2d889fa01e6d73cef5e3b4e5f92d1d 100755 --- a/tools/gen_header_for_bin_reduce.py +++ b/tools/gen_header_for_bin_reduce.py @@ -185,6 +185,18 @@ class HeaderGen: ).stdout.decode('utf-8') self._fout.write('// midout \n') self._fout.write(cvt) + if cvt.find(" half,"): + change = open(self._fout.name).read().replace(" half,", " __fp16,") + with open("fix_fp16_bin_reduce.h", "w") as fix_fp16: + fix_fp16.write(change) + msg = ( + "WARNING:\n" + "hit half in trace, try use fix_fp16_bin_reduce.h when build failed with bin_reduce.h\n" + "which caused by LLVM mangle issue on __fp16 dtype, if you find msg 'error: use of undeclared identifier 'half'\n" + "then try use fix_fp16_bin_reduce.h, if build failed again, submit a issue to Engine team!!!" + ) + print(msg) + def main(): parser = argparse.ArgumentParser(