check_op_desc.py 17.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import sys
17 18
from paddle.utils import OpLastCheckpointChecker
from paddle.fluid.core import OpUpdateType
19 20 21 22 23

INPUTS = "Inputs"
OUTPUTS = "Outputs"
ATTRS = "Attrs"

24 25 26 27 28 29
# The constant `ADD` means that an item has been added. In particular,
# we use `ADD_WITH_DEFAULT` to mean adding attributes with default
# attributes, and `ADD_DISPENSABLE` to mean adding optional inputs or
# outputs.
ADD_WITH_DEFAULT = "Add_with_default"
ADD_DISPENSABLE = "Add_dispensable"
30
ADD = "Add"
31

32 33 34 35 36 37 38 39 40 41 42
DELETE = "Delete"
CHANGE = "Change"

DUPLICABLE = "duplicable"
INTERMEDIATE = "intermediate"
DISPENSABLE = "dispensable"

TYPE = "type"
GENERATED = "generated"
DEFAULT_VALUE = "default_value"

43 44 45 46 47
# add_with_extra, add_with_quant and add_with_def
EXTRA = "extra"
QUANT = "quant"
DEF = "def"

48 49
error = False

50 51 52 53 54 55 56 57 58 59 60 61 62
version_update_map = {
    INPUTS: {
        ADD: OpUpdateType.kNewInput,
    },
    OUTPUTS: {
        ADD: OpUpdateType.kNewOutput,
    },
    ATTRS: {
        ADD: OpUpdateType.kNewAttr,
        CHANGE: OpUpdateType.kModifyAttr,
    },
}

63 64 65 66 67

def diff_vars(origin_vars, new_vars):
    global error
    var_error = False
    var_changed_error_massage = {}
68 69
    var_add_massage = []
    var_add_dispensable_massage = []
70 71
    var_deleted_error_massage = []

72 73 74
    var_add_quant_message = []
    var_add_def_message = []

75 76 77 78 79
    common_vars_name = set(origin_vars.keys()) & set(new_vars.keys())
    vars_name_only_in_origin = set(origin_vars.keys()) - set(new_vars.keys())
    vars_name_only_in_new = set(new_vars.keys()) - set(origin_vars.keys())

    for var_name in common_vars_name:
80
        if origin_vars.get(var_name) == new_vars.get(var_name):
81 82 83 84 85 86 87
            continue
        else:
            error, var_error = True, True
            for arg_name in origin_vars.get(var_name):
                new_arg_value = new_vars.get(var_name, {}).get(arg_name)
                origin_arg_value = origin_vars.get(var_name, {}).get(arg_name)
                if new_arg_value != origin_arg_value:
88 89
                    if var_name not in var_changed_error_massage.keys():
                        var_changed_error_massage[var_name] = {}
90 91 92 93 94 95 96 97
                    var_changed_error_massage[var_name][arg_name] = (
                        origin_arg_value, new_arg_value)

    for var_name in vars_name_only_in_origin:
        error, var_error = True, True
        var_deleted_error_massage.append(var_name)

    for var_name in vars_name_only_in_new:
98
        var_add_massage.append(var_name)
99
        if not new_vars.get(var_name).get(DISPENSABLE):
100
            error, var_error = True, True
101
            var_add_dispensable_massage.append(var_name)
102

103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
        # if added var is extra, then no need to check.
        if new_vars.get(var_name).get(EXTRA):
            continue

        # if added var is quant, slim needs to review, needs to register.
        if new_vars.get(var_name).get(QUANT):
            error, var_error = True, True
            var_add_quant_message.append(var_name)

        # if added var is def, inference needs to review, needs to register.
        if not new_vars.get(var_name).get(EXTRA) and not new_vars.get(
                var_name).get(QUANT):
            error, var_error = True, True
            var_add_def_message.append(var_name)

118
    var_diff_message = {}
119 120 121 122
    if var_add_massage:
        var_diff_message[ADD] = var_add_massage
    if var_add_dispensable_massage:
        var_diff_message[ADD_DISPENSABLE] = var_add_dispensable_massage
123 124 125 126
    if var_changed_error_massage:
        var_diff_message[CHANGE] = var_changed_error_massage
    if var_deleted_error_massage:
        var_diff_message[DELETE] = var_deleted_error_massage
127 128 129 130
    if var_add_quant_message:
        var_diff_message[QUANT] = var_add_quant_message
    if var_add_def_message:
        var_diff_message[DEF] = var_add_def_message
131 132 133 134 135 136 137 138 139 140

    return var_error, var_diff_message


