From e96dae8b18f6f2a4574a5a2f3af3058abe8c9d23 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Sat, 13 Aug 2022 09:28:41 +0800 Subject: [PATCH] Refine program cache (#45005) * add cached_serialize_str_ * support program hash * add sha * add ut * use hash_str only for new_exe * fix attr order --- paddle/fluid/framework/CMakeLists.txt | 5 ++ paddle/fluid/framework/block_desc.cc | 4 +- paddle/fluid/framework/block_desc.h | 2 +- paddle/fluid/framework/io/crypto/sha.h | 44 ++++++++++++++ .../framework/new_executor/interpretercore.cc | 18 +++++- paddle/fluid/framework/op_desc.cc | 9 ++- paddle/fluid/framework/program_desc.cc | 17 ++++++ paddle/fluid/framework/program_desc.h | 4 ++ paddle/fluid/pybind/protobuf.cc | 10 +++- python/paddle/fluid/executor.py | 11 +++- .../fluid/tests/unittests/test_program.py | 59 +++++++++++++++++++ 11 files changed, 174 insertions(+), 9 deletions(-) create mode 100644 paddle/fluid/framework/io/crypto/sha.h diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index d4d5f4903f..ebee8603f4 100755 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -505,6 +505,11 @@ cc_library( SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc process_mesh_desc.cc DEPS attribute shape_inference op_info operator glog version) +if(WITH_CRYPTO) + add_dependencies(proto_desc cryptopp) + target_link_libraries(proto_desc cryptopp) +endif() + cc_library( op_registry SRCS op_registry.cc diff --git a/paddle/fluid/framework/block_desc.cc b/paddle/fluid/framework/block_desc.cc index e8d26f6728..e971ebd396 100644 --- a/paddle/fluid/framework/block_desc.cc +++ b/paddle/fluid/framework/block_desc.cc @@ -197,10 +197,11 @@ void BlockDesc::Flush() { var_names.emplace_back(var.name()); var_names_set.insert(var.name()); } - + VLOG(4) << "vars in desc " << this->desc_->vars().size(); this->desc_->mutable_vars()->Clear(); for (const auto &name : var_names) { if (vars_.count(name)) { + VLOG(4) << "Flush " << name; this->desc_->mutable_vars()->Add()->CopyFrom(*vars_[name]->Proto()); vars_[name]->SetNeedUpdate(false); } @@ -208,6 +209,7 @@ void BlockDesc::Flush() { for (auto &var_desc : vars_) { if (var_names_set.count(var_desc.first) != 1) { + VLOG(4) << "Flush " << var_desc.first; this->desc_->mutable_vars()->Add()->CopyFrom(*var_desc.second->Proto()); var_desc.second->SetNeedUpdate(false); } diff --git a/paddle/fluid/framework/block_desc.h b/paddle/fluid/framework/block_desc.h index bb7227d071..53bece3b24 100644 --- a/paddle/fluid/framework/block_desc.h +++ b/paddle/fluid/framework/block_desc.h @@ -122,7 +122,7 @@ class BlockDesc { // vars_ std::deque> ops_; - std::unordered_map> vars_; + std::map> vars_; DISABLE_COPY_AND_ASSIGN(BlockDesc); }; diff --git a/paddle/fluid/framework/io/crypto/sha.h b/paddle/fluid/framework/io/crypto/sha.h new file mode 100644 index 0000000000..62a98807f7 --- /dev/null +++ b/paddle/fluid/framework/io/crypto/sha.h @@ -0,0 +1,44 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { + +std::string GetSha1(std::string msg) { + std::string digest; + CryptoPP::SHA1 hash; + hash.Update(reinterpret_cast(&msg.at(0)), msg.size()); + digest.resize(hash.DigestSize()); + hash.Final(reinterpret_cast(&digest.at(0))); + return digest; +} + +std::string HexEncoding(std::string bytes) { + std::string encoded; + // Everything newed is destroyed when the StringSource is destroyed + CryptoPP::StringSource ss( + bytes, true, new CryptoPP::HexEncoder(new CryptoPP::StringSink(encoded))); + return encoded; +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 66e8f93736..4cd0a2c9e1 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -114,6 +114,11 @@ interpreter::CostInfo InterpreterCore::DryRun( // until the second step run. async_work_queue_ = GetWorkQueue(); + // lazy initialization of gc, do not create gc is the program only run once + if (!gc_) { + gc_ = CreateInterpreterCoreGarbageCollector(place_, vec_instruction_); + } + ExecuteInstructionList(vec_instruction_); platform::DeviceContextPool::Instance().Get(place_)->Wait(); } @@ -144,6 +149,12 @@ paddle::framework::FetchList InterpreterCore::Run( // create work_queue, so the async_work_queue_ is created // until the second step run. async_work_queue_ = GetWorkQueue(); + + // lazy initialization of gc, do not create gc is the program only run once + if (!gc_) { + gc_ = CreateInterpreterCoreGarbageCollector(place_, vec_instruction_); + } + ExecuteInstructionList(vec_instruction_); #ifdef PADDLE_WITH_ASCEND_CL platform::DeviceContextPool::Instance().Get(place_)->Wait(); @@ -193,6 +204,11 @@ paddle::framework::FetchList InterpreterCore::Run( // until the second step run. async_work_queue_ = GetWorkQueue(); + // lazy initialization of gc, do not create gc is the program only run once + if (!gc_) { + gc_ = CreateInterpreterCoreGarbageCollector(place_, vec_instruction_); + } + ExecuteInstructionList(vec_instruction_); #ifdef PADDLE_WITH_ASCEND_CL platform::DeviceContextPool::Instance().Get(place_)->Wait(); @@ -495,7 +511,7 @@ void InterpreterCore::Convert( } BuildSkipShareLoDInfo(); - gc_ = CreateInterpreterCoreGarbageCollector(place_, vec_instruction_); + bool inplaced = false; for (auto inst : vec_instruction_) { if (inst.OpBase()->Type() == "share_buffer" || diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 204b2c8754..4ae4f88118 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -901,7 +901,14 @@ void OpDesc::Flush() { } this->desc_.mutable_attrs()->Clear(); - for (auto &attr : attrs_) { + std::vector> sorted_attrs{attrs_.begin(), + attrs_.end()}; + std::sort( + sorted_attrs.begin(), + sorted_attrs.end(), + [](std::pair a, + std::pair b) { return a.first < b.first; }); + for (auto &attr : sorted_attrs) { auto *attr_desc = desc_.add_attrs(); attr_desc->set_name(attr.first); attr_desc->set_type( diff --git a/paddle/fluid/framework/program_desc.cc b/paddle/fluid/framework/program_desc.cc index 1788119490..7b4f0a71a5 100644 --- a/paddle/fluid/framework/program_desc.cc +++ b/paddle/fluid/framework/program_desc.cc @@ -17,6 +17,9 @@ limitations under the License. */ #include #include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/version.h" +#ifdef PADDLE_WITH_CRYPTO +#include "paddle/fluid/framework/io/crypto/sha.h" +#endif namespace paddle { namespace framework { @@ -249,6 +252,20 @@ void ProgramDesc::SetFetchHolderName(const std::string &fetch_holder_name) { fetch_holder->SetPersistable(true); } +std::string ProgramDesc::CachedHashString() { + std::string serialize_str; + if (cached_hash_str_.size() == 0 || NeedUpdate()) { + Flush(); + desc_.SerializePartialToString(&serialize_str); +#ifdef PADDLE_WITH_CRYPTO + cached_hash_str_ = HexEncoding(GetSha1(serialize_str)); +#else + cached_hash_str_ = serialize_str; +#endif + } + return cached_hash_str_; +} + bool ProgramDesc::NeedUpdate() const { bool need = false; for (auto &block : blocks_) { diff --git a/paddle/fluid/framework/program_desc.h b/paddle/fluid/framework/program_desc.h index 86d347caf5..e1dbd85f12 100644 --- a/paddle/fluid/framework/program_desc.h +++ b/paddle/fluid/framework/program_desc.h @@ -85,6 +85,8 @@ class ProgramDesc { // This function is used to change or unify the fetch_holder variables' name. void SetFetchHolderName(const std::string &fetch_holder_name); + std::string CachedHashString(); + bool NeedUpdate() const; private: @@ -93,6 +95,8 @@ class ProgramDesc { proto::ProgramDesc desc_; std::vector> blocks_; + + std::string cached_hash_str_; }; } // namespace framework } // namespace paddle diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index cc16b89544..74debcf888 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -102,8 +102,14 @@ void BindProgramDesc(pybind11::module *m) { pybind11::arg("version") = pd::kCurProgramVersion) .def("_version", [](pd::ProgramDesc &self) -> int64_t { return self.Version(); }) - .def("get_op_deps", [](const framework::ProgramDesc &program) { - return framework::ir::GetOpDependencies(program); + .def("get_op_deps", + [](const framework::ProgramDesc &program) { + return framework::ir::GetOpDependencies(program); + }) + .def("need_update", &pd::ProgramDesc::NeedUpdate) + .def("cached_hash_str", [](pd::ProgramDesc &self) { + return self.CachedHashString(); + // return pybind11::bytes(self.CachedHashString()); }); } diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index f4ee554c19..fe7f65835e 100755 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -426,8 +426,13 @@ def _prepare_fleet_executor(): return fleet_exe +def _get_strong_program_cache_key_for_new_exe(program, feed, fetch_list): + return program.desc.cached_hash_str() + _get_program_cache_key( + feed, fetch_list) + + def _get_strong_program_cache_key(program, feed, fetch_list): - # NOTE(xiongkun) id(proram) may be duplicate. So add addition var_name as cache key. + # TODO(zhiqiu): use hash_str to generate cache key as above def _get_varname_from_block(block): block_str = [] for var_name in list(block.vars.keys()): @@ -1455,8 +1460,8 @@ class Executor(object): % (type(feed))) feed = self._update_feed(program, feed) - key = _get_strong_program_cache_key(inner_program, feed, - fetch_list) + key = _get_strong_program_cache_key_for_new_exe( + inner_program, feed, fetch_list) # a little bit tricy here, use inner_program before _add_feed_fetch_ops to get key # while use program to geet _StandaloneExecutor diff --git a/python/paddle/fluid/tests/unittests/test_program.py b/python/paddle/fluid/tests/unittests/test_program.py index cbfecf0816..80290a9961 100644 --- a/python/paddle/fluid/tests/unittests/test_program.py +++ b/python/paddle/fluid/tests/unittests/test_program.py @@ -241,5 +241,64 @@ class TestProgramProto(unittest.TestCase): self.assertTrue(a == b) # not affected +class TestProgramHash(unittest.TestCase): + + def build_program(self): + main_program = paddle.static.Program() + startuo_program = paddle.static.Program() + with paddle.utils.unique_name.guard(): + with paddle.static.program_guard(main_program, startuo_program): + x = paddle.static.data(name='x', shape=[3, 2, 1]) + out = paddle.static.nn.fc(x=x, size=1, num_flatten_dims=2) + return main_program + + def test_program_need_update(self): + program = self.build_program() + self.assertTrue(program.desc.need_update()) + program.desc.flush() + self.assertFalse(program.desc.need_update()) + + def test_program_hash_equal(self): + programs = [] + for i in range(2): + programs.append(self.build_program()) + program1, program2 = programs[0], programs[1] + # why not write as below? + # since the callstack attribute are not equal + #program1 = self.build_program() + #program2 = self.build_program() + + self.assertTrue(program1.desc.need_update()) + self.assertTrue(program2.desc.need_update()) + # two program with same content + self.assertFalse(id(program1) == id(program2)) + # print(program1, program2) + self.assertTrue( + program1.desc.cached_hash_str() == program2.desc.cached_hash_str()) + + self.assertFalse(program1.desc.need_update()) + self.assertFalse(program2.desc.need_update()) + + def test_program_clone(self): + program = self.build_program() + program_clone = program.clone() + + self.assertFalse(id(program) == id(program_clone)) + self.assertTrue(program.desc.cached_hash_str() == + program_clone.desc.cached_hash_str()) + + def test_program_update(self): + program = self.build_program() + hash1 = program.desc.cached_hash_str() + id1 = id(program) + # change mul's attr + program.current_block().ops[0]._set_attr('use_mkldnn', True) + program.current_block().ops[0]._set_attr('scale_x', 2.0) + hash2 = program.desc.cached_hash_str() + id2 = id(program) + self.assertTrue(id1 == id2) + self.assertFalse(hash1 == hash2) + + if __name__ == '__main__': unittest.main() -- GitLab