未验证 提交 d6b062e8 编写于 作者: Z zyfncg 提交者: GitHub

Add validity check for config in yaml (#49049)

* add validity check for config in yaml

* delete debug log
上级 ba422913
...@@ -276,11 +276,51 @@ def parse_forward(op_name: str, forward_config: str) -> Dict[str, Any]: ...@@ -276,11 +276,51 @@ def parse_forward(op_name: str, forward_config: str) -> Dict[str, Any]:
return forward_cfg return forward_cfg
def check_op_config(op_entry, op_name):
base_key_set = (
'op',
'backward_op',
'forward',
'args',
'output',
'infer_meta',
'kernel',
'backward',
'invoke',
'inplace',
'view',
'optional',
'intermediate',
'no_need_buffer',
'data_transform',
)
infer_meta_key_set = ('func', 'param')
kernel_key_set = ('func', 'param', 'data_type', 'layout', 'backend')
for key in op_entry.keys():
assert (
key in base_key_set
), f"Op ({op_name}) : invalid key ({key}) in Yaml."
if 'infer_meta' in op_entry:
for infer_meta_key in op_entry['infer_meta'].keys():
assert (
infer_meta_key in infer_meta_key_set
), f"Op ({op_name}) : invalid key (infer_meta.{infer_meta_key}) in Yaml."
if 'kernel' in op_entry:
for kernel_key in op_entry['kernel'].keys():
assert (
kernel_key in kernel_key_set
), f"Op ({op_name}) : invalid key (kernel.{kernel_key}) in Yaml."
def parse_op_entry(op_entry: Dict[str, Any], name_field="op"): def parse_op_entry(op_entry: Dict[str, Any], name_field="op"):
op_name = op_entry[name_field] op_name = op_entry[name_field]
inputs, attrs = parse_input_and_attr(op_name, op_entry["args"]) inputs, attrs = parse_input_and_attr(op_name, op_entry["args"])
outputs = parse_outputs(op_name, op_entry["output"]) outputs = parse_outputs(op_name, op_entry["output"])
check_op_config(op_entry, op_name)
# validate default value of DataType and DataLayout # validate default value of DataType and DataLayout
for attr in attrs: for attr in attrs:
if "default_value" in attr: if "default_value" in attr:
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
func : AccuracyInferMeta func : AccuracyInferMeta
kernel : kernel :
func : accuracy func : accuracy
dtype : x data_type : x
- op : adadelta_ - op : adadelta_
args : (Tensor param, Tensor grad, Tensor avg_squared_grad, Tensor avg_squared_update, float rho, float epsilon) args : (Tensor param, Tensor grad, Tensor avg_squared_grad, Tensor avg_squared_update, float rho, float epsilon)
...@@ -1132,9 +1132,9 @@ ...@@ -1132,9 +1132,9 @@
output : Tensor(solution), Tensor(residuals), Tensor(rank), Tensor(singular_values) output : Tensor(solution), Tensor(residuals), Tensor(rank), Tensor(singular_values)
infer_meta : infer_meta :
func : LstsqInferMeta func : LstsqInferMeta
dtype : x
kernel : kernel :
func : lstsq func : lstsq
data_type : x
- op : lu - op : lu
args : (Tensor x, bool pivot) args : (Tensor x, bool pivot)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册