From a412b6fa8539101f9d92adfb552164cb2c9e69eb Mon Sep 17 00:00:00 2001 From: yukavio <67678385+yukavio@users.noreply.github.com> Date: Mon, 31 Aug 2020 19:33:46 +0800 Subject: [PATCH] Complete prune load api (#430) * add new argument for easy load prune program * add argument make load pruned program more easier * change save function arg * fix test * fix format --- paddleslim/prune/prune_io.py | 19 ++++++++++--------- tests/test_pruned_model_save_load.py | 22 ++++++++++++++++++++-- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/paddleslim/prune/prune_io.py b/paddleslim/prune/prune_io.py index 2c6fad61..c0e2c975 100644 --- a/paddleslim/prune/prune_io.py +++ b/paddleslim/prune/prune_io.py @@ -10,7 +10,6 @@ __all__ = ["save_model", "load_model"] _logger = get_logger(__name__, level=logging.INFO) -_PARAMS_FILE = "__params__" _SHAPES_FILE = "__shapes__" @@ -30,8 +29,8 @@ def save_model(exe, graph, dirname): executor=exe, dirname=dirname, main_program=graph.program, - filename=_PARAMS_FILE) - weights_file = os.path.join(dirname, _PARAMS_FILE) + filename=None) + weights_file = dirname _logger.info("Save model weights into {}".format(weights_file)) shapes = {} for var in fluid.io.get_program_persistable_vars(graph.program): @@ -57,17 +56,19 @@ def load_model(exe, graph, dirname): _logger.info("Load shapes of weights from {}".format(SHAPES_FILE)) with open(SHAPES_FILE, "r") as f: shapes = json.load(f) - for param, shape in shapes.items(): - graph.var(param).set_shape(shape) + for param_name, shape in shapes.items(): + param = graph.var(param_name) + if param is not None: + param.set_shape(shape) + else: + _logger.info('{} is not loaded'.format(param_name)) _logger.info("Load shapes of weights from {}".format(SHAPES_FILE)) - fluid.io.load_persistables( executor=exe, dirname=dirname, main_program=graph.program, - filename=_PARAMS_FILE) + filename=None) graph.update_groups_of_conv() graph.infer_shape() - _logger.info("Load weights from {}".format( - os.path.join(dirname, _PARAMS_FILE))) + _logger.info("Load weights from {}".format(dirname)) diff --git a/tests/test_pruned_model_save_load.py b/tests/test_pruned_model_save_load.py index 399ec18e..b3ee10c8 100644 --- a/tests/test_pruned_model_save_load.py +++ b/tests/test_pruned_model_save_load.py @@ -35,6 +35,15 @@ class TestSaveAndLoad(unittest.TestCase): sum2 = conv4 + sum1 conv5 = conv_bn_layer(sum2, 8, 3, "conv5") conv6 = conv_bn_layer(conv5, 8, 3, "conv6") + feature = fluid.layers.reshape(conv6, [-1, 128, 16]) + predict = fluid.layers.fc(input=feature, size=10, act='softmax') + label = fluid.data(name='label', shape=[None, 1], dtype='int64') + print(label.shape) + print(predict.shape) + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(cost) + adam_optimizer = fluid.optimizer.AdamOptimizer(0.01) + adam_optimizer.minimize(avg_cost) place = fluid.CPUPlace() exe = fluid.Executor(place) @@ -55,9 +64,11 @@ class TestSaveAndLoad(unittest.TestCase): param_shape_backup=None) x = numpy.random.random(size=(10, 3, 16, 16)).astype('float32') + label = numpy.random.random(size=(10, 1)).astype('int64') loss_data, = exe.run(train_program, - feed={"image": x}, - fetch_list=[conv6.name]) + feed={"image": x, + "label": label}, + fetch_list=[cost.name]) save_model(exe, main_program, 'model_file') pruned_program = fluid.Program() @@ -72,8 +83,10 @@ class TestSaveAndLoad(unittest.TestCase): sum2 = conv4 + sum1 conv5 = conv_bn_layer(sum2, 8, 3, "conv5") conv6 = conv_bn_layer(conv5, 8, 3, "conv6") + pruned_test_program = pruned_program.clone(for_test=True) exe.run(pruned_startup_program) load_model(exe, pruned_program, 'model_file') + load_model(exe, pruned_test_program, 'model_file') shapes = { "conv1_weights": (4, 3, 3, 3), "conv2_weights": (4, 4, 3, 3), @@ -88,6 +101,11 @@ class TestSaveAndLoad(unittest.TestCase): print("param: {}; param shape: {}".format(param.name, param.shape)) self.assertTrue(param.shape == shapes[param.name]) + for param in pruned_test_program.global_block().all_parameters(): + if "weights" in param.name: + print("param: {}; param shape: {}".format(param.name, + param.shape)) + self.assertTrue(param.shape == shapes[param.name]) if __name__ == '__main__': -- GitLab