未验证 提交 2fbebfd8 编写于 作者: W wanghuancoder 提交者: GitHub

refine save numpy (#57040)

* refine save numpy
上级 0d845ada
...@@ -62,7 +62,7 @@ def _build_saved_state_dict(state_dict): ...@@ -62,7 +62,7 @@ def _build_saved_state_dict(state_dict):
raise ValueError( raise ValueError(
"The saved tensor is not initialized. If you used group sharded, please use save_group_sharded_model." "The saved tensor is not initialized. If you used group sharded, please use save_group_sharded_model."
) )
save_dict[key] = np.array(value) save_dict[key] = np.array(value.cpu())
name_table[key] = value.name name_table[key] = value.name
else: else:
save_dict[key] = value save_dict[key] = value
...@@ -91,7 +91,9 @@ def _load_state_dict_from_save_inference_model(model_path, config): ...@@ -91,7 +91,9 @@ def _load_state_dict_from_save_inference_model(model_path, config):
# 3. construct state_dict # 3. construct state_dict
load_param_dict = {} load_param_dict = {}
for var_name in persistable_var_dict: for var_name in persistable_var_dict:
load_param_dict[var_name] = np.array(persistable_var_dict[var_name]) load_param_dict[var_name] = np.array(
persistable_var_dict[var_name].cpu()
)
# if *.info exists, we can recover structured_name # if *.info exists, we can recover structured_name
var_info_filename = str(config.params_filename) + ".info" var_info_filename = str(config.params_filename) + ".info"
...@@ -145,7 +147,7 @@ def _load_state_dict_from_save_params(model_path): ...@@ -145,7 +147,7 @@ def _load_state_dict_from_save_params(model_path):
# 3. construct state_dict # 3. construct state_dict
load_param_dict = {} load_param_dict = {}
for var in load_var_list: for var in load_var_list:
load_param_dict[var.name] = np.array(var) load_param_dict[var.name] = np.array(var.cpu())
return load_param_dict return load_param_dict
...@@ -290,13 +292,15 @@ def _pickle_save(obj, f, protocol): ...@@ -290,13 +292,15 @@ def _pickle_save(obj, f, protocol):
) )
def reduce_varbase(self): def reduce_varbase(self):
data = np.array(self) data = np.array(self.cpu())
name = self.name name = self.name
return (tuple, ((name, data),)) return (tuple, ((name, data),))
def reduce_LoDTensor(self): def reduce_LoDTensor(self):
data = np.array(self) p = core.Place()
p.set_place(paddle.CPUPlace())
data = np.array(self._copy(p))
return (eval, ('data', {'data': data})) return (eval, ('data', {'data': data}))
...@@ -1108,7 +1112,9 @@ def load(path, **configs): ...@@ -1108,7 +1112,9 @@ def load(path, **configs):
try: try:
tensor, _ = _load_lod_tensor(path) tensor, _ = _load_lod_tensor(path)
if config.return_numpy: if config.return_numpy:
return np.array(tensor) p = core.Place()
p.set_place(paddle.CPUPlace())
return np.array(tensor._copy(p))
else: else:
if in_dygraph_mode(): if in_dygraph_mode():
return _lod_tensor2varbase(tensor) return _lod_tensor2varbase(tensor)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册