diff --git a/core/utils/validation.py b/core/utils/validation.py index 836b3373ffe7dec9fecb4fea1a9a3c92e63fe719..691fb1b7d68c321fd49ac488673edafede9c37e8 100644 --- a/core/utils/validation.py +++ b/core/utils/validation.py @@ -16,11 +16,10 @@ from paddlerec.core.utils import envs class ValueFormat: - def __init__(self, type, value, value_handler): - self.type = type + def __init__(self, value_type, value, value_handler): + self.value_type = value_type self.value = value self.value_handler = value_handler - self.help = help def is_valid(self, name, value): ret = self.is_type_valid(name, value) @@ -31,24 +30,24 @@ class ValueFormat: return ret def is_type_valid(self, name, value): - if self.type == "int": + if self.value_type == "int": if not isinstance(value, int): print("\nattr {} should be int, but {} now\n".format( - name, self.type)) + name, self.value_type)) return False return True - elif self.type == "str": + elif self.value_type == "str": if not isinstance(value, str): print("\nattr {} should be str, but {} now\n".format( - name, self.type)) + name, self.value_type)) return False return True - elif self.type == "strs": + elif self.value_type == "strs": if not isinstance(value, list): print("\nattr {} should be list(str), but {} now\n".format( - name, self.type)) + name, self.value_type)) return False for v in value: if not isinstance(v, str): @@ -57,10 +56,10 @@ class ValueFormat: return False return True - elif self.type == "ints": + elif self.value_type == "ints": if not isinstance(value, list): print("\nattr {} should be list(int), but {} now\n".format( - name, self.type)) + name, self.value_type)) return False for v in value: if not isinstance(v, int): @@ -113,13 +112,6 @@ def le_value_handler(name, value, values): def register(): validations = {} - validations["train.workspace"] = ValueFormat("str", None, eq_value_handler) - validations["train.device"] = ValueFormat("str", ["cpu", "gpu"], - in_value_handler) - validations["train.epochs"] = ValueFormat("int", 1, ge_value_handler) - validations["train.engine"] = ValueFormat( - "str", ["single", "local_cluster", "cluster"], in_value_handler) - requires = ["workspace", "dataset", "mode", "runner", "phase"] return validations, requires