提交 51f35482 编写于 作者: T tangwei

add workspace and lazy import

上级 74ee4f3f
......@@ -53,7 +53,7 @@ class TrainerFactory(object):
train_dirname = os.path.dirname(trainer_abs)
base_name = os.path.splitext(os.path.basename(trainer_abs))[0]
sys.path.append(train_dirname)
trainer_class = envs.lazy_instance(base_name, train_mode)
trainer_class = envs.lazy_instance_by_fliename(base_name, train_mode)
trainer = trainer_class(yaml_path)
return trainer
......
......@@ -35,6 +35,7 @@ class Reader(dg.MultiSlotDataGenerator):
raise ValueError("reader config only support yaml")
envs.set_global_envs(_config)
envs.update_workspace()
@abc.abstractmethod
def init(self):
......
......@@ -95,5 +95,5 @@ def user_define_engine(engine_yaml):
train_dirname = os.path.dirname(train_location)
base_name = os.path.splitext(os.path.basename(train_location))[0]
sys.path.append(train_dirname)
trainer_class = envs.lazy_instance(base_name, "UserDefineTraining")
trainer_class = envs.lazy_instance_by_fliename(base_name, "UserDefineTraining")
return trainer_class
......@@ -76,7 +76,7 @@ class CtrPaddleTrainer(Trainer):
def instance(self, context):
models = envs.get_global_env("train.model.models")
model_class = envs.lazy_instance(models, "Model")
model_class = envs.lazy_instance_by_fliename(models, "Model")
self.model = model_class(None)
context['status'] = 'init_pass'
......
......@@ -132,7 +132,7 @@ class TranspileTrainer(Trainer):
def instance(self, context):
models = envs.get_global_env("train.model.models")
model_class = envs.lazy_instance(models, "Model")
model_class = envs.lazy_instance_by_fliename(models, "Model")
self.model = model_class(None)
context['status'] = 'init_pass'
......
......@@ -16,7 +16,7 @@ from __future__ import print_function
import os
import sys
from fleetrec.core.utils.envs import lazy_instance
from fleetrec.core.utils.envs import lazy_instance_by_fliename
from fleetrec.core.utils.envs import get_global_env
from fleetrec.core.utils.envs import get_runtime_environ
......@@ -38,7 +38,7 @@ def dataloader(readerclass, train, yaml_file):
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
reader_class = lazy_instance(readerclass, reader_name)
reader_class = lazy_instance_by_fliename(readerclass, reader_name)
reader = reader_class(yaml_file)
reader.init()
......
......@@ -14,7 +14,7 @@
from __future__ import print_function
import sys
from fleetrec.core.utils.envs import lazy_instance
from fleetrec.core.utils.envs import lazy_instance_by_fliename
if len(sys.argv) != 4:
raise ValueError("reader only accept 3 argument: 1. reader_class 2.train/evaluate 3.yaml_abs_path")
......@@ -27,7 +27,7 @@ else:
reader_name = "EvaluateReader"
yaml_abs_path = sys.argv[3]
reader_class = lazy_instance(reader_package, reader_name)
reader_class = lazy_instance_by_fliename(reader_package, reader_name)
reader = reader_class(yaml_abs_path)
reader.init()
reader.run_from_stdin()
......@@ -14,6 +14,7 @@
import os
import copy
import sys
global_envs = {}
......@@ -89,7 +90,6 @@ def update_workspace():
workspace = global_envs.get("train.workspace", None)
if not workspace:
return
workspace = ""
# is fleet inner models
if workspace.startswith("fleetrec."):
......@@ -104,14 +104,14 @@ def update_workspace():
value = value.replace("{workspace}", path)
global_envs[name] = value
def pretty_print_envs(envs, header=None):
spacing = 5
max_k = 45
max_v = 20
max_v = 50
for k, v in envs.items():
max_k = max(max_k, len(k))
max_v = max(max_v, len(str(v)))
h_format = "{{:^{}s}}{}{{:<{}s}}\n".format(max_k, " " * spacing, max_v)
l_format = "{{:<{}s}}{{}}{{:<{}s}}\n".format(max_k, max_v)
......@@ -131,7 +131,12 @@ def pretty_print_envs(envs, header=None):
draws += line + "\n"
for k, v in envs.items():
draws += l_format.format(k, " " * spacing, str(v))
if isinstance(v, str) and len(v) >= max_v:
str_v = "... " + v[-46:]
else:
str_v = v
draws += l_format.format(k, " " * spacing, str(str_v))
draws += border
......@@ -139,13 +144,26 @@ def pretty_print_envs(envs, header=None):
return _str
def lazy_instance(package, class_name):
def lazy_instance_by_fliename(package, class_name):
models = get_global_env("train.model.models")
model_package = __import__(package, globals(), locals(), package.split("."))
instance = getattr(model_package, class_name)
return instance
def lazy_instance_by_fliename(package, class_name):
models = get_global_env("train.model.models")
dirname = os.path.dirname(models)
basename = os.path.basename(models)
sys.path.append(dirname)
from basename import Model
# model_package = __import__(package, globals(), locals(), package.split("."))
# instance = getattr(model_package, class_name)
return Model
def get_platform():
import platform
plats = platform.platform()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册