run.py 5.2 KB
Newer Older
T
tangwei 已提交
1 2
import argparse
import os
T
tangwei 已提交
3
import subprocess
T
tangwei 已提交
4
import yaml
T
tangwei 已提交
5 6 7

from fleetrec.core.factory import TrainerFactory
from fleetrec.core.utils import envs
T
tangwei 已提交
8
from fleetrec.core.utils import util
T
tangwei 已提交
9

T
tangwei 已提交
10
engines = {"TRAINSPILER": {}, "PSLIB": {}}
T
tangwei 已提交
11 12 13
clusters = ["SINGLE", "LOCAL_CLUSTER", "CLUSTER"]


T
tangwei 已提交
14 15 16 17 18 19 20 21 22 23 24 25
def is_transpiler():
    FNULL = open(os.devnull, 'w')
    cmd = ["python", "-c",
           "import paddle.fluid as fluid; fleet_ptr = fluid.core.Fleet(); [fleet_ptr.copy_table_by_feasign(10, 10, [2020, 1010])];"]
    proc = subprocess.Popen(cmd, stdout=FNULL, stderr=FNULL, cwd=os.getcwd())
    ret = proc.wait()
    if ret == -11:
        return False
    else:
        return True


T
tangwei 已提交
26
def set_runtime_envs(cluster_envs, engine_yaml):
T
tangwei 已提交
27
    def get_engine_extras():
T
tangwei 已提交
28 29
        with open(engine_yaml, 'r') as rb:
            _envs = yaml.load(rb.read(), Loader=yaml.FullLoader)
T
tangwei 已提交
30 31 32 33 34 35 36 37

        flattens = envs.flatten_environs(_envs)

        engine_extras = {}
        for k, v in flattens.items():
            if k.startswith("train.trainer."):
                engine_extras[k] = v
        return engine_extras
T
tangwei 已提交
38 39 40

    if cluster_envs is None:
        cluster_envs = {}
T
tangwei 已提交
41 42 43

    envs.set_runtime_environs(cluster_envs)
    envs.set_runtime_environs(get_engine_extras())
T
fix bug  
tangwei 已提交
44 45 46

    need_print = {}
    for k, v in os.environ.items():
T
tangwei 已提交
47
        if k.startswith("train.trainer."):
T
fix bug  
tangwei 已提交
48 49 50
            need_print[k] = v

    print(envs.pretty_print_envs(need_print, ("Runtime Envs", "Value")))
T
tangwei 已提交
51 52


T
tangwei 已提交
53 54
def get_engine(engine):
    engine = engine.upper()
T
tangwei 已提交
55
    if is_transpiler():
T
tangwei 已提交
56 57 58 59 60 61 62 63 64 65
        run_engine = engines["TRAINSPILER"].get(engine, None)
    else:
        run_engine = engines["PSLIB"].get(engine, None)

    if run_engine is None:
        raise ValueError("engine only support SINGLE/LOCAL_CLUSTER/CLUSTER")
    return run_engine


def single_engine(args):
T
tangwei 已提交
66
    print("use single engine to run model: {}".format(args.model))
T
fix bug  
tangwei 已提交
67 68

    single_envs = {}
T
tangwei 已提交
69 70 71
    single_envs["train.trainer.trainer"] = "SingleTrainer"
    single_envs["train.trainer.threads"] = "2"
    single_envs["train.trainer.engine"] = "single"
T
tangwei 已提交
72
    set_runtime_envs(single_envs, args.model)
T
tangwei 已提交
73 74 75 76 77
    trainer = TrainerFactory.create(args.model)
    return trainer


def cluster_engine(args):
T
tangwei 已提交
78 79
    print("launch cluster engine with cluster to run model: {}".format(args.model))

T
fix bug  
tangwei 已提交
80
    cluster_envs = {}
T
tangwei 已提交
81 82
    cluster_envs["train.trainer.trainer"] = "ClusterTrainer"
    cluster_envs["train.trainer.engine"] = "cluster"
T
tangwei 已提交
83
    set_runtime_envs(cluster_envs, args.model)
T
tangwei 已提交
84 85 86 87 88 89

    trainer = TrainerFactory.create(args.model)
    return trainer


def cluster_mpi_engine(args):
T
tangwei 已提交
90 91
    print("launch cluster engine with cluster to run model: {}".format(args.model))

