envs.py 8.6 KB
Newer Older
T
tangwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

T
tangwei 已提交
15
from contextlib import closing
Y
yaoxuefeng 已提交
16
import yaml
T
tangwei12 已提交
17
import copy
T
tangwei 已提交
18
import os
C
chengmo 已提交
19
import socket
T
tangwei 已提交
20
import sys
C
Chengmo 已提交
21
import six
X
xionghang 已提交
22
import traceback
L
liuyuhui 已提交
23
import warnings
T
tangwei 已提交
24

T
tangwei12 已提交
25
global_envs = {}
T
tangwei 已提交
26
global_envs_flatten = {}
T
tangwei 已提交
27

X
fix  
xjqbest 已提交
28

T
tangwei 已提交
29
def flatten_environs(envs, separator="."):
T
tangwei 已提交
30
    flatten_dict = {}
T
tangwei 已提交
31 32
    assert isinstance(envs, dict)

T
fix bug  
tangwei 已提交
33
    def fatten_env_namespace(namespace_nests, local_envs):
T
fix bug  
tangwei 已提交
34
        if not isinstance(local_envs, dict):
T
tangwei 已提交
35
            global_k = separator.join(namespace_nests)
T
tangwei 已提交
36
            flatten_dict[global_k] = str(local_envs)
T
fix bug  
tangwei 已提交
37 38 39 40 41 42 43
        else:
            for k, v in local_envs.items():
                if isinstance(v, dict):
                    nests = copy.deepcopy(namespace_nests)
                    nests.append(k)
                    fatten_env_namespace(nests, v)
                else:
T
tangwei 已提交
44
                    global_k = separator.join(namespace_nests + [k])
T
tangwei 已提交
45
                    flatten_dict[global_k] = str(v)
T
fix bug  
tangwei 已提交
46

T
tangwei 已提交
47
    for k, v in envs.items():
T
fix bug  
tangwei 已提交
48
        fatten_env_namespace([k], v)
T
tangwei 已提交
49

T
tangwei 已提交
50
    return flatten_dict
T
tangwei 已提交
51

T
tangwei 已提交
52 53 54

def set_runtime_environs(environs):
    for k, v in environs.items():
T
tangwei 已提交
55
        os.environ[k] = str(v)
T
tangwei 已提交
56

T
tangwei 已提交
57

T
tangwei 已提交
58
def get_runtime_environ(key):
T
tangwei 已提交
59 60
    return os.getenv(key, None)

T
tangwei 已提交
61

T
tangwei 已提交
62
def get_trainer():
T
tangwei 已提交
63
    train_mode = get_runtime_environ("train.trainer.trainer")
T
tangwei 已提交
64 65 66
    return train_mode


C
Chengmo 已提交
67 68 69 70 71
def get_fleet_mode():
    fleet_mode = get_runtime_environ("fleet_mode")
    return fleet_mode


T
tangwei 已提交
72
def set_global_envs(envs):
T
tangwei12 已提交
73
    assert isinstance(envs, dict)
T
tangwei 已提交
74

T
tangwei12 已提交
75 76 77 78 79 80
    def fatten_env_namespace(namespace_nests, local_envs):
        for k, v in local_envs.items():
            if isinstance(v, dict):
                nests = copy.deepcopy(namespace_nests)
                nests.append(k)
                fatten_env_namespace(nests, v)
X
fix  
xjqbest 已提交
81 82
            elif (k == "dataset" or k == "phase" or
                  k == "runner") and isinstance(v, list):
X
fix  
xjqbest 已提交
83 84 85 86 87 88 89
                for i in v:
                    if i.get("name") is None:
                        raise ValueError("name must be in dataset list ", v)
                    nests = copy.deepcopy(namespace_nests)
                    nests.append(k)
                    nests.append(i["name"])
                    fatten_env_namespace(nests, i)
T
tangwei12 已提交
90 91 92
            else:
                global_k = ".".join(namespace_nests + [k])
                global_envs[global_k] = v
T
tangwei 已提交
93

X
fix  
xjqbest 已提交
94
    fatten_env_namespace([], envs)
T
tangwei 已提交
95