def diff_attr(ori_attrs, new_attrs):
    global error
    attr_error = False

    attr_changed_error_massage = {}
    attr_added_error_massage = []
141
    attr_added_def_error_massage = []
142 143
    attr_deleted_error_massage = []

144 145 146
    attr_added_quant_message = []
    attr_added_define_message = []

147 148 149 150 151
    common_attrs = set(ori_attrs.keys()) & set(new_attrs.keys())
    attrs_only_in_origin = set(ori_attrs.keys()) - set(new_attrs.keys())
    attrs_only_in_new = set(new_attrs.keys()) - set(ori_attrs.keys())

    for attr_name in common_attrs:
152
        if ori_attrs.get(attr_name) == new_attrs.get(attr_name):
153 154 155 156 157 158 159
            continue
        else:
            error, attr_error = True, True
            for arg_name in ori_attrs.get(attr_name):
                new_arg_value = new_attrs.get(attr_name, {}).get(arg_name)
                origin_arg_value = ori_attrs.get(attr_name, {}).get(arg_name)
                if new_arg_value != origin_arg_value:
160 161
                    if attr_name not in attr_changed_error_massage.keys():
                        attr_changed_error_massage[attr_name] = {}
162 163 164 165 166 167 168 169
                    attr_changed_error_massage[attr_name][arg_name] = (
                        origin_arg_value, new_arg_value)

    for attr_name in attrs_only_in_origin:
        error, attr_error = True, True
        attr_deleted_error_massage.append(attr_name)

    for attr_name in attrs_only_in_new:
170
        attr_added_error_massage.append(attr_name)
171
        if new_attrs.get(attr_name).get(DEFAULT_VALUE) == None:
172
            error, attr_error = True, True
173
            attr_added_def_error_massage.append(attr_name)
174

175 176 177 178 179 180 181 182 183 184 185
        # if added attr is quant, slim needs to review, needs to register
        if new_attrs.get(attr_name).get(QUANT):
            error, var_error = True, True
            attr_added_quant_message.append(attr_name)

        # if added attr is def, inference needs to review, needs to register
        if not new_attrs.get(attr_name).get(EXTRA) and not new_attrs.get(
                attr_name).get(QUANT):
            error, var_error = True, True
            attr_added_define_message.append(attr_name)

186 187 188
    attr_diff_message = {}
    if attr_added_error_massage:
        attr_diff_message[ADD] = attr_added_error_massage
189 190
    if attr_added_def_error_massage:
        attr_diff_message[ADD_WITH_DEFAULT] = attr_added_def_error_massage
191 192 193 194
    if attr_changed_error_massage:
        attr_diff_message[CHANGE] = attr_changed_error_massage
    if attr_deleted_error_massage:
        attr_diff_message[DELETE] = attr_deleted_error_massage
195 196 197 198
    if attr_added_define_message:
        attr_diff_message[DEF] = attr_added_define_message
    if attr_added_quant_message:
        attr_diff_message[QUANT] = attr_added_quant_message
199 200 201 202

    return attr_error, attr_diff_message


203 204 205 206
def check_io_registry(io_type, op, diff):
    checker = OpLastCheckpointChecker()
    results = {}
    for update_type in [ADD]:
207
        for item in diff.get(update_type, []):
208 209 210
            infos = checker.filter_updates(
                op, version_update_map[io_type][update_type], item)
            if not infos:
211 212 213 214 215 216 217 218
                if update_type not in results.keys():
                    results[update_type] = []
                # extra not need to register.
                qaunt_ios = diff.get(QUANT, [])
                def_ios = diff.get(DEF, [])
                if item in qaunt_ios or item in def_ios:
                    results[update_type].append((op, item, io_type))

219 220 221
    return results


222
def check_attr_registry(op, diff, origin_attrs):
223 224
    checker = OpLastCheckpointChecker()
    results = {}
225 226 227
    qaunt_attrs = diff.get(QUANT, [])
    def_attrs = diff.get(DEF, [])
    change_attrs = diff.get(CHANGE, {})
228 229 230 231 232
    for update_type in [ADD, CHANGE]:
        for item in diff.get(update_type, {}):
            infos = checker.filter_updates(
                op, version_update_map[ATTRS][update_type], item)
            if not infos:
233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249
                if update_type == ADD:
                    if update_type not in results.keys():
                        results[update_type] = []
                    # extra not need to register.
                    if item in qaunt_attrs or item in def_attrs:
                        results[update_type].append((op, item))
                elif update_type == CHANGE:
                    if CHANGE not in results.keys():
                        results[update_type] = {}
                    for attr_name, attr_change in change_attrs.items():
                        # extra not need to register.
                        if not origin_attrs.get(attr_name).get(EXTRA):
                            results[update_type][attr_name] = attr_change

    for update_type in [ADD, CHANGE]:
        if update_type in results.keys() and len(results[update_type]) == 0:
            del results[update_type]
