run.py 7.0 KB
Newer Older
T
tangwei 已提交
1 2
import argparse
import os
T
tangwei 已提交
3
import subprocess
T
tangwei 已提交
4
import yaml
T
tangwei 已提交
5

T
rename  
tangwei 已提交
6 7 8
from fleetrec.core.factory import TrainerFactory
from fleetrec.core.utils import envs
from fleetrec.core.utils import util
T
tangwei 已提交
9

T
tangwei 已提交
10 11
engines = {}
device = ["CPU", "GPU"]
T
tangwei 已提交
12 13 14
clusters = ["SINGLE", "LOCAL_CLUSTER", "CLUSTER"]


T
tangwei 已提交
15 16 17 18 19
def engine_registry():
    cpu = {"TRANSPILER": {}, "PSLIB": {}}
    cpu["TRANSPILER"]["SINGLE"] = single_engine
    cpu["TRANSPILER"]["LOCAL_CLUSTER"] = local_cluster_engine
    cpu["TRANSPILER"]["CLUSTER"] = cluster_engine
C
chengmo 已提交
20
    cpu["TRANSPILER"]["TDM_SINGLE"] = tdm_single_engine
T
tangwei 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
    cpu["PSLIB"]["SINGLE"] = local_mpi_engine
    cpu["PSLIB"]["LOCAL_CLUSTER"] = local_mpi_engine
    cpu["PSLIB"]["CLUSTER"] = cluster_mpi_engine

    gpu = {"TRANSPILER": {}, "PSLIB": {}}
    gpu["TRANSPILER"]["SINGLE"] = single_engine

    engines["CPU"] = cpu
    engines["GPU"] = gpu


def get_engine(engine, device):
    d_engine = engines[device]
    transpiler = get_transpiler()
    run_engine = d_engine[transpiler].get(engine, None)

    if run_engine is None:
C
chengmo 已提交
38 39
        raise ValueError(
            "engine {} can not be supported on device: {}".format(engine, device))
T
tangwei 已提交
40 41 42 43
    return run_engine


def get_transpiler():
T
tangwei 已提交
44 45 46 47 48 49
    FNULL = open(os.devnull, 'w')
    cmd = ["python", "-c",
           "import paddle.fluid as fluid; fleet_ptr = fluid.core.Fleet(); [fleet_ptr.copy_table_by_feasign(10, 10, [2020, 1010])];"]
    proc = subprocess.Popen(cmd, stdout=FNULL, stderr=FNULL, cwd=os.getcwd())
    ret = proc.wait()
    if ret == -11:
T
tangwei 已提交
50
        return "PSLIB"
T
tangwei 已提交
51
    else:
T
tangwei 已提交
52
        return "TRANSPILER"
T
tangwei 已提交
53 54


T
tangwei 已提交
55
def set_runtime_envs(cluster_envs, engine_yaml):
T
tangwei 已提交
56
    def get_engine_extras():
T
tangwei 已提交
57 58
        with open(engine_yaml, 'r') as rb:
            _envs = yaml.load(rb.read(), Loader=yaml.FullLoader)
T
tangwei 已提交
59 60 61 62 63 64 65 66

        flattens = envs.flatten_environs(_envs)

        engine_extras = {}
        for k, v in flattens.items():
            if k.startswith("train.trainer."):
                engine_extras[k] = v
        return engine_extras
T
tangwei 已提交
67 68 69

    if cluster_envs is None:
        cluster_envs = {}
T
tangwei 已提交
70 71 72

    envs.set_runtime_environs(cluster_envs)
    envs.set_runtime_environs(get_engine_extras())
T
fix bug  
tangwei 已提交
73 74 75

    need_print = {}
    for k, v in os.environ.items():
T
tangwei 已提交
76
        if k.startswith("train.trainer."):
T
fix bug  
tangwei 已提交
77 78 79
            need_print[k] = v

    print(envs.pretty_print_envs(need_print, ("Runtime Envs", "Value")))
