op_version.py 2.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ..fluid import core

__all__ = ['OpLastCheckpointChecker']


def Singleton(cls):
    _instance = {}

    def _singleton(*args, **kargs):
        if cls not in _instance:
            _instance[cls] = cls(*args, **kargs)
        return _instance[cls]

    return _singleton


class OpUpdateInfoHelper(object):
    def __init__(self, info):
        self._info = info

    def verify_key_value(self, name=''):
        result = False
        key_funcs = {
            core.OpAttrInfo: 'name',
            core.OpInputOutputInfo: 'name',
        }
        if name == '':
            result = True
        elif type(self._info) in key_funcs:
            if getattr(self._info, key_funcs[type(self._info)])() == name:
                result = True
        return result


@Singleton
class OpLastCheckpointChecker(object):
    def __init__(self):
        self.raw_version_map = core.get_op_version_map()
        self.checkpoints_map = {}
        self._construct_map()

    def _construct_map(self):
        for op_name in self.raw_version_map:
            last_checkpoint = self.raw_version_map[op_name].checkpoints()[-1]
            infos = last_checkpoint.version_desc().infos()
            self.checkpoints_map[op_name] = infos

    def filter_updates(self, op_name, type=core.OpUpdateType.kInvalid, key=''):
        updates = []
        if op_name in self.checkpoints_map:
            for update in self.checkpoints_map[op_name]:
                if (update.type() == type) or (
                        type == core.OpUpdateType.kInvalid):
                    if OpUpdateInfoHelper(update.info()).verify_key_value(key):
                        updates.append(update.info())
        return updates