diff --git a/fleet_rec/core/trainers/single_trainer.py b/fleet_rec/core/trainers/single_trainer.py index 9cf970827b08a5b71f036ad1fab192ce1b80d03b..040334fc72e800a5db63efffc3e883baf518a715 100644 --- a/fleet_rec/core/trainers/single_trainer.py +++ b/fleet_rec/core/trainers/single_trainer.py @@ -122,6 +122,10 @@ class SingleTrainer(TranspileTrainer): with fluid.program_guard(infer_program, startup_program): self.model.infer_net() + if self.model._infer_data_loader is None: + context['status'] = 'terminal_pass' + return + reader = self._get_dataloader("Evaluate") metrics_varnames = [] diff --git a/fleet_rec/core/trainers/transpiler_trainer.py b/fleet_rec/core/trainers/transpiler_trainer.py index eb7d8b0b8f3a6c52b9fdf0a9a63b49e08f062141..84bf3f89b0840e8d8e635e2f3be32e8e601416aa 100644 --- a/fleet_rec/core/trainers/transpiler_trainer.py +++ b/fleet_rec/core/trainers/transpiler_trainer.py @@ -102,8 +102,8 @@ class TranspileTrainer(Trainer): if not need_save(epoch_id, save_interval, False): return - # print("save inference model is not supported now.") - # return + print("save inference model is not supported now.") + return feed_varnames = envs.get_global_env("save.inference.feed_varnames", None, namespace) fetch_varnames = envs.get_global_env("save.inference.fetch_varnames", None, namespace)