未验证 提交 29543da5 编写于 作者: S Shibo Tao 提交者: GitHub

export paddle.static.normalize_program method. test=develop (#31080)

上级 1d2bd35e
......@@ -356,6 +356,48 @@ class TestSaveInferenceModelNew(unittest.TestCase):
self.assertRaises(TypeError, paddle.static.io.deserialize_persistables,
None, None, None)
def test_normalize_program(self):
init_program = fluid.default_startup_program()
program = fluid.default_main_program()
# fake program without feed/fetch
with program_guard(program, init_program):
x = layers.data(name='x', shape=[2], dtype='float32')
y = layers.data(name='y', shape=[1], dtype='float32')
y_predict = layers.fc(input=x, size=1, act=None)
cost = layers.square_error_cost(input=y_predict, label=y)
avg_cost = layers.mean(cost)
sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.001)
sgd_optimizer.minimize(avg_cost, init_program)
place = core.CPUPlace()
exe = executor.Executor(place)
exe.run(init_program, feed={}, fetch_list=[])
tensor_x = np.array([[1, 1], [1, 2], [5, 2]]).astype("float32")
tensor_y = np.array([[-2], [-3], [-7]]).astype("float32")
for i in six.moves.xrange(3):
exe.run(program,
feed={'x': tensor_x,
'y': tensor_y},
fetch_list=[avg_cost])
# test if return type of serialize_program is bytes
res = paddle.static.normalize_program(program, [x, y], [avg_cost])
self.assertTrue(isinstance(res, Program))
# test program type
self.assertRaises(TypeError, paddle.static.normalize_program, None,
[x, y], [avg_cost])
# test feed_vars type
self.assertRaises(TypeError, paddle.static.normalize_program, program,
['x', 'y'], [avg_cost])
# test fetch_vars type
self.assertRaises(TypeError, paddle.static.normalize_program, program,
[x, y], ['avg_cost'])
class TestLoadInferenceModelError(unittest.TestCase):
def test_load_model_not_exist(self):
......
......@@ -59,6 +59,7 @@ from .io import deserialize_program #DEFINE_ALIAS
from .io import serialize_program #DEFINE_ALIAS
from .io import load_from_file #DEFINE_ALIAS
from .io import save_to_file #DEFINE_ALIAS
from .io import normalize_program #DEFINE_ALIAS
from ..fluid import Scope #DEFINE_ALIAS
from .input import data #DEFINE_ALIAS
from .input import InputSpec #DEFINE_ALIAS
......
......@@ -46,6 +46,7 @@ __all__ = [
'deserialize_program',
'deserialize_persistables',
'load_from_file',
'normalize_program',
]
_logger = get_logger(
......@@ -127,10 +128,64 @@ def _clone_var_in_block(block, var):
persistable=True)
def _normalize_program(program, feed_vars, fetch_vars):
def normalize_program(program, feed_vars, fetch_vars):
"""
optimize program according feed_vars and fetch_vars.
:api_attr: Static Graph
Normalize/Optimize a program according to feed_vars and fetch_vars.
Args:
program(Program): Specify a program you want to optimize.
feed_vars(Variable | list[Variable]): Variables needed by inference.
fetch_vars(Variable | list[Variable]): Variables returned by inference.
Returns:
Program: Normalized/Optimized program.
Raises:
TypeError: If `program` is not a Program, an exception is thrown.
TypeError: If `feed_vars` is not a Variable or a list of Variable, an exception is thrown.
TypeError: If `fetch_vars` is not a Variable or a list of Variable, an exception is thrown.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
path_prefix = "./infer_model"
# User defined network, here a softmax regession example
image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
predict = paddle.static.nn.fc(image, 10, activation='softmax')
loss = paddle.nn.functional.cross_entropy(predict, label)
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(paddle.static.default_startup_program())
# normalize main program.
program = default_main_program()
normalized_program = paddle.static.normalize_program(program, [image], [predict])
"""
if not isinstance(program, Program):
raise TypeError(
"program type must be `fluid.Program`, but received `%s`" %
type(program))
if not isinstance(feed_vars, list):
feed_vars = [feed_vars]
if not all(isinstance(v, Variable) for v in feed_vars):
raise TypeError(
"feed_vars type must be a Variable or a list of Variable.")
if not isinstance(fetch_vars, list):
fetch_vars = [fetch_vars]
if not all(isinstance(v, Variable) for v in fetch_vars):
raise TypeError(
"fetch_vars type must be a Variable or a list of Variable.")
# remind users to set auc_states to 0 if auc op were found.
for op in program.global_block().ops:
# clear device of Op
......@@ -255,7 +310,7 @@ def serialize_program(feed_vars, fetch_vars, **kwargs):
_check_vars('fetch_vars', fetch_vars)
program = _get_valid_program(kwargs.get('program', None))
program = _normalize_program(program, feed_vars, fetch_vars)
program = normalize_program(program, feed_vars, fetch_vars)
return _serialize_program(program)
......@@ -319,7 +374,7 @@ def serialize_persistables(feed_vars, fetch_vars, executor, **kwargs):
_check_vars('fetch_vars', fetch_vars)
program = _get_valid_program(kwargs.get('program', None))
program = _normalize_program(program, feed_vars, fetch_vars)
program = normalize_program(program, feed_vars, fetch_vars)
return _serialize_persistables(program, executor)
......@@ -463,7 +518,7 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor,
_check_vars('fetch_vars', fetch_vars)
program = _get_valid_program(kwargs.get('program', None))
program = _normalize_program(program, feed_vars, fetch_vars)
program = normalize_program(program, feed_vars, fetch_vars)
# serialize and save program
program_bytes = _serialize_program(program)
save_to_file(model_path, program_bytes)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册