From 07bd7092335e6e009e9a61d28951c51575302f3e Mon Sep 17 00:00:00 2001 From: xjqbest <173596896@qq.com> Date: Thu, 28 May 2020 10:40:51 +0800 Subject: [PATCH] fix --- core/trainers/single_trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/trainers/single_trainer.py b/core/trainers/single_trainer.py index 3ddc48e2..c4a5fed2 100755 --- a/core/trainers/single_trainer.py +++ b/core/trainers/single_trainer.py @@ -64,12 +64,12 @@ class SingleTrainer(TranspileTrainer): dense_slots = envs.get_global_env(name + "dense_slots") thread_num = envs.get_global_env(name + "thread_num") batch_size = envs.get_global_env(name + "batch_size") - reader_class = envs.get_global_env("data_convertor") + reader_class = envs.get_global_env(name + "data_converter") abs_dir = os.path.dirname(os.path.abspath(__file__)) reader = os.path.join(abs_dir, '../utils', 'dataset_instance.py') if sparse_slots is None and dense_slots is None: - pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "fake", self._config_yaml) + pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "TRAIN", self._config_yaml) else: if sparse_slots is None: sparse_slots = "#" @@ -103,7 +103,7 @@ class SingleTrainer(TranspileTrainer): dense_slots = envs.get_global_env(name + "dense_slots") thread_num = envs.get_global_env(name + "thread_num") batch_size = envs.get_global_env(name + "batch_size") - reader_class = envs.get_global_env("data_convertor") + reader_class = envs.get_global_env(name + "data_converter") abs_dir = os.path.dirname(os.path.abspath(__file__)) if sparse_slots is None and dense_slots is None: reader = dataloader_instance.dataloader_by_name(reader_class, dataset_name, self._config_yaml) @@ -156,7 +156,7 @@ class SingleTrainer(TranspileTrainer): if envs.get_global_env("dataset." + dataset_name + ".type") == "DataLoader": model._init_dataloader() self._get_dataloader(dataset_name, model._data_loader) - model.net(model._data_var, is_infer=model_dict["is_infer"]) + model.net(model._data_var, is_infer=model_dict.get("is_infer", False)) optimizer = model._build_optimizer(opt_name, opt_lr, opt_strategy) optimizer.minimize(model._cost) self._model[model_dict["name"]][0] = train_program -- GitLab