T
tangwei 已提交
96 97 98 99
    for name, value in global_envs.items():
        if isinstance(value, str):
            value = os_path_adapter(workspace_adapter(value))
            global_envs[name] = value
T
tangwei 已提交
100

L
liuyuhui 已提交
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
    for runner in envs["runner"]:
        if "save_step_interval" in runner or "save_step_path" in runner:
            phase_name = runner["phases"]
            phase = [
                phase for phase in envs["phase"]
                if phase["name"] == phase_name[0]
            ]
            dataset_name = phase[0].get("dataset_name")
            dataset = [
                dataset for dataset in envs["dataset"]
                if dataset["name"] == dataset_name
            ]
            if dataset[0].get("type") == "QueueDataset":
                runner["save_step_interval"] = None
                runner["save_step_path"] = None
                warnings.warn(
                    "QueueDataset can not support save by step, please not config save_step_interval and save_step_path in your yaml"
                )

T
tangwei 已提交
120 121
    if get_platform() != "LINUX":
        for dataset in envs["dataset"]:
T
tangwei 已提交
122
            name = ".".join(["dataset", dataset["name"], "type"])
T
tangwei 已提交
123 124
            global_envs[name] = "DataLoader"

C
Chengmo 已提交
125 126 127 128 129 130
    if get_platform() == "LINUX" and six.PY3:
        print("QueueDataset can not support PY3, change to DataLoader")
        for dataset in envs["dataset"]:
            name = ".".join(["dataset", dataset["name"], "type"])
            global_envs[name] = "DataLoader"

X
fix  
xjqbest 已提交
131

T
tangwei12 已提交
132
def get_global_env(env_name, default_value=None, namespace=None):
T
tangwei 已提交
133 134 135
    """
    get os environment value
    """
C
chengmo 已提交
136 137
    _env_name = env_name if namespace is None else ".".join(
        [namespace, env_name])
T
tangwei12 已提交
138 139 140
    return global_envs.get(_env_name, default_value)


T
tangwei 已提交
141 142 143 144
def get_global_envs():
    return global_envs


T
tangwei 已提交
145
def paddlerec_adapter(path):
T
tangwei 已提交
146 147
    if path.startswith("paddlerec."):
        package = get_runtime_environ("PACKAGE_BASE")
F
frankwhzhang 已提交
148 149
        l_p = path.split("paddlerec.")[1].replace(".", "/")
        return os.path.join(package, l_p)
T
tangwei 已提交
150
    else:
T
tangwei 已提交
151
        return path
T
tangwei 已提交
152 153


T
tangwei 已提交
154 155 156 157 158 159
def os_path_adapter(value):
    if get_platform() == "WINDOWS":
        value = value.replace("/", "\\")
    else:
        value = value.replace("\\", "/")
    return value
T
tangwei 已提交
160 161


T
tangwei 已提交
162
def workspace_adapter(value):
X
fix  
xjqbest 已提交
163
    workspace = global_envs.get("workspace")
J
Jinhua Liang 已提交
164 165 166 167
    return workspace_adapter_by_specific(value, workspace)


def workspace_adapter_by_specific(value, workspace):
T
tangwei 已提交
168
    workspace = paddlerec_adapter(workspace)
T
tangwei 已提交
169 170
    value = value.replace("{workspace}", workspace)
    return value
T
tangwei 已提交
171

T
tangwei 已提交
172

T
tangwei 已提交
173 174 175 176 177 178 179 180 181
def reader_adapter():
    if get_platform() != "WINDOWS":
        return

    datasets = global_envs.get("dataset")
    for dataset in datasets:
        dataset["type"] = "DataLoader"


T
tangwei12 已提交
182
def pretty_print_envs(envs, header=None):
T
tangwei12 已提交
183 184
    spacing = 5
    max_k = 45
T
tangwei 已提交
185
    max_v = 50
T
tangwei12 已提交
186

T
tangwei 已提交
187
    for k, v in envs.items():
T
tangwei12 已提交
188 189
        max_k = max(max_k, len(k))

T
tangwei12 已提交
190
    h_format = "{{:^{}s}}{}{{:<{}s}}\n".format(max_k, " " * spacing, max_v)
