未验证 提交 f9882854 编写于 作者: Z zyfncg 提交者: GitHub

Adjust the rule of configure in api_compat.yaml (#45672)

* set use_cudnn=true for conv2d

* refine the config rule of api_compat
上级 e2823c8c
......@@ -123,9 +123,8 @@
str fuse_activation = "", float fuse_alpha = 0.0f, float fuse_beta = 0.0f,
int workspace_size_MB = platform::GetDefaultConvWorkspaceSizeLimitMB()]
- api : diag
op_name : diag_v2
grad_op_name : diag_v2_grad
- api : diag (diag_v2)
backward : diag_grad (diag_v2_grad)
inputs :
x : X
outputs :
......
......@@ -56,18 +56,29 @@ def restruct_io(api):
# replace name of op and params for OpMaker
def replace_compat_name(api_op_map, forward_api_dict, backward_api_dict):
def get_api_and_op_name(api_item):
names = api_item.split('(')
if len(names) == 1:
return names[0].strip(), names[0].strip()
else:
return names[0].strip(), names[1].split(')')[0].strip()
for api_args in api_op_map:
if api_args['api'] not in forward_api_dict:
api_name, op_name = get_api_and_op_name(api_args['api'])
if api_name not in forward_api_dict:
continue
forward_api_item = forward_api_dict[api_args['api']]
forward_api_item = forward_api_dict[api_name]
has_backward = True if forward_api_item['backward'] else False
if has_backward:
backward_api_item = backward_api_dict[forward_api_item['backward']]
if 'op_name' in api_args:
forward_api_item['op_name'] = api_args['op_name']
if 'grad_op_name' in api_args and has_backward:
forward_api_item['backward'] = api_args['grad_op_name']
backward_api_item['op_name'] = api_args['grad_op_name']
if api_name != op_name:
forward_api_item['op_name'] = op_name
if 'backward' in api_args and has_backward:
bw_api_name, bw_op_name = get_api_and_op_name(
api_args['backward'].split(',')[0])
forward_api_item['backward'] = bw_op_name
backward_api_item['op_name'] = bw_op_name
key_set = ['inputs', 'attrs', 'outputs']
args_map = {}
......
......@@ -70,6 +70,13 @@ def generate_extra_info(api_compat_yaml_path, ops_extra_info_path):
with open(api_compat_yaml_path, 'rt') as f:
compat_apis = yaml.safe_load(f)
def get_op_name(api_item):
names = api_item.split('(')
if len(names) == 1:
return names[0].strip()
else:
return names[1].split(')')[0].strip()
extra_map_str_list = []
extra_checker_str_list = []
......@@ -96,18 +103,19 @@ def generate_extra_info(api_compat_yaml_path, ops_extra_info_path):
api_extra_attr_checkers = ",\n ".join(
attr_checker_func_list)
extra_map_str_list.append(
f"{{\"{api_compat_args['api']}\", {{ {api_extra_attr_map} }}}}"
f"{{\"{get_op_name(api_compat_args['api'])}\", {{ {api_extra_attr_map} }}}}"
)
extra_checker_str_list.append(
f"{{\"{api_compat_args['api']}\", {{ {api_extra_attr_checkers} }}}}"
f"{{\"{get_op_name(api_compat_args['api'])}\", {{ {api_extra_attr_checkers} }}}}"
)
if 'backward' in api_compat_args:
extra_map_str_list.append(
f"{{\"{api_compat_args['backward']}\", {{ {api_extra_attr_map} }}}}"
)
extra_checker_str_list.append(
f"{{\"{api_compat_args['backward']}\", {{ {api_extra_attr_checkers} }}}}"
)
for bw_item in api_compat_args['backward'].split(','):
bw_op_name = get_op_name(bw_item)
extra_map_str_list.append(
f"{{\"{bw_op_name}\", {{ {api_extra_attr_map} }}}}")
extra_checker_str_list.append(
f"{{\"{bw_op_name}\", {{ {api_extra_attr_checkers} }}}}"
)
ops_extra_info_file = open(ops_extra_info_path, 'w')
ops_extra_info_file.write(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册