From e5bb4edb2ce70625d01d665968b9ea8cb29095ca Mon Sep 17 00:00:00 2001 From: WeiXin Date: Fri, 15 Jan 2021 11:09:56 +0800 Subject: [PATCH] perfect 'var_list' of static.load/fluid.load (#30457) --- python/paddle/fluid/io.py | 9 +- .../tests/unittests/test_static_save_load.py | 114 ++++++++++++++++++ 2 files changed, 122 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 1a7da4add31..d5963675a82 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -1895,6 +1895,12 @@ def load(program, model_path, executor=None, var_list=None): raise ValueError( "executor is required when loading model file saved with [ save_params, save_persistables, save_vars ]" ) + + if var_list is not None: + var_list_names = [var.name for var in var_list] + else: + var_list_names = None + if os.path.isdir(model_path): binary_file_set = set() for root, dirs, files in os.walk(model_path, topdown=False): @@ -1905,7 +1911,8 @@ def load(program, model_path, executor=None, var_list=None): loaded_var_list = [] for var in program_var_list: var_path = os.path.join(model_path, var.name).replace("\\", "/") - if var_path in binary_file_set: + load_condition = var_list_names is None or var.name in var_list_names + if var_path in binary_file_set and load_condition: loaded_var_list.append(var) binary_file_set.remove(var_path) if len(binary_file_set) > 0: diff --git a/python/paddle/fluid/tests/unittests/test_static_save_load.py b/python/paddle/fluid/tests/unittests/test_static_save_load.py index e275cb525bc..0f4fca6d7f8 100644 --- a/python/paddle/fluid/tests/unittests/test_static_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_static_save_load.py @@ -794,6 +794,9 @@ class TestLoadFromOldInterface(unittest.TestCase): if os.path.exists("test_path.pdparams"): os.remove("test_path.pdparams") + if os.path.exists("test_static_load_var_list.pdparams"): + os.remove("test_static_load_var_list.pdparams") + def test_load_from_old_interface(self): seed = 90 hidden_size = 10 @@ -910,6 +913,117 @@ class TestLoadFromOldInterface(unittest.TestCase): fluid.load(test_clone_program, "test_path", exe) + def test_load_from_old_interface_var_list(self): + seed = 90 + hidden_size = 10 + vocab_size = 1000 + num_layers = 1 + num_steps = 3 + init_scale = 0.1 + batch_size = 4 + batch_num = 200 + + with new_program_scope(): + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + ptb_model = PtbModel( + "ptb_model", + hidden_size=hidden_size, + vocab_size=vocab_size, + num_layers=num_layers, + num_steps=num_steps, + init_scale=init_scale) + + place = fluid.CPUPlace() if not core.is_compiled_with_cuda( + ) else fluid.CUDAPlace(0) + exe = fluid.Executor(place) + sgd = Adam(learning_rate=1e-3) + x = fluid.layers.data( + name="x", shape=[-1, num_steps], dtype='int64') + y = fluid.layers.data(name="y", shape=[-1, 1], dtype='float32') + init_hidden = fluid.layers.data( + name="init_hidden", shape=[1], dtype='float32') + init_cell = fluid.layers.data( + name="init_cell", shape=[1], dtype='float32') + + static_loss, static_last_hidden, static_last_cell = ptb_model( + x, y, init_hidden, init_cell) + + test_clone_program = fluid.default_main_program().clone() + sgd.minimize(static_loss) + static_param_updated = dict() + static_param_init = dict() + + out = exe.run(framework.default_startup_program()) + + static_loss_value = None + static_last_cell_value = None + static_last_hidden_value = None + for i in range(batch_num): + x_data = np.arange(12).reshape(4, 3).astype('int64') + y_data = np.arange(1, 13).reshape(4, 3).astype('int64') + x_data = x_data.reshape((-1, num_steps, 1)) + y_data = y_data.reshape((-1, 1)) + init_hidden_data = np.zeros( + (num_layers, batch_size, hidden_size), dtype='float32') + init_cell_data = np.zeros( + (num_layers, batch_size, hidden_size), dtype='float32') + fetch_list = [static_loss, static_last_hidden, static_last_cell] + out = exe.run(fluid.default_main_program(), + feed={ + "x": x_data, + "y": y_data, + "init_hidden": init_hidden_data, + "init_cell": init_cell_data + }, + fetch_list=fetch_list) + static_loss_value = out[0] + static_last_hidden_value = out[1] + static_last_cell_value = out[2] + + # get value before save + main_program = framework.default_main_program() + base_map = {} + for var in main_program.list_vars(): + if isinstance(var, framework.Parameter) or var.persistable: + t = np.array(fluid.global_scope().find_var(var.name) + .get_tensor()) + # make sure all the paramerter or optimizer var have been update + self.assertTrue(np.sum(np.abs(t)) != 0) + base_map[var.name] = t + + #fluid.save(main_program, "./test_1") + fluid.io.save_persistables(exe, "test_static_load_var_list", + main_program) + + # set var to zero + var_list = [] + for i, var in enumerate(main_program.list_vars()): + if isinstance(var, framework.Parameter) or var.persistable: + if i % 2 == 0: + var_list.append(var) + ten = fluid.global_scope().find_var(var.name).get_tensor() + ten.set(np.zeros_like(np.array(ten)), place) + + new_t = np.array(fluid.global_scope().find_var(var.name) + .get_tensor()) + # make sure all the paramerter or optimizer var have been set to zero + self.assertTrue(np.sum(np.abs(new_t)) == 0) + + fluid.load(main_program, "test_static_load_var_list", exe, var_list) + var_list_names = [var.name for var in var_list] + for var in main_program.list_vars(): + if isinstance(var, framework.Parameter) or var.persistable: + new_t = np.array(fluid.global_scope().find_var(var.name) + .get_tensor()) + if var.name in var_list_names: + # loaded vars + base_t = base_map[var.name] + self.assertTrue(np.array_equal(new_t, base_t)) + else: + #not loaded vars + self.assertTrue(np.sum(np.abs(new_t)) == 0) + class TestLoadFromOldInterfaceSingleFile(unittest.TestCase): def test_load_from_old_interface(self): -- GitLab