提交 e8d24aa1 编写于 作者: Q qiaolongfei

Inferencer support parallel_executor

上级 2a636529
...@@ -17,6 +17,7 @@ import core ...@@ -17,6 +17,7 @@ import core
import executor import executor
import framework import framework
import io import io
import parallel_executor
import unique_name import unique_name
from trainer import check_and_get_place from trainer import check_and_get_place
...@@ -24,7 +25,7 @@ __all__ = ['Inferencer', ] ...@@ -24,7 +25,7 @@ __all__ = ['Inferencer', ]
class Inferencer(object): class Inferencer(object):
def __init__(self, infer_func, param_path, place=None): def __init__(self, infer_func, param_path, place=None, parallel=False):
""" """
:param infer_func: a function that will return predict Variable :param infer_func: a function that will return predict Variable
:param param_path: the path where the inference model is saved by fluid.io.save_params :param param_path: the path where the inference model is saved by fluid.io.save_params
...@@ -32,13 +33,20 @@ class Inferencer(object): ...@@ -32,13 +33,20 @@ class Inferencer(object):
""" """
self.param_path = param_path self.param_path = param_path
self.scope = core.Scope() self.scope = core.Scope()
self.parallel = parallel
self.place = check_and_get_place(place)
self.inference_program = framework.Program() self.inference_program = framework.Program()
with framework.program_guard(self.inference_program): with framework.program_guard(self.inference_program):
with unique_name.guard(): with unique_name.guard():
self.predict_var = infer_func() self.predict_var = infer_func()
self.exe = executor.Executor(check_and_get_place(place)) if parallel:
self.exe = parallel_executor.ParallelExecutor(
use_cuda=isinstance(self.place, core.CUDAPlace),
loss_name=self.predict_var.name)
else:
self.exe = executor.Executor(self.place)
with executor.scope_guard(self.scope): with executor.scope_guard(self.scope):
# load params from param_path into scope # load params from param_path into scope
io.load_params(self.exe, param_path, self.inference_program) io.load_params(self.exe, param_path, self.inference_program)
......
...@@ -12,18 +12,18 @@ ...@@ -12,18 +12,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import contextlib
import os import os
import core import core
import framework
import executor
import data_feeder import data_feeder
import contextlib import executor
import framework
import io import io
import unique_name
import parallel_executor
# optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module # optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module
import optimizer as opt_module import optimizer as opt_module
import parallel_executor
from transpiler import distribute_transpiler from transpiler import distribute_transpiler
__all__ = [ __all__ = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册