未验证 提交 7bd50187 编写于 作者: K kangguangli 提交者: GitHub

[NewIR] support custom verify in op definition generation (#55428)

* support custom verify

* fix

* fix

* fix

* fix coverage ci

* remove custom verify in assert
上级 f7c0697c
...@@ -307,6 +307,9 @@ class OpInfoParser: ...@@ -307,6 +307,9 @@ class OpInfoParser:
self.inplace_map = self.parse_op_inplace_info() self.inplace_map = self.parse_op_inplace_info()
self.view_map = self.parse_op_view_info() self.view_map = self.parse_op_view_info()
# parse has_custom_verify
self.custom_verify = self.parse_custom_verify()
def cross_check(self, name_list, type_list, optional_list=None): def cross_check(self, name_list, type_list, optional_list=None):
assert len(name_list) == len( assert len(name_list) == len(
type_list type_list
...@@ -316,6 +319,11 @@ class OpInfoParser: ...@@ -316,6 +319,11 @@ class OpInfoParser:
optional_list optional_list
), "type list size != optional list size." ), "type list size != optional list size."
def parse_custom_verify(self):
if 'custom_verify' in self.op_yaml_item:
return self.op_yaml_item['custom_verify']
return False
def parse_op_phi_name(self): def parse_op_phi_name(self):
if (self.parse_op_inplace_info() is None) and ( if (self.parse_op_inplace_info() is None) and (
self.parse_op_view_info() is None self.parse_op_view_info() is None
...@@ -980,6 +988,8 @@ def OpGenerator( ...@@ -980,6 +988,8 @@ def OpGenerator(
) )
# generate op verify function str # generate op verify function str
op_verify_str = ''
if not op_info.custom_verify:
op_verify_str = gen_verify_func_str( op_verify_str = gen_verify_func_str(
op_class_name, op_class_name,
op_input_type_list, op_input_type_list,
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
data_transform: null data_transform: null
inplace: null inplace: null
backward: null backward: null
- name: fetch - name: fetch
inputs: inputs:
- typename: Tensor - typename: Tensor
...@@ -150,18 +151,6 @@ ...@@ -150,18 +151,6 @@
no_need_buffer: null no_need_buffer: null
data_transform: null data_transform: null
- name: py_func_
inputs:
- {typename: 'Tensor', name: x, optional: false, no_need_buffer: false, data_transform: {}}
attrs:
- {typename: 'int', name: forward_callable_id, default_value: '0'}
- {typename: 'int', name: backward_callable_id, default_value: '-1'}
- {typename: 'str[]', name: backward_skip_vars, default_value: '{}'}
outputs:
- {typename: 'Tensor', name: out, optional: false, intermediate: false}
no_need_buffer: null
data_transform: null
- name: embedding_grad_sparse - name: embedding_grad_sparse
inputs: inputs:
- typename: Tensor - typename: Tensor
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册