未验证 提交 a412b6fa 编写于 作者: Y yukavio 提交者: GitHub

Complete prune load api (#430)

* add new argument for easy load prune program

* add argument make load pruned program more easier

* change save function arg

* fix test

* fix format
上级 81db340f
......@@ -10,7 +10,6 @@ __all__ = ["save_model", "load_model"]
_logger = get_logger(__name__, level=logging.INFO)
_PARAMS_FILE = "__params__"
_SHAPES_FILE = "__shapes__"
......@@ -30,8 +29,8 @@ def save_model(exe, graph, dirname):
executor=exe,
dirname=dirname,
main_program=graph.program,
filename=_PARAMS_FILE)
weights_file = os.path.join(dirname, _PARAMS_FILE)
filename=None)
weights_file = dirname
_logger.info("Save model weights into {}".format(weights_file))
shapes = {}
for var in fluid.io.get_program_persistable_vars(graph.program):
......@@ -57,17 +56,19 @@ def load_model(exe, graph, dirname):
_logger.info("Load shapes of weights from {}".format(SHAPES_FILE))
with open(SHAPES_FILE, "r") as f:
shapes = json.load(f)
for param, shape in shapes.items():
graph.var(param).set_shape(shape)
for param_name, shape in shapes.items():
param = graph.var(param_name)
if param is not None:
param.set_shape(shape)
else:
_logger.info('{} is not loaded'.format(param_name))
_logger.info("Load shapes of weights from {}".format(SHAPES_FILE))
fluid.io.load_persistables(
executor=exe,
dirname=dirname,
main_program=graph.program,
filename=_PARAMS_FILE)
filename=None)
graph.update_groups_of_conv()
graph.infer_shape()
_logger.info("Load weights from {}".format(
os.path.join(dirname, _PARAMS_FILE)))
_logger.info("Load weights from {}".format(dirname))
......@@ -35,6 +35,15 @@ class TestSaveAndLoad(unittest.TestCase):
sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
feature = fluid.layers.reshape(conv6, [-1, 128, 16])
predict = fluid.layers.fc(input=feature, size=10, act='softmax')
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
print(label.shape)
print(predict.shape)
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(cost)
adam_optimizer = fluid.optimizer.AdamOptimizer(0.01)
adam_optimizer.minimize(avg_cost)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
......@@ -55,9 +64,11 @@ class TestSaveAndLoad(unittest.TestCase):
param_shape_backup=None)
x = numpy.random.random(size=(10, 3, 16, 16)).astype('float32')
label = numpy.random.random(size=(10, 1)).astype('int64')
loss_data, = exe.run(train_program,
feed={"image": x},
fetch_list=[conv6.name])
feed={"image": x,
"label": label},
fetch_list=[cost.name])
save_model(exe, main_program, 'model_file')
pruned_program = fluid.Program()
......@@ -72,8 +83,10 @@ class TestSaveAndLoad(unittest.TestCase):
sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
pruned_test_program = pruned_program.clone(for_test=True)
exe.run(pruned_startup_program)
load_model(exe, pruned_program, 'model_file')
load_model(exe, pruned_test_program, 'model_file')
shapes = {
"conv1_weights": (4, 3, 3, 3),
"conv2_weights": (4, 4, 3, 3),
......@@ -88,6 +101,11 @@ class TestSaveAndLoad(unittest.TestCase):
print("param: {}; param shape: {}".format(param.name,
param.shape))
self.assertTrue(param.shape == shapes[param.name])
for param in pruned_test_program.global_block().all_parameters():
if "weights" in param.name:
print("param: {}; param shape: {}".format(param.name,
param.shape))
self.assertTrue(param.shape == shapes[param.name])
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册