250 251 252
    return results


253 254 255
def compare_op_desc(origin_op_desc, new_op_desc):
    origin = json.loads(origin_op_desc)
    new = json.loads(new_op_desc)
256 257
    desc_error_message = {}
    version_error_message = {}
258
    if origin_op_desc == new_op_desc:
259
        return desc_error_message, version_error_message
260 261 262 263 264 265 266 267 268 269 270 271

    for op_type in origin:
        # no need to compare if the operator is deleted
        if op_type not in new:
            continue

        origin_info = origin.get(op_type, {})
        new_info = new.get(op_type, {})

        origin_inputs = origin_info.get(INPUTS, {})
        new_inputs = new_info.get(INPUTS, {})
        ins_error, ins_diff = diff_vars(origin_inputs, new_inputs)
272
        ins_version_errors = check_io_registry(INPUTS, op_type, ins_diff)
273 274 275 276

        origin_outputs = origin_info.get(OUTPUTS, {})
        new_outputs = new_info.get(OUTPUTS, {})
        outs_error, outs_diff = diff_vars(origin_outputs, new_outputs)
277
        outs_version_errors = check_io_registry(OUTPUTS, op_type, outs_diff)
278 279 280 281

        origin_attrs = origin_info.get(ATTRS, {})
        new_attrs = new_info.get(ATTRS, {})
        attrs_error, attrs_diff = diff_attr(origin_attrs, new_attrs)
282 283
        attrs_version_errors = check_attr_registry(op_type, attrs_diff,
                                                   origin_attrs)
284

285
        if ins_diff:
286
            desc_error_message.setdefault(op_type, {})[INPUTS] = ins_diff
287
        if outs_diff:
288
            desc_error_message.setdefault(op_type, {})[OUTPUTS] = outs_diff
289
        if attrs_diff:
290
            desc_error_message.setdefault(op_type, {})[ATTRS] = attrs_diff
291

292 293 294 295 296 297 298 299 300
        if ins_version_errors:
            version_error_message.setdefault(op_type,
                                             {})[INPUTS] = ins_version_errors
        if outs_version_errors:
            version_error_message.setdefault(op_type,
                                             {})[OUTPUTS] = outs_version_errors
        if attrs_version_errors:
            version_error_message.setdefault(op_type,
                                             {})[ATTRS] = attrs_version_errors
301

302
    return desc_error_message, version_error_message
303

304 305 306 307

def print_desc_error_message(error_message):
    print("\n======================= \n"
          "Op desc error for the changes of Inputs/Outputs/Attrs of OPs:\n")
308 309 310 311 312
    for op_name in error_message:
        print("For OP '{}':".format(op_name))

        # 1. print inputs error message
        Inputs_error = error_message.get(op_name, {}).get(INPUTS, {})
313
        for name in Inputs_error.get(ADD_DISPENSABLE, {}):
314
            print(" * The added Input '{}' is not dispensable.".format(name))
315 316

        for name in Inputs_error.get(DELETE, {}):
317
            print(" * The Input '{}' is deleted.".format(name))
318 319 320 321 322 323

        for name in Inputs_error.get(CHANGE, {}):
            changed_args = Inputs_error.get(CHANGE, {}).get(name, {})
            for arg in changed_args:
                ori_value, new_value = changed_args.get(arg)
                print(
324
                    " * The arg '{}' of Input '{}' is changed: from '{}' to '{}'.".
325 326
                    format(arg, name, ori_value, new_value))

327 328 329 330 331 332 333 334
        for name in Inputs_error.get(QUANT, {}):
            print(" * The added Input '{}' is `quant`, need slim to review.".
                  format(name))

        for name in Inputs_error.get(DEF, {}):
            print(" * The added Input '{}' is `def`, need inference to review.".
                  format(name))

335 336
        # 2. print outputs error message
        Outputs_error = error_message.get(op_name, {}).get(OUTPUTS, {})
337
        for name in Outputs_error.get(ADD_DISPENSABLE, {}):
338
            print(" * The added Output '{}' is not dispensable.".format(name))
339 340

        for name in Outputs_error.get(DELETE, {}):
341
            print(" * The Output '{}' is deleted.".format(name))
342 343 344 345 346 347

        for name in Outputs_error.get(CHANGE, {}):
            changed_args = Outputs_error.get(CHANGE, {}).get(name, {})
            for arg in changed_args:
                ori_value, new_value = changed_args.get(arg)
                print(
348
                    " * The arg '{}' of Output '{}' is changed: from '{}' to '{}'.".
349 350
                    format(arg, name, ori_value, new_value))

