run.py 6.4 KB
Newer Older
T
tangwei 已提交
1 2
import argparse
import os
T
tangwei 已提交
3
import subprocess
T
tangwei 已提交
4
import yaml
T
tangwei 已提交
5

T
rename  
tangwei 已提交
6 7 8
from fleetrec.core.factory import TrainerFactory
from fleetrec.core.utils import envs
from fleetrec.core.utils import util
T
tangwei 已提交
9

T
tangwei 已提交
10 11
engines = {}
device = ["CPU", "GPU"]
T
tangwei 已提交
12 13 14
clusters = ["SINGLE", "LOCAL_CLUSTER", "CLUSTER"]


T
tangwei 已提交
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
def engine_registry():
    cpu = {"TRANSPILER": {}, "PSLIB": {}}
    cpu["TRANSPILER"]["SINGLE"] = single_engine
    cpu["TRANSPILER"]["LOCAL_CLUSTER"] = local_cluster_engine
    cpu["TRANSPILER"]["CLUSTER"] = cluster_engine
    cpu["PSLIB"]["SINGLE"] = local_mpi_engine
    cpu["PSLIB"]["LOCAL_CLUSTER"] = local_mpi_engine
    cpu["PSLIB"]["CLUSTER"] = cluster_mpi_engine

    gpu = {"TRANSPILER": {}, "PSLIB": {}}
    gpu["TRANSPILER"]["SINGLE"] = single_engine

    engines["CPU"] = cpu
    engines["GPU"] = gpu


def get_engine(engine, device):
    d_engine = engines[device]
    transpiler = get_transpiler()
    run_engine = d_engine[transpiler].get(engine, None)

    if run_engine is None:
        raise ValueError("engine {} can not be supported on device: {}".format(engine, device))
    return run_engine


def get_transpiler():
T
tangwei 已提交
42 43 44 45 46 47
    FNULL = open(os.devnull, 'w')
    cmd = ["python", "-c",
           "import paddle.fluid as fluid; fleet_ptr = fluid.core.Fleet(); [fleet_ptr.copy_table_by_feasign(10, 10, [2020, 1010])];"]
    proc = subprocess.Popen(cmd, stdout=FNULL, stderr=FNULL, cwd=os.getcwd())
    ret = proc.wait()
    if ret == -11:
T
tangwei 已提交
48
        return "PSLIB"
T
tangwei 已提交
49
    else:
T
tangwei 已提交
50
        return "TRANSPILER"
T
tangwei 已提交
51 52


T
tangwei 已提交
53
def set_runtime_envs(cluster_envs, engine_yaml):
T
tangwei 已提交
54
    def get_engine_extras():
T
tangwei 已提交
55 56
        with open(engine_yaml, 'r') as rb:
            _envs = yaml.load(rb.read(), Loader=yaml.FullLoader)
T
tangwei 已提交
57 58 59 60 61 62 63 64

        flattens = envs.flatten_environs(_envs)

        engine_extras = {}
        for k, v in flattens.items():
            if k.startswith("train.trainer."):
                engine_extras[k] = v
        return engine_extras
T
tangwei 已提交
65 66 67

    if cluster_envs is None:
        cluster_envs = {}
T
tangwei 已提交
68 69 70

    envs.set_runtime_environs(cluster_envs)
    envs.set_runtime_environs(get_engine_extras())
T
fix bug  
tangwei 已提交
71 72 73

    need_print = {}
    for k, v in os.environ.items():
T
tangwei 已提交
74
        if k.startswith("train.trainer."):
T
fix bug  
tangwei 已提交
75 76 77
            need_print[k] = v

    print(envs.pretty_print_envs(need_print, ("Runtime Envs", "Value")))
T
tangwei 已提交
78 79


T
tangwei 已提交
80
def single_engine(args):
T
tangwei 已提交
81
    print("use single engine to run model: {}".format(args.model))
T
fix bug  
tangwei 已提交
82 83

    single_envs = {}
T
tangwei 已提交
84 85 86
    single_envs["train.trainer.trainer"] = "SingleTrainer"
    single_envs["train.trainer.threads"] = "2"
    single_envs["train.trainer.engine"] = "single"
T
tangwei 已提交
87
    single_envs["train.trainer.device"] = args.device
T
tangwei 已提交
88
    single_envs["train.trainer.platform"] = envs.get_platform()
T
tangwei 已提交
89

T
tangwei 已提交
90
    set_runtime_envs(single_envs, args.model)
T
tangwei 已提交
91 92 93 94 95
    trainer = TrainerFactory.create(args.model)
    return trainer


def cluster_engine(args):
T
tangwei 已提交
96 97
    print("launch cluster engine with cluster to run model: {}".format(args.model))

T
fix bug  
tangwei 已提交
98
    cluster_envs = {}
T
tangwei 已提交
99 100
    cluster_envs["train.trainer.trainer"] = "ClusterTrainer"
    cluster_envs["train.trainer.engine"] = "cluster"
T
tangwei 已提交
101
    cluster_envs["train.trainer.device"] = args.device
T
tangwei 已提交
102
    cluster_envs["train.trainer.platform"] = envs.get_platform()
T
tangwei 已提交
103

T
tangwei 已提交
104
    set_runtime_envs(cluster_envs, args.model)
T
tangwei 已提交
105 106 107 108 109 110

    trainer = TrainerFactory.create(args.model)
    return trainer


