diff --git a/run.py b/run.py index cd86302179b64dbcaa530d62301a84b4aca05783..8be9cfebaa1fae428c5de0cf42e674226e27a8ac 100755 --- a/run.py +++ b/run.py @@ -218,29 +218,32 @@ def single_train_engine(args): def single_infer_engine(args): _envs = envs.load_yaml(args.model) - run_extras = get_all_inters_from_yaml(args.model, ["train.", "runner."]) - trainer_class = run_extras.get( - "runner." + _envs["mode"] + ".trainer_class", None) + run_extras = get_all_inters_from_yaml(args.model, ["runner."]) - if trainer_class: - trainer = trainer_class - else: - trainer = "GeneralTrainer" + mode = envs.get_runtime_environ("mode") + trainer_class = ".".join(["runner", mode, "trainer_class"]) + fleet_class = ".".join(["runner", mode, "fleet_mode"]) + device_class = ".".join(["runner", mode, "device"]) + selected_gpus_class = ".".join(["runner", mode, "selected_gpus"]) + trainer = run_extras.get(trainer_class, "GeneralTrainer") + fleet_mode = run_extras.get(fleet_class, "ps") + device = run_extras.get(device_class, "cpu") + selected_gpus = run_extras.get(selected_gpus_class, "0") executor_mode = "infer" - fleet_mode = run_extras.get("runner." + _envs["mode"] + ".fleet_mode", - "ps") - device = run_extras.get("runner." + _envs["mode"] + ".device", "cpu") - selected_gpus = run_extras.get( - "runner." + _envs["mode"] + ".selected_gpus", "0") - selected_gpus_num = len(selected_gpus.split(",")) + single_envs = {} + if device.upper() == "GPU": - assert selected_gpus_num == 1, "Single Mode Only Support One GPU, Set Local Cluster Mode to use Multi-GPUS" + selected_gpus_num = len(selected_gpus.split(",")) + if selected_gpus_num != 1: + raise ValueError( + "Single Mode Only Support One GPU, Set Local Cluster Mode to use Multi-GPUS" + ) + + single_envs["selsected_gpus"] = selected_gpus + single_envs["FLAGS_selected_gpus"] = selected_gpus - single_envs = {} - single_envs["selected_gpus"] = selected_gpus - single_envs["FLAGS_selected_gpus"] = selected_gpus single_envs["train.trainer.trainer"] = trainer single_envs["train.trainer.executor_mode"] = executor_mode single_envs["fleet_mode"] = fleet_mode