T
tangwei 已提交
80 81


T
tangwei 已提交
82
def single_engine(args):
T
tangwei 已提交
83
    print("use single engine to run model: {}".format(args.model))
T
fix bug  
tangwei 已提交
84 85

    single_envs = {}
T
tangwei 已提交
86 87 88
    single_envs["train.trainer.trainer"] = "SingleTrainer"
    single_envs["train.trainer.threads"] = "2"
    single_envs["train.trainer.engine"] = "single"
T
tangwei 已提交
89
    single_envs["train.trainer.device"] = args.device
T
tangwei 已提交
90
    single_envs["train.trainer.platform"] = envs.get_platform()
T
tangwei 已提交
91

T
tangwei 已提交
92
    set_runtime_envs(single_envs, args.model)
T
tangwei 已提交
93 94 95 96
    trainer = TrainerFactory.create(args.model)
    return trainer


C
chengmo 已提交
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
def tdm_single_engine(args):
    print("use tdm single engine to run model: {}".format(args.model))

    single_envs = {}
    single_envs["train.trainer.trainer"] = "TDMSingleTrainer"
    single_envs["train.trainer.threads"] = "2"
    single_envs["train.trainer.engine"] = "single"
    single_envs["train.trainer.device"] = args.device
    single_envs["train.trainer.platform"] = envs.get_platform()

    set_runtime_envs(single_envs, args.model)
    trainer = TrainerFactory.create(args.model)
    return trainer


T
tangwei 已提交
112
def cluster_engine(args):
T
tangwei 已提交
113 114
    print("launch cluster engine with cluster to run model: {}".format(args.model))

T
fix bug  
tangwei 已提交
115
    cluster_envs = {}
T
tangwei 已提交
116 117
    cluster_envs["train.trainer.trainer"] = "ClusterTrainer"
    cluster_envs["train.trainer.engine"] = "cluster"
T
tangwei 已提交
118
    cluster_envs["train.trainer.device"] = args.device
T
tangwei 已提交
119
    cluster_envs["train.trainer.platform"] = envs.get_platform()
T
tangwei 已提交
120

T
tangwei 已提交
121
    set_runtime_envs(cluster_envs, args.model)
T
tangwei 已提交
122 123 124 125 126 127

    trainer = TrainerFactory.create(args.model)
    return trainer


def cluster_mpi_engine(args):
T
tangwei 已提交
128 129
    print("launch cluster engine with cluster to run model: {}".format(args.model))

T
fix bug  
tangwei 已提交
130
    cluster_envs = {}
T
tangwei 已提交
131
    cluster_envs["train.trainer.trainer"] = "CtrCodingTrainer"
T
tangwei 已提交
132
    cluster_envs["train.trainer.device"] = args.device
T
tangwei 已提交
133
    cluster_envs["train.trainer.platform"] = envs.get_platform()
T
tangwei 已提交
134

T
tangwei 已提交
135
    set_runtime_envs(cluster_envs, args.model)
T
tangwei 已提交
136

T
tangwei 已提交
137 138 139 140 141
    trainer = TrainerFactory.create(args.model)
    return trainer


def local_cluster_engine(args):
T
tangwei 已提交
142
    print("launch cluster engine with cluster to run model: {}".format(args.model))
T
rename  
tangwei 已提交
143
    from fleetrec.core.engine.local_cluster_engine import LocalClusterEngine
T
tangwei 已提交
144

T
tangwei 已提交
145 146 147 148 149
    cluster_envs = {}
    cluster_envs["server_num"] = 1
    cluster_envs["worker_num"] = 1
    cluster_envs["start_port"] = 36001
    cluster_envs["log_dir"] = "logs"
T
tangwei 已提交
150 151 152 153
    cluster_envs["train.trainer.trainer"] = "ClusterTrainer"
    cluster_envs["train.trainer.strategy"] = "async"
    cluster_envs["train.trainer.threads"] = "2"
    cluster_envs["train.trainer.engine"] = "local_cluster"