def cluster_mpi_engine(args):
T
tangwei 已提交
111 112
    print("launch cluster engine with cluster to run model: {}".format(args.model))

T
fix bug  
tangwei 已提交
113
    cluster_envs = {}
T
tangwei 已提交
114
    cluster_envs["train.trainer.trainer"] = "CtrCodingTrainer"
T
tangwei 已提交
115
    cluster_envs["train.trainer.device"] = args.device
T
tangwei 已提交
116
    cluster_envs["train.trainer.platform"] = envs.get_platform()
T
tangwei 已提交
117

T
tangwei 已提交
118
    set_runtime_envs(cluster_envs, args.model)
T
tangwei 已提交
119

T
tangwei 已提交
120 121 122 123 124
    trainer = TrainerFactory.create(args.model)
    return trainer


def local_cluster_engine(args):
T
tangwei 已提交
125
    print("launch cluster engine with cluster to run model: {}".format(args.model))
T
rename  
tangwei 已提交
126
    from fleetrec.core.engine.local_cluster_engine import LocalClusterEngine
T
tangwei 已提交
127

T
tangwei 已提交
128 129 130 131 132
    cluster_envs = {}
    cluster_envs["server_num"] = 1
    cluster_envs["worker_num"] = 1
    cluster_envs["start_port"] = 36001
    cluster_envs["log_dir"] = "logs"
T
tangwei 已提交
133 134 135 136
    cluster_envs["train.trainer.trainer"] = "ClusterTrainer"
    cluster_envs["train.trainer.strategy"] = "async"
    cluster_envs["train.trainer.threads"] = "2"
    cluster_envs["train.trainer.engine"] = "local_cluster"
T
tangwei 已提交
137

T
tangwei 已提交
138
    cluster_envs["train.trainer.device"] = args.device
T
tangwei 已提交
139 140
    cluster_envs["train.trainer.platform"] = envs.get_platform()

T
fix bug  
tangwei 已提交
141
    cluster_envs["CPU_NUM"] = "2"
T
tangwei 已提交
142

T
tangwei 已提交
143
    set_runtime_envs(cluster_envs, args.model)
T
tangwei 已提交
144

T
tangwei 已提交
145 146
    launch = LocalClusterEngine(cluster_envs, args.model)
    return launch
T
tangwei 已提交
147

T
tangwei 已提交
148

T
tangwei 已提交
149
def local_mpi_engine(args):
T
tangwei 已提交
150
    print("launch cluster engine with cluster to run model: {}".format(args.model))
T
rename  
tangwei 已提交
151
    from fleetrec.core.engine.local_mpi_engine import LocalMPIEngine
T
tangwei 已提交
152

T
tangwei 已提交
153
    print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model))
T
tangwei 已提交
154

T
tangwei 已提交
155 156 157
    mpi = util.run_which("mpirun")
    if not mpi:
        raise RuntimeError("can not find mpirun, please check environment")
T
fix bug  
tangwei 已提交
158 159
    cluster_envs = {}
    cluster_envs["mpirun"] = mpi
T
tangwei 已提交
160
    cluster_envs["train.trainer.trainer"] = "CtrCodingTrainer"
T
fix bug  
tangwei 已提交
161
    cluster_envs["log_dir"] = "logs"
T
tangwei 已提交
162
    cluster_envs["train.trainer.engine"] = "local_cluster"
T
tangwei 已提交
163

T
tangwei 已提交
164
    cluster_envs["train.trainer.device"] = args.device
T
tangwei 已提交
165
    cluster_envs["train.trainer.platform"] = envs.get_platform()
T
tangwei 已提交
166

T
tangwei 已提交
167
    set_runtime_envs(cluster_envs, args.model)
T
tangwei 已提交
168 169 170 171
    launch = LocalMPIEngine(cluster_envs, args.model)
    return launch


T
tangwei 已提交
172 173 174 175 176 177 178 179 180 181 182 183
def get_abs_model(model):
    if model.startswith("fleetrec."):
        fleet_base = envs.get_runtime_environ("PACKAGE_BASE")
        workspace_dir = model.split("fleetrec.")[1].replace(".", "/")
        path = os.path.join(fleet_base, workspace_dir, "config.yaml")
    else:
        if not os.path.isfile(model):
            raise IOError("model config: {} invalid".format(model))
        path = model
    return path


T
tangwei 已提交
184 185
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='fleet-rec run')
T
tangwei 已提交
186
    parser.add_argument("-m", "--model", type=str)
T
tangwei 已提交
187 188
    parser.add_argument("-e", "--engine", type=str, choices=["single", "local_cluster", "cluster"])
    parser.add_argument("-d", "--device", type=str, choices=["cpu", "gpu"], default="cpu")
T
tangwei 已提交
189

T
tangwei 已提交
190 191 192
    abs_dir = os.path.dirname(os.path.abspath(__file__))
    envs.set_runtime_environs({"PACKAGE_BASE": abs_dir})

T
tangwei 已提交
193
    args = parser.parse_args()
T
tangwei 已提交
194 195
    args.engine = args.engine.upper()
    args.device = args.device.upper()
T
tangwei 已提交
196
    args.model = get_abs_model(args.model)
T
tangwei 已提交
197
    engine_registry()
T
tangwei 已提交
198

T
tangwei 已提交
199
    which_engine = get_engine(args.engine, args.device)
T
bug fix  
tangwei12 已提交
200

T
tangwei 已提交
201 202
    engine = which_engine(args)
    engine.run()