validation.py 5.9 KB
Newer Older
X
test  
xjqbest 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import time
import logging
X
test  
xjqbest 已提交
17 18
from paddlerec.core.utils import envs

19 20 21 22
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger()
logger.setLevel(logging.INFO)

X
test  
xjqbest 已提交
23 24

class ValueFormat:
25
    def __init__(self, value_type, value, value_handler, required=False):
X
xjqbest 已提交
26
        self.value_type = value_type
X
test  
xjqbest 已提交
27
        self.value_handler = value_handler
28 29
        self.value = value
        self.required = required
X
test  
xjqbest 已提交
30 31

    def is_valid(self, name, value):
32 33 34 35 36 37

        if not self.value_type:
            ret = True
        else:
            ret = self.is_type_valid(name, value)

X
test  
xjqbest 已提交
38 39 40
        if not ret:
            return ret

41 42 43
        if not self.value or not self.value_handler:
            return True

X
test  
xjqbest 已提交
44 45 46 47
        ret = self.is_value_valid(name, value)
        return ret

    def is_type_valid(self, name, value):
X
xjqbest 已提交
48
        if self.value_type == "int":
X
test  
xjqbest 已提交
49
            if not isinstance(value, int):
50
                logger.info("\nattr {} should be int, but {} now\n".format(
51
                    name, type(value)))
X
test  
xjqbest 已提交
52 53 54
                return False
            return True

X
xjqbest 已提交
55
        elif self.value_type == "str":
X
test  
xjqbest 已提交
56
            if not isinstance(value, str):
57
                logger.info("\nattr {} should be str, but {} now\n".format(
58
                    name, type(value)))
X
test  
xjqbest 已提交
59 60 61
                return False
            return True

X
xjqbest 已提交
62
        elif self.value_type == "strs":
X
test  
xjqbest 已提交
63
            if not isinstance(value, list):
64 65
                logger.info("\nattr {} should be list(str), but {} now\n".
                            format(name, type(value)))
X
test  
xjqbest 已提交
66 67 68
                return False
            for v in value:
                if not isinstance(v, str):
69 70 71
                    logger.info(
                        "\nattr {} should be list(str), but list({}) now\n".
                        format(name, type(v)))
X
test  
xjqbest 已提交
72 73 74
                    return False
            return True

75 76
        elif self.value_type == "dict":
            if not isinstance(value, dict):
77
                logger.info("\nattr {} should be str, but {} now\n".format(
78 79 80 81 82 83
                    name, type(value)))
                return False
            return True

        elif self.value_type == "dicts":
            if not isinstance(value, list):
84 85
                logger.info("\nattr {} should be list(dist), but {} now\n".
                            format(name, type(value)))
86 87 88
                return False
            for v in value:
                if not isinstance(v, dict):
89 90 91
                    logger.info(
                        "\nattr {} should be list(dist), but list({}) now\n".
                        format(name, type(v)))
92 93 94
                    return False
            return True

X
xjqbest 已提交
95
        elif self.value_type == "ints":
X
test  
xjqbest 已提交
96
            if not isinstance(value, list):
97 98
                logger.info("\nattr {} should be list(int), but {} now\n".
                            format(name, type(value)))
X
test  
xjqbest 已提交
99 100 101
                return False
            for v in value:
                if not isinstance(v, int):
102 103 104
                    logger.info(
                        "\nattr {} should be list(int), but list({}) now\n".
                        format(name, type(v)))
X
test  
xjqbest 已提交
105 106 107 108
                    return False
            return True

        else:
109 110
            logger.info("\nattr {}'s type is {}, can not be supported now\n".
                        format(name, type(value)))
X
test  
xjqbest 已提交
111 112 113
            return False

    def is_value_valid(self, name, value):
114
        ret = self.value_handler(name, value, self.value)
X
test  
xjqbest 已提交
115 116 117 118 119
        return ret


def in_value_handler(name, value, values):
    if value not in values:
120
        logger.info("\nattr {}'s value is {}, but {} is expected\n".format(
X
test  
xjqbest 已提交
121 122 123 124 125 126 127
            name, value, values))
        return False
    return True


def eq_value_handler(name, value, values):
    if value != values:
128
        logger.info("\nattr {}'s value is {}, but == {} is expected\n".format(
X
test  
xjqbest 已提交
129 130 131 132 133 134 135
            name, value, values))
        return False
    return True


def ge_value_handler(name, value, values):
    if value < values:
136
        logger.info("\nattr {}'s value is {}, but >= {} is expected\n".format(
X
test  
xjqbest 已提交
137 138 139 140 141 142 143
            name, value, values))
        return False
    return True


def le_value_handler(name, value, values):
    if value > values:
144
        logger.info("\nattr {}'s value is {}, but <= {} is expected\n".format(
X
test  
xjqbest 已提交
145 146 147 148 149 150 151
            name, value, values))
        return False
    return True


def register():
    validations = {}
152 153 154 155 156 157
    validations["workspace"] = ValueFormat("str", None, None, True)
    validations["mode"] = ValueFormat(None, None, None, True)
    validations["runner"] = ValueFormat("dicts", None, None, True)
    validations["phase"] = ValueFormat("dicts", None, None, True)
    validations["hyper_parameters"] = ValueFormat("dict", None, None, False)
    return validations
X
test  
xjqbest 已提交
158 159 160


def yaml_validation(config):
161 162 163 164 165 166
    all_checkers = register()

    require_checkers = []
    for name, checker in all_checkers.items():
        if checker.required:
            require_checkers.append(name)
X
test  
xjqbest 已提交
167 168 169 170

    _config = envs.load_yaml(config)

    for required in require_checkers:
171
        if required not in _config.keys():
172 173
            logger.info("\ncan not find {} in yaml, which is required\n".
                        format(required))
X
test  
xjqbest 已提交
174 175
            return False

176
    for name, value in _config.items():
X
test  
xjqbest 已提交
177
        checker = all_checkers.get(name, None)
178 179 180 181
        if checker:
            ret = checker.is_valid(name, value)
            if not ret:
                return False
X
test  
xjqbest 已提交
182 183

    return True