From cdd6437a81593118571bc06b84b60162eedfc335 Mon Sep 17 00:00:00 2001 From: WeiXin Date: Wed, 9 Jun 2021 12:49:13 +0800 Subject: [PATCH] paddle.save support object save to memory. (#32999) * support state_dict save to memory. * Perfect unittest * perfect unittest. * suport saving binary var to memory * polish code. * packag save/load files into pybind/io.py * polish code . * add example for save to memory; remove useless save load function(_load_static_dict,_save_dygraph_dict) * delete _load_static/dygraph_dict;_save_static/dygraph_dict * edit example of paddle.save/load --- paddle/fluid/framework/lod_tensor.cc | 2 +- paddle/fluid/framework/lod_tensor.h | 2 +- paddle/fluid/framework/selected_rows.cc | 2 +- paddle/fluid/framework/selected_rows.h | 2 +- paddle/fluid/pybind/CMakeLists.txt | 1 + paddle/fluid/pybind/io.cc | 111 +++++++++++ paddle/fluid/pybind/io.h | 24 +++ paddle/fluid/pybind/pybind.cc | 86 +-------- python/paddle/fluid/core.py | 16 -- python/paddle/fluid/io.py | 53 +++++- .../tests/unittests/test_paddle_save_load.py | 66 +++++++ .../unittests/test_paddle_save_load_binary.py | 41 +++- python/paddle/framework/io.py | 176 +++++++++++++----- 13 files changed, 430 insertions(+), 152 deletions(-) create mode 100644 paddle/fluid/pybind/io.cc create mode 100644 paddle/fluid/pybind/io.h diff --git a/paddle/fluid/framework/lod_tensor.cc b/paddle/fluid/framework/lod_tensor.cc index 0a6b5e44452..69a2a6eefaf 100644 --- a/paddle/fluid/framework/lod_tensor.cc +++ b/paddle/fluid/framework/lod_tensor.cc @@ -276,7 +276,7 @@ void SerializeToStream(std::ostream &os, const LoDTensor &tensor) { SerializeToStream(os, tensor, *dev_ctx); } -void DeserializeFromStream(std::ifstream &os, LoDTensor *tensor) { +void DeserializeFromStream(std::istream &os, LoDTensor *tensor) { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); const platform::DeviceContext *dev_ctx; dev_ctx = pool.Get(platform::CPUPlace()); diff --git a/paddle/fluid/framework/lod_tensor.h b/paddle/fluid/framework/lod_tensor.h index 6b357aba1c5..7dee0f44e38 100644 --- a/paddle/fluid/framework/lod_tensor.h +++ b/paddle/fluid/framework/lod_tensor.h @@ -257,7 +257,7 @@ LoD ConvertToOffsetBasedLoD(const LoD& length_lod); void SerializeToStream(std::ostream& os, const LoDTensor& tensor); -void DeserializeFromStream(std::ifstream& os, LoDTensor* tensor); +void DeserializeFromStream(std::istream& os, LoDTensor* tensor); } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/selected_rows.cc b/paddle/fluid/framework/selected_rows.cc index 7e48d0dc5f9..c67653953f8 100644 --- a/paddle/fluid/framework/selected_rows.cc +++ b/paddle/fluid/framework/selected_rows.cc @@ -121,7 +121,7 @@ void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows) { SerializeToStream(os, selected_rows, *dev_ctx); } -void DeserializeFromStream(std::ifstream& os, SelectedRows* selected_rows) { +void DeserializeFromStream(std::istream& os, SelectedRows* selected_rows) { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); const platform::DeviceContext* dev_ctx; dev_ctx = pool.Get(platform::CPUPlace()); diff --git a/paddle/fluid/framework/selected_rows.h b/paddle/fluid/framework/selected_rows.h index e53e3d973c5..3e4beb9498c 100644 --- a/paddle/fluid/framework/selected_rows.h +++ b/paddle/fluid/framework/selected_rows.h @@ -175,7 +175,7 @@ void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows, void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows); -void DeserializeFromStream(std::ifstream& os, SelectedRows* selected_rows); +void DeserializeFromStream(std::istream& os, SelectedRows* selected_rows); } // namespace framework } // namespace paddle diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 5fcb1e30fbe..5e5475da89f 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -56,6 +56,7 @@ set(PYBIND_SRCS ir.cc inference_api.cc compatible.cc + io.cc generator_py.cc) if(WITH_ASCEND) diff --git a/paddle/fluid/pybind/io.cc b/paddle/fluid/pybind/io.cc new file mode 100644 index 00000000000..fc49f763054 --- /dev/null +++ b/paddle/fluid/pybind/io.cc @@ -0,0 +1,111 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/pybind/io.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/pybind/pybind_boost_headers.h" + +namespace py = pybind11; +namespace paddle { +namespace pybind { + +void BindIO(pybind11::module *m) { + m->def("save_lod_tensor", [](const paddle::framework::LoDTensor &tensor, + const std::string &str_file_name) { + std::ofstream fout(str_file_name, std::ios::binary); + PADDLE_ENFORCE_EQ(static_cast(fout), true, + platform::errors::Unavailable( + "Cannot open %s to save variables.", str_file_name)); + paddle::framework::SerializeToStream(fout, tensor); + + int64_t tellp = fout.tellp(); + fout.close(); + return tellp; + }); + + m->def("load_lod_tensor", [](paddle::framework::LoDTensor &tensor, + const std::string &str_file_name) { + std::ifstream fin(str_file_name, std::ios::binary); + PADDLE_ENFORCE_EQ(static_cast(fin), true, + platform::errors::Unavailable( + "Cannot open %s to load variables.", str_file_name)); + + paddle::framework::DeserializeFromStream(fin, &tensor); + int64_t tellg = fin.tellg(); + fin.close(); + return tellg; + }); + + m->def("save_selected_rows", + [](const paddle::framework::SelectedRows &selected_rows, + const std::string &str_file_name) { + std::ofstream fout(str_file_name, std::ios::binary); + PADDLE_ENFORCE_EQ( + static_cast(fout), true, + platform::errors::Unavailable( + "Cannot open %s to save SelectedRows.", str_file_name)); + + paddle::framework::SerializeToStream(fout, selected_rows); + int64_t tellp = fout.tellp(); + fout.close(); + return tellp; + }); + + m->def("load_selected_rows", + [](paddle::framework::SelectedRows &selected_rows, + const std::string &str_file_name) { + std::ifstream fin(str_file_name, std::ios::binary); + PADDLE_ENFORCE_EQ( + static_cast(fin), true, + platform::errors::Unavailable( + "Cannot open %s to load SelectedRows.", str_file_name)); + + paddle::framework::DeserializeFromStream(fin, &selected_rows); + int64_t tellg = fin.tellg(); + fin.close(); + return tellg; + }); + + m->def("save_lod_tensor_to_memory", + [](const paddle::framework::LoDTensor &tensor) -> py::bytes { + std::ostringstream ss; + paddle::framework::SerializeToStream(ss, tensor); + return ss.str(); + }); + + m->def("load_lod_tensor_from_memory", [](paddle::framework::LoDTensor &tensor, + const std::string &tensor_bytes) { + std::istringstream fin(tensor_bytes, std::ios::in | std::ios::binary); + paddle::framework::DeserializeFromStream(fin, &tensor); + }); + + m->def("save_selected_rows_to_memory", + [](const paddle::framework::SelectedRows &selected_rows) -> py::bytes { + std::ostringstream ss; + paddle::framework::SerializeToStream(ss, selected_rows); + return ss.str(); + }); + + m->def("load_selected_rows_from_memory", + [](paddle::framework::SelectedRows &selected_rows, + const std::string &selected_rows_bytes) { + std::istringstream fin(selected_rows_bytes, + std::ios::in | std::ios::binary); + paddle::framework::DeserializeFromStream(fin, &selected_rows); + }); +} +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/io.h b/paddle/fluid/pybind/io.h new file mode 100644 index 00000000000..dfe3154cb95 --- /dev/null +++ b/paddle/fluid/pybind/io.h @@ -0,0 +1,24 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/pybind/pybind_boost_headers.h" + +namespace paddle { +namespace pybind { +void BindIO(pybind11::module* m); +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 6dd08e5dfa4..86084297c4a 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -68,6 +68,7 @@ limitations under the License. */ #include "paddle/fluid/platform/monitor.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/profiler.h" +#include "paddle/fluid/pybind/io.h" #ifdef PADDLE_WITH_ASCEND #include "paddle/fluid/pybind/ascend_wrapper_py.h" #endif @@ -496,70 +497,6 @@ PYBIND11_MODULE(core_noavx, m) { #endif return tensor; }); - m.def("_save_lod_tensor", [](const LoDTensor &tensor, - const std::string &str_file_name) { - std::ofstream fout(str_file_name, std::ios::binary); - PADDLE_ENFORCE_EQ(static_cast(fout), true, - platform::errors::Unavailable( - "Cannot open %s to save variables.", str_file_name)); - SerializeToStream(fout, tensor); - - int64_t tellp = fout.tellp(); - fout.close(); - return tellp; - }); - m.def("_load_lod_tensor", [](LoDTensor &tensor, - const std::string &str_file_name) { - std::ifstream fin(str_file_name, std::ios::binary); - PADDLE_ENFORCE_EQ(static_cast(fin), true, - platform::errors::Unavailable( - "Cannot open %s to load variables.", str_file_name)); - - DeserializeFromStream(fin, &tensor); - int64_t tellg = fin.tellg(); - fin.close(); - return tellg; - }); - m.def("_save_selected_rows", [](const SelectedRows &selected_rows, - const std::string &str_file_name) { - std::ofstream fout(str_file_name, std::ios::binary); - PADDLE_ENFORCE_EQ( - static_cast(fout), true, - platform::errors::Unavailable("Cannot open %s to save SelectedRows.", - str_file_name)); - - SerializeToStream(fout, selected_rows); - int64_t tellp = fout.tellp(); - fout.close(); - return tellp; - }); - m.def("_load_selected_rows", - [](SelectedRows &selected_rows, const std::string &str_file_name) { - std::ifstream fin(str_file_name, std::ios::binary); - PADDLE_ENFORCE_EQ( - static_cast(fin), true, - platform::errors::Unavailable( - "Cannot open %s to load SelectedRows.", str_file_name)); - - DeserializeFromStream(fin, &selected_rows); - int64_t tellg = fin.tellg(); - fin.close(); - return tellg; - }); - m.def("_save_static_dict", - [](const std::string &str_file_name, const py::handle &vec_var_list, - const Scope &scope) { - std::vector vec_name_list = GetNameList(vec_var_list); - SaveStaticNameListToDisk(str_file_name, vec_name_list, scope); - }); - - m.def("_load_static_dict", - [](const std::string &str_file_name, const py::handle &vec_var_list, - const Scope &scope, const Executor *executor) { - std::vector vec_name_list = GetNameList(vec_var_list); - CreateVariableIfNotExit(vec_var_list, scope, executor); - LoadStaticNameListFromDisk(str_file_name, vec_name_list, scope); - }); m.def("_create_loaded_parameter", [](const py::handle &vec_var_list, const Scope &scope, @@ -567,26 +504,6 @@ PYBIND11_MODULE(core_noavx, m) { CreateVariableIfNotExit(vec_var_list, scope, executor); }); - m.def("_save_dygraph_dict", [](const std::string &str_file_name, - const PyNameVarBaseMap &state_dict) { - auto vec_var_base_list = GetVarBaseList(state_dict); - - SaveDygraphVarBaseListToDisk(str_file_name, vec_var_base_list); - }); - - m.def("_load_dygraph_dict", [](const std::string &str_file_name) { - auto load_tensor = LoadDygraphVarBaseListFromDisk(str_file_name); - - std::unordered_map> - map_output; - - for (size_t i = 0; i < load_tensor.size(); ++i) { - map_output.emplace(load_tensor[i]->Name(), load_tensor[i]); - } - - return map_output; - }); - m.def("save_op_version_info", [](framework::ProgramDesc &desc) { framework::compatible::pb::OpVersionMap pb_vmap{desc.OpVersionMap()}; framework::compatible::SaveOpVersions( @@ -3111,6 +3028,7 @@ All parameter, weight, gradient are variables in Paddle. .def("device_count", &ParallelExecutor::DeviceCount); BindFleetWrapper(&m); + BindIO(&m); #ifdef PADDLE_WITH_PSLIB BindHeterWrapper(&m); diff --git a/python/paddle/fluid/core.py b/python/paddle/fluid/core.py index 9e931ad40c5..7886b6b3f7a 100644 --- a/python/paddle/fluid/core.py +++ b/python/paddle/fluid/core.py @@ -269,14 +269,6 @@ if avx_supported(): from .core_avx import _dygraph_debug_level from .core_avx import _switch_tracer from .core_avx import _set_paddle_lib_path - from .core_avx import _save_static_dict - from .core_avx import _load_static_dict - from .core_avx import _save_dygraph_dict - from .core_avx import _load_dygraph_dict - from .core_avx import _save_lod_tensor - from .core_avx import _load_lod_tensor - from .core_avx import _save_selected_rows - from .core_avx import _load_selected_rows from .core_avx import _create_loaded_parameter from .core_avx import _cuda_synchronize from .core_avx import _promote_types_if_complex_exists @@ -328,14 +320,6 @@ if load_noavx: from .core_noavx import _dygraph_debug_level from .core_noavx import _switch_tracer from .core_noavx import _set_paddle_lib_path - from .core_noavx import _save_static_dict - from .core_noavx import _load_static_dict - from .core_noavx import _save_dygraph_dict - from .core_noavx import _load_dygraph_dict - from .core_noavx import _save_lod_tensor - from .core_noavx import _load_lod_tensor - from .core_noavx import _save_selected_rows - from .core_noavx import _load_selected_rows from .core_noavx import _create_loaded_parameter from .core_noavx import _cuda_synchronize from .core_noavx import _promote_types_if_complex_exists diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 30a0b4053e6..2d3578c6c10 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -23,6 +23,7 @@ import pickle import contextlib from functools import reduce import sys +from io import BytesIO import numpy as np import math @@ -71,6 +72,52 @@ _logger = get_logger( __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') +class _open_buffer(object): + def __init__(self, buffer): + self.buffer = buffer + + def __enter__(self): + return self.buffer + + +class _buffer_reader(_open_buffer): + def __init__(self, buffer): + super(_buffer_reader, self).__init__(buffer) + self.initial_tell = self.buffer.tell() + + def __exit__(self, *args): + # `args[0]` is type of exception. When the `read` is abnormal, the file pointer returns to the initial position. + if args[0] is not None: + self.buffer.seek(self.initial_tell) + + +class _buffer_writer(_open_buffer): + def __exit__(self, *args): + self.buffer.flush() + + +def _is_file_path(path): + return isinstance(path, str) + + +def _open_file_buffer(path_or_buffer, mode): + + if _is_file_path(path_or_buffer): + return open(path_or_buffer, mode) + else: + if 'w' in mode: + return _buffer_writer(path_or_buffer) + elif 'r' in mode: + return _buffer_reader(path_or_buffer) + else: + raise ValueError("Expected 'r' or 'w' in mode but got {}".format( + mode)) + + +def _is_memory_buffer(buffer): + return isinstance(buffer, BytesIO) + + def is_parameter(var): """ Check whether the given variable is an instance of Parameter. @@ -1776,14 +1823,16 @@ def _legacy_save(param_dict, model_path, protocol=2): param_dict = {name: get_tensor(param_dict[name]) for name in param_dict} # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3' - if sys.platform == 'darwin' and sys.version_info.major == 3: + if _is_file_path( + model_path + ) and sys.platform == 'darwin' and sys.version_info.major == 3: pickle_bytes = pickle.dumps(param_dict, protocol=protocol) with open(model_path, 'wb') as f: max_bytes = 2**30 for i in range(0, len(pickle_bytes), max_bytes): f.write(pickle_bytes[i:i + max_bytes]) else: - with open(model_path, 'wb') as f: + with _open_file_buffer(model_path, 'wb') as f: pickle.dump(param_dict, f, protocol=protocol) diff --git a/python/paddle/fluid/tests/unittests/test_paddle_save_load.py b/python/paddle/fluid/tests/unittests/test_paddle_save_load.py index be2a6a653cc..594d0db035c 100644 --- a/python/paddle/fluid/tests/unittests/test_paddle_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_paddle_save_load.py @@ -19,6 +19,7 @@ import numpy as np import os import sys import six +from io import BytesIO import paddle import paddle.nn as nn @@ -760,6 +761,71 @@ class TestSaveLoadAny(unittest.TestCase): self.assertTrue(np.array_equal(origin_array, load_tensor_array)) +class TestSaveLoadToMemory(unittest.TestCase): + def test_dygraph_save_to_memory(self): + paddle.disable_static() + linear = LinearNet() + state_dict = linear.state_dict() + byio = BytesIO() + paddle.save(state_dict, byio) + tensor = paddle.randn([2, 3], dtype='float32') + paddle.save(tensor, byio) + byio.seek(0) + # load state_dict + dict_load = paddle.load(byio, return_numpy=True) + for k, v in state_dict.items(): + self.assertTrue(np.array_equal(v.numpy(), dict_load[k])) + # load tensor + tensor_load = paddle.load(byio, return_numpy=True) + self.assertTrue(np.array_equal(tensor_load, tensor.numpy())) + + with self.assertRaises(ValueError): + paddle.save(4, 3) + with self.assertRaises(ValueError): + paddle.save(state_dict, '') + with self.assertRaises(ValueError): + paddle.fluid.io._open_file_buffer('temp', 'b') + + def test_static_save_to_memory(self): + paddle.enable_static() + with new_program_scope(): + # create network + x = paddle.static.data( + name="x", shape=[None, IMAGE_SIZE], dtype='float32') + z = paddle.static.nn.fc(x, 10, bias_attr=False) + z = paddle.static.nn.fc(z, 128, bias_attr=False) + loss = fluid.layers.reduce_mean(z) + place = fluid.CPUPlace( + ) if not paddle.fluid.core.is_compiled_with_cuda( + ) else fluid.CUDAPlace(0) + prog = paddle.static.default_main_program() + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + + state_dict = prog.state_dict() + keys = list(state_dict.keys()) + tensor = state_dict[keys[0]] + + byio = BytesIO() + byio2 = BytesIO() + paddle.save(prog, byio2) + paddle.save(tensor, byio) + paddle.save(state_dict, byio) + byio.seek(0) + byio2.seek(0) + + prog_load = paddle.load(byio2) + self.assertTrue(prog.desc.serialize_to_string() == + prog_load.desc.serialize_to_string()) + + tensor_load = paddle.load(byio, return_numpy=True) + self.assertTrue(np.array_equal(tensor_load, np.array(tensor))) + + state_dict_load = paddle.load(byio, return_numpy=True) + for k, v in state_dict.items(): + self.assertTrue(np.array_equal(np.array(v), state_dict_load[k])) + + class TestSaveLoad(unittest.TestCase): def setUp(self): # enable dygraph mode diff --git a/python/paddle/fluid/tests/unittests/test_paddle_save_load_binary.py b/python/paddle/fluid/tests/unittests/test_paddle_save_load_binary.py index 7385da56bea..0b9e038f7cd 100644 --- a/python/paddle/fluid/tests/unittests/test_paddle_save_load_binary.py +++ b/python/paddle/fluid/tests/unittests/test_paddle_save_load_binary.py @@ -16,6 +16,7 @@ from __future__ import print_function import unittest import numpy as np +from io import BytesIO import os import sys import six @@ -176,13 +177,27 @@ class TestSaveLoadBinaryFormat(unittest.TestCase): paddle.save(temp_lod, path, use_binary_format=True) with self.assertRaises(RuntimeError): - fluid.core._save_lod_tensor( + fluid.core.save_lod_tensor( temp_lod, 'test_save_load_error_not_exist_file/not_exist_file') with self.assertRaises(RuntimeError): - fluid.core._load_lod_tensor( + fluid.core.load_lod_tensor( temp_lod, 'test_save_load_error_not_exist_file/not_exist_file') + # save to memory + byio = BytesIO() + paddle.save(tensor, byio, use_binary_format=True) + byio.seek(0) + # load from memory + loaded_tensor_mem = paddle.load(byio) + to_array_mem = np.array(loaded_tensor_mem) + self.assertTrue(np.array_equal(np.array(tensor), to_array_mem)) + + with self.assertRaises(NotImplementedError): + paddle.framework.io._save_lod_tensor(tensor, 1) + with self.assertRaises(NotImplementedError): + paddle.framework.io._load_lod_tensor(1) + def test_save_load_selected_rows(self): paddle.enable_static() place = fluid.CPUPlace() if not paddle.fluid.core.is_compiled_with_cuda( @@ -210,10 +225,28 @@ class TestSaveLoadBinaryFormat(unittest.TestCase): np.array_equal(np.array(load_sr.get_tensor()), np_array)) with self.assertRaises(RuntimeError): - fluid.core._save_selected_rows( + fluid.core.save_selected_rows( selected_rows, 'test_paddle_save_load_selected_rows_not_exist_file/temp') with self.assertRaises(RuntimeError): - fluid.core._load_selected_rows( + fluid.core.load_selected_rows( selected_rows, 'test_paddle_save_load_selected_rows_not_exist_file/temp') + + # save to memory + byio = BytesIO() + paddle.save(selected_rows, byio, use_binary_format=True) + byio.seek(0) + # load from memory + selected_rows_mem = paddle.load(byio) + to_array_mem = np.array(selected_rows_mem) + self.assertTrue(isinstance(selected_rows_mem, fluid.core.SelectedRows)) + self.assertTrue(list(selected_rows_mem.rows()) == rows) + self.assertTrue(selected_rows_mem.height() == height) + self.assertTrue( + np.array_equal(np.array(selected_rows_mem.get_tensor()), np_array)) + + with self.assertRaises(NotImplementedError): + paddle.framework.io._save_selected_rows(selected_rows, 1) + with self.assertRaises(NotImplementedError): + paddle.framework.io._load_selected_rows(1) diff --git a/python/paddle/framework/io.py b/python/paddle/framework/io.py index 1705db50d39..5f1ffa81eab 100644 --- a/python/paddle/framework/io.py +++ b/python/paddle/framework/io.py @@ -32,6 +32,7 @@ from paddle import fluid from paddle.fluid import core from paddle.fluid.io import _unpack_saved_dict, _pack_loaded_dict, _pickle_loads_mac from paddle.fluid.io import _legacy_save as _legacy_static_save +from paddle.fluid.io import _open_file_buffer, _is_file_path, _is_memory_buffer from paddle.fluid.framework import Variable, _varbase_creator, _dygraph_tracer, in_dygraph_mode, ParamBase, _current_expected_place, Program from paddle.fluid.dygraph.jit import _SaveLoadConfig @@ -450,30 +451,81 @@ def _parse_load_result(obj, return_numpy): def _save_lod_tensor(tensor, file_name): if not tensor._is_initialized(): raise ValueError("The saved tensor is not initialized.") - _seek = core._save_lod_tensor(tensor, file_name) - # '_seek' is the end position of this tensor in the file. + if _is_file_path(file_name): + _seek = core.save_lod_tensor(tensor, file_name) + # '_seek' is the end position of this tensor in the file. + + elif _is_memory_buffer(file_name): + tensor_bytes = core.save_lod_tensor_to_memory(tensor) + + with _open_file_buffer(file_name, 'wb') as f: + f.write(tensor_bytes) + _seek = f.tell() + + else: + raise NotImplementedError( + 'Only supports saving objects to file or BytesIO, but received {}'. + format(type(file_name))) return _seek def _load_lod_tensor(file_name): temp_t = paddle.fluid.core.LoDTensor() - # '_seek' is the end position of this tensor in the file. - _seek = paddle.fluid.core._load_lod_tensor(temp_t, file_name) + if _is_file_path(file_name): + # '_seek' is the end position of this tensor in the file. + _seek = paddle.fluid.core.load_lod_tensor(temp_t, file_name) + + elif _is_memory_buffer(file_name): + with _open_file_buffer(file_name, 'rb') as f: + tensor_bytes = f.read() + paddle.fluid.core.load_lod_tensor_from_memory(temp_t, tensor_bytes) + _seek = f.tell() + + else: + raise NotImplementedError( + 'Only supports load objects from file or BytesIO, but received {}'. + format(type(file_name))) + return temp_t, _seek def _save_selected_rows(selected_rows, file_name): - # '_seek' is the end position of this SelectedRows in the file. if not selected_rows.get_tensor()._is_initialized(): raise ValueError("The saved tensor is not initialized.") - _seek = core._save_selected_rows(selected_rows, file_name) + if _is_file_path(file_name): + # '_seek' is the end position of this SelectedRows in the file. + _seek = core.save_selected_rows(selected_rows, file_name) + + elif _is_memory_buffer(file_name): + selected_rows_bytes = core.save_selected_rows_to_memory(selected_rows) + with _open_file_buffer(file_name, 'wb') as f: + f.write(selected_rows_bytes) + _seek = f.tell() + else: + raise NotImplementedError( + 'Only supports saving objects to file or BytesIO, but received {}'. + format(type(file_name))) return _seek def _load_selected_rows(file_name): temp_sr = core.SelectedRows() - # '_seek' is the end position of this SelectedRows in the file. - _seek = core._load_selected_rows(temp_sr, file_name) + if _is_file_path(file_name): + # '_seek' is the end position of this SelectedRows in the file. + _seek = core.load_selected_rows(temp_sr, file_name) + + elif _is_memory_buffer(file_name): + with _open_file_buffer(file_name, 'rb') as f: + selected_rows_bytes = f.read() + paddle.fluid.core.load_selected_rows_from_memory( + temp_sr, selected_rows_bytes) + _seek = f.tell() + + else: + raise NotImplementedError( + 'Only supports load objects from file or BytesIO, but received {}'. + format(type(file_name))) + return temp_sr, _seek @@ -509,7 +561,7 @@ def save(obj, path, protocol=4, **configs): Args: obj(Object) : The object to be saved. - path(str) : The path of the object to be saved. + path(str|BytesIO) : The path/buffer of the object to be saved. If saved in the current directory, the input path string will be used as the file name. protocol(int, optional): The protocol version of pickle module must be greater than 1 and less than 5. Default: 4 @@ -593,18 +645,39 @@ def save(obj, path, protocol=4, **configs): main_program = paddle.static.default_main_program() path = "example/main_program.pdmodel" paddle.save(main_program, path) - ''' - # 1. input check - filename = os.path.basename(path) - if filename == "": - raise ValueError("The input path MUST be format of dirname/filename " - "[dirname\\filename in Windows system], but received " - "filename is empty string.") - # 2. save object - dirname = os.path.dirname(path) - if dirname and not os.path.exists(dirname): - os.makedirs(dirname) + + # example 5: save object to memory + from io import BytesIO + import paddle + from paddle.nn import Linear + paddle.disable_static() + + linear = Linear(5, 10) + state_dict = linear.state_dict() + byio = BytesIO() + paddle.save(state_dict, byio) + tensor = paddle.randn([2, 3], dtype='float32') + paddle.save(tensor, byio) + + ''' + if _is_file_path(path): + # 1. input check + filename = os.path.basename(path) + if filename == "": + raise ValueError( + "The input path MUST be format of dirname/filename " + "[dirname\\filename in Windows system], but received " + "filename is empty string.") + + # 2. save object + dirname = os.path.dirname(path) + if dirname and not os.path.exists(dirname): + os.makedirs(dirname) + elif not _is_memory_buffer(path): + raise ValueError( + "only supports saving objects to file and `BytesIO`, but got {}". + format(type(path))) config = _parse_save_config(configs) @@ -625,7 +698,7 @@ def save(obj, path, protocol=4, **configs): if isinstance(obj, Program): obj.desc.flush() - with open(path, "wb") as f: + with _open_file_buffer(path, "wb") as f: f.write(obj.desc.serialize_to_string()) elif _is_state_dict(obj): @@ -634,7 +707,7 @@ def save(obj, path, protocol=4, **configs): else: _legacy_static_save(obj, path, protocol) else: - with open(path, 'wb') as f: + with _open_file_buffer(path, 'wb') as f: _pickle_save(obj, f, protocol) @@ -648,12 +721,6 @@ def _legacy_save(obj, path, protocol=2): if len(obj) == 0: warnings.warn("The input state dict is empty, no need to save.") - filename = os.path.basename(path) - if filename == "": - raise ValueError("The input path MUST be format of dirname/filename " - "[dirname\\filename in Windows system], but received " - "filename is empty string.") - if not isinstance(protocol, int): raise ValueError("The 'protocol' MUST be `int`, but received {}".format( type(protocol))) @@ -662,26 +729,33 @@ def _legacy_save(obj, path, protocol=2): raise ValueError("Expected 1<'protocol'<5, but received protocol={}". format(protocol)) - # 2. save object - dirname = os.path.dirname(path) - if dirname and not os.path.exists(dirname): - os.makedirs(dirname) + if _is_file_path(path): + filename = os.path.basename(path) + if filename == "": + raise ValueError( + "The input path MUST be format of dirname/filename " + "[dirname\\filename in Windows system], but received " + "filename is empty string.") + # 2. save object + dirname = os.path.dirname(path) + if dirname and not os.path.exists(dirname): + os.makedirs(dirname) - # TODO(chenweihang): supports save other object if isinstance(obj, dict): saved_obj = _build_saved_state_dict(obj) saved_obj = _unpack_saved_dict(saved_obj, protocol) # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3' - if sys.platform == 'darwin' and sys.version_info.major == 3: + if _is_file_path( + path) and sys.platform == 'darwin' and sys.version_info.major == 3: pickle_bytes = pickle.dumps(saved_obj, protocol=protocol) with open(path, 'wb') as f: max_bytes = 2**30 for i in range(0, len(pickle_bytes), max_bytes): f.write(pickle_bytes[i:i + max_bytes]) else: - with open(path, 'wb') as f: + with _open_file_buffer(path, 'wb') as f: pickle.dump(saved_obj, f, protocol=protocol) @@ -716,7 +790,7 @@ def load(path, **configs): ``Layer.set_state_dict`` later. Args: - path(str) : The path to load the target object. Generally, the path is the target + path(str|BytesIO) : The path/buffer to load the target object. Generally, the path is the target file path. When loading state_dict from the saved result of the API used to save the inference model, the path may be a file prefix or directory. **configs (dict, optional): other load configuration options for compatibility. We do not @@ -822,18 +896,36 @@ def load(path, **configs): print(load_main) + # example 5: save object to memory + from io import BytesIO + import paddle + from paddle.nn import Linear + paddle.disable_static() + + linear = Linear(5, 10) + state_dict = linear.state_dict() + byio = BytesIO() + paddle.save(state_dict, byio) + tensor = paddle.randn([2, 3], dtype='float32') + paddle.save(tensor, byio) + byio.seek(0) + # load state_dict + dict_load = paddle.load(byio) + ''' - if os.path.isfile(path): + if _is_memory_buffer(path) or os.path.isfile(path): config = _parse_load_config(configs) if six.PY2: exception_type = KeyError else: exception_type = pickle.UnpicklingError try: - with open(path, 'rb') as f: + with _open_file_buffer(path, 'rb') as f: # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3' - if sys.platform == 'darwin' and sys.version_info.major == 3: + if _is_file_path( + path + ) and sys.platform == 'darwin' and sys.version_info.major == 3: load_result = _pickle_loads_mac(path, f) else: load_result = pickle.load(f) if six.PY2 else pickle.load( @@ -875,7 +967,7 @@ def load(path, **configs): return tensor except: try: - with open(path, "rb") as f: + with _open_file_buffer(path, "rb") as f: program_desc_str = f.read() program = Program.parse_from_string( program_desc_str) @@ -895,9 +987,9 @@ def _legacy_load(path, **configs): load_result = None config = _parse_load_config(configs) - if os.path.isfile(path): + if os.path.isfile(path) or _is_memory_buffer(path): # we think path is file means this file is created by paddle.save - with open(path, 'rb') as f: + with _open_file_buffer(path, 'rb') as f: load_result = pickle.load(f) if six.PY2 else pickle.load( f, encoding='latin1') load_result = _pack_loaded_dict(load_result) -- GitLab