未验证 提交 6783dcee 编写于 作者: F fengjiayi 提交者: GitHub

Python API for inference model saving/load (#5020)

* Add `dump_to_file()` for ProgrameDescBind in pybind

* Update

* Add utility.py

* typo

* Fix bugs

* Move add_feed/fetch_components to untility.py

* Compelete dump

* Follow comments

* Change output of Prune() from inference to pointer

* Expose Prune() to Python

* Compelete save/load API of inference model

* Fix errors

* Debuging

* Compelete unit tests

* follow comments
上级 f3ac4d8e
......@@ -28,3 +28,4 @@ cmake_install.cmake
paddle/.timestamp
python/paddlepaddle.egg-info/
paddle/pybind/pybind.h
python/paddle/v2/framework/tests/tmp/*
......@@ -107,6 +107,8 @@ class OpDescBind {
void InferVarType(BlockDescBind *block) const;
void MarkAsTarget() { desc_.set_is_target(true); }
void Flush();
private:
......
......@@ -49,6 +49,13 @@ ProgramDescBind::ProgramDescBind(const ProgramDescBind &o) {
}
}
ProgramDescBind::ProgramDescBind(const ProgramDesc &desc) {
desc_ = desc;
for (auto &block_desc : *desc_.mutable_blocks()) {
blocks_.emplace_back(new BlockDescBind(this, &block_desc));
}
}
ProgramDescBind::ProgramDescBind(const std::string &binary_str) {
PADDLE_ENFORCE(desc_.ParseFromString(binary_str),
"Fail to parse program_desc from binary string.");
......
......@@ -29,6 +29,8 @@ class ProgramDescBind {
public:
ProgramDescBind();
explicit ProgramDescBind(const ProgramDesc &desc);
ProgramDescBind(const ProgramDescBind &o);
explicit ProgramDescBind(const std::string &binary_str);
......
......@@ -46,7 +46,7 @@ bool IsTarget(const OpDesc& op_desc) {
return false;
}
void prune_impl(const ProgramDesc& input, ProgramDesc& output, int block_id) {
void prune_impl(const ProgramDesc& input, ProgramDesc* output, int block_id) {
// TODO(tonyyang-svail):
// - will change to use multiple blocks for RNN op and Cond Op
......@@ -91,8 +91,8 @@ void prune_impl(const ProgramDesc& input, ProgramDesc& output, int block_id) {
// we reverse the should_run vector
std::reverse(should_run.begin(), should_run.end());
output = input;
auto* op_field = output.mutable_blocks(block_id)->mutable_ops();
*output = input;
auto* op_field = output->mutable_blocks(block_id)->mutable_ops();
op_field->Clear();
for (size_t i = 0; i < should_run.size(); ++i) {
if (should_run[i]) {
......@@ -101,7 +101,8 @@ void prune_impl(const ProgramDesc& input, ProgramDesc& output, int block_id) {
}
}
void Prune(const ProgramDesc& input, ProgramDesc& output) {
// TODO(fengjiayi): Prune() could be inplaced to avoid unnecessary copies
void Prune(const ProgramDesc& input, ProgramDesc* output) {
prune_impl(input, output, 0);
}
......
......@@ -20,7 +20,7 @@ limitations under the License. */
namespace paddle {
namespace framework {
void Prune(const ProgramDesc& input, ProgramDesc& output);
void Prune(const ProgramDesc& input, ProgramDesc* output);
} // namespace framework
} // namespace paddle
......@@ -59,11 +59,11 @@ TEST(Prune, one_operator) {
f::ProgramDesc *pdesc = program.Proto();
f::ProgramDesc pruned;
Prune(*pdesc, pruned);
Prune(*pdesc, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 0);
pdesc->mutable_blocks(0)->mutable_ops(0)->set_is_target(true);
Prune(*pdesc, pruned);
Prune(*pdesc, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 1);
}
......@@ -81,7 +81,7 @@ TEST(Prune, forward) {
for (int i = 0; i < pdesc->blocks(0).ops_size(); ++i) {
f::ProgramDesc pruned;
pdesc->mutable_blocks(0)->mutable_ops(i)->set_is_target(true);
Prune(*pdesc, pruned);
Prune(*pdesc, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), i + 1);
}
}
......@@ -100,7 +100,7 @@ TEST(Prune, multi_input_op) {
pdesc->mutable_blocks(0)->mutable_ops(3)->set_is_target(true);
f::ProgramDesc pruned;
Prune(*pdesc, pruned);
Prune(*pdesc, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 4);
}
......@@ -116,7 +116,7 @@ TEST(Prune, multi_output_op) {
pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true);
f::ProgramDesc pruned;
Prune(*pdesc, pruned);
Prune(*pdesc, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 2);
}
......@@ -133,6 +133,6 @@ TEST(Prune, multi_target) {
pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true);
f::ProgramDesc pruned;
Prune(*pdesc, pruned);
Prune(*pdesc, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 3);
}
if(WITH_PYTHON)
cc_library(paddle_pybind SHARED
SRCS pybind.cc exception.cc protobuf.cc
DEPS pybind python backward proto_desc tensor_array paddle_memory executor
DEPS pybind python backward proto_desc tensor_array paddle_memory executor prune
${GLOB_OP_LIB})
endif(WITH_PYTHON)
......
......@@ -141,6 +141,13 @@ void BindProgramDesc(py::module &m) {
desc->SerializeToString(&res),
"Serialize ProgramDesc Error. This could be a bug of Paddle.");
return res;
})
.def("parse_from_string",
[](ProgramDescBind &program_desc, const std::string &data) {
ProgramDesc *desc = program_desc.Proto();
PADDLE_ENFORCE(desc->ParseFromString(data),
"Fail to parse ProgramDesc from string. This could "
"be a bug of Paddle.");
});
}
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/framework/feed_fetch_method.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/prune.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/framework/tensor_array.h"
#include "paddle/operators/cond_op.h"
......@@ -237,6 +238,16 @@ All parameter, weight, gradient are variables in Paddle.
}
return ret_values;
});
m.def("prune", [](const ProgramDescBind &origin,
const std::vector<std::array<size_t, 2>> &targets) {
ProgramDescBind prog_with_targets(origin);
for (const auto &t : targets) {
prog_with_targets.Block(t[0])->Op(t[1])->MarkAsTarget();
}
ProgramDesc pruned_desc;
Prune(*prog_with_targets.Proto(), &pruned_desc);
return new ProgramDescBind(pruned_desc);
});
m.def_submodule(
"var_names",
"The module will return special predefined variable name in Paddle")
......
......@@ -251,6 +251,8 @@ class Operator(object):
self.desc.set_output(out_proto.name, out_argu_names)
if attrs is not None:
if not isinstance(attrs, dict):
raise TypeError("'attrs' should be a dict.")
for attr in proto.attrs:
attr_name = attr.name
if (not attr_name in attrs) or (attrs[attr_name] is None):
......@@ -291,6 +293,14 @@ class Operator(object):
def output_names(self):
return self.desc.output_names()
@property
def idx(self):
for i, op in enumerate(self.block.ops):
if op == self:
return i
raise ValueError(
"Can't find op itself in it's block. It could be a bug of Paddle.")
def has_attr(self, name):
return self.desc.has_attr(name)
......@@ -440,10 +450,31 @@ class Program(object):
p.sync_with_cpp()
return p
def prune(self, targets):
if not isinstance(targets, list):
targets = [targets]
targets_idx = []
for t in targets:
if not isinstance(t, Operator):
if isinstance(t, Variable):
t = t.op
else:
raise ValueError(
"All targets of prune() can only be Variable or Operator."
)
targets_idx.append([t.block.idx, t.idx])
res = Program()
res.desc = core.prune(self.desc, targets_idx)
res.blocks = [Block(res, i) for i in xrange(res.desc.num_blocks())]
res.sync_with_cpp()
return res
@staticmethod
def parse_from_string(binary_str):
p = Program()
p.desc = core.ProgramDesc(binary_str)
p.blocks = [Block(p, i) for i in xrange(p.desc.num_blocks())]
p.sync_with_cpp()
return p
......
import os
import cPickle as pickle
from paddle.v2.framework.framework import Program, Parameter, g_program, \
Variable
__all__ = [
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
'load_persistables'
'load_persistables', "save_inference_model", "load_inference_model"
]
......@@ -124,6 +125,7 @@ def load_vars(executor, dirname, program=None, vars=None, predicate=None):
inputs={},
outputs={"Out": [new_var]},
attrs={'file_path': os.path.join(dirname, new_var.name)})
executor.run(load_prog)
......@@ -141,3 +143,88 @@ def load_persistables(executor, dirname, program=None):
"""
load_vars(
executor, dirname=dirname, program=program, predicate=is_persistable)
def save_inference_model(dirname,
feeded_var_names,
target_vars,
executor,
program=None):
"""
Build a model especially for inference,
and save it to directory by the executor.
:param dirname: directory path
:param feeded_var_names: Names of variables that need to be feeded data during inference
:param target_vars: Variables from which we can get inference results.
:param executor: executor that save inference model
:param program: original program, which will be pruned to build the inference model.
Default g_program.
:return: None
"""
if program is None:
program = g_program
if not isinstance(target_vars, list):
target_vars = [target_vars]
if not os.path.isdir(dirname):
os.makedirs(dirname)
pruned_program = program.prune(target_vars)
fetch_var_names = [v.name for v in target_vars]
model_file_name = dirname + "/__model__"
with open(model_file_name, "w") as f:
pickle.dump({
"program_desc_str": pruned_program.desc.serialize_to_string(),
"feed_var_names": feeded_var_names,
"fetch_var_names": fetch_var_names
}, f, -1)
save_params(executor, dirname, program)
def load_persistables_if_exist(executor, dirname, program=None):
filenames = next(os.walk(dirname))[2]
filenames = set(filenames)
def _is_presistable_and_exist_(var):
if not is_persistable(var):
return False
else:
return var.name in filenames
load_vars(
executor,
dirname,
program=program,
vars=None,
predicate=_is_presistable_and_exist_)
def load_inference_model(dirname, executor):
"""
Load inference model from a directory
:param dirname: directory path
:param executor: executor that load inference model
:return: [program, feed_var_names, fetch_var_names]
program: program especially for inference.
feeded_var_names: Names of variables that need to feed data
fetch_vars: Variables from which we can get inference results.
"""
if not os.path.isdir(dirname):
raise ValueError("There is no directory named '%s'", dirname)
model_file_name = dirname + "/__model__"
model = pickle.load(open(model_file_name, "r"))
program_desc_str = model["program_desc_str"]
feed_var_names = model["feed_var_names"]
fetch_var_names = model["fetch_var_names"]
program = Program.parse_from_string(program_desc_str)
load_persistables_if_exist(executor, dirname, program)
fetch_vars = [program.global_block().var(name) for name in fetch_var_names]
return [program, feed_var_names, fetch_vars]
import paddle.v2 as paddle
import paddle.v2.framework.layers as layers
import paddle.v2.framework.core as core
import paddle.v2.framework.optimizer as optimizer
from paddle.v2.framework.framework import Program, g_program
from paddle.v2.framework.io import save_inference_model, load_inference_model
import paddle.v2.framework.executor as executor
import unittest
import numpy as np
class TestBook(unittest.TestCase):
def test_fit_line_inference_model(self):
MODEL_DIR = "./tmp/inference_model"
init_program = Program()
program = Program()
x = layers.data(
name='x',
shape=[2],
data_type='float32',
program=program,
init_program=init_program)
y = layers.data(
name='y',
shape=[1],
data_type='float32',
program=program,
init_program=init_program)
y_predict = layers.fc(input=x,
size=1,
act=None,
program=program,
init_program=init_program)
cost = layers.square_error_cost(
input=y_predict,
label=y,
program=program,
init_program=init_program)
avg_cost = layers.mean(
x=cost, program=program, init_program=init_program)
sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.001)
opts = sgd_optimizer.minimize(avg_cost)
place = core.CPUPlace()
exe = executor.Executor(place)
exe.run(init_program, feed={}, fetch_list=[])
for i in xrange(100):
x_data = np.array(
[[1, 1], [1, 2], [3, 4], [5, 2]]).astype("float32")
y_data = np.array([[-2], [-3], [-7], [-7]]).astype("float32")
tensor_x = core.LoDTensor()
tensor_x.set(x_data, place)
tensor_y = core.LoDTensor()
tensor_y.set(y_data, place)
exe.run(program,
feed={'x': tensor_x,
'y': tensor_y},
fetch_list=[avg_cost])
save_inference_model(MODEL_DIR, ["x", "y"], [avg_cost], exe, program)
outs = exe.run(program,
feed={'x': tensor_x,
'y': tensor_y},
fetch_list=[avg_cost])
expected = np.array(outs[0])
reload(executor) # reload to build a new scope
exe = executor.Executor(place)
[infer_prog, feed_var_names, fetch_vars] = load_inference_model(
MODEL_DIR, exe)
outs = exe.run(
infer_prog,
feed={feed_var_names[0]: tensor_x,
feed_var_names[1]: tensor_y},
fetch_list=fetch_vars)
actual = np.array(outs[0])
self.assertEqual(feed_var_names, ["x", "y"])
self.assertEqual(len(fetch_vars), 1)
self.assertEqual(str(fetch_vars[0]), str(avg_cost))
self.assertEqual(expected, actual)
if __name__ == '__main__':
unittest.main()
import unittest
from paddle.v2.framework.framework import Variable, g_program
from paddle.v2.framework.framework import Variable, Program, g_program
import paddle.v2.framework.core as core
......@@ -21,7 +21,8 @@ class TestOperator(unittest.TestCase):
"Operator \"no_such_op\" has not been registered.")
def test_op_desc_creation(self):
block = g_program.current_block()
program = Program()
block = program.current_block()
mul_x = block.create_var(
dtype="float32", shape=[5, 10], lod_level=0, name="mul.x")
mul_y = block.create_var(
......@@ -50,10 +51,12 @@ class TestOperator(unittest.TestCase):
self.assertEqual(mul_op.has_attr("y_num_col_dims"), True)
self.assertEqual(mul_op.attr_type("y_num_col_dims"), core.AttrType.INT)
self.assertEqual(mul_op.attr("y_num_col_dims"), 1)
self.assertEqual(mul_op.idx, 0)
self.assertEqual(mul_out.op, mul_op)
def test_mult_input(self):
block = g_program.current_block()
program = Program()
block = program.current_block()
sum_x1 = block.create_var(
dtype="int", shape=[3, 4], lod_level=0, name="sum.x1")
sum_x2 = block.create_var(
......@@ -71,6 +74,7 @@ class TestOperator(unittest.TestCase):
self.assertEqual(sum_op.input("X"), ["sum.x1", "sum.x2", "sum.x3"])
self.assertEqual(sum_op.output_names, ["Out"])
self.assertEqual(sum_op.output("Out"), ["sum.out"])
self.assertEqual(sum_op.idx, 0)
self.assertEqual(sum_out.op, sum_op)
......
......@@ -99,6 +99,8 @@ class TestProgram(unittest.TestCase):
outputs={"Out": add_out},
attrs={"x_num_col_dims": 1})
self.assertEqual(mul_op.idx, 0)
self.assertEqual(add_op.idx, 1)
param_to_grad = prog.append_backward(add_out, set())
def grad_name(name):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册