diff --git a/core/utils/validation.py b/core/utils/validation.py index 691fb1b7d68c321fd49ac488673edafede9c37e8..43c21d3e12d2ca84246e1d298b296c8eb5e76868 100644 --- a/core/utils/validation.py +++ b/core/utils/validation.py @@ -112,6 +112,13 @@ def le_value_handler(name, value, values): def register(): validations = {} + validations["train.workspace"] = ValueFormat("str", None, eq_value_handler) + validations["train.device"] = ValueFormat("str", ["cpu", "gpu"], + in_value_handler) + validations["train.epochs"] = ValueFormat("int", 1, ge_value_handler) + validations["train.engine"] = ValueFormat( + "str", ["single", "local_cluster", "cluster"], in_value_handler) + requires = ["workspace", "dataset", "mode", "runner", "phase"] return validations, requires