From a09255fbf73c3b1ec62b441b068a610d716e86e7 Mon Sep 17 00:00:00 2001 From: xjqbest <173596896@qq.com> Date: Thu, 28 May 2020 22:56:13 +0800 Subject: [PATCH] fix --- core/model.py | 32 +++++++++++++++++--------- core/trainers/single_infer.py | 22 +++++++++--------- core/trainers/single_trainer.py | 37 +++++++++++++++++-------------- core/utils/dataloader_instance.py | 16 +++++++++---- models/rank/dnn/config.yaml | 1 + 5 files changed, 67 insertions(+), 41 deletions(-) diff --git a/core/model.py b/core/model.py index cfe71f2a..5aae5c52 100755 --- a/core/model.py +++ b/core/model.py @@ -59,11 +59,17 @@ class Model(object): dataset = i break name = "dataset." + dataset["name"] + "." - sparse_slots = envs.get_global_env(name + "sparse_slots") - dense_slots = envs.get_global_env(name + "dense_slots") - if sparse_slots is not None or dense_slots is not None: - sparse_slots = sparse_slots.strip().split(" ") - dense_slots = dense_slots.strip().split(" ") + sparse_slots = envs.get_global_env(name + "sparse_slots", "").strip() + dense_slots = envs.get_global_env(name + "dense_slots", "").strip() + if sparse_slots != "" or dense_slots != "": + if sparse_slots == "": + sparse_slots = [] + else: + sparse_slots = sparse_slots.strip().split(" ") + if dense_slots == "": + dense_slots = [] + else: + dense_slots = dense_slots.strip().split(" ") dense_slots_shape = [[ int(j) for j in i.split(":")[1].strip("[]").split(",") ] for i in dense_slots] @@ -151,11 +157,17 @@ class Model(object): def input_data(self, is_infer=False, **kwargs): name = "dataset." + kwargs.get("dataset_name") + "." - sparse_slots = envs.get_global_env(name + "sparse_slots") - dense_slots = envs.get_global_env(name + "dense_slots") - if sparse_slots is not None or dense_slots is not None: - sparse_slots = sparse_slots.strip().split(" ") - dense_slots = dense_slots.strip().split(" ") + sparse_slots = envs.get_global_env(name + "sparse_slots", "").strip() + dense_slots = envs.get_global_env(name + "dense_slots", "").strip() + if sparse_slots != "" or dense_slots != "": + if sparse_slots == "": + sparse_slots = [] + else: + sparse_slots = sparse_slots.strip().split(" ") + if dense_slots == "": + dense_slots = [] + else: + dense_slots = dense_slots.strip().split(" ") dense_slots_shape = [[ int(j) for j in i.split(":")[1].strip("[]").split(",") ] for i in dense_slots] diff --git a/core/trainers/single_infer.py b/core/trainers/single_infer.py index d830fcbb..7da93bd8 100755 --- a/core/trainers/single_infer.py +++ b/core/trainers/single_infer.py @@ -67,15 +67,14 @@ class SingleInfer(TranspileTrainer): def _get_dataset(self, dataset_name): name = "dataset." + dataset_name + "." - sparse_slots = envs.get_global_env(name + "sparse_slots") - dense_slots = envs.get_global_env(name + "dense_slots") thread_num = envs.get_global_env(name + "thread_num") batch_size = envs.get_global_env(name + "batch_size") reader_class = envs.get_global_env(name + "data_converter") abs_dir = os.path.dirname(os.path.abspath(__file__)) reader = os.path.join(abs_dir, '../utils', 'dataset_instance.py') - - if sparse_slots is None and dense_slots is None: + sparse_slots = envs.get_global_env(name + "sparse_slots", "").strip() + dense_slots = envs.get_global_env(name + "dense_slots", "").strip() + if sparse_slots == "" and dense_slots == "": pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "TRAIN", self._config_yaml) else: @@ -107,13 +106,13 @@ class SingleInfer(TranspileTrainer): def _get_dataloader(self, dataset_name, dataloader): name = "dataset." + dataset_name + "." - sparse_slots = envs.get_global_env(name + "sparse_slots") - dense_slots = envs.get_global_env(name + "dense_slots") thread_num = envs.get_global_env(name + "thread_num") batch_size = envs.get_global_env(name + "batch_size") reader_class = envs.get_global_env(name + "data_converter") abs_dir = os.path.dirname(os.path.abspath(__file__)) - if sparse_slots is None and dense_slots is None: + sparse_slots = envs.get_global_env(name + "sparse_slots", "").strip() + dense_slots = envs.get_global_env(name + "dense_slots", "").strip() + if sparse_slots == "" and dense_slots == "": reader = dataloader_instance.dataloader_by_name( reader_class, dataset_name, self._config_yaml) reader_class = envs.lazy_instance_by_fliename(reader_class, @@ -228,7 +227,9 @@ class SingleInfer(TranspileTrainer): model_class = self._model[model_name][3] fetch_vars = [] fetch_alias = [] - fetch_period = 20 + fetch_period = int( + envs.get_global_env("runner." + self._runner_name + + ".fetch_period", 20)) metrics = model_class.get_infer_results() if metrics: fetch_vars = metrics.values() @@ -251,14 +252,15 @@ class SingleInfer(TranspileTrainer): program = self._model[model_name][0].clone() fetch_vars = [] fetch_alias = [] - fetch_period = 20 metrics = model_class.get_infer_results() if metrics: fetch_vars = metrics.values() fetch_alias = metrics.keys() metrics_varnames = [] metrics_format = [] - fetch_period = 20 + fetch_period = int( + envs.get_global_env("runner." + self._runner_name + + ".fetch_period", 20)) metrics_format.append("{}: {{}}".format("batch")) for name, var in metrics.items(): metrics_varnames.append(var.name) diff --git a/core/trainers/single_trainer.py b/core/trainers/single_trainer.py index f7239d39..26462552 100755 --- a/core/trainers/single_trainer.py +++ b/core/trainers/single_trainer.py @@ -61,21 +61,20 @@ class SingleTrainer(TranspileTrainer): def _get_dataset(self, dataset_name): name = "dataset." + dataset_name + "." - sparse_slots = envs.get_global_env(name + "sparse_slots") - dense_slots = envs.get_global_env(name + "dense_slots") thread_num = envs.get_global_env(name + "thread_num") batch_size = envs.get_global_env(name + "batch_size") reader_class = envs.get_global_env(name + "data_converter") abs_dir = os.path.dirname(os.path.abspath(__file__)) reader = os.path.join(abs_dir, '../utils', 'dataset_instance.py') - - if sparse_slots is None and dense_slots is None: + sparse_slots = envs.get_global_env(name + "sparse_slots", "").strip() + dense_slots = envs.get_global_env(name + "dense_slots", "").strip() + if sparse_slots != "" and dense_slots != "": pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "TRAIN", self._config_yaml) else: - if sparse_slots is None: + if sparse_slots == "": sparse_slots = "#" - if dense_slots is None: + if dense_slots == "": dense_slots = "#" padding = envs.get_global_env(name + "padding", 0) pipe_cmd = "python {} {} {} {} {} {} {} {}".format( @@ -101,13 +100,13 @@ class SingleTrainer(TranspileTrainer): def _get_dataloader(self, dataset_name, dataloader): name = "dataset." + dataset_name + "." - sparse_slots = envs.get_global_env(name + "sparse_slots") - dense_slots = envs.get_global_env(name + "dense_slots") + sparse_slots = envs.get_global_env(name + "sparse_slots", "").strip() + dense_slots = envs.get_global_env(name + "dense_slots", "").strip() thread_num = envs.get_global_env(name + "thread_num") batch_size = envs.get_global_env(name + "batch_size") reader_class = envs.get_global_env(name + "data_converter") abs_dir = os.path.dirname(os.path.abspath(__file__)) - if sparse_slots is None and dense_slots is None: + if sparse_slots == "" and dense_slots == "": reader = dataloader_instance.dataloader_by_name( reader_class, dataset_name, self._config_yaml) reader_class = envs.lazy_instance_by_fliename(reader_class, @@ -125,8 +124,8 @@ class SingleTrainer(TranspileTrainer): def _create_dataset(self, dataset_name): name = "dataset." + dataset_name + "." - sparse_slots = envs.get_global_env(name + "sparse_slots") - dense_slots = envs.get_global_env(name + "dense_slots") + sparse_slots = envs.get_global_env(name + "sparse_slots", "").strip() + dense_slots = envs.get_global_env(name + "dense_slots", "").strip() thread_num = envs.get_global_env(name + "thread_num") batch_size = envs.get_global_env(name + "batch_size") type_name = envs.get_global_env(name + "type") @@ -225,7 +224,9 @@ class SingleTrainer(TranspileTrainer): model_class = self._model[model_name][3] fetch_vars = [] fetch_alias = [] - fetch_period = 20 + fetch_period = int( + envs.get_global_env("runner." + self._runner_name + + ".fetch_period", 20)) metrics = model_class.get_metrics() if metrics: fetch_vars = metrics.values() @@ -250,14 +251,15 @@ class SingleTrainer(TranspileTrainer): loss_name=model_class.get_avg_cost().name) fetch_vars = [] fetch_alias = [] - fetch_period = 20 + fetch_period = int( + envs.get_global_env("runner." + self._runner_name + + ".fetch_period", 20)) metrics = model_class.get_metrics() if metrics: fetch_vars = metrics.values() fetch_alias = metrics.keys() metrics_varnames = [] metrics_format = [] - fetch_period = 20 metrics_format.append("{}: {{}}".format("batch")) for name, var in metrics.items(): metrics_varnames.append(var.name) @@ -312,10 +314,11 @@ class SingleTrainer(TranspileTrainer): if not need_save(epoch_id, save_interval, False): return feed_varnames = envs.get_global_env( - name + "save_inference_feed_varnames", None) + name + "save_inference_feed_varnames", []) fetch_varnames = envs.get_global_env( - name + "save_inference_fetch_varnames", None) - if feed_varnames is None or fetch_varnames is None or feed_varnames == "": + name + "save_inference_fetch_varnames", []) + if feed_varnames is None or fetch_varnames is None or feed_varnames == "" or fetch_varnames == "" or \ + len(feed_varnames) == 0 or len(fetch_varnames) == 0: return fetch_vars = [ fluid.default_main_program().global_block().vars[varname] diff --git a/core/utils/dataloader_instance.py b/core/utils/dataloader_instance.py index 23cdfdc5..a26e2df2 100755 --- a/core/utils/dataloader_instance.py +++ b/core/utils/dataloader_instance.py @@ -68,8 +68,12 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file): data_path = os.path.join(package_base, data_path.split("::")[1]) files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)] - sparse = get_global_env(name + "sparse_slots") - dense = get_global_env(name + "dense_slots") + sparse = get_global_env(name + "sparse_slots", "#") + if sparse == "": + sparse = "#" + dense = get_global_env(name + "dense_slots", "#") + if dense == "": + dense = "#" padding = get_global_env(name + "padding", 0) reader = SlotReader(yaml_file) reader.init(sparse, dense, int(padding)) @@ -158,8 +162,12 @@ def slotdataloader(readerclass, train, yaml_file): files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)] - sparse = get_global_env("sparse_slots", None, namespace) - dense = get_global_env("dense_slots", None, namespace) + sparse = get_global_env("sparse_slots", "#", namespace) + if sparse == "": + sparse = "#" + dense = get_global_env("dense_slots", "#", namespace) + if dense == "": + dense = "#" padding = get_global_env("padding", 0, namespace) reader = SlotReader(yaml_file) reader.init(sparse, dense, int(padding)) diff --git a/models/rank/dnn/config.yaml b/models/rank/dnn/config.yaml index f144d2dd..57bb81d5 100755 --- a/models/rank/dnn/config.yaml +++ b/models/rank/dnn/config.yaml @@ -62,6 +62,7 @@ runner: save_inference_feed_varnames: [] # feed vars of save inference save_inference_fetch_varnames: [] # fetch vars of save inference init_model_path: "" # load model path + fetch_period: 10 - name: runner2 class: single_infer # num of epochs -- GitLab