envs.py 6.4 KB
Newer Older
T
tangwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

T
tangwei 已提交
15
from contextlib import closing
T
tangwei12 已提交
16
import copy
T
tangwei 已提交
17
import os
C
chengmo 已提交
18
import socket
T
tangwei 已提交
19
import sys
T
tangwei 已提交
20

T
tangwei12 已提交
21
global_envs = {}
T
tangwei 已提交
22

X
fix  
xjqbest 已提交
23
#global_envs_raw = {}
T
tangwei 已提交
24

T
tangwei 已提交
25
def flatten_environs(envs, separator="."):
T
tangwei 已提交
26
    flatten_dict = {}
T
tangwei 已提交
27 28
    assert isinstance(envs, dict)

T
fix bug  
tangwei 已提交
29
    def fatten_env_namespace(namespace_nests, local_envs):
T
fix bug  
tangwei 已提交
30
        if not isinstance(local_envs, dict):
T
tangwei 已提交
31
            global_k = separator.join(namespace_nests)
T
tangwei 已提交
32
            flatten_dict[global_k] = str(local_envs)
T
fix bug  
tangwei 已提交
33 34 35 36 37 38 39
        else:
            for k, v in local_envs.items():
                if isinstance(v, dict):
                    nests = copy.deepcopy(namespace_nests)
                    nests.append(k)
                    fatten_env_namespace(nests, v)
                else:
T
tangwei 已提交
40
                    global_k = separator.join(namespace_nests + [k])
T
tangwei 已提交
41
                    flatten_dict[global_k] = str(v)
T
fix bug  
tangwei 已提交
42

T
tangwei 已提交
43
    for k, v in envs.items():
T
fix bug  
tangwei 已提交
44
        fatten_env_namespace([k], v)
T
tangwei 已提交
45

T
tangwei 已提交
46
    return flatten_dict
T
tangwei 已提交
47

T
tangwei 已提交
48 49 50

def set_runtime_environs(environs):
    for k, v in environs.items():
T
tangwei 已提交
51
        os.environ[k] = str(v)
T
tangwei 已提交
52

T
tangwei 已提交
53

T
tangwei 已提交
54
def get_runtime_environ(key):
T
tangwei 已提交
55 56
    return os.getenv(key, None)

T
tangwei 已提交
57

T
tangwei 已提交
58
def get_trainer():
T
tangwei 已提交
59
    train_mode = get_runtime_environ("train.trainer.trainer")
T
tangwei 已提交
60 61 62
    return train_mode


T
tangwei12 已提交
63 64
def set_global_envs(envs):
    assert isinstance(envs, dict)
T
tangwei 已提交
65

X
fix  
xjqbest 已提交
66 67
#    namespace_nests = []
    #print(envs)
T
tangwei12 已提交
68
    def fatten_env_namespace(namespace_nests, local_envs):
X
fix  
xjqbest 已提交
69 70 71 72
#        if not isinstance(local_envs, dict):
#            global_k = ".".join(namespace_nests)
#            global_envs[global_k] = local_envs
#            return
T
tangwei12 已提交
73
        for k, v in local_envs.items():
X
fix  
xjqbest 已提交
74
            #print(k)
T
tangwei12 已提交
75 76 77 78
            if isinstance(v, dict):
                nests = copy.deepcopy(namespace_nests)
                nests.append(k)
                fatten_env_namespace(nests, v)
X
fix  
xjqbest 已提交
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
            elif (k == "dataset" or k == "executor") and isinstance(v, list):
                #print("=======================")
                #print([i for i in v])
                for i in v:
                    if i.get("name") is None:
                        raise ValueError("name must be in dataset list ", v)
                    nests = copy.deepcopy(namespace_nests)
                    nests.append(k)
                    nests.append(i["name"])
                    fatten_env_namespace(nests, i)
                    #global_k = ".".join(namespace_nests + [k, i["name"]])
                    #global_envs[global_k] = i

                #print([i for i in v])
                #global_k = ".".join(namespace_nests + [k])
                #global_envs[global_k] = v
T
tangwei12 已提交
95 96 97
            else:
                global_k = ".".join(namespace_nests + [k])
                global_envs[global_k] = v
T
tangwei 已提交
98

X
fix  
xjqbest 已提交
99 100 101 102 103
    #for k, v in envs.items():
    #    fatten_env_namespace([k], v)
    fatten_env_namespace([], envs)
    for i in global_envs:
        print i,":",global_envs[i]
T
tangwei 已提交
104

T
tangwei12 已提交
105
def get_global_env(env_name, default_value=None, namespace=None):
T
tangwei 已提交
106 107 108
    """
    get os environment value
    """
C
chengmo 已提交
109 110
    _env_name = env_name if namespace is None else ".".join(
        [namespace, env_name])