T
tangwei 已提交
154

T
tangwei 已提交
155
    cluster_envs["train.trainer.device"] = args.device
T
tangwei 已提交
156 157
    cluster_envs["train.trainer.platform"] = envs.get_platform()

T
fix bug  
tangwei 已提交
158
    cluster_envs["CPU_NUM"] = "2"
T
tangwei 已提交
159

T
tangwei 已提交
160
    set_runtime_envs(cluster_envs, args.model)
T
tangwei 已提交
161

T
tangwei 已提交
162 163
    launch = LocalClusterEngine(cluster_envs, args.model)
    return launch
T
tangwei 已提交
164

T
tangwei 已提交
165

T
tangwei 已提交
166
def local_mpi_engine(args):
T
tangwei 已提交
167
    print("launch cluster engine with cluster to run model: {}".format(args.model))
T
rename  
tangwei 已提交
168
    from fleetrec.core.engine.local_mpi_engine import LocalMPIEngine
T
tangwei 已提交
169

T
tangwei 已提交
170
    print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model))
T
tangwei 已提交
171

T
tangwei 已提交
172 173 174
    mpi = util.run_which("mpirun")
    if not mpi:
        raise RuntimeError("can not find mpirun, please check environment")
T
fix bug  
tangwei 已提交
175 176
    cluster_envs = {}
    cluster_envs["mpirun"] = mpi
T
tangwei 已提交
177
    cluster_envs["train.trainer.trainer"] = "CtrCodingTrainer"
T
fix bug  
tangwei 已提交
178
    cluster_envs["log_dir"] = "logs"
T
tangwei 已提交
179
    cluster_envs["train.trainer.engine"] = "local_cluster"
T
tangwei 已提交
180

T
tangwei 已提交
181
    cluster_envs["train.trainer.device"] = args.device
T
tangwei 已提交
182
    cluster_envs["train.trainer.platform"] = envs.get_platform()
T
tangwei 已提交
183

T
tangwei 已提交
184
    set_runtime_envs(cluster_envs, args.model)
T
tangwei 已提交
185 186 187 188
    launch = LocalMPIEngine(cluster_envs, args.model)
    return launch


T
tangwei 已提交
189 190 191 192 193 194 195 196 197 198 199 200
def get_abs_model(model):
    if model.startswith("fleetrec."):
        fleet_base = envs.get_runtime_environ("PACKAGE_BASE")
        workspace_dir = model.split("fleetrec.")[1].replace(".", "/")
        path = os.path.join(fleet_base, workspace_dir, "config.yaml")
    else:
        if not os.path.isfile(model):
            raise IOError("model config: {} invalid".format(model))
        path = model
    return path


T
tangwei 已提交
201 202
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='fleet-rec run')
T
tangwei 已提交
203
    parser.add_argument("-m", "--model", type=str)
C
chengmo 已提交
204 205 206 207
    parser.add_argument("-e", "--engine", type=str,
                        choices=["single", "local_cluster", "cluster"])
    parser.add_argument("-d", "--device", type=str,
                        choices=["cpu", "gpu"], default="cpu")
T
tangwei 已提交
208

T
tangwei 已提交
209 210 211
    abs_dir = os.path.dirname(os.path.abspath(__file__))
    envs.set_runtime_environs({"PACKAGE_BASE": abs_dir})

T
tangwei 已提交
212
    args = parser.parse_args()
T
tangwei 已提交
213 214
    args.engine = args.engine.upper()
    args.device = args.device.upper()
T
tangwei 已提交
215
    args.model = get_abs_model(args.model)
T
tangwei 已提交
216
    engine_registry()
T
tangwei 已提交
217

T
tangwei 已提交
218
    which_engine = get_engine(args.engine, args.device)
T
bug fix  
tangwei12 已提交
219

T
tangwei 已提交
220 221
    engine = which_engine(args)
    engine.run()