diff --git a/core/trainers/single_trainer.py b/core/trainers/single_trainer.py index ef304cf46ee5775949910431603e34466015971d..74d340242755c8b5198a48b76eec50a267530dce 100755 --- a/core/trainers/single_trainer.py +++ b/core/trainers/single_trainer.py @@ -55,7 +55,12 @@ class SingleTrainer(TranspileTrainer): if metrics: self.fetch_vars = metrics.values() self.fetch_alias = metrics.keys() - context['status'] = 'startup_pass' + evaluate_only = envs.get_global_env( + 'evaluate_only', False, namespace='evaluate') + if evaluate_only: + context['status'] = 'infer_pass' + else: + context['status'] = 'startup_pass' def startup(self, context): self._exe.run(fluid.default_startup_program()) diff --git a/core/trainers/transpiler_trainer.py b/core/trainers/transpiler_trainer.py index 2aed5bfbedfd264b89671ba0fa776ff716042cf5..44db24b883e0accb438d416dbfe73f9755beffb7 100755 --- a/core/trainers/transpiler_trainer.py +++ b/core/trainers/transpiler_trainer.py @@ -229,8 +229,17 @@ class TranspileTrainer(Trainer): metrics_format = ", ".join(metrics_format) self._exe.run(startup_program) - for (epoch, model_dir) in self.increment_models: - print("Begin to infer epoch {}, model_dir: {}".format(epoch, model_dir)) + model_list = self.increment_models + + evaluate_only = envs.get_global_env( + 'evaluate_only', False, namespace='evaluate') + if evaluate_only: + model_list = [(0, envs.get_global_env( + 'evaluate_model_path', "", namespace='evaluate'))] + + for (epoch, model_dir) in model_list: + print("Begin to infer No.{} model, model_dir: {}".format( + epoch, model_dir)) program = infer_program.clone() fluid.io.load_persistables(self._exe, model_dir, program) reader.start() diff --git a/models/recall/word2vec/config.yaml b/models/recall/word2vec/config.yaml index fd48ef056a1dad7bcb0b992a7f3619e86d9f1cfc..9bb5c4d3fe42bc2385ff22ad150924ad5a3a59cd 100755 --- a/models/recall/word2vec/config.yaml +++ b/models/recall/word2vec/config.yaml @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. evaluate: - workspace: "paddlerec.models.recall.word2vec" + workspace: "paddlerec.models.recall.word2vec" + + evaluate_only: False + evaluate_model_path: "" + reader: batch_size: 50 class: "{workspace}/w2v_evaluate_reader.py"