提交 51f35482 编写于 作者: T tangwei

add workspace and lazy import

上级 74ee4f3f
...@@ -53,7 +53,7 @@ class TrainerFactory(object): ...@@ -53,7 +53,7 @@ class TrainerFactory(object):
train_dirname = os.path.dirname(trainer_abs) train_dirname = os.path.dirname(trainer_abs)
base_name = os.path.splitext(os.path.basename(trainer_abs))[0] base_name = os.path.splitext(os.path.basename(trainer_abs))[0]
sys.path.append(train_dirname) sys.path.append(train_dirname)
trainer_class = envs.lazy_instance(base_name, train_mode) trainer_class = envs.lazy_instance_by_fliename(base_name, train_mode)
trainer = trainer_class(yaml_path) trainer = trainer_class(yaml_path)
return trainer return trainer
......
...@@ -35,6 +35,7 @@ class Reader(dg.MultiSlotDataGenerator): ...@@ -35,6 +35,7 @@ class Reader(dg.MultiSlotDataGenerator):
raise ValueError("reader config only support yaml") raise ValueError("reader config only support yaml")
envs.set_global_envs(_config) envs.set_global_envs(_config)
envs.update_workspace()
@abc.abstractmethod @abc.abstractmethod
def init(self): def init(self):
......
...@@ -95,5 +95,5 @@ def user_define_engine(engine_yaml): ...@@ -95,5 +95,5 @@ def user_define_engine(engine_yaml):
train_dirname = os.path.dirname(train_location) train_dirname = os.path.dirname(train_location)
base_name = os.path.splitext(os.path.basename(train_location))[0] base_name = os.path.splitext(os.path.basename(train_location))[0]
sys.path.append(train_dirname) sys.path.append(train_dirname)
trainer_class = envs.lazy_instance(base_name, "UserDefineTraining") trainer_class = envs.lazy_instance_by_fliename(base_name, "UserDefineTraining")
return trainer_class return trainer_class
...@@ -76,7 +76,7 @@ class CtrPaddleTrainer(Trainer): ...@@ -76,7 +76,7 @@ class CtrPaddleTrainer(Trainer):
def instance(self, context): def instance(self, context):
models = envs.get_global_env("train.model.models") models = envs.get_global_env("train.model.models")
model_class = envs.lazy_instance(models, "Model") model_class = envs.lazy_instance_by_fliename(models, "Model")
self.model = model_class(None) self.model = model_class(None)
context['status'] = 'init_pass' context['status'] = 'init_pass'
......
...@@ -132,7 +132,7 @@ class TranspileTrainer(Trainer): ...@@ -132,7 +132,7 @@ class TranspileTrainer(Trainer):
def instance(self, context): def instance(self, context):
models = envs.get_global_env("train.model.models") models = envs.get_global_env("train.model.models")
model_class = envs.lazy_instance(models, "Model") model_class = envs.lazy_instance_by_fliename(models, "Model")
self.model = model_class(None) self.model = model_class(None)
context['status'] = 'init_pass' context['status'] = 'init_pass'
......
...@@ -16,7 +16,7 @@ from __future__ import print_function ...@@ -16,7 +16,7 @@ from __future__ import print_function
import os import os
import sys import sys
from fleetrec.core.utils.envs import lazy_instance from fleetrec.core.utils.envs import lazy_instance_by_fliename
from fleetrec.core.utils.envs import get_global_env from fleetrec.core.utils.envs import get_global_env
from fleetrec.core.utils.envs import get_runtime_environ from fleetrec.core.utils.envs import get_runtime_environ
...@@ -38,7 +38,7 @@ def dataloader(readerclass, train, yaml_file): ...@@ -38,7 +38,7 @@ def dataloader(readerclass, train, yaml_file):
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)] files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
reader_class = lazy_instance(readerclass, reader_name) reader_class = lazy_instance_by_fliename(readerclass, reader_name)
reader = reader_class(yaml_file) reader = reader_class(yaml_file)
reader.init() reader.init()
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from __future__ import print_function from __future__ import print_function
import sys import sys
from fleetrec.core.utils.envs import lazy_instance from fleetrec.core.utils.envs import lazy_instance_by_fliename
if len(sys.argv) != 4: if len(sys.argv) != 4:
raise ValueError("reader only accept 3 argument: 1. reader_class 2.train/evaluate 3.yaml_abs_path") raise ValueError("reader only accept 3 argument: 1. reader_class 2.train/evaluate 3.yaml_abs_path")
...@@ -27,7 +27,7 @@ else: ...@@ -27,7 +27,7 @@ else:
reader_name = "EvaluateReader" reader_name = "EvaluateReader"
yaml_abs_path = sys.argv[3] yaml_abs_path = sys.argv[3]
reader_class = lazy_instance(reader_package, reader_name) reader_class = lazy_instance_by_fliename(reader_package, reader_name)
reader = reader_class(yaml_abs_path) reader = reader_class(yaml_abs_path)
reader.init() reader.init()
reader.run_from_stdin() reader.run_from_stdin()
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import os import os
import copy import copy
import sys
global_envs = {} global_envs = {}
...@@ -89,7 +90,6 @@ def update_workspace(): ...@@ -89,7 +90,6 @@ def update_workspace():
workspace = global_envs.get("train.workspace", None) workspace = global_envs.get("train.workspace", None)
if not workspace: if not workspace:
return return
workspace = ""
# is fleet inner models # is fleet inner models
if workspace.startswith("fleetrec."): if workspace.startswith("fleetrec."):
...@@ -104,14 +104,14 @@ def update_workspace(): ...@@ -104,14 +104,14 @@ def update_workspace():
value = value.replace("{workspace}", path) value = value.replace("{workspace}", path)
global_envs[name] = value global_envs[name] = value
def pretty_print_envs(envs, header=None): def pretty_print_envs(envs, header=None):
spacing = 5 spacing = 5
max_k = 45 max_k = 45
max_v = 20 max_v = 50
for k, v in envs.items(): for k, v in envs.items():
max_k = max(max_k, len(k)) max_k = max(max_k, len(k))
max_v = max(max_v, len(str(v)))
h_format = "{{:^{}s}}{}{{:<{}s}}\n".format(max_k, " " * spacing, max_v) h_format = "{{:^{}s}}{}{{:<{}s}}\n".format(max_k, " " * spacing, max_v)
l_format = "{{:<{}s}}{{}}{{:<{}s}}\n".format(max_k, max_v) l_format = "{{:<{}s}}{{}}{{:<{}s}}\n".format(max_k, max_v)
...@@ -131,7 +131,12 @@ def pretty_print_envs(envs, header=None): ...@@ -131,7 +131,12 @@ def pretty_print_envs(envs, header=None):
draws += line + "\n" draws += line + "\n"
for k, v in envs.items(): for k, v in envs.items():
draws += l_format.format(k, " " * spacing, str(v)) if isinstance(v, str) and len(v) >= max_v:
str_v = "... " + v[-46:]
else:
str_v = v
draws += l_format.format(k, " " * spacing, str(str_v))
draws += border draws += border
...@@ -139,13 +144,26 @@ def pretty_print_envs(envs, header=None): ...@@ -139,13 +144,26 @@ def pretty_print_envs(envs, header=None):
return _str return _str
def lazy_instance(package, class_name): def lazy_instance_by_fliename(package, class_name):
models = get_global_env("train.model.models") models = get_global_env("train.model.models")
model_package = __import__(package, globals(), locals(), package.split(".")) model_package = __import__(package, globals(), locals(), package.split("."))
instance = getattr(model_package, class_name) instance = getattr(model_package, class_name)
return instance return instance
def lazy_instance_by_fliename(package, class_name):
models = get_global_env("train.model.models")
dirname = os.path.dirname(models)
basename = os.path.basename(models)
sys.path.append(dirname)
from basename import Model
# model_package = __import__(package, globals(), locals(), package.split("."))
# instance = getattr(model_package, class_name)
return Model
def get_platform(): def get_platform():
import platform import platform
plats = platform.platform() plats = platform.platform()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册