envs.py 2.8 KB
Newer Older
T
tangwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

T
tangwei 已提交
15
import os
T
tangwei12 已提交
16
import copy
T
tangwei 已提交
17

T
tangwei12 已提交
18
global_envs = {}
T
tangwei 已提交
19 20


T
tangwei 已提交
21 22 23 24
def set_runtime_envions(envs):
    assert isinstance(envs, dict)

    for k, v in envs.items():
T
tangwei 已提交
25
        os.environ[k] = str(v)
T
tangwei 已提交
26 27 28 29 30 31


def get_runtime_envion(key):
    return os.getenv(key, None)


T
tangwei 已提交
32 33
def get_trainer():
    train_mode = get_runtime_envion("trainer.trainer")
T
tangwei 已提交
34 35 36
    return train_mode


T
tangwei12 已提交
37 38
def set_global_envs(envs):
    assert isinstance(envs, dict)
T
tangwei 已提交
39

T
tangwei12 已提交
40 41 42 43 44 45 46 47 48
    def fatten_env_namespace(namespace_nests, local_envs):
        for k, v in local_envs.items():
            if isinstance(v, dict):
                nests = copy.deepcopy(namespace_nests)
                nests.append(k)
                fatten_env_namespace(nests, v)
            else:
                global_k = ".".join(namespace_nests + [k])
                global_envs[global_k] = v
T
tangwei 已提交
49

T
tangwei12 已提交
50 51
    for k, v in envs.items():
        fatten_env_namespace([k], v)
T
tangwei 已提交
52 53


T
tangwei12 已提交
54
def get_global_env(env_name, default_value=None, namespace=None):
T
tangwei 已提交
55 56 57
    """
    get os environment value
    """
T
tangwei12 已提交
58 59 60 61
    _env_name = env_name if namespace is None else ".".join([namespace, env_name])
    return global_envs.get(_env_name, default_value)


T
tangwei 已提交
62 63 64 65
def get_global_envs():
    return global_envs


T
tangwei12 已提交
66
def pretty_print_envs(envs, header=None):
T
tangwei12 已提交
67 68 69 70
    spacing = 5
    max_k = 45
    max_v = 20

T
tangwei 已提交
71
    for k, v in envs.items():
T
tangwei12 已提交
72 73 74
        max_k = max(max_k, len(k))
        max_v = max(max_v, len(str(v)))

T
tangwei12 已提交
75
    h_format = "{{:^{}s}}{}{{:<{}s}}\n".format(max_k, " " * spacing, max_v)
T
tangwei12 已提交
76 77 78 79 80 81 82 83
    l_format = "{{:<{}s}}{{}}{{:<{}s}}\n".format(max_k, max_v)
    length = max_k + max_v + spacing

    border = "".join(["="] * length)
    line = "".join(["-"] * length)

    draws = ""
    draws += border + "\n"
T
tangwei 已提交
84 85 86 87

    if header:
        draws += h_format.format(header[0], header[1])
    else:
T
tangwei 已提交
88
        draws += h_format.format("fleetrec Global Envs", "Value")
T
tangwei 已提交
89

T
tangwei12 已提交
90 91
    draws += line + "\n"

T
tangwei 已提交
92
    for k, v in envs.items():
T
tangwei12 已提交
93 94 95 96 97 98
        draws += l_format.format(k, " " * spacing, str(v))

    draws += border

    _str = "\n{}\n".format(draws)
    return _str
T
tangwei 已提交
99 100 101 102 103 104 105


def lazy_instance(package, class_name):
    models = get_global_env("train.model.models")
    model_package = __import__(package, globals(), locals(), package.split("."))
    instance = getattr(model_package, class_name)
    return instance