diff --git a/fleet_rec/core/factory.py b/fleet_rec/core/factory.py index 78821dc27e929e2059d409cd4d1f92aa47b4a633..ebd90c00ca731ec52f02c151b929eb8371714d1a 100644 --- a/fleet_rec/core/factory.py +++ b/fleet_rec/core/factory.py @@ -53,7 +53,7 @@ class TrainerFactory(object): train_dirname = os.path.dirname(trainer_abs) base_name = os.path.splitext(os.path.basename(trainer_abs))[0] sys.path.append(train_dirname) - trainer_class = envs.lazy_instance(base_name, train_mode) + trainer_class = envs.lazy_instance_by_fliename(base_name, train_mode) trainer = trainer_class(yaml_path) return trainer diff --git a/fleet_rec/core/reader.py b/fleet_rec/core/reader.py index 3aa96d2ccf9a14a3c7910dc4d82b8e6ca717f18c..81da9ebf82f5bfa9d409a9a17169f19b13c7716b 100644 --- a/fleet_rec/core/reader.py +++ b/fleet_rec/core/reader.py @@ -35,6 +35,7 @@ class Reader(dg.MultiSlotDataGenerator): raise ValueError("reader config only support yaml") envs.set_global_envs(_config) + envs.update_workspace() @abc.abstractmethod def init(self): diff --git a/fleet_rec/core/trainer.py b/fleet_rec/core/trainer.py index 06b5800ee49fdf51be51d9ca1e9e304c46d174cd..ba2f6e8ddb01cc2e2c3a69b54b965869cc314311 100755 --- a/fleet_rec/core/trainer.py +++ b/fleet_rec/core/trainer.py @@ -95,5 +95,5 @@ def user_define_engine(engine_yaml): train_dirname = os.path.dirname(train_location) base_name = os.path.splitext(os.path.basename(train_location))[0] sys.path.append(train_dirname) - trainer_class = envs.lazy_instance(base_name, "UserDefineTraining") + trainer_class = envs.lazy_instance_by_fliename(base_name, "UserDefineTraining") return trainer_class diff --git a/fleet_rec/core/trainers/ctr_coding_trainer.py b/fleet_rec/core/trainers/ctr_coding_trainer.py index ae6b63ceb94e87f006fb5a8b342bf298954bf3d2..fdfdd1ecb2f40747cc8cb61e3f8ce0e3c1c19d4c 100755 --- a/fleet_rec/core/trainers/ctr_coding_trainer.py +++ b/fleet_rec/core/trainers/ctr_coding_trainer.py @@ -76,7 +76,7 @@ class CtrPaddleTrainer(Trainer): def instance(self, context): models = envs.get_global_env("train.model.models") - model_class = envs.lazy_instance(models, "Model") + model_class = envs.lazy_instance_by_fliename(models, "Model") self.model = model_class(None) context['status'] = 'init_pass' diff --git a/fleet_rec/core/trainers/transpiler_trainer.py b/fleet_rec/core/trainers/transpiler_trainer.py index 5b2a73556ff9aebc8adde485214767b4de0f9d45..a4875044aadd7aa1a58012476cd399e309811729 100644 --- a/fleet_rec/core/trainers/transpiler_trainer.py +++ b/fleet_rec/core/trainers/transpiler_trainer.py @@ -132,7 +132,7 @@ class TranspileTrainer(Trainer): def instance(self, context): models = envs.get_global_env("train.model.models") - model_class = envs.lazy_instance(models, "Model") + model_class = envs.lazy_instance_by_fliename(models, "Model") self.model = model_class(None) context['status'] = 'init_pass' diff --git a/fleet_rec/core/utils/dataloader_instance.py b/fleet_rec/core/utils/dataloader_instance.py index 5af98e7c4ef2ae6e87291805d47716691a737cfd..eb7e5fd6f41c5cb4042e045babda55caa765cc4f 100644 --- a/fleet_rec/core/utils/dataloader_instance.py +++ b/fleet_rec/core/utils/dataloader_instance.py @@ -16,7 +16,7 @@ from __future__ import print_function import os import sys -from fleetrec.core.utils.envs import lazy_instance +from fleetrec.core.utils.envs import lazy_instance_by_fliename from fleetrec.core.utils.envs import get_global_env from fleetrec.core.utils.envs import get_runtime_environ @@ -38,7 +38,7 @@ def dataloader(readerclass, train, yaml_file): files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)] - reader_class = lazy_instance(readerclass, reader_name) + reader_class = lazy_instance_by_fliename(readerclass, reader_name) reader = reader_class(yaml_file) reader.init() diff --git a/fleet_rec/core/utils/dataset_instance.py b/fleet_rec/core/utils/dataset_instance.py index 91b6af94e122c882bc7bdbbed48a6f8475f02b50..89e6e45d2c53bb033e0e5fc6e436b149d76cc7c2 100644 --- a/fleet_rec/core/utils/dataset_instance.py +++ b/fleet_rec/core/utils/dataset_instance.py @@ -14,7 +14,7 @@ from __future__ import print_function import sys -from fleetrec.core.utils.envs import lazy_instance +from fleetrec.core.utils.envs import lazy_instance_by_fliename if len(sys.argv) != 4: raise ValueError("reader only accept 3 argument: 1. reader_class 2.train/evaluate 3.yaml_abs_path") @@ -27,7 +27,7 @@ else: reader_name = "EvaluateReader" yaml_abs_path = sys.argv[3] -reader_class = lazy_instance(reader_package, reader_name) +reader_class = lazy_instance_by_fliename(reader_package, reader_name) reader = reader_class(yaml_abs_path) reader.init() reader.run_from_stdin() diff --git a/fleet_rec/core/utils/envs.py b/fleet_rec/core/utils/envs.py index d3452b88be6585ea623d7b35153c79f4d03f26a1..70a63d881d0c7844894a8d40761c8b8eb0652163 100644 --- a/fleet_rec/core/utils/envs.py +++ b/fleet_rec/core/utils/envs.py @@ -14,6 +14,7 @@ import os import copy +import sys global_envs = {} @@ -89,7 +90,6 @@ def update_workspace(): workspace = global_envs.get("train.workspace", None) if not workspace: return - workspace = "" # is fleet inner models if workspace.startswith("fleetrec."): @@ -104,14 +104,14 @@ def update_workspace(): value = value.replace("{workspace}", path) global_envs[name] = value + def pretty_print_envs(envs, header=None): spacing = 5 max_k = 45 - max_v = 20 + max_v = 50 for k, v in envs.items(): max_k = max(max_k, len(k)) - max_v = max(max_v, len(str(v))) h_format = "{{:^{}s}}{}{{:<{}s}}\n".format(max_k, " " * spacing, max_v) l_format = "{{:<{}s}}{{}}{{:<{}s}}\n".format(max_k, max_v) @@ -131,7 +131,12 @@ def pretty_print_envs(envs, header=None): draws += line + "\n" for k, v in envs.items(): - draws += l_format.format(k, " " * spacing, str(v)) + if isinstance(v, str) and len(v) >= max_v: + str_v = "... " + v[-46:] + else: + str_v = v + + draws += l_format.format(k, " " * spacing, str(str_v)) draws += border @@ -139,13 +144,26 @@ def pretty_print_envs(envs, header=None): return _str -def lazy_instance(package, class_name): +def lazy_instance_by_fliename(package, class_name): models = get_global_env("train.model.models") model_package = __import__(package, globals(), locals(), package.split(".")) instance = getattr(model_package, class_name) return instance +def lazy_instance_by_fliename(package, class_name): + models = get_global_env("train.model.models") + + dirname = os.path.dirname(models) + basename = os.path.basename(models) + sys.path.append(dirname) + from basename import Model + +# model_package = __import__(package, globals(), locals(), package.split(".")) +# instance = getattr(model_package, class_name) + return Model + + def get_platform(): import platform plats = platform.platform()