From c6b1cf3b4c4c3a90bf5e01d3e732ba0c7a7b9f70 Mon Sep 17 00:00:00 2001 From: tangwei Date: Thu, 30 Apr 2020 18:50:10 +0800 Subject: [PATCH] add workspace and lazy import --- fleet_rec/core/factory.py | 5 +---- fleet_rec/core/utils/envs.py | 17 +++++++---------- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/fleet_rec/core/factory.py b/fleet_rec/core/factory.py index ebd90c00..bbcc05f8 100644 --- a/fleet_rec/core/factory.py +++ b/fleet_rec/core/factory.py @@ -50,10 +50,7 @@ class TrainerFactory(object): trainer_abs = train_mode train_mode = "UserDefineTrainer" - train_dirname = os.path.dirname(trainer_abs) - base_name = os.path.splitext(os.path.basename(trainer_abs))[0] - sys.path.append(train_dirname) - trainer_class = envs.lazy_instance_by_fliename(base_name, train_mode) + trainer_class = envs.lazy_instance_by_fliename(trainer_abs, train_mode) trainer = trainer_class(yaml_path) return trainer diff --git a/fleet_rec/core/utils/envs.py b/fleet_rec/core/utils/envs.py index 70a63d88..22550a32 100644 --- a/fleet_rec/core/utils/envs.py +++ b/fleet_rec/core/utils/envs.py @@ -144,24 +144,21 @@ def pretty_print_envs(envs, header=None): return _str -def lazy_instance_by_fliename(package, class_name): +def lazy_instance_by_package(package, class_name): models = get_global_env("train.model.models") model_package = __import__(package, globals(), locals(), package.split(".")) instance = getattr(model_package, class_name) return instance -def lazy_instance_by_fliename(package, class_name): - models = get_global_env("train.model.models") - - dirname = os.path.dirname(models) - basename = os.path.basename(models) +def lazy_instance_by_fliename(abs, class_name): + dirname = os.path.dirname(abs) sys.path.append(dirname) - from basename import Model + package = os.path.splitext(os.path.basename(abs))[0] -# model_package = __import__(package, globals(), locals(), package.split(".")) -# instance = getattr(model_package, class_name) - return Model + model_package = __import__(package, globals(), locals(), package.split(".")) + instance = getattr(model_package, class_name) + return instance def get_platform(): -- GitLab