提交 07bd7092 编写于 作者: X xjqbest

fix

上级 654ba717
...@@ -64,12 +64,12 @@ class SingleTrainer(TranspileTrainer): ...@@ -64,12 +64,12 @@ class SingleTrainer(TranspileTrainer):
dense_slots = envs.get_global_env(name + "dense_slots") dense_slots = envs.get_global_env(name + "dense_slots")
thread_num = envs.get_global_env(name + "thread_num") thread_num = envs.get_global_env(name + "thread_num")
batch_size = envs.get_global_env(name + "batch_size") batch_size = envs.get_global_env(name + "batch_size")
reader_class = envs.get_global_env("data_convertor") reader_class = envs.get_global_env(name + "data_converter")
abs_dir = os.path.dirname(os.path.abspath(__file__)) abs_dir = os.path.dirname(os.path.abspath(__file__))
reader = os.path.join(abs_dir, '../utils', 'dataset_instance.py') reader = os.path.join(abs_dir, '../utils', 'dataset_instance.py')
if sparse_slots is None and dense_slots is None: if sparse_slots is None and dense_slots is None:
pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "fake", self._config_yaml) pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "TRAIN", self._config_yaml)
else: else:
if sparse_slots is None: if sparse_slots is None:
sparse_slots = "#" sparse_slots = "#"
...@@ -103,7 +103,7 @@ class SingleTrainer(TranspileTrainer): ...@@ -103,7 +103,7 @@ class SingleTrainer(TranspileTrainer):
dense_slots = envs.get_global_env(name + "dense_slots") dense_slots = envs.get_global_env(name + "dense_slots")
thread_num = envs.get_global_env(name + "thread_num") thread_num = envs.get_global_env(name + "thread_num")
batch_size = envs.get_global_env(name + "batch_size") batch_size = envs.get_global_env(name + "batch_size")
reader_class = envs.get_global_env("data_convertor") reader_class = envs.get_global_env(name + "data_converter")
abs_dir = os.path.dirname(os.path.abspath(__file__)) abs_dir = os.path.dirname(os.path.abspath(__file__))
if sparse_slots is None and dense_slots is None: if sparse_slots is None and dense_slots is None:
reader = dataloader_instance.dataloader_by_name(reader_class, dataset_name, self._config_yaml) reader = dataloader_instance.dataloader_by_name(reader_class, dataset_name, self._config_yaml)
...@@ -156,7 +156,7 @@ class SingleTrainer(TranspileTrainer): ...@@ -156,7 +156,7 @@ class SingleTrainer(TranspileTrainer):
if envs.get_global_env("dataset." + dataset_name + ".type") == "DataLoader": if envs.get_global_env("dataset." + dataset_name + ".type") == "DataLoader":
model._init_dataloader() model._init_dataloader()
self._get_dataloader(dataset_name, model._data_loader) self._get_dataloader(dataset_name, model._data_loader)
model.net(model._data_var, is_infer=model_dict["is_infer"]) model.net(model._data_var, is_infer=model_dict.get("is_infer", False))
optimizer = model._build_optimizer(opt_name, opt_lr, opt_strategy) optimizer = model._build_optimizer(opt_name, opt_lr, opt_strategy)
optimizer.minimize(model._cost) optimizer.minimize(model._cost)
self._model[model_dict["name"]][0] = train_program self._model[model_dict["name"]][0] = train_program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册