envs.py 4.8 KB
Newer Older
T
tangwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

T
tangwei 已提交
15
import os
T
tangwei12 已提交
16
import copy
T
tangwei 已提交
17
import sys
T
tangwei 已提交
18

T
tangwei12 已提交
19
global_envs = {}
T
tangwei 已提交
20 21


T
tangwei 已提交
22 23
def flatten_environs(envs):
    flatten_dict = {}
T
tangwei 已提交
24 25
    assert isinstance(envs, dict)

T
fix bug  
tangwei 已提交
26
    def fatten_env_namespace(namespace_nests, local_envs):
T
fix bug  
tangwei 已提交
27 28
        if not isinstance(local_envs, dict):
            global_k = ".".join(namespace_nests)
T
tangwei 已提交
29
            flatten_dict[global_k] = str(local_envs)
T
fix bug  
tangwei 已提交
30 31 32 33 34 35 36 37
        else:
            for k, v in local_envs.items():
                if isinstance(v, dict):
                    nests = copy.deepcopy(namespace_nests)
                    nests.append(k)
                    fatten_env_namespace(nests, v)
                else:
                    global_k = ".".join(namespace_nests + [k])
T
tangwei 已提交
38
                    flatten_dict[global_k] = str(v)
T
fix bug  
tangwei 已提交
39

T
tangwei 已提交
40
    for k, v in envs.items():
T
fix bug  
tangwei 已提交
41
        fatten_env_namespace([k], v)
T
tangwei 已提交
42

T
tangwei 已提交
43
    return flatten_dict
T
tangwei 已提交
44

T
tangwei 已提交
45 46 47

def set_runtime_environs(environs):
    for k, v in environs.items():
T
tangwei 已提交
48
        os.environ[k] = str(v)
T
tangwei 已提交
49

T
tangwei 已提交
50

T
tangwei 已提交
51
def get_runtime_environ(key):
T
tangwei 已提交
52 53
    return os.getenv(key, None)

T
tangwei 已提交
54

T
tangwei 已提交
55
def get_trainer():
T
tangwei 已提交
56
    train_mode = get_runtime_environ("train.trainer.trainer")
T
tangwei 已提交
57 58 59
    return train_mode


T
tangwei12 已提交
60 61
def set_global_envs(envs):
    assert isinstance(envs, dict)
T
tangwei 已提交
62

T
tangwei12 已提交
63 64 65 66 67 68 69 70 71
    def fatten_env_namespace(namespace_nests, local_envs):
        for k, v in local_envs.items():
            if isinstance(v, dict):
                nests = copy.deepcopy(namespace_nests)
                nests.append(k)
                fatten_env_namespace(nests, v)
            else:
                global_k = ".".join(namespace_nests + [k])
                global_envs[global_k] = v
T
tangwei 已提交
72

T
tangwei12 已提交
73 74
    for k, v in envs.items():
        fatten_env_namespace([k], v)
T
tangwei 已提交
75 76


T
tangwei12 已提交
77
def get_global_env(env_name, default_value=None, namespace=None):
T
tangwei 已提交
78 79 80
    """
    get os environment value
    """
T
tangwei12 已提交
81 82 83 84
    _env_name = env_name if namespace is None else ".".join([namespace, env_name])
    return global_envs.get(_env_name, default_value)


T
tangwei 已提交
85 86 87 88
def get_global_envs():
    return global_envs


T
tangwei 已提交
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
def update_workspace():
    workspace = global_envs.get("train.workspace", None)
    if not workspace:
        return

    # is fleet inner models
    if workspace.startswith("fleetrec."):
        fleet_package = get_runtime_environ("PACKAGE_BASE")
        workspace_dir = workspace.split("fleetrec.")[1].replace(".", "/")
        path = os.path.join(fleet_package, workspace_dir)
    else:
        path = workspace

    for name, value in global_envs.items():
        if isinstance(value, str):
            value = value.replace("{workspace}", path)
            global_envs[name] = value

T
tangwei 已提交
107

T
tangwei12 已提交
108
def pretty_print_envs(envs, header=None):
T
tangwei12 已提交
109 110
    spacing = 5
    max_k = 45
T
tangwei 已提交
111
    max_v = 50
T
tangwei12 已提交
112

T
tangwei 已提交
113
    for k, v in envs.items():
T
tangwei12 已提交
114 115
        max_k = max(max_k, len(k))

T
tangwei12 已提交
116
    h_format = "{{:^{}s}}{}{{:<{}s}}\n".format(max_k, " " * spacing, max_v)
T
tangwei12 已提交
117 118 119 120 121 122 123 124
    l_format = "{{:<{}s}}{{}}{{:<{}s}}\n".format(max_k, max_v)
    length = max_k + max_v + spacing

    border = "".join(["="] * length)
    line = "".join(["-"] * length)

    draws = ""
    draws += border + "\n"
T
tangwei 已提交
125 126 127 128

    if header:
        draws += h_format.format(header[0], header[1])
    else:
T
tangwei 已提交
129
        draws += h_format.format("fleetrec Global Envs", "Value")
T
tangwei 已提交
130

T
tangwei12 已提交
131 132
    draws += line + "\n"

T
tangwei 已提交
133
    for k, v in envs.items():
T
tangwei 已提交
134 135 136 137 138 139
        if isinstance(v, str) and len(v) >= max_v:
            str_v = "... " + v[-46:]
        else:
            str_v = v

        draws += l_format.format(k, " " * spacing, str(str_v))
T
tangwei12 已提交
140 141 142 143 144

    draws += border

    _str = "\n{}\n".format(draws)
    return _str
T
tangwei 已提交
145 146


T
tangwei 已提交
147
def lazy_instance_by_fliename(package, class_name):
T
tangwei 已提交
148 149 150 151
    models = get_global_env("train.model.models")
    model_package = __import__(package, globals(), locals(), package.split("."))
    instance = getattr(model_package, class_name)
    return instance
T
tangwei 已提交
152 153


T
tangwei 已提交
154 155 156 157 158 159 160 161 162 163 164 165 166
def lazy_instance_by_fliename(package, class_name):
    models = get_global_env("train.model.models")

    dirname = os.path.dirname(models)
    basename = os.path.basename(models)
    sys.path.append(dirname)
    from basename import Model

#    model_package = __import__(package, globals(), locals(), package.split("."))
#    instance = getattr(model_package, class_name)
    return Model


T
tangwei 已提交
167 168 169 170 171 172 173 174 175
def get_platform():
    import platform
    plats = platform.platform()
    if 'Linux' in plats:
        return "LINUX"
    if 'Darwin' in plats:
        return "DARWIN"
    if 'Windows' in plats:
        return "WINDOWS"