diff --git a/paddlehub/finetune/task/base_task.py b/paddlehub/finetune/task/base_task.py index 5f0bf59b31b379368b085473231535f1d6318f56..a545eff20a3966de90620eac2cbbdda7f936f63c 100644 --- a/paddlehub/finetune/task/base_task.py +++ b/paddlehub/finetune/task/base_task.py @@ -30,12 +30,13 @@ if six.PY2: else: from inspect import getfullargspec as get_args import numpy as np +import paddle import paddle.fluid as fluid from tb_paddle import SummaryWriter import paddlehub as hub from paddlehub.common.paddle_helper import dtype_map, clone_program -from paddlehub.common.utils import mkdir +from paddlehub.common.utils import mkdir, version_compare from paddlehub.common.dir import tmp_dir from paddlehub.common.logger import logger from paddlehub.finetune.checkpoint import load_checkpoint, save_checkpoint @@ -983,7 +984,7 @@ class BaseTask(object): data, load_best_model=True, return_result=False, - accelerate_mode=False): + accelerate_mode=True): """ make prediction for the input data. @@ -996,6 +997,11 @@ class BaseTask(object): Returns: RunState: the running result of predict phase """ + if not version_compare(paddle.__version__, "1.6.2") and accelerate_mode: + logger.warning( + "Fail to open predict accelerate mode as it does not support paddle < 1.6.2. Please update PaddlePaddle." + ) + accelerate_mode = False self.accelerate_mode = accelerate_mode with self.phase_guard(phase="predict"):