提交 c6a3a9fd 编写于 作者: T tangwei

fix windows adapter

上级 fc724787
...@@ -59,7 +59,7 @@ class TrainerFactory(object): ...@@ -59,7 +59,7 @@ class TrainerFactory(object):
@staticmethod @staticmethod
def create(config): def create(config):
_config = envs.load_yaml(config) _config = envs.load_yaml(config)
envs.set_global_envs(_config, True) envs.set_global_envs(_config)
trainer = TrainerFactory._build_trainer(config) trainer = TrainerFactory._build_trainer(config)
return trainer return trainer
......
...@@ -26,7 +26,7 @@ class ReaderBase(dg.MultiSlotDataGenerator): ...@@ -26,7 +26,7 @@ class ReaderBase(dg.MultiSlotDataGenerator):
def __init__(self, config): def __init__(self, config):
dg.MultiSlotDataGenerator.__init__(self) dg.MultiSlotDataGenerator.__init__(self)
_config = envs.load_yaml(config) _config = envs.load_yaml(config)
envs.set_global_envs(_config, True) envs.set_global_envs(_config)
@abc.abstractmethod @abc.abstractmethod
def init(self): def init(self):
...@@ -44,7 +44,7 @@ class SlotReader(dg.MultiSlotDataGenerator): ...@@ -44,7 +44,7 @@ class SlotReader(dg.MultiSlotDataGenerator):
def __init__(self, config): def __init__(self, config):
dg.MultiSlotDataGenerator.__init__(self) dg.MultiSlotDataGenerator.__init__(self)
_config = envs.load_yaml(config) _config = envs.load_yaml(config)
envs.set_global_envs(_config, True) envs.set_global_envs(_config)
def init(self, sparse_slots, dense_slots, padding=0): def init(self, sparse_slots, dense_slots, padding=0):
from operator import mul from operator import mul
......
...@@ -16,7 +16,6 @@ import abc ...@@ -16,7 +16,6 @@ import abc
import os import os
import time import time
import sys import sys
import yaml
import traceback import traceback
from paddle import fluid from paddle import fluid
...@@ -74,11 +73,14 @@ class Trainer(object): ...@@ -74,11 +73,14 @@ class Trainer(object):
phase_names = envs.get_global_env( phase_names = envs.get_global_env(
"runner." + self._runner_name + ".phases", None) "runner." + self._runner_name + ".phases", None)
_config = envs.load_yaml(config)
phases = [] phases = []
if phase_names is None: if phase_names is None:
phases = envs.get_global_env("phase") phases = _config.get("phase")
else: else:
for phase in envs.get_global_env("phase"): for phase in _config.get("phase"):
if phase["name"] in phase_names: if phase["name"] in phase_names:
phases.append(phase) phases.append(phase)
...@@ -244,15 +246,3 @@ class Trainer(object): ...@@ -244,15 +246,3 @@ class Trainer(object):
self.context_process(self._context) self.context_process(self._context)
if self._context['is_exit']: if self._context['is_exit']:
break break
def user_define_engine(engine_yaml):
_config = envs.load_yaml(engine_yaml)
envs.set_runtime_environs(_config)
train_location = envs.get_global_env("engine.file")
train_dirname = os.path.dirname(train_location)
base_name = os.path.splitext(os.path.basename(train_location))[0]
sys.path.append(train_dirname)
trainer_class = envs.lazy_instance_by_fliename(base_name,
"UserDefineTraining")
return trainer_class
...@@ -20,9 +20,8 @@ import socket ...@@ -20,9 +20,8 @@ import socket
import sys import sys
import traceback import traceback
import yaml
global_envs = {} global_envs = {}
global_envs_flatten = {}
def flatten_environs(envs, separator="."): def flatten_environs(envs, separator="."):
...@@ -68,7 +67,7 @@ def get_fleet_mode(): ...@@ -68,7 +67,7 @@ def get_fleet_mode():
return fleet_mode return fleet_mode
def set_global_envs(envs, adapter): def set_global_envs(envs):
assert isinstance(envs, dict) assert isinstance(envs, dict)
def fatten_env_namespace(namespace_nests, local_envs): def fatten_env_namespace(namespace_nests, local_envs):
...@@ -92,10 +91,9 @@ def set_global_envs(envs, adapter): ...@@ -92,10 +91,9 @@ def set_global_envs(envs, adapter):
fatten_env_namespace([], envs) fatten_env_namespace([], envs)
if adapter: workspace_adapter()
workspace_adapter() os_path_adapter()
os_path_adapter() reader_adapter()
reader_adapter()
def get_global_env(env_name, default_value=None, namespace=None): def get_global_env(env_name, default_value=None, namespace=None):
...@@ -134,6 +132,7 @@ def workspace_adapter(): ...@@ -134,6 +132,7 @@ def workspace_adapter():
workspace = global_envs.get("workspace") workspace = global_envs.get("workspace")
if not workspace: if not workspace:
return return
workspace = paddlerec_adapter(workspace) workspace = paddlerec_adapter(workspace)
for name, value in global_envs.items(): for name, value in global_envs.items():
......
...@@ -197,7 +197,7 @@ class Reader(dg.MultiSlotDataGenerator): ...@@ -197,7 +197,7 @@ class Reader(dg.MultiSlotDataGenerator):
def __init__(self, config): def __init__(self, config):
dg.MultiSlotDataGenerator.__init__(self) dg.MultiSlotDataGenerator.__init__(self)
_config = envs.load_yaml(config) _config = envs.load_yaml(config)
envs.set_global_envs(_config, True) envs.set_global_envs(_config)
@abc.abstractmethod @abc.abstractmethod
def init(self): def init(self):
......
...@@ -110,7 +110,6 @@ def get_modes(running_config): ...@@ -110,7 +110,6 @@ def get_modes(running_config):
def get_engine(args, running_config, mode): def get_engine(args, running_config, mode):
transpiler = get_transpiler() transpiler = get_transpiler()
_envs = envs.load_yaml(args.model)
engine_class = ".".join(["runner", mode, "class"]) engine_class = ".".join(["runner", mode, "class"])
engine_device = ".".join(["runner", mode, "device"]) engine_device = ".".join(["runner", mode, "device"])
...@@ -122,11 +121,14 @@ def get_engine(args, running_config, mode): ...@@ -122,11 +121,14 @@ def get_engine(args, running_config, mode):
mode, engine_class)) mode, engine_class))
device = running_config.get(engine_device, None) device = running_config.get(engine_device, None)
engine = engine.upper()
device = device.upper()
if device is None: if device is None:
print("not find device be specified in yaml, set CPU as default") print("not find device be specified in yaml, set CPU as default")
device = "CPU" device = "CPU"
if device.upper() == "GPU": if device == "GPU":
selected_gpus = running_config.get(device_gpu_choices, None) selected_gpus = running_config.get(device_gpu_choices, None)
if selected_gpus is None: if selected_gpus is None:
...@@ -142,7 +144,6 @@ def get_engine(args, running_config, mode): ...@@ -142,7 +144,6 @@ def get_engine(args, running_config, mode):
if selected_gpus_num > 1: if selected_gpus_num > 1:
engine = "LOCAL_CLUSTER" engine = "LOCAL_CLUSTER"
engine = engine.upper()
if engine not in engine_choices: if engine not in engine_choices:
raise ValueError("{} can not be chosen in {}".format(engine_class, raise ValueError("{} can not be chosen in {}".format(engine_class,
engine_choices)) engine_choices))
...@@ -180,9 +181,7 @@ def set_runtime_envs(cluster_envs, engine_yaml): ...@@ -180,9 +181,7 @@ def set_runtime_envs(cluster_envs, engine_yaml):
def single_train_engine(args): def single_train_engine(args):
_envs = envs.load_yaml(args.model)
run_extras = get_all_inters_from_yaml(args.model, ["runner."]) run_extras = get_all_inters_from_yaml(args.model, ["runner."])
mode = envs.get_runtime_environ("mode") mode = envs.get_runtime_environ("mode")
trainer_class = ".".join(["runner", mode, "trainer_class"]) trainer_class = ".".join(["runner", mode, "trainer_class"])
fleet_class = ".".join(["runner", mode, "fleet_mode"]) fleet_class = ".".join(["runner", mode, "fleet_mode"])
...@@ -435,7 +434,7 @@ def local_mpi_engine(args): ...@@ -435,7 +434,7 @@ def local_mpi_engine(args):
def get_abs_model(model): def get_abs_model(model):
if model.startswith("paddlerec."): if model.startswith("paddlerec."):
dir = envs.path_adapter(model) dir = envs.paddlerec_adapter(model)
path = os.path.join(dir, "config.yaml") path = os.path.join(dir, "config.yaml")
else: else:
if not os.path.isfile(model): if not os.path.isfile(model):
...@@ -453,13 +452,12 @@ if __name__ == "__main__": ...@@ -453,13 +452,12 @@ if __name__ == "__main__":
envs.set_runtime_environs({"PACKAGE_BASE": abs_dir}) envs.set_runtime_environs({"PACKAGE_BASE": abs_dir})
args = parser.parse_args() args = parser.parse_args()
model_name = args.model.split('.')[-1]
args.model = get_abs_model(args.model) args.model = get_abs_model(args.model)
if not validation.yaml_validation(args.model): if not validation.yaml_validation(args.model):
sys.exit(-1) sys.exit(-1)
engine_registry()
engine_registry()
running_config = get_all_inters_from_yaml(args.model, ["mode", "runner."]) running_config = get_all_inters_from_yaml(args.model, ["mode", "runner."])
modes = get_modes(running_config) modes = get_modes(running_config)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册