未验证 提交 f4dd169a 编写于 作者: L liym27 提交者: GitHub

move get_all_ops_desc from check_op_desc.py to print_op_desc.py (#21613)

* move get_all_ops_desc from check_op_desc.py to print_op_desc.py

* polish error message. test=develop
上级 88960684
...@@ -12,10 +12,7 @@ ...@@ -12,10 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle.fluid.framework as framework
from paddle.fluid import core
import json import json
from paddle import compat as cpt
import sys import sys
SAME = 0 SAME = 0
...@@ -39,52 +36,6 @@ DEFAULT_VALUE = "default_value" ...@@ -39,52 +36,6 @@ DEFAULT_VALUE = "default_value"
error = False error = False
def get_attr_default_value(op_name):
return core.get_op_attrs_default_value(cpt.to_bytes(op_name))
def get_vars_info(op_vars_proto):
vars_info = {}
for var_proto in op_vars_proto:
name = str(var_proto.name)
vars_info[name] = {}
vars_info[name][DUPLICABLE] = var_proto.duplicable
vars_info[name][DISPENSABLE] = var_proto.dispensable
vars_info[name][INTERMEDIATE] = var_proto.intermediate
return vars_info
def get_attrs_info(op_proto, op_attrs_proto):
attrs_info = {}
attrs_default_values = get_attr_default_value(op_proto.type)
for attr_proto in op_attrs_proto:
attr_name = str(attr_proto.name)
attrs_info[attr_name] = {}
attrs_info[attr_name][TYPE] = attr_proto.type
attrs_info[attr_name][GENERATED] = attr_proto.generated
attrs_info[attr_name][DEFAULT_VALUE] = attrs_default_values[
attr_name] if attr_name in attrs_default_values else None
return attrs_info
def get_op_desc(op_proto):
op_info = {}
op_info[INPUTS] = get_vars_info(op_proto.inputs)
op_info[OUTPUTS] = get_vars_info(op_proto.outputs)
op_info[ATTRS] = get_attrs_info(op_proto, op_proto.attrs)
return op_info
def get_all_ops_desc():
all_op_protos_dict = {}
all_op_protos = framework.get_all_op_protos()
for op_proto in all_op_protos:
op_type = str(op_proto.type)
all_op_protos_dict[op_type] = get_op_desc(op_proto)
return all_op_protos_dict
def diff_vars(origin_vars, new_vars): def diff_vars(origin_vars, new_vars):
global error global error
var_error = False var_error = False
...@@ -202,7 +153,6 @@ def compare_op_desc(origin_op_desc, new_op_desc): ...@@ -202,7 +153,6 @@ def compare_op_desc(origin_op_desc, new_op_desc):
new_attrs = new_info.get(ATTRS, {}) new_attrs = new_info.get(ATTRS, {})
attrs_error, attrs_diff = diff_attr(origin_attrs, new_attrs) attrs_error, attrs_diff = diff_attr(origin_attrs, new_attrs)
if ins_error or outs_error or attrs_error:
if ins_error: if ins_error:
error_message.setdefault(op_type, {})[INPUTS] = ins_diff error_message.setdefault(op_type, {})[INPUTS] = ins_diff
if outs_error: if outs_error:
...@@ -222,80 +172,67 @@ def print_error_message(error_message): ...@@ -222,80 +172,67 @@ def print_error_message(error_message):
# 1. print inputs error message # 1. print inputs error message
Inputs_error = error_message.get(op_name, {}).get(INPUTS, {}) Inputs_error = error_message.get(op_name, {}).get(INPUTS, {})
for name in Inputs_error.get(ADD, {}): for name in Inputs_error.get(ADD, {}):
print("The added Input '{}' is not dispensable.".format(name)) print(" * The added Input '{}' is not dispensable.".format(name))
for name in Inputs_error.get(DELETE, {}): for name in Inputs_error.get(DELETE, {}):
print("The Input '{}' is deleted.".format(name)) print(" * The Input '{}' is deleted.".format(name))
for name in Inputs_error.get(CHANGE, {}): for name in Inputs_error.get(CHANGE, {}):
changed_args = Inputs_error.get(CHANGE, {}).get(name, {}) changed_args = Inputs_error.get(CHANGE, {}).get(name, {})
for arg in changed_args: for arg in changed_args:
ori_value, new_value = changed_args.get(arg) ori_value, new_value = changed_args.get(arg)
print( print(
"The arg '{}' of Input '{}' is changed: from '{}' to '{}'.". " * The arg '{}' of Input '{}' is changed: from '{}' to '{}'.".
format(arg, name, ori_value, new_value)) format(arg, name, ori_value, new_value))
# 2. print outputs error message # 2. print outputs error message
Outputs_error = error_message.get(op_name, {}).get(OUTPUTS, {}) Outputs_error = error_message.get(op_name, {}).get(OUTPUTS, {})
for name in Outputs_error.get(ADD, {}): for name in Outputs_error.get(ADD, {}):
print("The added Output '{}' is not dispensable.".format(name)) print(" * The added Output '{}' is not dispensable.".format(name))
for name in Outputs_error.get(DELETE, {}): for name in Outputs_error.get(DELETE, {}):
print("The Output '{}' is deleted.".format(name)) print(" * The Output '{}' is deleted.".format(name))
for name in Outputs_error.get(CHANGE, {}): for name in Outputs_error.get(CHANGE, {}):
changed_args = Outputs_error.get(CHANGE, {}).get(name, {}) changed_args = Outputs_error.get(CHANGE, {}).get(name, {})
for arg in changed_args: for arg in changed_args:
ori_value, new_value = changed_args.get(arg) ori_value, new_value = changed_args.get(arg)
print( print(
"The arg '{}' of Output '{}' is changed: from '{}' to '{}'.". " * The arg '{}' of Output '{}' is changed: from '{}' to '{}'.".
format(arg, name, ori_value, new_value)) format(arg, name, ori_value, new_value))
# 3. print attrs error message # 3. print attrs error message
attrs_error = error_message.get(op_name, {}).get(ATTRS, {}) attrs_error = error_message.get(op_name, {}).get(ATTRS, {})
for name in attrs_error.get(ADD, {}): for name in attrs_error.get(ADD, {}):
print("The added attr '{}' doesn't set default value.".format(name)) print(" * The added attr '{}' doesn't set default value.".format(
name))
for name in attrs_error.get(DELETE, {}): for name in attrs_error.get(DELETE, {}):
print("The attr '{}' is deleted.".format(name)) print(" * The attr '{}' is deleted.".format(name))
for name in attrs_error.get(CHANGE, {}): for name in attrs_error.get(CHANGE, {}):
changed_args = attrs_error.get(CHANGE, {}).get(name, {}) changed_args = attrs_error.get(CHANGE, {}).get(name, {})
for arg in changed_args: for arg in changed_args:
ori_value, new_value = changed_args.get(arg) ori_value, new_value = changed_args.get(arg)
print( print(
"The arg '{}' of attr '{}' is changed: from '{}' to '{}'.". " * The arg '{}' of attr '{}' is changed: from '{}' to '{}'.".
format(arg, name, ori_value, new_value)) format(arg, name, ori_value, new_value))
print("-" * 30)
if len(sys.argv) == 1: def print_repeat_process():
''' print(
Print all ops desc in dict: "Tips:"
{op1_name: " If you want to repeat the process, please follow these steps:\n"
{INPUTS: "\t1. Compile and install paddle from develop branch \n"
{input_name1: "\t2. Run: python tools/print_op_desc.py > OP_DESC_DEV.spec \n"
{DISPENSABLE: bool, "\t3. Compile and install paddle from PR branch \n"
INTERMEDIATE: bool, "\t4. Run: python tools/print_op_desc.py > OP_DESC_PR.spec \n"
DUPLICABLE: bool "\t5. Run: python tools/check_op_desc.py OP_DESC_DEV.spec OP_DESC_PR.spec"
}, )
input_name2:{}
},
OUTPUTS:{}, if len(sys.argv) == 3:
ATTRS:
{attr_name1:
{TYPE: int,
GENERATED: bool,
DEFAULT_VALUE: int/str/etc,
}
}
}
op2_name:{}
}
'''
all_op_protos_dict = get_all_ops_desc()
result = json.dumps(all_op_protos_dict)
print(result)
elif len(sys.argv) == 3:
''' '''
Compare op_desc files generated by branch DEV and branch PR. Compare op_desc files generated by branch DEV and branch PR.
And print error message. And print error message.
...@@ -309,9 +246,6 @@ elif len(sys.argv) == 3: ...@@ -309,9 +246,6 @@ elif len(sys.argv) == 3:
error_message = compare_op_desc(origin_op_desc, new_op_desc) error_message = compare_op_desc(origin_op_desc, new_op_desc)
if error: if error:
print_error_message(error_message) print_error_message(error_message)
print_repeat_process()
else: else:
print("Usage:\n" \ print("Usage: python check_op_desc.py OP_DESC_DEV.spec OP_DESC_PR.spec")
"\t1. python check_op_desc.py > OP_DESC_DEV.spec\n" \
"\t2. python check_op_desc.py > OP_DESC_PR.spec\n"\
"\t3. python check_op_desc.py OP_DESC_DEV.spec OP_DESC_PR.spec > error_message")
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Print all ops desc in dict:
{op1_name:
{INPUTS:
{input_name1:
{DISPENSABLE: bool,
INTERMEDIATE: bool,
DUPLICABLE: bool
},
input_name2:{}
},
OUTPUTS:{},
ATTRS:
{attr_name1:
{TYPE: int,
GENERATED: bool,
DEFAULT_VALUE: int/str/etc,
}
}
}
op2_name:{}
}
Usage:
python print_op_desc.py > op_desc.spec
"""
import paddle.fluid.framework as framework
from paddle.fluid import core
import json
from paddle import compat as cpt
INPUTS = "Inputs"
OUTPUTS = "Outputs"
ATTRS = "Attrs"
DUPLICABLE = "duplicable"
INTERMEDIATE = "intermediate"
DISPENSABLE = "dispensable"
TYPE = "type"
GENERATED = "generated"
DEFAULT_VALUE = "default_value"
def get_attr_default_value(op_name):
return core.get_op_attrs_default_value(cpt.to_bytes(op_name))
def get_vars_info(op_vars_proto):
vars_info = {}
for var_proto in op_vars_proto:
name = str(var_proto.name)
vars_info[name] = {}
vars_info[name][DUPLICABLE] = var_proto.duplicable
vars_info[name][DISPENSABLE] = var_proto.dispensable
vars_info[name][INTERMEDIATE] = var_proto.intermediate
return vars_info
def get_attrs_info(op_proto, op_attrs_proto):
attrs_info = {}
attrs_default_values = get_attr_default_value(op_proto.type)
for attr_proto in op_attrs_proto:
attr_name = str(attr_proto.name)
attrs_info[attr_name] = {}
attrs_info[attr_name][TYPE] = attr_proto.type
attrs_info[attr_name][GENERATED] = attr_proto.generated
attrs_info[attr_name][DEFAULT_VALUE] = attrs_default_values[
attr_name] if attr_name in attrs_default_values else None
return attrs_info
def get_op_desc(op_proto):
op_info = {}
op_info[INPUTS] = get_vars_info(op_proto.inputs)
op_info[OUTPUTS] = get_vars_info(op_proto.outputs)
op_info[ATTRS] = get_attrs_info(op_proto, op_proto.attrs)
return op_info
def get_all_ops_desc():
all_op_protos_dict = {}
all_op_protos = framework.get_all_op_protos()
for op_proto in all_op_protos:
op_type = str(op_proto.type)
all_op_protos_dict[op_type] = get_op_desc(op_proto)
return all_op_protos_dict
all_op_protos_dict = get_all_ops_desc()
result = json.dumps(all_op_protos_dict)
print(result)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册