T
fix bug  
tangwei 已提交
92
    cluster_envs = {}
T
tangwei 已提交
93
    cluster_envs["train.trainer.trainer"] = "CtrCodingTrainer"
T
tangwei 已提交
94
    set_runtime_envs(cluster_envs, args.model)
T
tangwei 已提交
95

T
tangwei 已提交
96 97 98 99 100
    trainer = TrainerFactory.create(args.model)
    return trainer


def local_cluster_engine(args):
T
tangwei 已提交
101
    print("launch cluster engine with cluster to run model: {}".format(args.model))
T
tangwei 已提交
102 103
    from fleetrec.core.engine.local_cluster_engine import LocalClusterEngine

T
tangwei 已提交
104 105 106 107 108
    cluster_envs = {}
    cluster_envs["server_num"] = 1
    cluster_envs["worker_num"] = 1
    cluster_envs["start_port"] = 36001
    cluster_envs["log_dir"] = "logs"
T
tangwei 已提交
109 110 111 112
    cluster_envs["train.trainer.trainer"] = "ClusterTrainer"
    cluster_envs["train.trainer.strategy"] = "async"
    cluster_envs["train.trainer.threads"] = "2"
    cluster_envs["train.trainer.engine"] = "local_cluster"
T
fix bug  
tangwei 已提交
113
    cluster_envs["CPU_NUM"] = "2"
T
tangwei 已提交
114

T
tangwei 已提交
115
    set_runtime_envs(cluster_envs, args.model)
T
tangwei 已提交
116

T
tangwei 已提交
117 118
    launch = LocalClusterEngine(cluster_envs, args.model)
    return launch
T
tangwei 已提交
119

T
tangwei 已提交
120

T
tangwei 已提交
121
def local_mpi_engine(args):
T
tangwei 已提交
122
    print("launch cluster engine with cluster to run model: {}".format(args.model))
T
tangwei 已提交
123
    from fleetrec.core.engine.local_mpi_engine import LocalMPIEngine
T
tangwei 已提交
124

T
tangwei 已提交
125
    print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model))
T
tangwei 已提交
126

T
tangwei 已提交
127 128 129
    mpi = util.run_which("mpirun")
    if not mpi:
        raise RuntimeError("can not find mpirun, please check environment")
T
fix bug  
tangwei 已提交
130 131
    cluster_envs = {}
    cluster_envs["mpirun"] = mpi
T
tangwei 已提交
132
    cluster_envs["train.trainer.trainer"] = "CtrCodingTrainer"
T
fix bug  
tangwei 已提交
133
    cluster_envs["log_dir"] = "logs"
T
tangwei 已提交
134
    cluster_envs["train.trainer.engine"] = "local_cluster"
T
tangwei 已提交
135

T
tangwei 已提交
136
    set_runtime_envs(cluster_envs, args.model)
T
tangwei 已提交
137 138 139 140
    launch = LocalMPIEngine(cluster_envs, args.model)
    return launch


T
tangwei 已提交
141 142 143 144 145 146 147 148
def engine_registry():
    engines["TRAINSPILER"]["SINGLE"] = single_engine
    engines["TRAINSPILER"]["LOCAL_CLUSTER"] = local_cluster_engine
    engines["TRAINSPILER"]["CLUSTER"] = cluster_engine
    engines["PSLIB"]["SINGLE"] = local_mpi_engine
    engines["PSLIB"]["LOCAL_CLUSTER"] = local_mpi_engine
    engines["PSLIB"]["CLUSTER"] = cluster_mpi_engine

T
tangwei 已提交
149

T
tangwei 已提交
150
engine_registry()
T
tangwei 已提交
151 152 153

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='fleet-rec run')
T
tangwei 已提交
154 155
    parser.add_argument("-m", "--model", type=str)
    parser.add_argument("-e", "--engine", type=str)
T
tangwei 已提交
156 157 158 159

    args = parser.parse_args()

    if not os.path.exists(args.model) or not os.path.isfile(args.model):
T
tangwei 已提交
160 161
        raise ValueError("argument model: {} error, must specify an existed YAML file".format(args.model))

T
tangwei 已提交
162 163 164
    if args.engine.upper() not in clusters:
        raise ValueError("argument engine: {} error, must in {}".format(args.engine, clusters))

T
tangwei 已提交
165 166 167
    which_engine = get_engine(args.engine)
    engine = which_engine(args)
    engine.run()