From 7a7843c713a21433b09c893d53b9cb6df5120c0d Mon Sep 17 00:00:00 2001 From: chengmo Date: Thu, 14 May 2020 14:52:26 +0800 Subject: [PATCH] add evaluate only choice --- core/trainers/single_trainer.py | 7 ++++++- core/trainers/transpiler_trainer.py | 13 +++++++++++-- models/recall/word2vec/config.yaml | 6 +++++- 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/core/trainers/single_trainer.py b/core/trainers/single_trainer.py index ef304cf4..74d34024 100755 --- a/core/trainers/single_trainer.py +++ b/core/trainers/single_trainer.py @@ -55,7 +55,12 @@ class SingleTrainer(TranspileTrainer): if metrics: self.fetch_vars = metrics.values() self.fetch_alias = metrics.keys() - context['status'] = 'startup_pass' + evaluate_only = envs.get_global_env( + 'evaluate_only', False, namespace='evaluate') + if evaluate_only: + context['status'] = 'infer_pass' + else: + context['status'] = 'startup_pass' def startup(self, context): self._exe.run(fluid.default_startup_program()) diff --git a/core/trainers/transpiler_trainer.py b/core/trainers/transpiler_trainer.py index 2aed5bfb..44db24b8 100755 --- a/core/trainers/transpiler_trainer.py +++ b/core/trainers/transpiler_trainer.py @@ -229,8 +229,17 @@ class TranspileTrainer(Trainer): metrics_format = ", ".join(metrics_format) self._exe.run(startup_program) - for (epoch, model_dir) in self.increment_models: - print("Begin to infer epoch {}, model_dir: {}".format(epoch, model_dir)) + model_list = self.increment_models + + evaluate_only = envs.get_global_env( + 'evaluate_only', False, namespace='evaluate') + if evaluate_only: + model_list = [(0, envs.get_global_env( + 'evaluate_model_path', "", namespace='evaluate'))] + + for (epoch, model_dir) in model_list: + print("Begin to infer No.{} model, model_dir: {}".format( + epoch, model_dir)) program = infer_program.clone() fluid.io.load_persistables(self._exe, model_dir, program) reader.start() diff --git a/models/recall/word2vec/config.yaml b/models/recall/word2vec/config.yaml index fd48ef05..9bb5c4d3 100755 --- a/models/recall/word2vec/config.yaml +++ b/models/recall/word2vec/config.yaml @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. evaluate: - workspace: "paddlerec.models.recall.word2vec" + workspace: "paddlerec.models.recall.word2vec" + + evaluate_only: False + evaluate_model_path: "" + reader: batch_size: 50 class: "{workspace}/w2v_evaluate_reader.py" -- GitLab