提交 c6b1cf3b 编写于 作者: T tangwei

add workspace and lazy import

上级 51f35482
......@@ -50,10 +50,7 @@ class TrainerFactory(object):
trainer_abs = train_mode
train_mode = "UserDefineTrainer"
train_dirname = os.path.dirname(trainer_abs)
base_name = os.path.splitext(os.path.basename(trainer_abs))[0]
sys.path.append(train_dirname)
trainer_class = envs.lazy_instance_by_fliename(base_name, train_mode)
trainer_class = envs.lazy_instance_by_fliename(trainer_abs, train_mode)
trainer = trainer_class(yaml_path)
return trainer
......
......@@ -144,24 +144,21 @@ def pretty_print_envs(envs, header=None):
return _str
def lazy_instance_by_fliename(package, class_name):
def lazy_instance_by_package(package, class_name):
models = get_global_env("train.model.models")
model_package = __import__(package, globals(), locals(), package.split("."))
instance = getattr(model_package, class_name)
return instance
def lazy_instance_by_fliename(package, class_name):
models = get_global_env("train.model.models")
dirname = os.path.dirname(models)
basename = os.path.basename(models)
def lazy_instance_by_fliename(abs, class_name):
dirname = os.path.dirname(abs)
sys.path.append(dirname)
from basename import Model
package = os.path.splitext(os.path.basename(abs))[0]
# model_package = __import__(package, globals(), locals(), package.split("."))
# instance = getattr(model_package, class_name)
return Model
model_package = __import__(package, globals(), locals(), package.split("."))
instance = getattr(model_package, class_name)
return instance
def get_platform():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册