351 352 353 354 355 356 357 358 359
        for name in Outputs_error.get(QUANT, {}):
            print(" * The added Output '{}' is `quant`, need slim to review.".
                  format(name))

        for name in Outputs_error.get(DEF, {}):
            print(
                " * The added Output '{}' is `def`, need inference to review.".
                format(name))

360 361
        # 3. print attrs error message
        attrs_error = error_message.get(op_name, {}).get(ATTRS, {})
362
        for name in attrs_error.get(ADD_WITH_DEFAULT, {}):
363 364
            print(" * The added attr '{}' doesn't set default value.".format(
                name))
365 366

        for name in attrs_error.get(DELETE, {}):
367
            print(" * The attr '{}' is deleted.".format(name))
368 369 370 371 372 373

        for name in attrs_error.get(CHANGE, {}):
            changed_args = attrs_error.get(CHANGE, {}).get(name, {})
            for arg in changed_args:
                ori_value, new_value = changed_args.get(arg)
                print(
374
                    " * The arg '{}' of attr '{}' is changed: from '{}' to '{}'.".
375 376
                    format(arg, name, ori_value, new_value))

377 378 379 380 381 382 383 384 385 386
        for name in attrs_error.get(QUANT, {}):
            # TODO(Wilber):
            print(" * The added attr '{}' is `quant`, need slim to review.".
                  format(name))

        for name in attrs_error.get(DEF, {}):
            # TODO(Wilber):
            print(" * The added attr '{}' is `def`, need inference to review.".
                  format(name))

387

388 389 390 391 392 393 394 395 396 397
def print_version_error_message(error_message):
    print(
        "\n======================= \n"
        "Operator registration error for the changes of Inputs/Outputs/Attrs of OPs:\n"
    )
    for op_name in error_message:
        print("For OP '{}':".format(op_name))

        # 1. print inputs error message
        inputs_error = error_message.get(op_name, {}).get(INPUTS, {})
398 399 400 401 402
        error_list = inputs_error.get(ADD, [])
        if error_list:
            for tup in error_list:
                print(" * The added input '{}' is not yet registered.".format(
                    tup[1]))
403

404
        # 2. print outputs error message
405
        outputs_error = error_message.get(op_name, {}).get(OUTPUTS, {})
406 407 408 409 410
        error_list = outputs_error.get(ADD, [])
        if error_list:
            for tup in error_list:
                print(" * The added output '{}' is not yet registered.".format(
                    tup[1]))
411 412 413

        #3. print attrs error message
        attrs_error = error_message.get(op_name, {}).get(ATTRS, {})
414 415 416 417 418 419 420 421
        error_list = attrs_error.get(ADD, [])
        if error_list:
            for tup in error_list:
                print(" * The added attribute '{}' is not yet registered.".
                      format(tup[1]))
        error_dic = error_message.get(op_name, {}).get(ATTRS, {}).get(CHANGE,
                                                                      {})
        for key, val in error_dic.items():
422
            print(" * The change of attribute '{}' is not yet registered.".
423
                  format(key))
424 425


426 427 428 429 430 431 432 433 434 435 436 437 438
def print_repeat_process():
    print(
        "Tips:"
        " If you want to repeat the process, please follow these steps:\n"
        "\t1. Compile and install paddle from develop branch \n"
        "\t2. Run: python tools/print_op_desc.py  > OP_DESC_DEV.spec \n"
        "\t3. Compile and install paddle from PR branch \n"
        "\t4. Run: python tools/print_op_desc.py  > OP_DESC_PR.spec \n"
        "\t5. Run: python tools/check_op_desc.py OP_DESC_DEV.spec OP_DESC_PR.spec"
    )


if len(sys.argv) == 3:
439 440 441 442 443 444 445 446 447 448
    '''
    Compare op_desc files generated by branch DEV and branch PR.
    And print error message.
    '''
    with open(sys.argv[1], 'r') as f:
        origin_op_desc = f.read()

    with open(sys.argv[2], 'r') as f:
        new_op_desc = f.read()

449 450
    desc_error_message, version_error_message = compare_op_desc(origin_op_desc,
                                                                new_op_desc)
451
    if error:
452
        print("-" * 30)
453 454
        print_desc_error_message(desc_error_message)
        print_version_error_message(version_error_message)
455
        print("-" * 30)
456
else:
457
    print("Usage: python check_op_desc.py OP_DESC_DEV.spec OP_DESC_PR.spec")