diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 36088aa803cd3d4f097c7b8683f5e71623879bf7..313855b6c55d42293fff99e7f07ae1f6efab0892 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -2180,6 +2180,7 @@ def load_program_state(model_path, var_list=None): with open(parameter_file_name, 'rb') as f: para_dict = pickle.load(f) if six.PY2 else pickle.load( f, encoding='latin1') + para_dict = _pack_loaded_dict(para_dict) opt_file_name = model_prefix + ".pdopt" if os.path.exists(opt_file_name): @@ -2231,6 +2232,7 @@ def set_program_state(program, state_dict): static.set_program_state(prog, program_state) """ + state_dict = _pack_loaded_dict(state_dict) parameter_list = list(filter(is_persistable, program.list_vars())) used_para_list = {} diff --git a/python/paddle/fluid/tests/unittests/test_static_save_load.py b/python/paddle/fluid/tests/unittests/test_static_save_load.py index 257d6e04890ec3dccf601883ba2cfa44a1182e49..68d0e07e0cf2d8c1ca4b626a2afa2eb165a834b1 100644 --- a/python/paddle/fluid/tests/unittests/test_static_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_static_save_load.py @@ -1365,6 +1365,25 @@ class TestStaticSaveLoadLargeParameters(unittest.TestCase): base_t = base_map[var.name] self.assertTrue(np.array_equal(new_t, base_t)) + # set var to zero + for var in prog.list_vars(): + if isinstance(var, framework.Parameter) or var.persistable: + ten = fluid.global_scope().find_var(var.name).get_tensor() + ten.set(np.zeros_like(np.array(ten)), place) + + new_t = np.array(fluid.global_scope().find_var(var.name) + .get_tensor()) + self.assertTrue(np.sum(np.abs(new_t)) == 0) + + program_state = fluid.load_program_state(path) + fluid.set_program_state(prog, program_state) + for var in prog.list_vars(): + if isinstance(var, framework.Parameter) or var.persistable: + new_t = np.array(fluid.global_scope().find_var(var.name) + .get_tensor()) + base_t = base_map[var.name] + self.assertTrue(np.array_equal(new_t, base_t)) + class TestProgramStateOldSaveSingleModel(unittest.TestCase): def test_ptb_rnn_cpu_float32(self): diff --git a/tools/static_mode_white_list.pyc b/tools/static_mode_white_list.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9012c233595b6844f54e625972360f5aeeb0d3b Binary files /dev/null and b/tools/static_mode_white_list.pyc differ