run.py 6.0 KB
Newer Older
T
tangwei 已提交
1 2
import argparse
import os
T
tangwei 已提交
3
import subprocess
T
tangwei 已提交
4
import yaml
T
tangwei 已提交
5

T
tangwei 已提交
6 7 8
from fleet_rec.core.factory import TrainerFactory
from fleet_rec.core.utils import envs
from fleet_rec.core.utils import util
T
tangwei 已提交
9

T
tangwei 已提交
10 11
engines = {}
device = ["CPU", "GPU"]
T
tangwei 已提交
12 13 14
clusters = ["SINGLE", "LOCAL_CLUSTER", "CLUSTER"]


T
tangwei 已提交
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
def engine_registry():
    cpu = {"TRANSPILER": {}, "PSLIB": {}}
    cpu["TRANSPILER"]["SINGLE"] = single_engine
    cpu["TRANSPILER"]["LOCAL_CLUSTER"] = local_cluster_engine
    cpu["TRANSPILER"]["CLUSTER"] = cluster_engine
    cpu["PSLIB"]["SINGLE"] = local_mpi_engine
    cpu["PSLIB"]["LOCAL_CLUSTER"] = local_mpi_engine
    cpu["PSLIB"]["CLUSTER"] = cluster_mpi_engine

    gpu = {"TRANSPILER": {}, "PSLIB": {}}
    gpu["TRANSPILER"]["SINGLE"] = single_engine

    engines["CPU"] = cpu
    engines["GPU"] = gpu


def get_engine(engine, device):
    d_engine = engines[device]
    transpiler = get_transpiler()
    run_engine = d_engine[transpiler].get(engine, None)

    if run_engine is None:
        raise ValueError("engine {} can not be supported on device: {}".format(engine, device))
    return run_engine


def get_transpiler():
T
tangwei 已提交
42 43 44 45 46 47
    FNULL = open(os.devnull, 'w')
    cmd = ["python", "-c",
           "import paddle.fluid as fluid; fleet_ptr = fluid.core.Fleet(); [fleet_ptr.copy_table_by_feasign(10, 10, [2020, 1010])];"]
    proc = subprocess.Popen(cmd, stdout=FNULL, stderr=FNULL, cwd=os.getcwd())
    ret = proc.wait()
    if ret == -11:
T
tangwei 已提交
48
        return "PSLIB"
T
tangwei 已提交
49
    else:
T
tangwei 已提交
50
        return "TRANSPILER"
T
tangwei 已提交
51 52


T
tangwei 已提交
53
def set_runtime_envs(cluster_envs, engine_yaml):
T
tangwei 已提交
54
    def get_engine_extras():
T
tangwei 已提交
55 56
        with open(engine_yaml, 'r') as rb:
            _envs = yaml.load(rb.read(), Loader=yaml.FullLoader)
T
tangwei 已提交
57 58 59 60 61 62 63 64

        flattens = envs.flatten_environs(_envs)

        engine_extras = {}
        for k, v in flattens.items():
            if k.startswith("train.trainer."):
                engine_extras[k] = v
        return engine_extras
T
tangwei 已提交
65 66 67

    if cluster_envs is None:
        cluster_envs = {}
T
tangwei 已提交
68 69 70

    envs.set_runtime_environs(cluster_envs)
    envs.set_runtime_environs(get_engine_extras())
T
fix bug  
tangwei 已提交
71 72 73

    need_print = {}
    for k, v in os.environ.items():
T
tangwei 已提交
74
        if k.startswith("train.trainer."):
T
fix bug  
tangwei 已提交
75 76 77
            need_print[k] = v

    print(envs.pretty_print_envs(need_print, ("Runtime Envs", "Value")))
T
tangwei 已提交
78 79


T
tangwei 已提交
80
def single_engine(args):
T
tangwei 已提交
81
    print("use single engine to run model: {}".format(args.model))
T
fix bug  
tangwei 已提交
82 83

    single_envs = {}
T
tangwei 已提交
84 85 86
    single_envs["train.trainer.trainer"] = "SingleTrainer"
    single_envs["train.trainer.threads"] = "2"
    single_envs["train.trainer.engine"] = "single"
T
tangwei 已提交
87
    single_envs["train.trainer.device"] = args.device
T
tangwei 已提交
88
    single_envs["train.trainer.platform"] = envs.get_platform()
T
tangwei 已提交
89

T
tangwei 已提交
90
    set_runtime_envs(single_envs, args.model)
T
tangwei 已提交
91 92 93 94 95
    trainer = TrainerFactory.create(args.model)
    return trainer


def cluster_engine(args):
T
tangwei 已提交
96 97
    print("launch cluster engine with cluster to run model: {}".format(args.model))

T
fix bug  
tangwei 已提交
98
    cluster_envs = {}
T
tangwei 已提交
99 100
    cluster_envs["train.trainer.trainer"] = "ClusterTrainer"
    cluster_envs["train.trainer.engine"] = "cluster"
T
tangwei 已提交
101
    cluster_envs["train.trainer.device"] = args.device
T
tangwei 已提交
102
    cluster_envs["train.trainer.platform"] = envs.get_platform()
