提交 bb7af6ed 编写于 作者: X xjqbest

fix

上级 95f2364b
...@@ -212,8 +212,11 @@ class SingleTrainer(TranspileTrainer): ...@@ -212,8 +212,11 @@ class SingleTrainer(TranspileTrainer):
self._executor_dataloader_train(model_dict) self._executor_dataloader_train(model_dict)
else: else:
self._executor_dataset_train(model_dict) self._executor_dataset_train(model_dict)
with fluid.scope_guard(self._model[model_name][2]): with fluid.scope_guard(self._model[model_dict["name"]][2]):
self.save(self, j) train_prog = self._model[model_dict["name"]][0]
startup_prog = self._model[model_dict["name"]][1]
with fluid.program_guard(train_prog, startup_prog):
self.save(j)
end_time = time.time() end_time = time.time()
seconds = end_time - begin_time seconds = end_time - begin_time
print("epoch {} done, time elasped: {}".format(j, seconds)) print("epoch {} done, time elasped: {}".format(j, seconds))
...@@ -318,10 +321,9 @@ class SingleTrainer(TranspileTrainer): ...@@ -318,10 +321,9 @@ class SingleTrainer(TranspileTrainer):
else: else:
fluid.io.save_inference_model(dirname, feed_varnames, fluid.io.save_inference_model(dirname, feed_varnames,
fetch_vars, self._exe) fetch_vars, self._exe)
self.inference_models.append((epoch_id, dirname))
def save_persistables(): def save_persistables():
save_interval = envs.get_global_env("epoch.save_checkpoint_interval", -1) save_interval = int(envs.get_global_env("epoch.save_checkpoint_interval", -1))
if not need_save(epoch_id, save_interval, False): if not need_save(epoch_id, save_interval, False):
return return
dirname = envs.get_global_env("epoch.save_checkpoint_path", None) dirname = envs.get_global_env("epoch.save_checkpoint_path", None)
...@@ -331,7 +333,6 @@ class SingleTrainer(TranspileTrainer): ...@@ -331,7 +333,6 @@ class SingleTrainer(TranspileTrainer):
fleet.save_persistables(self._exe, dirname) fleet.save_persistables(self._exe, dirname)
else: else:
fluid.io.save_persistables(self._exe, dirname) fluid.io.save_persistables(self._exe, dirname)
self.increment_models.append((epoch_id, dirname))
save_persistables() save_persistables()
save_inference_model() save_inference_model()
...@@ -21,8 +21,8 @@ workspace: "paddlerec.models.rank.dnn" ...@@ -21,8 +21,8 @@ workspace: "paddlerec.models.rank.dnn"
dataset: dataset:
- name: dataset_2 - name: dataset_2
batch_size: 2 batch_size: 2
#type: QueueDataset type: QueueDataset
type: DataLoader #type: DataLoader
data_path: "{workspace}/data/sample_data/train" data_path: "{workspace}/data/sample_data/train"
sparse_slots: "click 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26" sparse_slots: "click 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26"
dense_slots: "dense_var:13" dense_slots: "dense_var:13"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册