T
tangwei12 已提交
111 112 113
    return global_envs.get(_env_name, default_value)


T
tangwei 已提交
114 115 116 117
def get_global_envs():
    return global_envs


T
tangwei 已提交
118
def path_adapter(path):
T
tangwei 已提交
119 120
    if path.startswith("paddlerec."):
        package = get_runtime_environ("PACKAGE_BASE")
F
frankwhzhang 已提交
121 122
        l_p = path.split("paddlerec.")[1].replace(".", "/")
        return os.path.join(package, l_p)
T
tangwei 已提交
123
    else:
T
tangwei 已提交
124
        return path
T
tangwei 已提交
125 126


T
tangwei 已提交
127 128 129 130 131 132 133
def windows_path_converter(path):
    if get_platform() == "WINDOWS":
        return path.replace("/", "\\")
    else:
        return path.replace("\\", "/")


T
tangwei 已提交
134
def update_workspace():
X
fix  
xjqbest 已提交
135
    workspace = global_envs.get("workspace")
T
tangwei 已提交
136 137
    if not workspace:
        return
T
tangwei 已提交
138
    workspace = path_adapter(workspace)
T
tangwei 已提交
139 140 141

    for name, value in global_envs.items():
        if isinstance(value, str):
T
tangwei 已提交
142
            value = value.replace("{workspace}", workspace)
T
tangwei 已提交
143
            value = windows_path_converter(value)
T
tangwei 已提交
144 145
            global_envs[name] = value

T
tangwei 已提交
146

T
tangwei12 已提交
147
def pretty_print_envs(envs, header=None):
T
tangwei12 已提交
148 149
    spacing = 5
    max_k = 45
T
tangwei 已提交
150
    max_v = 50
T
tangwei12 已提交
151

T
tangwei 已提交
152
    for k, v in envs.items():
T
tangwei12 已提交
153 154
        max_k = max(max_k, len(k))

T
tangwei12 已提交
155
    h_format = "{{:^{}s}}{}{{:<{}s}}\n".format(max_k, " " * spacing, max_v)
T
tangwei12 已提交
156 157 158 159 160 161 162 163
    l_format = "{{:<{}s}}{{}}{{:<{}s}}\n".format(max_k, max_v)
    length = max_k + max_v + spacing

    border = "".join(["="] * length)
    line = "".join(["-"] * length)

    draws = ""
    draws += border + "\n"
T
tangwei 已提交
164 165 166 167

    if header:
        draws += h_format.format(header[0], header[1])
    else:
168
        draws += h_format.format("paddlerec Global Envs", "Value")
T
tangwei 已提交
169

T
tangwei12 已提交
170 171
    draws += line + "\n"

T
tangwei 已提交
172
    for k, v in envs.items():
T
tangwei 已提交
173 174 175 176 177 178
        if isinstance(v, str) and len(v) >= max_v:
            str_v = "... " + v[-46:]
        else:
            str_v = v

        draws += l_format.format(k, " " * spacing, str(str_v))
T
tangwei12 已提交
179 180 181 182 183

    draws += border

    _str = "\n{}\n".format(draws)
    return _str
T
tangwei 已提交
184 185


T
tangwei 已提交
186
def lazy_instance_by_package(package, class_name):
T
tangwei 已提交
187
    models = get_global_env("train.model.models")
T
tangwei 已提交
188 189
    model_package = __import__(package,
                               globals(), locals(), package.split("."))
T
tangwei 已提交
190 191
    instance = getattr(model_package, class_name)
    return instance
T
tangwei 已提交
192 193


T
tangwei 已提交
194 195
def lazy_instance_by_fliename(abs, class_name):
    dirname = os.path.dirname(abs)
T
tangwei 已提交
196
    sys.path.append(dirname)
T
tangwei 已提交
197
    package = os.path.splitext(os.path.basename(abs))[0]
T
tangwei 已提交
198

T
tangwei 已提交
199 200
    model_package = __import__(package,
                               globals(), locals(), package.split("."))
T
tangwei 已提交
201 202
    instance = getattr(model_package, class_name)
    return instance
T
tangwei 已提交
203 204


T
tangwei 已提交
205 206 207 208 209 210 211 212 213
def get_platform():
    import platform
    plats = platform.platform()
    if 'Linux' in plats:
        return "LINUX"
    if 'Darwin' in plats:
        return "DARWIN"
    if 'Windows' in plats:
        return "WINDOWS"
C
chengmo 已提交
214 215 216 217


def find_free_port():
    def __free_port():
T
tangwei 已提交
218
        with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
C
chengmo 已提交
219 220
            s.bind(('', 0))
            return s.getsockname()[1]
T
tangwei 已提交
221

C
chengmo 已提交
222 223
    new_port = __free_port()
    return new_port