提交 a987bb78 编写于 作者: X xjqbest

fix

上级 5090515b
......@@ -110,11 +110,13 @@ class SingleTrainer(TranspileTrainer):
reader_class = envs.get_global_env(name + "data_converter")
abs_dir = os.path.dirname(os.path.abspath(__file__))
if sparse_slots is None and dense_slots is None:
reader = dataloader_instance.dataloader_by_name(reader_class, dataset_name, self._config_yaml)
reader = dataloader_instance.dataloader_by_name(
reader_class, dataset_name, self._config_yaml)
reader_class = envs.lazy_instance_by_fliename(reader_class, "TrainReader")
reader_ins = reader_class(self._config_yaml)
else:
reader = dataloader_instance.slotdataloader_by_name("", dataset_name, self._config_yaml)
reader = dataloader_instance.slotdataloader_by_name(
"", dataset_name, self._config_yaml)
reader_ins = SlotReader(self._config_yaml)
if hasattr(reader_ins, 'generate_batch_from_trainfiles'):
dataloader.set_sample_list_generator(reader)
......@@ -122,7 +124,6 @@ class SingleTrainer(TranspileTrainer):
dataloader.set_sample_generator(reader, batch_size)
return dataloader
def _create_dataset(self, dataset_name):
name = "dataset." + dataset_name + "."
sparse_slots = envs.get_global_env(name + "sparse_slots")
......@@ -131,7 +132,8 @@ class SingleTrainer(TranspileTrainer):
batch_size = envs.get_global_env(name + "batch_size")
type_name = envs.get_global_env(name + "type")
if envs.get_platform() != "LINUX":
print("platform ", envs.get_platform(), " change reader to DataLoader")
print("platform ", envs.get_platform(),
" change reader to DataLoader")
type_name = "DataLoader"
padding = 0
......@@ -140,7 +142,6 @@ class SingleTrainer(TranspileTrainer):
else:
return self._get_dataset(dataset_name)
def init(self, context):
for model_dict in self._env["executor"]:
self._model[model_dict["name"]] = [None] * 5
......@@ -187,7 +188,6 @@ class SingleTrainer(TranspileTrainer):
self._model[model_dict["name"]][3] = model
self._model[model_dict["name"]][4] = train_program.clone()
for dataset in self._env["dataset"]:
if dataset["type"] != "DataLoader":
self._dataset[dataset["name"]] = self._create_dataset(dataset[
......@@ -268,7 +268,8 @@ class SingleTrainer(TranspileTrainer):
program = self._model[model_name][0].clone()
if not model_dict["is_infer"]:
program = fluid.compiler.CompiledProgram(
program).with_data_parallel(loss_name=model_class.get_avg_cost().name)
program).with_data_parallel(
loss_name=model_class.get_avg_cost().name)
fetch_vars = []
fetch_alias = []
fetch_period = 20
......@@ -333,8 +334,10 @@ class SingleTrainer(TranspileTrainer):
envs.get_global_env("epoch.save_inference_interval", -1))
if not need_save(epoch_id, save_interval, False):
return
feed_varnames = envs.get_global_env("epoch.save_inference_feed_varnames", None)
fetch_varnames = envs.get_global_env("epoch.save_inference_fetch_varnames", None)
feed_varnames = envs.get_global_env(
"epoch.save_inference_feed_varnames", None)
fetch_varnames = envs.get_global_env(
"epoch.save_inference_fetch_varnames", None)
if feed_varnames is None or fetch_varnames is None or feed_varnames == "":
return
fetch_vars = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册