diff --git a/paddle_hub/finetune/task.py b/paddle_hub/finetune/task.py index 1d75b6665cf7aa28a741b50d589732e7a3eec001..9275e95d08276133699d22f7ff1f8669a79c9a38 100644 --- a/paddle_hub/finetune/task.py +++ b/paddle_hub/finetune/task.py @@ -30,6 +30,7 @@ class Task(object): self.graph_var_dict = graph_var_dict self._main_program = main_program self._startup_program = startup_program + self._inference_program = main_program.clone(for_test=True) def variable(self, var_name): if var_name in self.graph_var_dict: @@ -42,3 +43,6 @@ class Task(object): def startup_program(self): return self._startup_program + + def inference_program(self): + return self._inference_program