run.py 5.0 KB
Newer Older
T
tangwei 已提交
1 2
import argparse
import os
T
tangwei 已提交
3 4 5

import yaml
from paddle.fluid.incubate.fleet.parameter_server import version
T
tangwei 已提交
6 7 8

from fleetrec.core.factory import TrainerFactory
from fleetrec.core.utils import envs
T
tangwei 已提交
9
from fleetrec.core.utils import util
T
tangwei 已提交
10

T
tangwei 已提交
11
engines = {"TRAINSPILER": {}, "PSLIB": {}}
T
tangwei 已提交
12 13 14 15 16 17 18 19 20 21 22 23 24
clusters = ["SINGLE", "LOCAL_CLUSTER", "CLUSTER"]


def set_runtime_envs(cluster_envs, engine_yaml):
    if engine_yaml is not None:
        with open(engine_yaml, 'r') as rb:
            _envs = yaml.load(rb.read(), Loader=yaml.FullLoader)

    if cluster_envs is None:
        cluster_envs = {}
    cluster_envs.update(_envs)
    envs.set_runtime_envions(cluster_envs)
    print(envs.pretty_print_envs(cluster_envs, ("Runtime Envs", "Value")))
T
tangwei 已提交
25 26


T
tangwei 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
def engine_registry():
    engines["TRAINSPILER"]["SINGLE"] = single_engine
    engines["TRAINSPILER"]["LOCAL_CLUSTER"] = local_cluster_engine
    engines["TRAINSPILER"]["CLUSTER"] = cluster_engine
    engines["PSLIB"]["SINGLE"] = local_mpi_engine
    engines["PSLIB"]["LOCAL_CLUSTER"] = local_mpi_engine
    engines["PSLIB"]["CLUSTER"] = cluster_mpi_engine


def get_engine(engine):
    engine = engine.upper()
    if version.is_transpiler():
        run_engine = engines["TRAINSPILER"].get(engine, None)
    else:
        run_engine = engines["PSLIB"].get(engine, None)

    if run_engine is None:
        raise ValueError("engine only support SINGLE/LOCAL_CLUSTER/CLUSTER")
    return run_engine


def single_engine(args):
T
tangwei 已提交
49 50
    print("use single engine to run model: {}".format(args.model))
    single_envs = {"trainer.trainer": "SingleTraining"}
T
tangwei 已提交
51

T
tangwei 已提交
52
    set_runtime_envs(single_envs, args.engine_extras)
T
tangwei 已提交
53

T
tangwei 已提交
54 55 56 57 58
    trainer = TrainerFactory.create(args.model)
    return trainer


def cluster_engine(args):
T
tangwei 已提交
59 60 61 62
    print("launch cluster engine with cluster to run model: {}".format(args.model))

    cluster_envs = {"trainer.trainer": "ClusterTraining"}
    set_runtime_envs(cluster_envs, args.engine_extras)
T
tangwei 已提交
63 64 65 66 67 68 69

    envs.set_runtime_envions(cluster_envs)
    trainer = TrainerFactory.create(args.model)
    return trainer


def cluster_mpi_engine(args):
T
tangwei 已提交
70 71 72 73
    print("launch cluster engine with cluster to run model: {}".format(args.model))

    cluster_envs = {"trainer.trainer": "CtrTraining"}
    set_runtime_envs(cluster_envs, args.engine_extras)
T
tangwei 已提交
74

T
tangwei 已提交
75 76 77 78 79
    trainer = TrainerFactory.create(args.model)
    return trainer


def local_cluster_engine(args):
T
tangwei 已提交
80
    print("launch cluster engine with cluster to run model: {}".format(args.model))
T
tangwei 已提交
81 82
    from fleetrec.core.engine.local_cluster_engine import LocalClusterEngine

T
tangwei 已提交
83 84 85 86 87
    cluster_envs = {}
    cluster_envs["server_num"] = 1
    cluster_envs["worker_num"] = 1
    cluster_envs["start_port"] = 36001
    cluster_envs["log_dir"] = "logs"
T
tangwei 已提交
88 89
    cluster_envs["trainer.trainer"] = "ClusterTraining"
    cluster_envs["trainer.strategy.mode"] = "async"
T
tangwei 已提交
90

T
tangwei 已提交
91
    set_runtime_envs(cluster_envs, args.engine_extras)
T
tangwei 已提交
92

T
tangwei 已提交
93 94
    launch = LocalClusterEngine(cluster_envs, args.model)
    return launch
T
tangwei 已提交
95

T
tangwei 已提交
96

T
tangwei 已提交
97
def local_mpi_engine(args):
T
tangwei 已提交
98
    print("launch cluster engine with cluster to run model: {}".format(args.model))
T
tangwei 已提交
99
    from fleetrec.core.engine.local_mpi_engine import LocalMPIEngine
T
tangwei 已提交
100

T
tangwei 已提交
101
    print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model))
T
tangwei 已提交
102

T
tangwei 已提交
103 104 105
    mpi = util.run_which("mpirun")
    if not mpi:
        raise RuntimeError("can not find mpirun, please check environment")
T
tangwei 已提交
106

T
tangwei 已提交
107 108
    cluster_envs = {"mpirun": mpi, "trainer.trainer": "CtrTraining", "log_dir": "logs"}
    set_runtime_envs(cluster_envs, args.engine_extras)
T
tangwei 已提交
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
    launch = LocalMPIEngine(cluster_envs, args.model)
    return launch


#
# def yaml_engine(engine_yaml, model_yaml):
#     with open(engine_yaml, 'r') as rb:
#         _config = yaml.load(rb.read(), Loader=yaml.FullLoader)
#     assert _config is not None
#
#     envs.set_global_envs(_config)
#
#     train_location = envs.get_global_env("engine.file")
#     train_dirname = os.path.dirname(train_location)
#     base_name = os.path.splitext(os.path.basename(train_location))[0]
#     sys.path.append(train_dirname)
#     trainer_class = envs.lazy_instance(base_name, "UserDefineTraining")
#     trainer = trainer_class(model_yaml)
#     return trainer
T
tangwei 已提交
128 129 130 131


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='fleet-rec run')
T
tangwei 已提交
132 133
    parser.add_argument("-m", "--model", type=str)
    parser.add_argument("-e", "--engine", type=str)
T
tangwei 已提交
134
    parser.add_argument("-ex", "--engine_extras", default=None, type=str)
T
tangwei 已提交
135 136 137 138

    args = parser.parse_args()

    if not os.path.exists(args.model) or not os.path.isfile(args.model):
T
tangwei 已提交
139 140
        raise ValueError("argument model: {} error, must specify an existed YAML file".format(args.model))

T
tangwei 已提交
141 142 143 144 145 146 147 148
    if args.engine.upper() not in clusters:
        raise ValueError("argument engine: {} error, must in {}".format(args.engine, clusters))

    if args.engine_extras is not None:
        if not os.path.exists(args.engine_extras) or not os.path.isfile(args.engine_extras):
            raise ValueError(
                "argument engine_extras: {} error, must specify an existed YAML file".format(args.engine_extras))

T
tangwei 已提交
149 150 151
    which_engine = get_engine(args.engine)
    engine = which_engine(args)
    engine.run()