diff --git a/tools/infrt/get_compat_kernel_signature.py b/tools/infrt/get_compat_kernel_signature.py index 45dc931fac19d4f64237742288816f2a87080829..a66a236b0f9759759f83aa4419cfb9cdcf9b3712 100644 --- a/tools/infrt/get_compat_kernel_signature.py +++ b/tools/infrt/get_compat_kernel_signature.py @@ -19,6 +19,13 @@ import json skip_list = ["adam_sig.cc", "adamw_sig.cc"] +def is_grad_kernel(kernel_info): + kernel_name = kernel_info.split(",")[0] + if kernel_name.endswith("_grad"): + return True + return False + + def parse_compat_registry(kernel_info): name, inputs_str, attrs_str, outputs_str = kernel_info.split(",{") kernel_info = {} @@ -62,6 +69,8 @@ def get_compat_kernels_info(): "").strip("return").strip("KernelSignature(").strip( "\);").replace("\"", "").replace("\\", "") registry = False + if is_grad_kernel(data): + continue name, registry_info = parse_compat_registry(data) if name in kernels_info: