未验证 提交 cdd6437a 编写于 作者: W WeiXin 提交者: GitHub

paddle.save support object save to memory. (#32999)

* support state_dict save to memory.

* Perfect unittest

* perfect unittest.

* suport saving binary var to memory

* polish code.

* packag save/load files into pybind/io.py

* polish code .

* add example for save to memory; remove useless save load function(_load_static_dict,_save_dygraph_dict)

* delete _load_static/dygraph_dict;_save_static/dygraph_dict

* edit example of paddle.save/load
上级 cda893fc
......@@ -276,7 +276,7 @@ void SerializeToStream(std::ostream &os, const LoDTensor &tensor) {
SerializeToStream(os, tensor, *dev_ctx);
}
void DeserializeFromStream(std::ifstream &os, LoDTensor *tensor) {
void DeserializeFromStream(std::istream &os, LoDTensor *tensor) {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
const platform::DeviceContext *dev_ctx;
dev_ctx = pool.Get(platform::CPUPlace());
......
......@@ -257,7 +257,7 @@ LoD ConvertToOffsetBasedLoD(const LoD& length_lod);
void SerializeToStream(std::ostream& os, const LoDTensor& tensor);
void DeserializeFromStream(std::ifstream& os, LoDTensor* tensor);
void DeserializeFromStream(std::istream& os, LoDTensor* tensor);
} // namespace framework
} // namespace paddle
......@@ -121,7 +121,7 @@ void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows) {
SerializeToStream(os, selected_rows, *dev_ctx);
}
void DeserializeFromStream(std::ifstream& os, SelectedRows* selected_rows) {
void DeserializeFromStream(std::istream& os, SelectedRows* selected_rows) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
const platform::DeviceContext* dev_ctx;
dev_ctx = pool.Get(platform::CPUPlace());
......
......@@ -175,7 +175,7 @@ void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows,
void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows);
void DeserializeFromStream(std::ifstream& os, SelectedRows* selected_rows);
void DeserializeFromStream(std::istream& os, SelectedRows* selected_rows);
} // namespace framework
} // namespace paddle
......@@ -56,6 +56,7 @@ set(PYBIND_SRCS
ir.cc
inference_api.cc
compatible.cc
io.cc
generator_py.cc)
if(WITH_ASCEND)
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/pybind/io.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/pybind/pybind_boost_headers.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
void BindIO(pybind11::module *m) {
m->def("save_lod_tensor", [](const paddle::framework::LoDTensor &tensor,
const std::string &str_file_name) {
std::ofstream fout(str_file_name, std::ios::binary);
PADDLE_ENFORCE_EQ(static_cast<bool>(fout), true,
platform::errors::Unavailable(
"Cannot open %s to save variables.", str_file_name));
paddle::framework::SerializeToStream(fout, tensor);
int64_t tellp = fout.tellp();
fout.close();
return tellp;
});
m->def("load_lod_tensor", [](paddle::framework::LoDTensor &tensor,
const std::string &str_file_name) {
std::ifstream fin(str_file_name, std::ios::binary);
PADDLE_ENFORCE_EQ(static_cast<bool>(fin), true,
platform::errors::Unavailable(
"Cannot open %s to load variables.", str_file_name));
paddle::framework::DeserializeFromStream(fin, &tensor);
int64_t tellg = fin.tellg();
fin.close();
return tellg;
});
m->def("save_selected_rows",
[](const paddle::framework::SelectedRows &selected_rows,
const std::string &str_file_name) {
std::ofstream fout(str_file_name, std::ios::binary);
PADDLE_ENFORCE_EQ(
static_cast<bool>(fout), true,
platform::errors::Unavailable(
"Cannot open %s to save SelectedRows.", str_file_name));
paddle::framework::SerializeToStream(fout, selected_rows);
int64_t tellp = fout.tellp();
fout.close();
return tellp;
});
m->def("load_selected_rows",
[](paddle::framework::SelectedRows &selected_rows,
const std::string &str_file_name) {
std::ifstream fin(str_file_name, std::ios::binary);
PADDLE_ENFORCE_EQ(
static_cast<bool>(fin), true,
platform::errors::Unavailable(
"Cannot open %s to load SelectedRows.", str_file_name));
paddle::framework::DeserializeFromStream(fin, &selected_rows);
int64_t tellg = fin.tellg();
fin.close();
return tellg;
});
m->def("save_lod_tensor_to_memory",
[](const paddle::framework::LoDTensor &tensor) -> py::bytes {
std::ostringstream ss;
paddle::framework::SerializeToStream(ss, tensor);
return ss.str();
});
m->def("load_lod_tensor_from_memory", [](paddle::framework::LoDTensor &tensor,
const std::string &tensor_bytes) {
std::istringstream fin(tensor_bytes, std::ios::in | std::ios::binary);
paddle::framework::DeserializeFromStream(fin, &tensor);
});
m->def("save_selected_rows_to_memory",
[](const paddle::framework::SelectedRows &selected_rows) -> py::bytes {
std::ostringstream ss;
paddle::framework::SerializeToStream(ss, selected_rows);
return ss.str();
});
m->def("load_selected_rows_from_memory",
[](paddle::framework::SelectedRows &selected_rows,
const std::string &selected_rows_bytes) {
std::istringstream fin(selected_rows_bytes,
std::ios::in | std::ios::binary);
paddle::framework::DeserializeFromStream(fin, &selected_rows);
});
}
} // namespace pybind
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <Python.h>
#include "paddle/fluid/pybind/pybind_boost_headers.h"
namespace paddle {
namespace pybind {
void BindIO(pybind11::module* m);
} // namespace pybind
} // namespace paddle
......@@ -68,6 +68,7 @@ limitations under the License. */
#include "paddle/fluid/platform/monitor.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/pybind/io.h"
#ifdef PADDLE_WITH_ASCEND
#include "paddle/fluid/pybind/ascend_wrapper_py.h"
#endif
......@@ -496,70 +497,6 @@ PYBIND11_MODULE(core_noavx, m) {
#endif
return tensor;
});
m.def("_save_lod_tensor", [](const LoDTensor &tensor,
const std::string &str_file_name) {
std::ofstream fout(str_file_name, std::ios::binary);
PADDLE_ENFORCE_EQ(static_cast<bool>(fout), true,
platform::errors::Unavailable(
"Cannot open %s to save variables.", str_file_name));
SerializeToStream(fout, tensor);
int64_t tellp = fout.tellp();
fout.close();
return tellp;
});
m.def("_load_lod_tensor", [](LoDTensor &tensor,
const std::string &str_file_name) {
std::ifstream fin(str_file_name, std::ios::binary);
PADDLE_ENFORCE_EQ(static_cast<bool>(fin), true,
platform::errors::Unavailable(
"Cannot open %s to load variables.", str_file_name));
DeserializeFromStream(fin, &tensor);
int64_t tellg = fin.tellg();
fin.close();
return tellg;
});
m.def("_save_selected_rows", [](const SelectedRows &selected_rows,
const std::string &str_file_name) {
std::ofstream fout(str_file_name, std::ios::binary);
PADDLE_ENFORCE_EQ(
static_cast<bool>(fout), true,
platform::errors::Unavailable("Cannot open %s to save SelectedRows.",
str_file_name));
SerializeToStream(fout, selected_rows);
int64_t tellp = fout.tellp();
fout.close();
return tellp;
});
m.def("_load_selected_rows",
[](SelectedRows &selected_rows, const std::string &str_file_name) {
std::ifstream fin(str_file_name, std::ios::binary);
PADDLE_ENFORCE_EQ(
static_cast<bool>(fin), true,
platform::errors::Unavailable(
"Cannot open %s to load SelectedRows.", str_file_name));
DeserializeFromStream(fin, &selected_rows);
int64_t tellg = fin.tellg();
fin.close();
return tellg;
});
m.def("_save_static_dict",
[](const std::string &str_file_name, const py::handle &vec_var_list,
const Scope &scope) {
std::vector<std::string> vec_name_list = GetNameList(vec_var_list);
SaveStaticNameListToDisk(str_file_name, vec_name_list, scope);
});
m.def("_load_static_dict",
[](const std::string &str_file_name, const py::handle &vec_var_list,
const Scope &scope, const Executor *executor) {
std::vector<std::string> vec_name_list = GetNameList(vec_var_list);
CreateVariableIfNotExit(vec_var_list, scope, executor);
LoadStaticNameListFromDisk(str_file_name, vec_name_list, scope);
});
m.def("_create_loaded_parameter",
[](const py::handle &vec_var_list, const Scope &scope,
......@@ -567,26 +504,6 @@ PYBIND11_MODULE(core_noavx, m) {
CreateVariableIfNotExit(vec_var_list, scope, executor);
});
m.def("_save_dygraph_dict", [](const std::string &str_file_name,
const PyNameVarBaseMap &state_dict) {
auto vec_var_base_list = GetVarBaseList(state_dict);
SaveDygraphVarBaseListToDisk(str_file_name, vec_var_base_list);
});
m.def("_load_dygraph_dict", [](const std::string &str_file_name) {
auto load_tensor = LoadDygraphVarBaseListFromDisk(str_file_name);
std::unordered_map<std::string, std::shared_ptr<imperative::VarBase>>
map_output;
for (size_t i = 0; i < load_tensor.size(); ++i) {
map_output.emplace(load_tensor[i]->Name(), load_tensor[i]);
}
return map_output;
});
m.def("save_op_version_info", [](framework::ProgramDesc &desc) {
framework::compatible::pb::OpVersionMap pb_vmap{desc.OpVersionMap()};
framework::compatible::SaveOpVersions(
......@@ -3111,6 +3028,7 @@ All parameter, weight, gradient are variables in Paddle.
.def("device_count", &ParallelExecutor::DeviceCount);
BindFleetWrapper(&m);
BindIO(&m);
#ifdef PADDLE_WITH_PSLIB
BindHeterWrapper(&m);
......
......@@ -269,14 +269,6 @@ if avx_supported():
from .core_avx import _dygraph_debug_level
from .core_avx import _switch_tracer
from .core_avx import _set_paddle_lib_path
from .core_avx import _save_static_dict
from .core_avx import _load_static_dict
from .core_avx import _save_dygraph_dict
from .core_avx import _load_dygraph_dict
from .core_avx import _save_lod_tensor
from .core_avx import _load_lod_tensor
from .core_avx import _save_selected_rows
from .core_avx import _load_selected_rows
from .core_avx import _create_loaded_parameter
from .core_avx import _cuda_synchronize
from .core_avx import _promote_types_if_complex_exists
......@@ -328,14 +320,6 @@ if load_noavx:
from .core_noavx import _dygraph_debug_level
from .core_noavx import _switch_tracer
from .core_noavx import _set_paddle_lib_path
from .core_noavx import _save_static_dict
from .core_noavx import _load_static_dict
from .core_noavx import _save_dygraph_dict
from .core_noavx import _load_dygraph_dict
from .core_noavx import _save_lod_tensor
from .core_noavx import _load_lod_tensor
from .core_noavx import _save_selected_rows
from .core_noavx import _load_selected_rows
from .core_noavx import _create_loaded_parameter
from .core_noavx import _cuda_synchronize
from .core_noavx import _promote_types_if_complex_exists
......
......@@ -23,6 +23,7 @@ import pickle
import contextlib
from functools import reduce
import sys
from io import BytesIO
import numpy as np
import math
......@@ -71,6 +72,52 @@ _logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
class _open_buffer(object):
def __init__(self, buffer):
self.buffer = buffer
def __enter__(self):
return self.buffer
class _buffer_reader(_open_buffer):
def __init__(self, buffer):
super(_buffer_reader, self).__init__(buffer)
self.initial_tell = self.buffer.tell()
def __exit__(self, *args):
# `args[0]` is type of exception. When the `read` is abnormal, the file pointer returns to the initial position.
if args[0] is not None:
self.buffer.seek(self.initial_tell)
class _buffer_writer(_open_buffer):
def __exit__(self, *args):
self.buffer.flush()
def _is_file_path(path):
return isinstance(path, str)
def _open_file_buffer(path_or_buffer, mode):
if _is_file_path(path_or_buffer):
return open(path_or_buffer, mode)
else:
if 'w' in mode:
return _buffer_writer(path_or_buffer)
elif 'r' in mode:
return _buffer_reader(path_or_buffer)
else:
raise ValueError("Expected 'r' or 'w' in mode but got {}".format(
mode))
def _is_memory_buffer(buffer):
return isinstance(buffer, BytesIO)
def is_parameter(var):
"""
Check whether the given variable is an instance of Parameter.
......@@ -1776,14 +1823,16 @@ def _legacy_save(param_dict, model_path, protocol=2):
param_dict = {name: get_tensor(param_dict[name]) for name in param_dict}
# When value of dict is lager than 4GB ,there is a Bug on 'MAC python3'
if sys.platform == 'darwin' and sys.version_info.major == 3:
if _is_file_path(
model_path
) and sys.platform == 'darwin' and sys.version_info.major == 3:
pickle_bytes = pickle.dumps(param_dict, protocol=protocol)
with open(model_path, 'wb') as f:
max_bytes = 2**30
for i in range(0, len(pickle_bytes), max_bytes):
f.write(pickle_bytes[i:i + max_bytes])
else:
with open(model_path, 'wb') as f:
with _open_file_buffer(model_path, 'wb') as f:
pickle.dump(param_dict, f, protocol=protocol)
......
......@@ -19,6 +19,7 @@ import numpy as np
import os
import sys
import six
from io import BytesIO
import paddle
import paddle.nn as nn
......@@ -760,6 +761,71 @@ class TestSaveLoadAny(unittest.TestCase):
self.assertTrue(np.array_equal(origin_array, load_tensor_array))
class TestSaveLoadToMemory(unittest.TestCase):
def test_dygraph_save_to_memory(self):
paddle.disable_static()
linear = LinearNet()
state_dict = linear.state_dict()
byio = BytesIO()
paddle.save(state_dict, byio)
tensor = paddle.randn([2, 3], dtype='float32')
paddle.save(tensor, byio)
byio.seek(0)
# load state_dict
dict_load = paddle.load(byio, return_numpy=True)
for k, v in state_dict.items():
self.assertTrue(np.array_equal(v.numpy(), dict_load[k]))
# load tensor
tensor_load = paddle.load(byio, return_numpy=True)
self.assertTrue(np.array_equal(tensor_load, tensor.numpy()))
with self.assertRaises(ValueError):
paddle.save(4, 3)
with self.assertRaises(ValueError):
paddle.save(state_dict, '')
with self.assertRaises(ValueError):
paddle.fluid.io._open_file_buffer('temp', 'b')
def test_static_save_to_memory(self):
paddle.enable_static()
with new_program_scope():
# create network
x = paddle.static.data(
name="x", shape=[None, IMAGE_SIZE], dtype='float32')
z = paddle.static.nn.fc(x, 10, bias_attr=False)
z = paddle.static.nn.fc(z, 128, bias_attr=False)
loss = fluid.layers.reduce_mean(z)
place = fluid.CPUPlace(
) if not paddle.fluid.core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
prog = paddle.static.default_main_program()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
state_dict = prog.state_dict()
keys = list(state_dict.keys())
tensor = state_dict[keys[0]]
byio = BytesIO()
byio2 = BytesIO()
paddle.save(prog, byio2)
paddle.save(tensor, byio)
paddle.save(state_dict, byio)
byio.seek(0)
byio2.seek(0)
prog_load = paddle.load(byio2)
self.assertTrue(prog.desc.serialize_to_string() ==
prog_load.desc.serialize_to_string())
tensor_load = paddle.load(byio, return_numpy=True)
self.assertTrue(np.array_equal(tensor_load, np.array(tensor)))
state_dict_load = paddle.load(byio, return_numpy=True)
for k, v in state_dict.items():
self.assertTrue(np.array_equal(np.array(v), state_dict_load[k]))
class TestSaveLoad(unittest.TestCase):
def setUp(self):
# enable dygraph mode
......
......@@ -16,6 +16,7 @@ from __future__ import print_function
import unittest
import numpy as np
from io import BytesIO
import os
import sys
import six
......@@ -176,13 +177,27 @@ class TestSaveLoadBinaryFormat(unittest.TestCase):
paddle.save(temp_lod, path, use_binary_format=True)
with self.assertRaises(RuntimeError):
fluid.core._save_lod_tensor(
fluid.core.save_lod_tensor(
temp_lod, 'test_save_load_error_not_exist_file/not_exist_file')
with self.assertRaises(RuntimeError):
fluid.core._load_lod_tensor(
fluid.core.load_lod_tensor(
temp_lod, 'test_save_load_error_not_exist_file/not_exist_file')
# save to memory
byio = BytesIO()
paddle.save(tensor, byio, use_binary_format=True)
byio.seek(0)
# load from memory
loaded_tensor_mem = paddle.load(byio)
to_array_mem = np.array(loaded_tensor_mem)
self.assertTrue(np.array_equal(np.array(tensor), to_array_mem))
with self.assertRaises(NotImplementedError):
paddle.framework.io._save_lod_tensor(tensor, 1)
with self.assertRaises(NotImplementedError):
paddle.framework.io._load_lod_tensor(1)
def test_save_load_selected_rows(self):
paddle.enable_static()
place = fluid.CPUPlace() if not paddle.fluid.core.is_compiled_with_cuda(
......@@ -210,10 +225,28 @@ class TestSaveLoadBinaryFormat(unittest.TestCase):
np.array_equal(np.array(load_sr.get_tensor()), np_array))
with self.assertRaises(RuntimeError):
fluid.core._save_selected_rows(
fluid.core.save_selected_rows(
selected_rows,
'test_paddle_save_load_selected_rows_not_exist_file/temp')
with self.assertRaises(RuntimeError):
fluid.core._load_selected_rows(
fluid.core.load_selected_rows(
selected_rows,
'test_paddle_save_load_selected_rows_not_exist_file/temp')
# save to memory
byio = BytesIO()
paddle.save(selected_rows, byio, use_binary_format=True)
byio.seek(0)
# load from memory
selected_rows_mem = paddle.load(byio)
to_array_mem = np.array(selected_rows_mem)
self.assertTrue(isinstance(selected_rows_mem, fluid.core.SelectedRows))
self.assertTrue(list(selected_rows_mem.rows()) == rows)
self.assertTrue(selected_rows_mem.height() == height)
self.assertTrue(
np.array_equal(np.array(selected_rows_mem.get_tensor()), np_array))
with self.assertRaises(NotImplementedError):
paddle.framework.io._save_selected_rows(selected_rows, 1)
with self.assertRaises(NotImplementedError):
paddle.framework.io._load_selected_rows(1)
......@@ -32,6 +32,7 @@ from paddle import fluid
from paddle.fluid import core
from paddle.fluid.io import _unpack_saved_dict, _pack_loaded_dict, _pickle_loads_mac
from paddle.fluid.io import _legacy_save as _legacy_static_save
from paddle.fluid.io import _open_file_buffer, _is_file_path, _is_memory_buffer
from paddle.fluid.framework import Variable, _varbase_creator, _dygraph_tracer, in_dygraph_mode, ParamBase, _current_expected_place, Program
from paddle.fluid.dygraph.jit import _SaveLoadConfig
......@@ -450,30 +451,81 @@ def _parse_load_result(obj, return_numpy):
def _save_lod_tensor(tensor, file_name):
if not tensor._is_initialized():
raise ValueError("The saved tensor is not initialized.")
_seek = core._save_lod_tensor(tensor, file_name)
# '_seek' is the end position of this tensor in the file.
if _is_file_path(file_name):
_seek = core.save_lod_tensor(tensor, file_name)
# '_seek' is the end position of this tensor in the file.
elif _is_memory_buffer(file_name):
tensor_bytes = core.save_lod_tensor_to_memory(tensor)
with _open_file_buffer(file_name, 'wb') as f:
f.write(tensor_bytes)
_seek = f.tell()
else:
raise NotImplementedError(
'Only supports saving objects to file or BytesIO, but received {}'.
format(type(file_name)))
return _seek
def _load_lod_tensor(file_name):
temp_t = paddle.fluid.core.LoDTensor()
# '_seek' is the end position of this tensor in the file.
_seek = paddle.fluid.core._load_lod_tensor(temp_t, file_name)
if _is_file_path(file_name):
# '_seek' is the end position of this tensor in the file.
_seek = paddle.fluid.core.load_lod_tensor(temp_t, file_name)
elif _is_memory_buffer(file_name):
with _open_file_buffer(file_name, 'rb') as f:
tensor_bytes = f.read()
paddle.fluid.core.load_lod_tensor_from_memory(temp_t, tensor_bytes)
_seek = f.tell()
else:
raise NotImplementedError(
'Only supports load objects from file or BytesIO, but received {}'.
format(type(file_name)))
return temp_t, _seek
def _save_selected_rows(selected_rows, file_name):
# '_seek' is the end position of this SelectedRows in the file.
if not selected_rows.get_tensor()._is_initialized():
raise ValueError("The saved tensor is not initialized.")
_seek = core._save_selected_rows(selected_rows, file_name)
if _is_file_path(file_name):
# '_seek' is the end position of this SelectedRows in the file.
_seek = core.save_selected_rows(selected_rows, file_name)
elif _is_memory_buffer(file_name):
selected_rows_bytes = core.save_selected_rows_to_memory(selected_rows)
with _open_file_buffer(file_name, 'wb') as f:
f.write(selected_rows_bytes)
_seek = f.tell()
else:
raise NotImplementedError(
'Only supports saving objects to file or BytesIO, but received {}'.
format(type(file_name)))
return _seek
def _load_selected_rows(file_name):
temp_sr = core.SelectedRows()
# '_seek' is the end position of this SelectedRows in the file.
_seek = core._load_selected_rows(temp_sr, file_name)
if _is_file_path(file_name):
# '_seek' is the end position of this SelectedRows in the file.
_seek = core.load_selected_rows(temp_sr, file_name)
elif _is_memory_buffer(file_name):
with _open_file_buffer(file_name, 'rb') as f:
selected_rows_bytes = f.read()
paddle.fluid.core.load_selected_rows_from_memory(
temp_sr, selected_rows_bytes)
_seek = f.tell()
else:
raise NotImplementedError(
'Only supports load objects from file or BytesIO, but received {}'.
format(type(file_name)))
return temp_sr, _seek
......@@ -509,7 +561,7 @@ def save(obj, path, protocol=4, **configs):
Args:
obj(Object) : The object to be saved.
path(str) : The path of the object to be saved.
path(str|BytesIO) : The path/buffer of the object to be saved.
If saved in the current directory, the input path string will be used as the file name.
protocol(int, optional): The protocol version of pickle module must be greater than 1 and less than 5.
Default: 4
......@@ -593,18 +645,39 @@ def save(obj, path, protocol=4, **configs):
main_program = paddle.static.default_main_program()
path = "example/main_program.pdmodel"
paddle.save(main_program, path)
'''
# 1. input check
filename = os.path.basename(path)
if filename == "":
raise ValueError("The input path MUST be format of dirname/filename "
"[dirname\\filename in Windows system], but received "
"filename is empty string.")
# 2. save object
dirname = os.path.dirname(path)
if dirname and not os.path.exists(dirname):
os.makedirs(dirname)
# example 5: save object to memory
from io import BytesIO
import paddle
from paddle.nn import Linear
paddle.disable_static()
linear = Linear(5, 10)
state_dict = linear.state_dict()
byio = BytesIO()
paddle.save(state_dict, byio)
tensor = paddle.randn([2, 3], dtype='float32')
paddle.save(tensor, byio)
'''
if _is_file_path(path):
# 1. input check
filename = os.path.basename(path)
if filename == "":
raise ValueError(
"The input path MUST be format of dirname/filename "
"[dirname\\filename in Windows system], but received "
"filename is empty string.")
# 2. save object
dirname = os.path.dirname(path)
if dirname and not os.path.exists(dirname):
os.makedirs(dirname)
elif not _is_memory_buffer(path):
raise ValueError(
"only supports saving objects to file and `BytesIO`, but got {}".
format(type(path)))
config = _parse_save_config(configs)
......@@ -625,7 +698,7 @@ def save(obj, path, protocol=4, **configs):
if isinstance(obj, Program):
obj.desc.flush()
with open(path, "wb") as f:
with _open_file_buffer(path, "wb") as f:
f.write(obj.desc.serialize_to_string())
elif _is_state_dict(obj):
......@@ -634,7 +707,7 @@ def save(obj, path, protocol=4, **configs):
else:
_legacy_static_save(obj, path, protocol)
else:
with open(path, 'wb') as f:
with _open_file_buffer(path, 'wb') as f:
_pickle_save(obj, f, protocol)
......@@ -648,12 +721,6 @@ def _legacy_save(obj, path, protocol=2):
if len(obj) == 0:
warnings.warn("The input state dict is empty, no need to save.")
filename = os.path.basename(path)
if filename == "":
raise ValueError("The input path MUST be format of dirname/filename "
"[dirname\\filename in Windows system], but received "
"filename is empty string.")
if not isinstance(protocol, int):
raise ValueError("The 'protocol' MUST be `int`, but received {}".format(
type(protocol)))
......@@ -662,26 +729,33 @@ def _legacy_save(obj, path, protocol=2):
raise ValueError("Expected 1<'protocol'<5, but received protocol={}".
format(protocol))
# 2. save object
dirname = os.path.dirname(path)
if dirname and not os.path.exists(dirname):
os.makedirs(dirname)
if _is_file_path(path):
filename = os.path.basename(path)
if filename == "":
raise ValueError(
"The input path MUST be format of dirname/filename "
"[dirname\\filename in Windows system], but received "
"filename is empty string.")
# 2. save object
dirname = os.path.dirname(path)
if dirname and not os.path.exists(dirname):
os.makedirs(dirname)
# TODO(chenweihang): supports save other object
if isinstance(obj, dict):
saved_obj = _build_saved_state_dict(obj)
saved_obj = _unpack_saved_dict(saved_obj, protocol)
# When value of dict is lager than 4GB ,there is a Bug on 'MAC python3'
if sys.platform == 'darwin' and sys.version_info.major == 3:
if _is_file_path(
path) and sys.platform == 'darwin' and sys.version_info.major == 3:
pickle_bytes = pickle.dumps(saved_obj, protocol=protocol)
with open(path, 'wb') as f:
max_bytes = 2**30
for i in range(0, len(pickle_bytes), max_bytes):
f.write(pickle_bytes[i:i + max_bytes])
else:
with open(path, 'wb') as f:
with _open_file_buffer(path, 'wb') as f:
pickle.dump(saved_obj, f, protocol=protocol)
......@@ -716,7 +790,7 @@ def load(path, **configs):
``Layer.set_state_dict`` later.
Args:
path(str) : The path to load the target object. Generally, the path is the target
path(str|BytesIO) : The path/buffer to load the target object. Generally, the path is the target
file path. When loading state_dict from the saved result of the API used to save
the inference model, the path may be a file prefix or directory.
**configs (dict, optional): other load configuration options for compatibility. We do not
......@@ -822,18 +896,36 @@ def load(path, **configs):
print(load_main)
# example 5: save object to memory
from io import BytesIO
import paddle
from paddle.nn import Linear
paddle.disable_static()
linear = Linear(5, 10)
state_dict = linear.state_dict()
byio = BytesIO()
paddle.save(state_dict, byio)
tensor = paddle.randn([2, 3], dtype='float32')
paddle.save(tensor, byio)
byio.seek(0)
# load state_dict
dict_load = paddle.load(byio)
'''
if os.path.isfile(path):
if _is_memory_buffer(path) or os.path.isfile(path):
config = _parse_load_config(configs)
if six.PY2:
exception_type = KeyError
else:
exception_type = pickle.UnpicklingError
try:
with open(path, 'rb') as f:
with _open_file_buffer(path, 'rb') as f:
# When value of dict is lager than 4GB ,there is a Bug on 'MAC python3'
if sys.platform == 'darwin' and sys.version_info.major == 3:
if _is_file_path(
path
) and sys.platform == 'darwin' and sys.version_info.major == 3:
load_result = _pickle_loads_mac(path, f)
else:
load_result = pickle.load(f) if six.PY2 else pickle.load(
......@@ -875,7 +967,7 @@ def load(path, **configs):
return tensor
except:
try:
with open(path, "rb") as f:
with _open_file_buffer(path, "rb") as f:
program_desc_str = f.read()
program = Program.parse_from_string(
program_desc_str)
......@@ -895,9 +987,9 @@ def _legacy_load(path, **configs):
load_result = None
config = _parse_load_config(configs)
if os.path.isfile(path):
if os.path.isfile(path) or _is_memory_buffer(path):
# we think path is file means this file is created by paddle.save
with open(path, 'rb') as f:
with _open_file_buffer(path, 'rb') as f:
load_result = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
load_result = _pack_loaded_dict(load_result)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册