utils.py 12.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import itertools
17 18 19
import os
import re
from typing import Tuple
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40


def divisor(num, reverse=False):
    """Return the divisor of the given number."""
    results = set()
    i = 1
    mid = num // 2 + 1
    while i < mid:
        if num % i == 0:
            results.add(i)
            results.add(num // i)
        i += 1
    results = list(results)
    return sorted(results, reverse=reverse)


def dist_degree(mode, num_gpus, num_nodes):
    """Return the degree of different parallel modes by gpus and nodes num."""
    assert mode in ["dp", "mp", "pp", "sharding"]
    results = []
    if mode == "dp":
41
        results = divisor(num_gpus, reverse=False)
42 43 44

    elif mode == "pp":
        if num_nodes > 1:
45
            results = list(range(1, num_nodes + 1))
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
        else:
            results = divisor(num_gpus, reverse=True)

    elif mode == "mp":
        gpus_per_node = num_gpus // num_nodes
        results = divisor(gpus_per_node, reverse=True)

    elif mode == "sharding":
        results = divisor(num_gpus, reverse=True)

    return results


def default_candidates(tuner_cfg):
    """Return the default candidates of every hyper param which user defined auto"""
    candidates = {}
    num_gpus = tuner_cfg["num_gpus"]
    num_nodes = tuner_cfg["nodes"]
    assert num_gpus > 0

    if tuner_cfg.get("dp_degree", None) == "auto":
        candidates["dp_degree"] = dist_degree("dp", num_gpus, num_nodes)
    elif tuner_cfg.get("dp_degree", None):
        candidates["dp_degree"] = tuner_cfg.get("dp_degree")
    else:
        candidates["dp_degree"] = [1]

    if tuner_cfg.get("mp_degree", None) == "auto":
        candidates["mp_degree"] = dist_degree("mp", num_gpus, num_nodes)
    elif tuner_cfg.get("mp_degree", None):
        candidates["mp_degree"] = tuner_cfg.get("mp_degree")
    else:
        candidates["mp_degree"] = [1]

    if tuner_cfg.get("pp_degree", None) == "auto":
        candidates["pp_degree"] = dist_degree("pp", num_gpus, num_nodes)
    elif tuner_cfg.get("pp_degree", None):
        candidates["pp_degree"] = tuner_cfg.get("pp_degree")
    else:
        candidates["pp_degree"] = [1]

    if tuner_cfg.get("sharding_degree", None) == "auto":
        candidates["sharding_degree"] = dist_degree(
            "sharding", num_gpus, num_nodes
        )
    elif tuner_cfg.get("sharding_degree", None):
        candidates["sharding_degree"] = tuner_cfg.get("sharding_degree")
    else:
        candidates["sharding_degree"] = [1]

    if tuner_cfg.get("sharding_stage", None) == "auto":
        candidates["sharding_stage"] = [1, 2, 3]
    elif tuner_cfg.get("sharding_stage", None):
        candidates["sharding_stage"] = tuner_cfg.get("sharding_stage")
    else:
        candidates["sharding_stage"] = [None]

    if tuner_cfg.get("use_recompute", None) == "auto":
        candidates["use_recompute"] = [False, True]
    elif tuner_cfg.get("use_recompute", None):
        candidates["use_recompute"] = tuner_cfg.get("use_recompute")
    else:
        candidates["use_recompute"] = [None]

    if tuner_cfg.get("recompute_granularity", None) == "auto":
        candidates["recompute_granularity"] = ["full_attn", "full"]
    elif tuner_cfg.get("recompute_granularity", None):
        candidates["recompute_granularity"] = tuner_cfg.get(
            "recompute_granularity"
        )
    else:
        candidates["recompute_granularity"] = [None]

    if tuner_cfg.get("micro_batch_size", None) == "auto":
        candidates["micro_batch_size"] = list(
            range(tuner_cfg["model_cfg"]["global_batch_size"], 0, -1)
        )
    elif tuner_cfg.get("micro_batch_size", None):
        candidates["micro_batch_size"] = tuner_cfg.get("micro_batch_size")
    else:
126
        candidates["micro_batch_size"] = [None]
127 128 129 130 131 132 133

    return candidates


def search_all(tuner_cfg):
    """Permutate the candidates of all hyper params."""
    candidates = tuner_cfg["candidates"]
134
    # Order: dp -> sharding -> mbs -> pp -> mp  -> recompute
135 136 137 138 139 140 141 142 143 144 145 146 147
    dp_degree_candidates = candidates["dp_degree"]
    mp_degree_candidates = candidates["mp_degree"]
    pp_degree_candidates = candidates["pp_degree"]
    mbs_candidates = candidates["micro_batch_size"]
    sharding_stage_candidates = candidates["sharding_stage"]
    sharding_degree_candidates = candidates["sharding_degree"]
    use_recompute_candidates = candidates["use_recompute"]
    recompute_granularity_candidates = candidates["recompute_granularity"]
    all_cfgs = list(
        itertools.product(
            dp_degree_candidates,
            sharding_degree_candidates,
            sharding_stage_candidates,
148 149 150
            mbs_candidates,
            pp_degree_candidates,
            mp_degree_candidates,
151 152 153 154 155 156
            use_recompute_candidates,
            recompute_granularity_candidates,
        )
    )
    mapping = {
        0: "dp_degree",
157 158
        1: "sharding_degree",
        2: "sharding_stage",
159
        3: "micro_batch_size",
160 161
        4: "pp_degree",
        5: "mp_degree",
162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
        6: "use_recompute",
        7: "recompute_granularity",
    }
    new_all_cfgs = []
    for cfg in all_cfgs:
        new_cfg = {}
        for idx, val in enumerate(cfg):
            new_cfg[mapping[idx]] = val
        new_all_cfgs.append(new_cfg)
    return new_all_cfgs


def gen_new_args(raw_args, cfg, tuner_cfg):
    """Generate new script args."""
    assert "run_cmd" in tuner_cfg
    cmd = copy.deepcopy(tuner_cfg["run_cmd"])
    res_args = copy.deepcopy(raw_args)
    if "dp_degree" in cmd and "dp_degree" in cfg:
180 181 182 183 184 185 186 187
        if "--" in cmd["dp_degree"][0]:
            cmd["dp_degree"][1] = cmd["dp_degree"][1] + str(cfg["dp_degree"])
            res_args.extend(cmd["dp_degree"])
        else:
            cmd["dp_degree"][1] = (
                cmd["dp_degree"][1] + "=" + str(cfg["dp_degree"])
            )
            res_args.extend(cmd["dp_degree"])
188 189

    if "mp_degree" in cmd and "mp_degree" in cfg:
190 191 192 193 194 195 196 197
        if "--" in cmd["mp_degree"][0]:
            cmd["mp_degree"][1] = cmd["mp_degree"][1] + str(cfg["mp_degree"])
            res_args.extend(cmd["mp_degree"])
        else:
            cmd["mp_degree"][1] = (
                cmd["mp_degree"][1] + "=" + str(cfg["mp_degree"])
            )
            res_args.extend(cmd["mp_degree"])
198 199

    if "pp_degree" in cmd and "pp_degree" in cfg:
200 201 202 203 204 205 206 207
        if "--" in cmd["pp_degree"][0]:
            cmd["pp_degree"][1] = cmd["pp_degree"][1] + str(cfg["pp_degree"])
            res_args.extend(cmd["pp_degree"])
        else:
            cmd["pp_degree"][1] = (
                cmd["pp_degree"][1] + "=" + str(cfg["pp_degree"])
            )
            res_args.extend(cmd["pp_degree"])
208 209

    if "micro_batch_size" in cmd and "micro_batch_size" in cfg:
210 211 212 213 214 215 216 217 218 219
        if "--" in cmd["micro_batch_size"][0]:
            cmd["micro_batch_size"][1] = cmd["micro_batch_size"][1] + str(
                cfg["micro_batch_size"]
            )
            res_args.extend(cmd["micro_batch_size"])
        else:
            cmd["micro_batch_size"][1] = (
                cmd["micro_batch_size"][1] + "=" + str(cfg["micro_batch_size"])
            )
            res_args.extend(cmd["micro_batch_size"])
220 221

    if "sharding_degree" in cmd and "sharding_degree" in cfg:
222 223 224 225 226 227 228 229 230 231
        if "--" in cmd["sharding_degree"][0]:
            cmd["sharding_degree"][1] = cmd["sharding_degree"][1] + str(
                cfg["sharding_degree"]
            )
            res_args.extend(cmd["sharding_degree"])
        else:
            cmd["sharding_degree"][1] = (
                cmd["sharding_degree"][1] + "=" + str(cfg["sharding_degree"])
            )
            res_args.extend(cmd["sharding_degree"])
232 233

    if "sharding_stage" in cmd and "sharding_stage" in cfg:
234 235 236 237 238 239 240 241 242 243
        if "--" in cmd["sharding_stage"][0]:
            cmd["sharding_stage"][1] = cmd["sharding_stage"][1] + str(
                cfg["sharding_stage"]
            )
            res_args.extend(cmd["sharding_stage"])
        else:
            cmd["sharding_stage"][1] = (
                cmd["sharding_stage"][1] + "=" + str(cfg["sharding_stage"])
            )
            res_args.extend(cmd["sharding_stage"])
244 245

    if "use_recompute" in cmd and "use_recompute" in cfg:
246 247 248 249 250 251 252 253 254 255
        if "--" in cmd["use_recompute"][0]:
            cmd["use_recompute"][1] = cmd["use_recompute"][1] + str(
                cfg["use_recompute"]
            )
            res_args.extend(cmd["use_recompute"])
        else:
            cmd["use_recompute"][1] = (
                cmd["use_recompute"][1] + "=" + str(cfg["use_recompute"])
            )
            res_args.extend(cmd["use_recompute"])
256 257

    if "recompute_granularity" in cmd and "recompute_granularity" in cfg:
258 259 260 261 262 263 264 265 266 267 268 269
        if "--" in cmd["recompute_granularity"][0]:
            cmd["recompute_granularity"][1] = cmd["recompute_granularity"][
                1
            ] + str(cfg["recompute_granularity"])
            res_args.extend(cmd["recompute_granularity"])
        else:
            cmd["recompute_granularity"][1] = (
                cmd["recompute_granularity"][1]
                + "="
                + str(cfg["recompute_granularity"])
            )
            res_args.extend(cmd["recompute_granularity"])
270 271 272 273 274 275 276

    if "local_batch_size" in cmd:
        local_batch_size = (
            tuner_cfg["model_cfg"]["global_batch_size"]
            // cfg["sharding_degree"]
            // cfg["dp_degree"]
        )
277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318
        if "--" in cmd["local_batch_size"][0]:
            cmd["local_batch_size"][1] = cmd["local_batch_size"][1] + str(
                local_batch_size
            )
            res_args.extend(cmd["local_batch_size"])
        else:
            cmd["local_batch_size"][1] = (
                cmd["local_batch_size"][1] + "=" + str(local_batch_size)
            )
            res_args.extend(cmd["local_batch_size"])

    if "gradient_accumulation_steps" in cmd:
        if "--" in cmd["gradient_accumulation_steps"][0]:
            try:
                gradient_accumulation_steps = (
                    tuner_cfg["model_cfg"]["global_batch_size"]
                    // cfg["sharding_degree"]
                    // cfg["dp_degree"]
                    // cfg["micro_batch_size"]
                )
                cmd["gradient_accumulation_steps"][1] = cmd[
                    "gradient_accumulation_steps"
                ][1] + str(gradient_accumulation_steps)
                res_args.extend(cmd["gradient_accumulation_steps"])
            except:
                pass
        else:
            try:
                gradient_accumulation_steps = (
                    tuner_cfg["model_cfg"]["global_batch_size"]
                    // cfg["sharding_degree"]
                    // cfg["dp_degree"]
                    // cfg["micro_batch_size"]
                )
                cmd["gradient_accumulation_steps"][1] = (
                    cmd["gradient_accumulation_steps"][1]
                    + "="
                    + str(gradient_accumulation_steps)
                )
                res_args.extend(cmd["gradient_accumulation_steps"])
            except:
                pass
319 320

    return res_args
321 322 323 324 325 326 327 328 329 330 331


def read_log(
    path, file="workerlog.0", target_metric='step/s'
) -> Tuple[float, bool]:
    """For extracting metric from log file."""
    target_file = path + "/" + file
    if not os.path.exists(target_file):
        return (0.0, True)
    with open(target_file, "r") as f:
        # read file
332 333 334
        re_metric_pattern = (
            target_metric + r":* *(\d+(\.\d*)?)|(\d+(\.\d*)?) *" + target_metric
        )
335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357

        metric_list = []
        lines = f.readlines()
        for line in lines:
            metric = re.findall(re_metric_pattern, line)
            if metric:
                metric_list.append(float(metric[0][0]))
        if not metric_list:
            metric_ave = 0.0
            flag = True
        elif len(metric_list) < 10:
            metric_ave = metric_list[-1]
            flag = False
        elif len(metric_list) < 20:
            metric_ave = sum(metric_list[9:]) / (len(metric_list[9:]))
            flag = False
        else:
            metric_ave = sum(metric_list[-10:]) / 10
            flag = False
        # round to 5 decimal places
        metric_ave = round(metric_ave, 5)
    res = metric_ave, flag
    return res