未验证 提交 5844dfe4 编写于 作者: W WeiXin 提交者: GitHub

[cherry pick]perfect 'var_list' of static.load/fluid.load (#30457) (#30479)

完善static.load的var_list参数。
当加载的是多个小文件时,Tensor列表可以是所有加载文件中Tensor的子集。
原始PR:#30457
上级 f15bed11
......@@ -1895,6 +1895,12 @@ def load(program, model_path, executor=None, var_list=None):
raise ValueError(
"executor is required when loading model file saved with [ save_params, save_persistables, save_vars ]"
)
if var_list is not None:
var_list_names = [var.name for var in var_list]
else:
var_list_names = None
if os.path.isdir(model_path):
binary_file_set = set()
for root, dirs, files in os.walk(model_path, topdown=False):
......@@ -1905,7 +1911,8 @@ def load(program, model_path, executor=None, var_list=None):
loaded_var_list = []
for var in program_var_list:
var_path = os.path.join(model_path, var.name).replace("\\", "/")
if var_path in binary_file_set:
load_condition = var_list_names is None or var.name in var_list_names
if var_path in binary_file_set and load_condition:
loaded_var_list.append(var)
binary_file_set.remove(var_path)
if len(binary_file_set) > 0:
......
......@@ -794,6 +794,9 @@ class TestLoadFromOldInterface(unittest.TestCase):
if os.path.exists("test_path.pdparams"):
os.remove("test_path.pdparams")
if os.path.exists("test_static_load_var_list.pdparams"):
os.remove("test_static_load_var_list.pdparams")
def test_load_from_old_interface(self):
seed = 90
hidden_size = 10
......@@ -910,6 +913,117 @@ class TestLoadFromOldInterface(unittest.TestCase):
fluid.load(test_clone_program, "test_path", exe)
def test_load_from_old_interface_var_list(self):
seed = 90
hidden_size = 10
vocab_size = 1000
num_layers = 1
num_steps = 3
init_scale = 0.1
batch_size = 4
batch_num = 200
with new_program_scope():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
ptb_model = PtbModel(
"ptb_model",
hidden_size=hidden_size,
vocab_size=vocab_size,
num_layers=num_layers,
num_steps=num_steps,
init_scale=init_scale)
place = fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
exe = fluid.Executor(place)
sgd = Adam(learning_rate=1e-3)
x = fluid.layers.data(
name="x", shape=[-1, num_steps], dtype='int64')
y = fluid.layers.data(name="y", shape=[-1, 1], dtype='float32')
init_hidden = fluid.layers.data(
name="init_hidden", shape=[1], dtype='float32')
init_cell = fluid.layers.data(
name="init_cell", shape=[1], dtype='float32')
static_loss, static_last_hidden, static_last_cell = ptb_model(
x, y, init_hidden, init_cell)
test_clone_program = fluid.default_main_program().clone()
sgd.minimize(static_loss)
static_param_updated = dict()
static_param_init = dict()
out = exe.run(framework.default_startup_program())
static_loss_value = None
static_last_cell_value = None
static_last_hidden_value = None
for i in range(batch_num):
x_data = np.arange(12).reshape(4, 3).astype('int64')
y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
x_data = x_data.reshape((-1, num_steps, 1))
y_data = y_data.reshape((-1, 1))
init_hidden_data = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32')
init_cell_data = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32')
fetch_list = [static_loss, static_last_hidden, static_last_cell]
out = exe.run(fluid.default_main_program(),
feed={
"x": x_data,
"y": y_data,
"init_hidden": init_hidden_data,
"init_cell": init_cell_data
},
fetch_list=fetch_list)
static_loss_value = out[0]
static_last_hidden_value = out[1]
static_last_cell_value = out[2]
# get value before save
main_program = framework.default_main_program()
base_map = {}
for var in main_program.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
# make sure all the paramerter or optimizer var have been update
self.assertTrue(np.sum(np.abs(t)) != 0)
base_map[var.name] = t
#fluid.save(main_program, "./test_1")
fluid.io.save_persistables(exe, "test_static_load_var_list",
main_program)
# set var to zero
var_list = []
for i, var in enumerate(main_program.list_vars()):
if isinstance(var, framework.Parameter) or var.persistable:
if i % 2 == 0:
var_list.append(var)
ten = fluid.global_scope().find_var(var.name).get_tensor()
ten.set(np.zeros_like(np.array(ten)), place)
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
# make sure all the paramerter or optimizer var have been set to zero
self.assertTrue(np.sum(np.abs(new_t)) == 0)
fluid.load(main_program, "test_static_load_var_list", exe, var_list)
var_list_names = [var.name for var in var_list]
for var in main_program.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
if var.name in var_list_names:
# loaded vars
base_t = base_map[var.name]
self.assertTrue(np.array_equal(new_t, base_t))
else:
#not loaded vars
self.assertTrue(np.sum(np.abs(new_t)) == 0)
class TestLoadFromOldInterfaceSingleFile(unittest.TestCase):
def test_load_from_old_interface(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册