T
tangwei12 已提交
191 192 193 194 195 196 197 198
    l_format = "{{:<{}s}}{{}}{{:<{}s}}\n".format(max_k, max_v)
    length = max_k + max_v + spacing

    border = "".join(["="] * length)
    line = "".join(["-"] * length)

    draws = ""
    draws += border + "\n"
T
tangwei 已提交
199 200 201 202

    if header:
        draws += h_format.format(header[0], header[1])
    else:
203
        draws += h_format.format("paddlerec Global Envs", "Value")
T
tangwei 已提交
204

T
tangwei12 已提交
205 206
    draws += line + "\n"

T
tangwei 已提交
207
    for k, v in envs.items():
T
tangwei 已提交
208 209 210 211 212 213
        if isinstance(v, str) and len(v) >= max_v:
            str_v = "... " + v[-46:]
        else:
            str_v = v

        draws += l_format.format(k, " " * spacing, str(str_v))
T
tangwei12 已提交
214 215 216 217 218

    draws += border

    _str = "\n{}\n".format(draws)
    return _str
T
tangwei 已提交
219 220


T
tangwei 已提交
221
def lazy_instance_by_package(package, class_name):
X
xionghang 已提交
222 223 224 225 226
    try:
        model_package = __import__(package,
                                   globals(), locals(), package.split("."))
        instance = getattr(model_package, class_name)
        return instance
T
tangwei 已提交
227
    except Exception as err:
X
xionghang 已提交
228 229 230
        traceback.print_exc()
        print('Catch Exception:%s' % str(err))
        return None
T
tangwei 已提交
231 232


T
tangwei 已提交
233
def lazy_instance_by_fliename(abs, class_name):
X
xionghang 已提交
234 235 236 237 238 239 240 241 242
    try:
        dirname = os.path.dirname(abs)
        sys.path.append(dirname)
        package = os.path.splitext(os.path.basename(abs))[0]

        model_package = __import__(package,
                                   globals(), locals(), package.split("."))
        instance = getattr(model_package, class_name)
        return instance
T
tangwei 已提交
243
    except Exception as err:
X
xionghang 已提交
244 245 246
        traceback.print_exc()
        print('Catch Exception:%s' % str(err))
        return None
T
tangwei 已提交
247 248


T
tangwei 已提交
249 250 251 252 253 254 255 256 257
def get_platform():
    import platform
    plats = platform.platform()
    if 'Linux' in plats:
        return "LINUX"
    if 'Darwin' in plats:
        return "DARWIN"
    if 'Windows' in plats:
        return "WINDOWS"
C
chengmo 已提交
258 259 260 261


def find_free_port():
    def __free_port():
T
tangwei 已提交
262
        with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
C
chengmo 已提交
263 264
            s.bind(('', 0))
            return s.getsockname()[1]
T
tangwei 已提交
265

C
chengmo 已提交
266 267
    new_port = __free_port()
    return new_port
X
test  
xjqbest 已提交
268 269 270 271 272 273 274 275 276 277 278 279 280 281 282


def load_yaml(config):
    vs = [int(i) for i in yaml.__version__.split(".")]
    if vs[0] < 5:
        use_full_loader = False
    elif vs[0] > 5:
        use_full_loader = True
    else:
        if vs[1] >= 1:
            use_full_loader = True
        else:
            use_full_loader = False

    if os.path.isfile(config):
G
gentelyang 已提交
283 284 285 286 287 288 289 290 291 292 293 294 295 296
        if six.PY2:
            with open(config, 'r') as rb:
                if use_full_loader:
                    _config = yaml.load(rb.read(), Loader=yaml.FullLoader)
                else:
                    _config = yaml.load(rb.read())
                return _config
        else:
            with open(config, 'r', encoding="utf-8") as rb:
                if use_full_loader:
                    _config = yaml.load(rb.read(), Loader=yaml.FullLoader)
                else:
                    _config = yaml.load(rb.read())
                return _config
X
test  
xjqbest 已提交
297 298
    else:
        raise ValueError("config {} can not be supported".format(config))