From e8d24aa144b1f68436e98f4b343aa9d975e67717 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Thu, 17 May 2018 15:13:17 +0800 Subject: [PATCH] Inferencer support parallel_executor --- python/paddle/fluid/inferencer.py | 12 ++++++++++-- python/paddle/fluid/trainer.py | 12 ++++++------ 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/inferencer.py b/python/paddle/fluid/inferencer.py index 56c008d1af7..8c6dbd3b5ae 100644 --- a/python/paddle/fluid/inferencer.py +++ b/python/paddle/fluid/inferencer.py @@ -17,6 +17,7 @@ import core import executor import framework import io +import parallel_executor import unique_name from trainer import check_and_get_place @@ -24,7 +25,7 @@ __all__ = ['Inferencer', ] class Inferencer(object): - def __init__(self, infer_func, param_path, place=None): + def __init__(self, infer_func, param_path, place=None, parallel=False): """ :param infer_func: a function that will return predict Variable :param param_path: the path where the inference model is saved by fluid.io.save_params @@ -32,13 +33,20 @@ class Inferencer(object): """ self.param_path = param_path self.scope = core.Scope() + self.parallel = parallel + self.place = check_and_get_place(place) self.inference_program = framework.Program() with framework.program_guard(self.inference_program): with unique_name.guard(): self.predict_var = infer_func() - self.exe = executor.Executor(check_and_get_place(place)) + if parallel: + self.exe = parallel_executor.ParallelExecutor( + use_cuda=isinstance(self.place, core.CUDAPlace), + loss_name=self.predict_var.name) + else: + self.exe = executor.Executor(self.place) with executor.scope_guard(self.scope): # load params from param_path into scope io.load_params(self.exe, param_path, self.inference_program) diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index d158d586321..f4292208c94 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -12,18 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import os + import core -import framework -import executor + import data_feeder -import contextlib +import executor +import framework import io -import unique_name -import parallel_executor - # optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module import optimizer as opt_module +import parallel_executor from transpiler import distribute_transpiler __all__ = [ -- GitLab