run.py 4.4 KB
Newer Older
T
tangwei 已提交
1 2
import argparse
import os
T
tangwei 已提交
3 4 5 6
import sys

import yaml
from paddle.fluid.incubate.fleet.parameter_server import version
T
tangwei 已提交
7 8 9

from fleetrec.core.factory import TrainerFactory
from fleetrec.core.utils import envs
T
tangwei 已提交
10
from fleetrec.core.utils import util
T
tangwei 已提交
11 12 13 14 15 16 17


def run(model_yaml):
    trainer = TrainerFactory.create(model_yaml)
    trainer.run()


T
tangwei 已提交
18
def single_engine(single_envs, model_yaml):
T
tangwei 已提交
19
    print(envs.pretty_print_envs(single_envs, ("Single Envs", "Value")))
T
tangwei 已提交
20
    envs.set_runtime_envions(single_envs)
T
tangwei 已提交
21 22 23 24
    run(model_yaml)


def local_cluster_engine(cluster_envs, model_yaml):
T
tangwei 已提交
25 26
    from fleetrec.core.engine.local_cluster_engine import LocalClusterEngine

T
tangwei 已提交
27
    print(envs.pretty_print_envs(cluster_envs, ("Local Cluster Envs", "Value")))
T
tangwei 已提交
28
    envs.set_runtime_envions(cluster_envs)
T
tangwei 已提交
29
    launch = LocalClusterEngine(cluster_envs, model_yaml)
T
tangwei 已提交
30 31 32
    launch.run()


T
tangwei 已提交
33 34
def local_mpi_engine(cluster_envs, model_yaml):
    from fleetrec.core.engine.local_mpi_engine import LocalMPIEngine
T
tangwei 已提交
35

T
tangwei 已提交
36
    print(envs.pretty_print_envs(cluster_envs, ("Local MPI Cluster Envs", "Value")))
T
tangwei 已提交
37
    envs.set_runtime_envions(cluster_envs)
T
tangwei 已提交
38 39
    launch = LocalMPIEngine(cluster_envs, model_yaml)
    launch.run()
T
tangwei 已提交
40 41 42


def yaml_engine(engine_yaml, model_yaml):
T
tangwei 已提交
43 44 45 46 47 48 49 50 51 52
    with open(engine_yaml, 'r') as rb:
        _config = yaml.load(rb.read(), Loader=yaml.FullLoader)
    assert _config is not None

    envs.set_global_envs(_config)

    train_location = envs.get_global_env("engine.file")
    train_dirname = os.path.dirname(train_location)
    base_name = os.path.splitext(os.path.basename(train_location))[0]
    sys.path.append(train_dirname)
T
tangwei 已提交
53
    trainer_class = envs.lazy_instance(base_name, "UserDefineTraining")
T
tangwei 已提交
54 55
    trainer = trainer_class(model_yaml)
    trainer.run()
T
tangwei 已提交
56 57 58 59 60 61


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='fleet-rec run')
    parser.add_argument("--model", type=str)
    parser.add_argument("--engine", type=str)
T
tangwei 已提交
62
    parser.add_argument("--engine_extras", type=str)
T
tangwei 已提交
63 64 65 66 67 68

    args = parser.parse_args()

    if not os.path.exists(args.model) or not os.path.isfile(args.model):
        raise ValueError("argument model: {} error, must specify a existed yaml file".format(args.model))

T
tangwei 已提交
69
    if args.engine.upper() == "SINGLE":
T
tangwei 已提交
70 71 72 73 74
        if version.is_transpiler():
            print("use SingleTraining to run model: {}".format(args.model))
            single_envs = {"train.trainer": "SingleTraining"}
            single_engine(single_envs, args.model)
        else:
T
tangwei 已提交
75 76 77 78 79 80
            print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model))

            mpi_path = util.run_which("mpirun")
            if not mpi_path:
                raise RuntimeError("can not find mpirun, please check environment")

T
tangwei 已提交
81
            cluster_envs = {"mpirun": mpi_path, "train.trainer": "CtrTraining", "log_dir": "logs"}
T
tangwei 已提交
82
            local_mpi_engine(cluster_envs, args.model)
T
tangwei 已提交
83
    elif args.engine.upper() == "LOCAL_CLUSTER":
T
tangwei 已提交
84
        print("use 1X1 ClusterTraining at localhost to run model: {}".format(args.model))
T
tangwei 已提交
85 86 87 88 89 90 91 92 93 94 95
        if version.is_transpiler():
            cluster_envs = {}
            cluster_envs["server_num"] = 1
            cluster_envs["worker_num"] = 1
            cluster_envs["start_port"] = 36001
            cluster_envs["log_dir"] = "logs"
            cluster_envs["train.trainer"] = "ClusterTraining"
            cluster_envs["train.strategy.mode"] = "async"

            local_cluster_engine(cluster_envs, args.model)
        else:
T
tangwei 已提交
96 97 98 99 100 101
            print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model))

            mpi_path = util.run_which("mpirun")
            if not mpi_path:
                raise RuntimeError("can not find mpirun, please check environment")

T
tangwei 已提交
102
            cluster_envs = {"mpirun": mpi_path, "train.trainer": "CtrTraining", "log_dir": "logs"}
T
tangwei 已提交
103
            local_mpi_engine(cluster_envs, args.model)
T
tangwei 已提交
104 105 106 107 108 109 110 111 112
    elif args.engine.upper() == "CLUSTER":
        print("launch ClusterTraining with cluster to run model: {}".format(args.model))
        run(args.model)
    elif args.engine.upper() == "USER_DEFINE":
        engine_file = args.engine_extras
        if not os.path.exists(engine_file) or not os.path.isfile(engine_file):
            raise ValueError(
                "argument engine: user_define error, must specify a existed yaml file".format(args.engine_file))
        yaml_engine(engine_file, args.model)
T
tangwei 已提交
113
    else:
T
tangwei 已提交
114
        raise ValueError("engine only support SINGLE/LOCAL_CLUSTER/CLUSTER/USER_DEFINE")