未验证 提交 db412585 编写于 作者: S Shibo Tao 提交者: GitHub

add API serialize_program, serialize_persistables, save_to_file,...

add API serialize_program, serialize_persistables, save_to_file, deserialize_program, deserialize_persistables, load_from_file. (#29034)
上级 14013a2e
......@@ -281,7 +281,8 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
PADDLE_ENFORCE_EQ(
version, 0U,
platform::errors::InvalidArgument(
"Tensor version %u is not supported, only version 0 is supported.",
"Deserialize to tensor failed, maybe the loaded file is "
"not a paddle model(expected file format: 0, but %u found).",
version));
}
{
......@@ -307,7 +308,8 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
PADDLE_ENFORCE_EQ(
version, 0U,
platform::errors::InvalidArgument(
"Tensor version %u is not supported, only version 0 is supported.",
"Deserialize to tensor failed, maybe the loaded file is "
"not a paddle model(expected file format: 0, but %u found).",
version));
}
{
......
......@@ -226,8 +226,8 @@ class TestSaveInferenceModelNew(unittest.TestCase):
'y': tensor_y},
fetch_list=[avg_cost])
self.assertRaises(ValueError, paddle.static.save_inference_model,
None, ['x', 'y'], [avg_cost], exe)
self.assertRaises(ValueError, paddle.static.save_inference_model, None,
['x', 'y'], [avg_cost], exe)
self.assertRaises(ValueError, paddle.static.save_inference_model,
MODEL_DIR + "/", [x, y], [avg_cost], exe)
self.assertRaises(ValueError, paddle.static.save_inference_model,
......@@ -251,7 +251,8 @@ class TestSaveInferenceModelNew(unittest.TestCase):
MODEL_DIR + "_isdir", [x, y], [avg_cost], exe)
os.rmdir(params_path)
paddle.static.io.save_inference_model(MODEL_DIR, [x, y], [avg_cost], exe)
paddle.static.io.save_inference_model(MODEL_DIR, [x, y], [avg_cost],
exe)
self.assertTrue(os.path.exists(MODEL_DIR + ".pdmodel"))
self.assertTrue(os.path.exists(MODEL_DIR + ".pdiparams"))
......@@ -263,20 +264,34 @@ class TestSaveInferenceModelNew(unittest.TestCase):
six.moves.reload_module(executor) # reload to build a new scope
self.assertRaises(ValueError, paddle.static.load_inference_model,
None, exe)
self.assertRaises(ValueError, paddle.static.load_inference_model, None,
exe)
self.assertRaises(ValueError, paddle.static.load_inference_model,
MODEL_DIR + "/", exe)
self.assertRaises(ValueError, paddle.static.load_inference_model,
[MODEL_DIR], exe)
self.assertRaises(ValueError, paddle.static.load_inference_model,
MODEL_DIR, exe, pserver_endpoints=None)
self.assertRaises(ValueError, paddle.static.load_inference_model,
MODEL_DIR, exe, unsupported_param=None)
self.assertRaises((TypeError, ValueError), paddle.static.load_inference_model,
None, exe, model_filename="illegal", params_filename="illegal")
model = InferModel(paddle.static.io.load_inference_model(MODEL_DIR, exe))
self.assertRaises(
ValueError,
paddle.static.load_inference_model,
MODEL_DIR,
exe,
pserver_endpoints=None)
self.assertRaises(
ValueError,
paddle.static.load_inference_model,
MODEL_DIR,
exe,
unsupported_param=None)
self.assertRaises(
(TypeError, ValueError),
paddle.static.load_inference_model,
None,
exe,
model_filename="illegal",
params_filename="illegal")
model = InferModel(
paddle.static.io.load_inference_model(MODEL_DIR, exe))
outs = exe.run(model.program,
feed={
......@@ -289,7 +304,57 @@ class TestSaveInferenceModelNew(unittest.TestCase):
self.assertEqual(model.feed_var_names, ["x", "y"])
self.assertEqual(len(model.fetch_vars), 1)
self.assertEqual(expected, actual)
# test save_to_file content type should be bytes
self.assertRaises(ValueError, paddle.static.io.save_to_file, '', 123)
# test _get_valid_program
self.assertRaises(TypeError, paddle.static.io._get_valid_program, 0)
p = Program()
cp = CompiledProgram(p)
paddle.static.io._get_valid_program(cp)
self.assertTrue(paddle.static.io._get_valid_program(cp) is p)
cp._program = None
self.assertRaises(TypeError, paddle.static.io._get_valid_program, cp)
def test_serialize_program_and_persistables(self):
init_program = fluid.default_startup_program()
program = fluid.default_main_program()
# fake program without feed/fetch
with program_guard(program, init_program):
x = layers.data(name='x', shape=[2], dtype='float32')
y = layers.data(name='y', shape=[1], dtype='float32')
y_predict = layers.fc(input=x, size=1, act=None)
cost = layers.square_error_cost(input=y_predict, label=y)
avg_cost = layers.mean(cost)
sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.001)
sgd_optimizer.minimize(avg_cost, init_program)
place = core.CPUPlace()
exe = executor.Executor(place)
exe.run(init_program, feed={}, fetch_list=[])
tensor_x = np.array([[1, 1], [1, 2], [5, 2]]).astype("float32")
tensor_y = np.array([[-2], [-3], [-7]]).astype("float32")
for i in six.moves.xrange(3):
exe.run(program,
feed={'x': tensor_x,
'y': tensor_y},
fetch_list=[avg_cost])
# test if return type of serialize_program is bytes
res1 = paddle.static.io.serialize_program([x, y], [avg_cost])
self.assertTrue(isinstance(res1, bytes))
# test if return type of serialize_persistables is bytes
res2 = paddle.static.io.serialize_persistables([x, y], [avg_cost], exe)
self.assertTrue(isinstance(res2, bytes))
# test if variables in program is empty
res = paddle.static.io._serialize_persistables(Program(), None)
self.assertEqual(res, None)
self.assertRaises(TypeError, paddle.static.io.deserialize_persistables,
None, None, None)
class TestLoadInferenceModelError(unittest.TestCase):
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册