From bb7af6ed8a0368292d1c5c520f71bcc640800516 Mon Sep 17 00:00:00 2001 From: xjqbest <173596896@qq.com> Date: Wed, 27 May 2020 23:34:56 +0800 Subject: [PATCH] fix --- core/trainers/single_trainer.py | 11 ++++++----- models/rank/dnn/config.yaml | 4 ++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/core/trainers/single_trainer.py b/core/trainers/single_trainer.py index 1cd4ec84..d630c21f 100755 --- a/core/trainers/single_trainer.py +++ b/core/trainers/single_trainer.py @@ -212,8 +212,11 @@ class SingleTrainer(TranspileTrainer): self._executor_dataloader_train(model_dict) else: self._executor_dataset_train(model_dict) - with fluid.scope_guard(self._model[model_name][2]): - self.save(self, j) + with fluid.scope_guard(self._model[model_dict["name"]][2]): + train_prog = self._model[model_dict["name"]][0] + startup_prog = self._model[model_dict["name"]][1] + with fluid.program_guard(train_prog, startup_prog): + self.save(j) end_time = time.time() seconds = end_time - begin_time print("epoch {} done, time elasped: {}".format(j, seconds)) @@ -318,10 +321,9 @@ class SingleTrainer(TranspileTrainer): else: fluid.io.save_inference_model(dirname, feed_varnames, fetch_vars, self._exe) - self.inference_models.append((epoch_id, dirname)) def save_persistables(): - save_interval = envs.get_global_env("epoch.save_checkpoint_interval", -1) + save_interval = int(envs.get_global_env("epoch.save_checkpoint_interval", -1)) if not need_save(epoch_id, save_interval, False): return dirname = envs.get_global_env("epoch.save_checkpoint_path", None) @@ -331,7 +333,6 @@ class SingleTrainer(TranspileTrainer): fleet.save_persistables(self._exe, dirname) else: fluid.io.save_persistables(self._exe, dirname) - self.increment_models.append((epoch_id, dirname)) save_persistables() save_inference_model() diff --git a/models/rank/dnn/config.yaml b/models/rank/dnn/config.yaml index a51647d0..f15e67e9 100755 --- a/models/rank/dnn/config.yaml +++ b/models/rank/dnn/config.yaml @@ -21,8 +21,8 @@ workspace: "paddlerec.models.rank.dnn" dataset: - name: dataset_2 batch_size: 2 - #type: QueueDataset - type: DataLoader + type: QueueDataset + #type: DataLoader data_path: "{workspace}/data/sample_data/train" sparse_slots: "click 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26" dense_slots: "dense_var:13" -- GitLab