run.py 3.7 KB
Newer Older
T
tangwei 已提交
1 2
import argparse
import os
T
tangwei 已提交
3 4 5 6
import sys

import yaml
from paddle.fluid.incubate.fleet.parameter_server import version
T
tangwei 已提交
7 8 9 10 11 12 13 14 15 16 17

from fleetrec.core.factory import TrainerFactory
from fleetrec.core.utils import envs
from fleetrec.core.engine import local_engine


def run(model_yaml):
    trainer = TrainerFactory.create(model_yaml)
    trainer.run()


T
tangwei 已提交
18
def single_engine(single_envs, model_yaml):
T
tangwei 已提交
19
    print(envs.pretty_print_envs(single_envs, ("Single Envs", "Value")))
T
tangwei 已提交
20 21

    envs.set_runtime_envions(single_envs)
T
tangwei 已提交
22 23 24 25 26
    run(model_yaml)


def local_cluster_engine(cluster_envs, model_yaml):
    print(envs.pretty_print_envs(cluster_envs, ("Local Cluster Envs", "Value")))
T
tangwei 已提交
27
    envs.set_runtime_envions(cluster_envs)
T
tangwei 已提交
28 29 30 31
    launch = local_engine.Launch(cluster_envs, model_yaml)
    launch.run()


T
tangwei 已提交
32 33 34 35 36 37 38 39 40 41
def local_mpi_engine(model_yaml):
    print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model))

    cluster_envs = {}
    cluster_envs["server_num"] = 1
    cluster_envs["worker_num"] = 1
    cluster_envs["start_port"] = 36001
    cluster_envs["log_dir"] = "logs"
    cluster_envs["train.trainer"] = "CtrTraining"

T
tangwei 已提交
42
    print(envs.pretty_print_envs(cluster_envs, ("Local MPI Cluster Envs", "Value")))
T
tangwei 已提交
43
    envs.set_runtime_envions(cluster_envs)
T
tangwei 已提交
44 45 46 47
    print("coming soon")


def yaml_engine(engine_yaml, model_yaml):
T
tangwei 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60
    with open(engine_yaml, 'r') as rb:
        _config = yaml.load(rb.read(), Loader=yaml.FullLoader)
    assert _config is not None

    envs.set_global_envs(_config)

    train_location = envs.get_global_env("engine.file")
    train_dirname = os.path.dirname(train_location)
    base_name = os.path.splitext(os.path.basename(train_location))[0]
    sys.path.append(train_dirname)
    trainer_class = envs.lazy_instance(base_name, "UserDefineTrainer")
    trainer = trainer_class(model_yaml)
    trainer.run()
T
tangwei 已提交
61 62 63 64 65 66


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='fleet-rec run')
    parser.add_argument("--model", type=str)
    parser.add_argument("--engine", type=str)
T
tangwei 已提交
67
    parser.add_argument("--engine_extras", type=str)
T
tangwei 已提交
68 69 70 71 72 73

    args = parser.parse_args()

    if not os.path.exists(args.model) or not os.path.isfile(args.model):
        raise ValueError("argument model: {} error, must specify a existed yaml file".format(args.model))

T
tangwei 已提交
74
    if args.engine.upper() == "SINGLE":
T
tangwei 已提交
75 76 77 78 79 80
        if version.is_transpiler():
            print("use SingleTraining to run model: {}".format(args.model))
            single_envs = {"train.trainer": "SingleTraining"}
            single_engine(single_envs, args.model)
        else:
            local_mpi_engine(args.model)
T
tangwei 已提交
81
    elif args.engine.upper() == "LOCAL_CLUSTER":
T
tangwei 已提交
82
        print("use 1X1 ClusterTraining at localhost to run model: {}".format(args.model))
T
tangwei 已提交
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
        if version.is_transpiler():
            cluster_envs = {}
            cluster_envs["server_num"] = 1
            cluster_envs["worker_num"] = 1
            cluster_envs["start_port"] = 36001
            cluster_envs["log_dir"] = "logs"
            cluster_envs["train.trainer"] = "ClusterTraining"
            cluster_envs["train.strategy.mode"] = "async"

            local_cluster_engine(cluster_envs, args.model)
        else:
            local_mpi_engine(args.model)
    elif args.engine.upper() == "CLUSTER":
        print("launch ClusterTraining with cluster to run model: {}".format(args.model))
        run(args.model)
    elif args.engine.upper() == "USER_DEFINE":
        engine_file = args.engine_extras
        if not os.path.exists(engine_file) or not os.path.isfile(engine_file):
            raise ValueError(
                "argument engine: user_define error, must specify a existed yaml file".format(args.engine_file))
        yaml_engine(engine_file, args.model)
T
tangwei 已提交
104
    else:
T
tangwei 已提交
105
        raise ValueError("engine only support SINGLE/LOCAL_CLUSTER/CLUSTER/USER_DEFINE")