T
tangwei 已提交
103

T
tangwei 已提交
104
    set_runtime_envs(cluster_envs, args.model)
T
tangwei 已提交
105 106 107 108 109 110

    trainer = TrainerFactory.create(args.model)
    return trainer


def cluster_mpi_engine(args):
T
tangwei 已提交
111 112
    print("launch cluster engine with cluster to run model: {}".format(args.model))

T
fix bug  
tangwei 已提交
113
    cluster_envs = {}
T
tangwei 已提交
114
    cluster_envs["train.trainer.trainer"] = "CtrCodingTrainer"
T
tangwei 已提交
115
    cluster_envs["train.trainer.device"] = args.device
T
tangwei 已提交
116
    cluster_envs["train.trainer.platform"] = envs.get_platform()
T
tangwei 已提交
117

T
tangwei 已提交
118
    set_runtime_envs(cluster_envs, args.model)
T
tangwei 已提交
119

T
tangwei 已提交
120 121 122 123 124
    trainer = TrainerFactory.create(args.model)
    return trainer


def local_cluster_engine(args):
T
tangwei 已提交
125
    print("launch cluster engine with cluster to run model: {}".format(args.model))
T
tangwei 已提交
126
    from fleet_rec.core.engine.local_cluster_engine import LocalClusterEngine
T
tangwei 已提交
127

T
tangwei 已提交
128 129 130 131 132
    cluster_envs = {}
    cluster_envs["server_num"] = 1
    cluster_envs["worker_num"] = 1
    cluster_envs["start_port"] = 36001
    cluster_envs["log_dir"] = "logs"
T
tangwei 已提交
133 134 135 136
    cluster_envs["train.trainer.trainer"] = "ClusterTrainer"
    cluster_envs["train.trainer.strategy"] = "async"
    cluster_envs["train.trainer.threads"] = "2"
    cluster_envs["train.trainer.engine"] = "local_cluster"
T
tangwei 已提交
137

T
tangwei 已提交
138
    cluster_envs["train.trainer.device"] = args.device
T
tangwei 已提交
139 140
    cluster_envs["train.trainer.platform"] = envs.get_platform()

T
fix bug  
tangwei 已提交
141
    cluster_envs["CPU_NUM"] = "2"
T
tangwei 已提交
142

T
tangwei 已提交
143
    set_runtime_envs(cluster_envs, args.model)
T
tangwei 已提交
144

T
tangwei 已提交
145 146
    launch = LocalClusterEngine(cluster_envs, args.model)
    return launch
T
tangwei 已提交
147

T
tangwei 已提交
148

T
tangwei 已提交
149
def local_mpi_engine(args):
T
tangwei 已提交
150
    print("launch cluster engine with cluster to run model: {}".format(args.model))
T
tangwei 已提交
151
    from fleet_rec.core.engine.local_mpi_engine import LocalMPIEngine
T
tangwei 已提交
152

T
tangwei 已提交
153
    print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model))
T
tangwei 已提交
154

T
tangwei 已提交
155 156 157
    mpi = util.run_which("mpirun")
    if not mpi:
        raise RuntimeError("can not find mpirun, please check environment")
T
fix bug  
tangwei 已提交
158 159
    cluster_envs = {}
    cluster_envs["mpirun"] = mpi
T
tangwei 已提交
160
    cluster_envs["train.trainer.trainer"] = "CtrCodingTrainer"
T
fix bug  
tangwei 已提交
161
    cluster_envs["log_dir"] = "logs"
T
tangwei 已提交
162
    cluster_envs["train.trainer.engine"] = "local_cluster"
T
tangwei 已提交
163

T
tangwei 已提交
164
    cluster_envs["train.trainer.device"] = args.device
T
tangwei 已提交
165
    cluster_envs["train.trainer.platform"] = envs.get_platform()
T
tangwei 已提交
166

T
tangwei 已提交
167
    set_runtime_envs(cluster_envs, args.model)
T
tangwei 已提交
168 169 170 171
    launch = LocalMPIEngine(cluster_envs, args.model)
    return launch


T
tangwei 已提交
172 173
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='fleet-rec run')
T
tangwei 已提交
174
    parser.add_argument("-m", "--model", type=str)
T
tangwei 已提交
175 176
    parser.add_argument("-e", "--engine", type=str, choices=["single", "local_cluster", "cluster"])
    parser.add_argument("-d", "--device", type=str, choices=["cpu", "gpu"], default="cpu")
T
tangwei 已提交
177 178

    args = parser.parse_args()
T
tangwei 已提交
179 180
    args.engine = args.engine.upper()
    args.device = args.device.upper()
T
tangwei 已提交
181

T
tangwei 已提交
182
    if not os.path.isfile(args.model):
T
bug fix  
tangwei12 已提交
183
        raise IOError("argument model: {} do not exist".format(args.model))
T
tangwei 已提交
184
    engine_registry()
T
tangwei 已提交
185

T
tangwei 已提交
186
    which_engine = get_engine(args.engine, args.device)
T
bug fix  
tangwei12 已提交
187

T
tangwei 已提交
188 189
    engine = which_engine(args)
    engine.run()