未验证 提交 0bcbfe83 编写于 作者: M ming1753 提交者: GitHub

add odd rules for getting kernels (#55178)

上级 463a4f25
......@@ -55,10 +55,28 @@ def get_model_ops(model_file, ops_set):
def get_model_phi_kernels(ops_set):
phi_set = set()
phi_raw_list = [
"add",
"subtract",
"multiply",
"multiply_sr",
"divide",
"maximum",
"minimum",
"remainder",
"floor_divide",
"elementwise_pow",
]
phi_odd_dist = {"batch_norm": "batch_norm_infer"}
for op in ops_set:
print(op)
print(_get_phi_kernel_name(op))
phi_set.add(_get_phi_kernel_name(op))
phi_kernel = _get_phi_kernel_name(op)
print(phi_kernel)
phi_set.add(phi_kernel)
if phi_kernel in phi_raw_list:
phi_set.add(phi_kernel + "_raw")
if phi_kernel in phi_odd_dist.keys():
phi_set.add(phi_odd_dist[phi_kernel])
return phi_set
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册