未验证 提交 f9882854 编写于 作者: Z zyfncg 提交者: GitHub

Adjust the rule of configure in api_compat.yaml (#45672)

* set use_cudnn=true for conv2d

* refine the config rule of api_compat
上级 e2823c8c
...@@ -123,9 +123,8 @@ ...@@ -123,9 +123,8 @@
str fuse_activation = "", float fuse_alpha = 0.0f, float fuse_beta = 0.0f, str fuse_activation = "", float fuse_alpha = 0.0f, float fuse_beta = 0.0f,
int workspace_size_MB = platform::GetDefaultConvWorkspaceSizeLimitMB()] int workspace_size_MB = platform::GetDefaultConvWorkspaceSizeLimitMB()]
- api : diag - api : diag (diag_v2)
op_name : diag_v2 backward : diag_grad (diag_v2_grad)
grad_op_name : diag_v2_grad
inputs : inputs :
x : X x : X
outputs : outputs :
......
...@@ -56,18 +56,29 @@ def restruct_io(api): ...@@ -56,18 +56,29 @@ def restruct_io(api):
# replace name of op and params for OpMaker # replace name of op and params for OpMaker
def replace_compat_name(api_op_map, forward_api_dict, backward_api_dict): def replace_compat_name(api_op_map, forward_api_dict, backward_api_dict):
def get_api_and_op_name(api_item):
names = api_item.split('(')
if len(names) == 1:
return names[0].strip(), names[0].strip()
else:
return names[0].strip(), names[1].split(')')[0].strip()
for api_args in api_op_map: for api_args in api_op_map:
if api_args['api'] not in forward_api_dict: api_name, op_name = get_api_and_op_name(api_args['api'])
if api_name not in forward_api_dict:
continue continue
forward_api_item = forward_api_dict[api_args['api']] forward_api_item = forward_api_dict[api_name]
has_backward = True if forward_api_item['backward'] else False has_backward = True if forward_api_item['backward'] else False
if has_backward: if has_backward:
backward_api_item = backward_api_dict[forward_api_item['backward']] backward_api_item = backward_api_dict[forward_api_item['backward']]
if 'op_name' in api_args: if api_name != op_name:
forward_api_item['op_name'] = api_args['op_name'] forward_api_item['op_name'] = op_name
if 'grad_op_name' in api_args and has_backward: if 'backward' in api_args and has_backward:
forward_api_item['backward'] = api_args['grad_op_name'] bw_api_name, bw_op_name = get_api_and_op_name(
backward_api_item['op_name'] = api_args['grad_op_name'] api_args['backward'].split(',')[0])
forward_api_item['backward'] = bw_op_name
backward_api_item['op_name'] = bw_op_name
key_set = ['inputs', 'attrs', 'outputs'] key_set = ['inputs', 'attrs', 'outputs']
args_map = {} args_map = {}
......
...@@ -70,6 +70,13 @@ def generate_extra_info(api_compat_yaml_path, ops_extra_info_path): ...@@ -70,6 +70,13 @@ def generate_extra_info(api_compat_yaml_path, ops_extra_info_path):
with open(api_compat_yaml_path, 'rt') as f: with open(api_compat_yaml_path, 'rt') as f:
compat_apis = yaml.safe_load(f) compat_apis = yaml.safe_load(f)
def get_op_name(api_item):
names = api_item.split('(')
if len(names) == 1:
return names[0].strip()
else:
return names[1].split(')')[0].strip()
extra_map_str_list = [] extra_map_str_list = []
extra_checker_str_list = [] extra_checker_str_list = []
...@@ -96,17 +103,18 @@ def generate_extra_info(api_compat_yaml_path, ops_extra_info_path): ...@@ -96,17 +103,18 @@ def generate_extra_info(api_compat_yaml_path, ops_extra_info_path):
api_extra_attr_checkers = ",\n ".join( api_extra_attr_checkers = ",\n ".join(
attr_checker_func_list) attr_checker_func_list)
extra_map_str_list.append( extra_map_str_list.append(
f"{{\"{api_compat_args['api']}\", {{ {api_extra_attr_map} }}}}" f"{{\"{get_op_name(api_compat_args['api'])}\", {{ {api_extra_attr_map} }}}}"
) )
extra_checker_str_list.append( extra_checker_str_list.append(
f"{{\"{api_compat_args['api']}\", {{ {api_extra_attr_checkers} }}}}" f"{{\"{get_op_name(api_compat_args['api'])}\", {{ {api_extra_attr_checkers} }}}}"
) )
if 'backward' in api_compat_args: if 'backward' in api_compat_args:
for bw_item in api_compat_args['backward'].split(','):
bw_op_name = get_op_name(bw_item)
extra_map_str_list.append( extra_map_str_list.append(
f"{{\"{api_compat_args['backward']}\", {{ {api_extra_attr_map} }}}}" f"{{\"{bw_op_name}\", {{ {api_extra_attr_map} }}}}")
)
extra_checker_str_list.append( extra_checker_str_list.append(
f"{{\"{api_compat_args['backward']}\", {{ {api_extra_attr_checkers} }}}}" f"{{\"{bw_op_name}\", {{ {api_extra_attr_checkers} }}}}"
) )
ops_extra_info_file = open(ops_extra_info_path, 'w') ops_extra_info_file = open(ops_extra_info_path, 'w')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册