未验证 提交 7bd50187 编写于 作者: K kangguangli 提交者: GitHub

[NewIR] support custom verify in op definition generation (#55428)

* support custom verify

* fix

* fix

* fix

* fix coverage ci

* remove custom verify in assert
上级 f7c0697c
......@@ -307,6 +307,9 @@ class OpInfoParser:
self.inplace_map = self.parse_op_inplace_info()
self.view_map = self.parse_op_view_info()
# parse has_custom_verify
self.custom_verify = self.parse_custom_verify()
def cross_check(self, name_list, type_list, optional_list=None):
assert len(name_list) == len(
type_list
......@@ -316,6 +319,11 @@ class OpInfoParser:
optional_list
), "type list size != optional list size."
def parse_custom_verify(self):
if 'custom_verify' in self.op_yaml_item:
return self.op_yaml_item['custom_verify']
return False
def parse_op_phi_name(self):
if (self.parse_op_inplace_info() is None) and (
self.parse_op_view_info() is None
......@@ -980,17 +988,19 @@ def OpGenerator(
)
# generate op verify function str
op_verify_str = gen_verify_func_str(
op_class_name,
op_input_type_list,
op_input_optional_list,
op_mutable_attribute_name_list,
op_mutable_attribute_type_list,
op_non_mutable_attribute_name_list,
op_non_mutable_attribute_type_list,
op_output_type_list,
op_output_optional_list,
)
op_verify_str = ''
if not op_info.custom_verify:
op_verify_str = gen_verify_func_str(
op_class_name,
op_input_type_list,
op_input_optional_list,
op_mutable_attribute_name_list,
op_mutable_attribute_type_list,
op_non_mutable_attribute_name_list,
op_non_mutable_attribute_type_list,
op_output_type_list,
op_output_optional_list,
)
op_infer_meta_str = gen_op_infer_meta_str(op_info, op_class_name)
......
......@@ -9,6 +9,7 @@
data_transform: null
inplace: null
backward: null
- name: fetch
inputs:
- typename: Tensor
......@@ -150,18 +151,6 @@
no_need_buffer: null
data_transform: null
- name: py_func_
inputs:
- {typename: 'Tensor', name: x, optional: false, no_need_buffer: false, data_transform: {}}
attrs:
- {typename: 'int', name: forward_callable_id, default_value: '0'}
- {typename: 'int', name: backward_callable_id, default_value: '-1'}
- {typename: 'str[]', name: backward_skip_vars, default_value: '{}'}
outputs:
- {typename: 'Tensor', name: out, optional: false, intermediate: false}
no_need_buffer: null
data_transform: null
- name: embedding_grad_sparse
inputs:
- typename: Tensor
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册