提交 2a971f30 编写于 作者: Q Qiao Longfei 提交者: Jeff Wang

Add inferencer infer (#10445)

* add Inference.infer

* optimize code

* update no_test_word2vec_new_api.py

* update trainer

* split check_and_get_place

* use inference_program to save inference model in Trainer

* update demo

* update save_inference_model

* clean code
上级 aaab92b0
......@@ -16,31 +16,42 @@ import core
import framework
import executor
import io
from trainer import check_and_get_place
__all__ = ['Inferencer', ]
class Inferencer(object):
def __init__(self, network_func, param_path=None, place=None):
# 1. we need to generate a framework.Program by calling
# network_func. Reference: fluid.program_guard in test_word2vec.py
# 2. move the default_main_program to self.program.
# 3. run the default_startup program.
# 4. load params from param_path into scope
def __init__(self, param_path, place=None):
"""
:param param_path: the path where the inference model is saved by fluid.io.save_inference_model
:param place: place to do the inference
"""
self.param_path = param_path
self.scope = core.Scope()
self.place = place
self.startup_program = framework.Program()
# TODO: generate the startup_program with network_func
exe = executor.Executor(place)
exe.run(self.startup_program, scope=self.scope)
if param_path:
self.exe = executor.Executor(check_and_get_place(place))
with executor.scope_guard(self.scope):
# load params from param_path into scope
io.load_persistables(exe, dirname=param_path)
def infer(self, inputs):
# run self.program
pass
[self.inference_program, _,
self.fetch_targets] = io.load_inference_model(
executor=self.exe, dirname=param_path)
def infer(self, inputs, return_numpy=True):
"""
:param inputs: a map of {"input_name": input_var} that will be feed into the inference program
to get the predict value
:param return_numpy: if return numpy value for row tensor
:return: the predict value of the inference model
"""
if not isinstance(inputs, dict):
raise ValueError(
"inputs should be a map of {'input_name': input_var}")
with executor.scope_guard(self.scope):
results = self.exe.run(self.inference_program,
feed=inputs,
fetch_list=self.fetch_targets,
return_numpy=return_numpy)
return results
......@@ -263,6 +263,9 @@ def get_inference_program(target_vars, main_program=None):
def prepend_feed_ops(inference_program,
feed_target_names,
feed_holder_name='feed'):
if len(feed_target_names) == 0:
return
global_block = inference_program.global_block()
feed_var = global_block.create_var(
name=feed_holder_name,
......@@ -323,6 +326,7 @@ def save_inference_model(dirname,
if isinstance(feeded_var_names, basestring):
feeded_var_names = [feeded_var_names]
else:
if len(feeded_var_names) > 0:
if not (bool(feeded_var_names) and all(
isinstance(name, basestring) for name in feeded_var_names)):
raise ValueError("'feed_var_names' should be a list of str.")
......
......@@ -99,45 +99,45 @@ def train(use_cuda, is_sparse, save_path):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
def event_handler(event):
# print type(event)
if isinstance(event, fluid.EndEpochEvent):
outs = trainer.test(reader=test_reader)
avg_cost = outs[0]
print("loss= ", avg_cost)
if avg_cost < 5.0:
trainer.save_params(save_path)
trainer.save_inference_model(save_path)
return
if math.isnan(avg_cost):
sys.exit("got NaN loss, training failed.")
trainer = fluid.Trainer(
partial(train_program, is_sparse),
partial(inference_program, is_sparse),
fluid.optimizer.SGD(learning_rate=0.001),
place=place)
trainer.train(
reader=train_reader, num_epochs=100, event_handler=event_handler)
reader=train_reader, num_epochs=1, event_handler=event_handler)
def infer(use_cuda, is_sparse, save_path):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
inferencer = fluid.Inferencer(
partial(inference_program, is_sparse),
param_path=save_path,
place=place)
inferencer = fluid.Inferencer(param_path=save_path, place=place)
lod = [0, 1]
first_word = create_random_lodtensor(lod, place, low=0, high=dict_size - 1)
second_word = create_random_lodtensor(lod, place, low=0, high=dict_size - 1)
third_word = create_random_lodtensor(lod, place, low=0, high=dict_size - 1)
fourth_word = create_random_lodtensor(lod, place, low=0, high=dict_size - 1)
result = inferencer.infer({
result = inferencer.infer(
{
'firstw': first_word,
'secondw': second_word,
'thirdw': third_word,
'forthw': fourth_word
})
print(result)
},
return_numpy=False)
print(np.array(result[0]))
def main(use_cuda, is_sparse):
......
......@@ -19,7 +19,7 @@ import executor
import data_feeder
import contextlib
import io
import transpiler
import unique_name
# optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module
import optimizer as opt_module
......@@ -56,26 +56,62 @@ class EndStepEvent(object):
self.step = step_id
def check_and_get_place(place):
"""
Check the type of place or get the default place
Args:
place(None|core.CUDAPlace|core.CPUPlace): the place that trainer will be executed on.
Raises:
TypeError if the type mismatched.
Returns:
the original place if it is not None.
if fluid is compiled with CUDA, returns CUDAPlace(0) by default.
Otherwise returns CPUPlace by default.
"""
if place is None:
if core.is_compiled_with_cuda():
return core.CUDAPlace(0)
else:
return core.CPUPlace()
else:
if not isinstance(place, core.CUDAPlace) and not isinstance(
place, core.CPUPlace):
raise TypeError("Place should be either CUDAPlace or CPUPlace")
return place
class Trainer(object):
"""
Args:
program_func(callable): A function which will return loss. The loss must be a scaler.
train_func(callable): A function which will return loss. The loss must be a scalar.
infer_func(callable): A function which will return predict, used to save inference model
optimizer(optimizer.Optimizer): The optimizer should be an instance of Optimizer
place: The device place of this trainer.
"""
def __init__(self, program_func, optimizer, param_path=None, place=None):
def __init__(self,
train_func,
infer_func,
optimizer,
param_path=None,
place=None):
# 1. we need to generate a framework.Program by calling
# program_func. Reference: fluid.program_guard in
# test_word2vec.py
if not isinstance(optimizer, opt_module.Optimizer):
raise TypeError("The optimizer should be an instance of Optimizer")
self.infer_func = infer_func
self.scope = core.Scope()
self.startup_program = framework.Program()
self.train_program = framework.Program()
with framework.program_guard(self.train_program, self.startup_program):
program_func_outs = program_func()
program_func_outs = train_func()
self.test_outputs = program_func_outs if isinstance(
program_func_outs, list) else [program_func_outs]
self.test_program = self.train_program.clone()
......@@ -86,9 +122,9 @@ class Trainer(object):
loss = self.test_outputs[0]
optimize_ops, params_grads = optimizer.minimize(loss)
self.place = Trainer._check_and_get_place(place)
self.place = check_and_get_place(place)
self.dist_transpile_if_necessary(optimize_ops, params_grads)
self._dist_transpile_if_necessary(optimize_ops, params_grads)
# 2. move the default_main_program to self.program and run the
# default_startup program on an empty core.Scope()
......@@ -101,7 +137,7 @@ class Trainer(object):
# load params from param_path into scope
io.load_persistables(exe, dirname=param_path)
def dist_transpile_if_necessary(self, optimize_ops, params_grads):
def _dist_transpile_if_necessary(self, optimize_ops, params_grads):
if "PADDLE_TRAINING_ROLE" not in os.environ:
return
......@@ -190,31 +226,14 @@ class Trainer(object):
exe = executor.Executor(self.place)
io.save_persistables(exe, dirname=param_path)
@staticmethod
def _check_and_get_place(place):
"""
Check the type of place or get the default place
Args:
place(None|core.CUDAPlace|core.CPUPlace): the place that trainer will be executed on.
Raises:
TypeError if the type mismatched.
Returns:
the original place if it is not None.
if fluid is compiled with CUDA, returns CUDAPlace(0) by default.
Otherwise returns CPUPlace by default.
"""
if place is None:
if core.is_compiled_with_cuda():
return core.CUDAPlace(0)
else:
return core.CPUPlace()
else:
if not isinstance(place, core.CUDAPlace) and not isinstance(
place, core.CPUPlace):
raise TypeError("Place should be either CUDAPlace or CPUPlace")
return place
def save_inference_model(self, model_path):
inference_program = framework.Program()
with framework.program_guard(inference_program):
with unique_name.guard():
predict_var = self.infer_func()
predict_var = self.train_program.block(0).var(predict_var.name)
exe = executor.Executor(self.place)
io.save_inference_model(model_path, [], [predict_var], exe)
@contextlib.contextmanager
def _prog_and_scope_guard(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册