From f4dd169a2f527a687fbd53620ffd564c1dfb55d4 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Sat, 7 Dec 2019 12:20:24 +0800 Subject: [PATCH] move get_all_ops_desc from check_op_desc.py to print_op_desc.py (#21613) * move get_all_ops_desc from check_op_desc.py to print_op_desc.py * polish error message. test=develop --- tools/check_op_desc.py | 130 ++++++++++------------------------------- tools/print_op_desc.py | 107 +++++++++++++++++++++++++++++++++ 2 files changed, 139 insertions(+), 98 deletions(-) create mode 100644 tools/print_op_desc.py diff --git a/tools/check_op_desc.py b/tools/check_op_desc.py index bf2998a4de8..92324ea0070 100644 --- a/tools/check_op_desc.py +++ b/tools/check_op_desc.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import paddle.fluid.framework as framework -from paddle.fluid import core import json -from paddle import compat as cpt import sys SAME = 0 @@ -39,52 +36,6 @@ DEFAULT_VALUE = "default_value" error = False -def get_attr_default_value(op_name): - return core.get_op_attrs_default_value(cpt.to_bytes(op_name)) - - -def get_vars_info(op_vars_proto): - vars_info = {} - for var_proto in op_vars_proto: - name = str(var_proto.name) - vars_info[name] = {} - vars_info[name][DUPLICABLE] = var_proto.duplicable - vars_info[name][DISPENSABLE] = var_proto.dispensable - vars_info[name][INTERMEDIATE] = var_proto.intermediate - return vars_info - - -def get_attrs_info(op_proto, op_attrs_proto): - attrs_info = {} - attrs_default_values = get_attr_default_value(op_proto.type) - for attr_proto in op_attrs_proto: - attr_name = str(attr_proto.name) - attrs_info[attr_name] = {} - attrs_info[attr_name][TYPE] = attr_proto.type - attrs_info[attr_name][GENERATED] = attr_proto.generated - attrs_info[attr_name][DEFAULT_VALUE] = attrs_default_values[ - attr_name] if attr_name in attrs_default_values else None - return attrs_info - - -def get_op_desc(op_proto): - op_info = {} - op_info[INPUTS] = get_vars_info(op_proto.inputs) - op_info[OUTPUTS] = get_vars_info(op_proto.outputs) - op_info[ATTRS] = get_attrs_info(op_proto, op_proto.attrs) - return op_info - - -def get_all_ops_desc(): - all_op_protos_dict = {} - all_op_protos = framework.get_all_op_protos() - for op_proto in all_op_protos: - op_type = str(op_proto.type) - all_op_protos_dict[op_type] = get_op_desc(op_proto) - - return all_op_protos_dict - - def diff_vars(origin_vars, new_vars): global error var_error = False @@ -202,13 +153,12 @@ def compare_op_desc(origin_op_desc, new_op_desc): new_attrs = new_info.get(ATTRS, {}) attrs_error, attrs_diff = diff_attr(origin_attrs, new_attrs) - if ins_error or outs_error or attrs_error: - if ins_error: - error_message.setdefault(op_type, {})[INPUTS] = ins_diff - if outs_error: - error_message.setdefault(op_type, {})[OUTPUTS] = outs_diff - if attrs_error: - error_message.setdefault(op_type, {})[ATTRS] = attrs_diff + if ins_error: + error_message.setdefault(op_type, {})[INPUTS] = ins_diff + if outs_error: + error_message.setdefault(op_type, {})[OUTPUTS] = outs_diff + if attrs_error: + error_message.setdefault(op_type, {})[ATTRS] = attrs_diff return error_message @@ -222,80 +172,67 @@ def print_error_message(error_message): # 1. print inputs error message Inputs_error = error_message.get(op_name, {}).get(INPUTS, {}) for name in Inputs_error.get(ADD, {}): - print("The added Input '{}' is not dispensable.".format(name)) + print(" * The added Input '{}' is not dispensable.".format(name)) for name in Inputs_error.get(DELETE, {}): - print("The Input '{}' is deleted.".format(name)) + print(" * The Input '{}' is deleted.".format(name)) for name in Inputs_error.get(CHANGE, {}): changed_args = Inputs_error.get(CHANGE, {}).get(name, {}) for arg in changed_args: ori_value, new_value = changed_args.get(arg) print( - "The arg '{}' of Input '{}' is changed: from '{}' to '{}'.". + " * The arg '{}' of Input '{}' is changed: from '{}' to '{}'.". format(arg, name, ori_value, new_value)) # 2. print outputs error message Outputs_error = error_message.get(op_name, {}).get(OUTPUTS, {}) for name in Outputs_error.get(ADD, {}): - print("The added Output '{}' is not dispensable.".format(name)) + print(" * The added Output '{}' is not dispensable.".format(name)) for name in Outputs_error.get(DELETE, {}): - print("The Output '{}' is deleted.".format(name)) + print(" * The Output '{}' is deleted.".format(name)) for name in Outputs_error.get(CHANGE, {}): changed_args = Outputs_error.get(CHANGE, {}).get(name, {}) for arg in changed_args: ori_value, new_value = changed_args.get(arg) print( - "The arg '{}' of Output '{}' is changed: from '{}' to '{}'.". + " * The arg '{}' of Output '{}' is changed: from '{}' to '{}'.". format(arg, name, ori_value, new_value)) # 3. print attrs error message attrs_error = error_message.get(op_name, {}).get(ATTRS, {}) for name in attrs_error.get(ADD, {}): - print("The added attr '{}' doesn't set default value.".format(name)) + print(" * The added attr '{}' doesn't set default value.".format( + name)) for name in attrs_error.get(DELETE, {}): - print("The attr '{}' is deleted.".format(name)) + print(" * The attr '{}' is deleted.".format(name)) for name in attrs_error.get(CHANGE, {}): changed_args = attrs_error.get(CHANGE, {}).get(name, {}) for arg in changed_args: ori_value, new_value = changed_args.get(arg) print( - "The arg '{}' of attr '{}' is changed: from '{}' to '{}'.". + " * The arg '{}' of attr '{}' is changed: from '{}' to '{}'.". format(arg, name, ori_value, new_value)) + print("-" * 30) -if len(sys.argv) == 1: - ''' - Print all ops desc in dict: - {op1_name: - {INPUTS: - {input_name1: - {DISPENSABLE: bool, - INTERMEDIATE: bool, - DUPLICABLE: bool - }, - input_name2:{} - }, - OUTPUTS:{}, - ATTRS: - {attr_name1: - {TYPE: int, - GENERATED: bool, - DEFAULT_VALUE: int/str/etc, - } - } - } - op2_name:{} - } - ''' - all_op_protos_dict = get_all_ops_desc() - result = json.dumps(all_op_protos_dict) - print(result) -elif len(sys.argv) == 3: +def print_repeat_process(): + print( + "Tips:" + " If you want to repeat the process, please follow these steps:\n" + "\t1. Compile and install paddle from develop branch \n" + "\t2. Run: python tools/print_op_desc.py > OP_DESC_DEV.spec \n" + "\t3. Compile and install paddle from PR branch \n" + "\t4. Run: python tools/print_op_desc.py > OP_DESC_PR.spec \n" + "\t5. Run: python tools/check_op_desc.py OP_DESC_DEV.spec OP_DESC_PR.spec" + ) + + +if len(sys.argv) == 3: ''' Compare op_desc files generated by branch DEV and branch PR. And print error message. @@ -309,9 +246,6 @@ elif len(sys.argv) == 3: error_message = compare_op_desc(origin_op_desc, new_op_desc) if error: print_error_message(error_message) - + print_repeat_process() else: - print("Usage:\n" \ - "\t1. python check_op_desc.py > OP_DESC_DEV.spec\n" \ - "\t2. python check_op_desc.py > OP_DESC_PR.spec\n"\ - "\t3. python check_op_desc.py OP_DESC_DEV.spec OP_DESC_PR.spec > error_message") + print("Usage: python check_op_desc.py OP_DESC_DEV.spec OP_DESC_PR.spec") diff --git a/tools/print_op_desc.py b/tools/print_op_desc.py new file mode 100644 index 00000000000..64445bab3a6 --- /dev/null +++ b/tools/print_op_desc.py @@ -0,0 +1,107 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Print all ops desc in dict: + {op1_name: + {INPUTS: + {input_name1: + {DISPENSABLE: bool, + INTERMEDIATE: bool, + DUPLICABLE: bool + }, + input_name2:{} + }, + OUTPUTS:{}, + ATTRS: + {attr_name1: + {TYPE: int, + GENERATED: bool, + DEFAULT_VALUE: int/str/etc, + } + } + } + op2_name:{} + } + +Usage: + python print_op_desc.py > op_desc.spec +""" + +import paddle.fluid.framework as framework +from paddle.fluid import core +import json +from paddle import compat as cpt + +INPUTS = "Inputs" +OUTPUTS = "Outputs" +ATTRS = "Attrs" + +DUPLICABLE = "duplicable" +INTERMEDIATE = "intermediate" +DISPENSABLE = "dispensable" + +TYPE = "type" +GENERATED = "generated" +DEFAULT_VALUE = "default_value" + + +def get_attr_default_value(op_name): + return core.get_op_attrs_default_value(cpt.to_bytes(op_name)) + + +def get_vars_info(op_vars_proto): + vars_info = {} + for var_proto in op_vars_proto: + name = str(var_proto.name) + vars_info[name] = {} + vars_info[name][DUPLICABLE] = var_proto.duplicable + vars_info[name][DISPENSABLE] = var_proto.dispensable + vars_info[name][INTERMEDIATE] = var_proto.intermediate + return vars_info + + +def get_attrs_info(op_proto, op_attrs_proto): + attrs_info = {} + attrs_default_values = get_attr_default_value(op_proto.type) + for attr_proto in op_attrs_proto: + attr_name = str(attr_proto.name) + attrs_info[attr_name] = {} + attrs_info[attr_name][TYPE] = attr_proto.type + attrs_info[attr_name][GENERATED] = attr_proto.generated + attrs_info[attr_name][DEFAULT_VALUE] = attrs_default_values[ + attr_name] if attr_name in attrs_default_values else None + return attrs_info + + +def get_op_desc(op_proto): + op_info = {} + op_info[INPUTS] = get_vars_info(op_proto.inputs) + op_info[OUTPUTS] = get_vars_info(op_proto.outputs) + op_info[ATTRS] = get_attrs_info(op_proto, op_proto.attrs) + return op_info + + +def get_all_ops_desc(): + all_op_protos_dict = {} + all_op_protos = framework.get_all_op_protos() + for op_proto in all_op_protos: + op_type = str(op_proto.type) + all_op_protos_dict[op_type] = get_op_desc(op_proto) + + return all_op_protos_dict + + +all_op_protos_dict = get_all_ops_desc() +result = json.dumps(all_op_protos_dict) +print(result) -- GitLab