提交 a987bb78 编写于 作者: X xjqbest

fix

上级 5090515b
...@@ -110,11 +110,13 @@ class SingleTrainer(TranspileTrainer): ...@@ -110,11 +110,13 @@ class SingleTrainer(TranspileTrainer):
reader_class = envs.get_global_env(name + "data_converter") reader_class = envs.get_global_env(name + "data_converter")
abs_dir = os.path.dirname(os.path.abspath(__file__)) abs_dir = os.path.dirname(os.path.abspath(__file__))
if sparse_slots is None and dense_slots is None: if sparse_slots is None and dense_slots is None:
reader = dataloader_instance.dataloader_by_name(reader_class, dataset_name, self._config_yaml) reader = dataloader_instance.dataloader_by_name(
reader_class, dataset_name, self._config_yaml)
reader_class = envs.lazy_instance_by_fliename(reader_class, "TrainReader") reader_class = envs.lazy_instance_by_fliename(reader_class, "TrainReader")
reader_ins = reader_class(self._config_yaml) reader_ins = reader_class(self._config_yaml)
else: else:
reader = dataloader_instance.slotdataloader_by_name("", dataset_name, self._config_yaml) reader = dataloader_instance.slotdataloader_by_name(
"", dataset_name, self._config_yaml)
reader_ins = SlotReader(self._config_yaml) reader_ins = SlotReader(self._config_yaml)
if hasattr(reader_ins, 'generate_batch_from_trainfiles'): if hasattr(reader_ins, 'generate_batch_from_trainfiles'):
dataloader.set_sample_list_generator(reader) dataloader.set_sample_list_generator(reader)
...@@ -122,7 +124,6 @@ class SingleTrainer(TranspileTrainer): ...@@ -122,7 +124,6 @@ class SingleTrainer(TranspileTrainer):
dataloader.set_sample_generator(reader, batch_size) dataloader.set_sample_generator(reader, batch_size)
return dataloader return dataloader
def _create_dataset(self, dataset_name): def _create_dataset(self, dataset_name):
name = "dataset." + dataset_name + "." name = "dataset." + dataset_name + "."
sparse_slots = envs.get_global_env(name + "sparse_slots") sparse_slots = envs.get_global_env(name + "sparse_slots")
...@@ -131,7 +132,8 @@ class SingleTrainer(TranspileTrainer): ...@@ -131,7 +132,8 @@ class SingleTrainer(TranspileTrainer):
batch_size = envs.get_global_env(name + "batch_size") batch_size = envs.get_global_env(name + "batch_size")
type_name = envs.get_global_env(name + "type") type_name = envs.get_global_env(name + "type")
if envs.get_platform() != "LINUX": if envs.get_platform() != "LINUX":
print("platform ", envs.get_platform(), " change reader to DataLoader") print("platform ", envs.get_platform(),
" change reader to DataLoader")
type_name = "DataLoader" type_name = "DataLoader"
padding = 0 padding = 0
...@@ -140,7 +142,6 @@ class SingleTrainer(TranspileTrainer): ...@@ -140,7 +142,6 @@ class SingleTrainer(TranspileTrainer):
else: else:
return self._get_dataset(dataset_name) return self._get_dataset(dataset_name)
def init(self, context): def init(self, context):
for model_dict in self._env["executor"]: for model_dict in self._env["executor"]:
self._model[model_dict["name"]] = [None] * 5 self._model[model_dict["name"]] = [None] * 5
...@@ -187,7 +188,6 @@ class SingleTrainer(TranspileTrainer): ...@@ -187,7 +188,6 @@ class SingleTrainer(TranspileTrainer):
self._model[model_dict["name"]][3] = model self._model[model_dict["name"]][3] = model
self._model[model_dict["name"]][4] = train_program.clone() self._model[model_dict["name"]][4] = train_program.clone()
for dataset in self._env["dataset"]: for dataset in self._env["dataset"]:
if dataset["type"] != "DataLoader": if dataset["type"] != "DataLoader":
self._dataset[dataset["name"]] = self._create_dataset(dataset[ self._dataset[dataset["name"]] = self._create_dataset(dataset[
...@@ -268,7 +268,8 @@ class SingleTrainer(TranspileTrainer): ...@@ -268,7 +268,8 @@ class SingleTrainer(TranspileTrainer):
program = self._model[model_name][0].clone() program = self._model[model_name][0].clone()
if not model_dict["is_infer"]: if not model_dict["is_infer"]:
program = fluid.compiler.CompiledProgram( program = fluid.compiler.CompiledProgram(
program).with_data_parallel(loss_name=model_class.get_avg_cost().name) program).with_data_parallel(
loss_name=model_class.get_avg_cost().name)
fetch_vars = [] fetch_vars = []
fetch_alias = [] fetch_alias = []
fetch_period = 20 fetch_period = 20
...@@ -333,8 +334,10 @@ class SingleTrainer(TranspileTrainer): ...@@ -333,8 +334,10 @@ class SingleTrainer(TranspileTrainer):
envs.get_global_env("epoch.save_inference_interval", -1)) envs.get_global_env("epoch.save_inference_interval", -1))
if not need_save(epoch_id, save_interval, False): if not need_save(epoch_id, save_interval, False):
return return
feed_varnames = envs.get_global_env("epoch.save_inference_feed_varnames", None) feed_varnames = envs.get_global_env(
fetch_varnames = envs.get_global_env("epoch.save_inference_fetch_varnames", None) "epoch.save_inference_feed_varnames", None)
fetch_varnames = envs.get_global_env(
"epoch.save_inference_fetch_varnames", None)
if feed_varnames is None or fetch_varnames is None or feed_varnames == "": if feed_varnames is None or fetch_varnames is None or feed_varnames == "":
return return
fetch_vars = [ fetch_vars = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册