未验证 提交 e96dae8b 编写于 作者: L Leo Chen 提交者: GitHub

Refine program cache (#45005)

* add cached_serialize_str_

* support program hash

* add sha

* add ut

* use hash_str only for new_exe

* fix attr order
上级 3f5c405f
...@@ -505,6 +505,11 @@ cc_library( ...@@ -505,6 +505,11 @@ cc_library(
SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc process_mesh_desc.cc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc process_mesh_desc.cc
DEPS attribute shape_inference op_info operator glog version) DEPS attribute shape_inference op_info operator glog version)
if(WITH_CRYPTO)
add_dependencies(proto_desc cryptopp)
target_link_libraries(proto_desc cryptopp)
endif()
cc_library( cc_library(
op_registry op_registry
SRCS op_registry.cc SRCS op_registry.cc
......
...@@ -197,10 +197,11 @@ void BlockDesc::Flush() { ...@@ -197,10 +197,11 @@ void BlockDesc::Flush() {
var_names.emplace_back(var.name()); var_names.emplace_back(var.name());
var_names_set.insert(var.name()); var_names_set.insert(var.name());
} }
VLOG(4) << "vars in desc " << this->desc_->vars().size();
this->desc_->mutable_vars()->Clear(); this->desc_->mutable_vars()->Clear();
for (const auto &name : var_names) { for (const auto &name : var_names) {
if (vars_.count(name)) { if (vars_.count(name)) {
VLOG(4) << "Flush " << name;
this->desc_->mutable_vars()->Add()->CopyFrom(*vars_[name]->Proto()); this->desc_->mutable_vars()->Add()->CopyFrom(*vars_[name]->Proto());
vars_[name]->SetNeedUpdate(false); vars_[name]->SetNeedUpdate(false);
} }
...@@ -208,6 +209,7 @@ void BlockDesc::Flush() { ...@@ -208,6 +209,7 @@ void BlockDesc::Flush() {
for (auto &var_desc : vars_) { for (auto &var_desc : vars_) {
if (var_names_set.count(var_desc.first) != 1) { if (var_names_set.count(var_desc.first) != 1) {
VLOG(4) << "Flush " << var_desc.first;
this->desc_->mutable_vars()->Add()->CopyFrom(*var_desc.second->Proto()); this->desc_->mutable_vars()->Add()->CopyFrom(*var_desc.second->Proto());
var_desc.second->SetNeedUpdate(false); var_desc.second->SetNeedUpdate(false);
} }
......
...@@ -122,7 +122,7 @@ class BlockDesc { ...@@ -122,7 +122,7 @@ class BlockDesc {
// vars_ // vars_
std::deque<std::unique_ptr<OpDesc>> ops_; std::deque<std::unique_ptr<OpDesc>> ops_;
std::unordered_map<std::string, std::unique_ptr<VarDesc>> vars_; std::map<std::string, std::unique_ptr<VarDesc>> vars_;
DISABLE_COPY_AND_ASSIGN(BlockDesc); DISABLE_COPY_AND_ASSIGN(BlockDesc);
}; };
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cryptopp/cryptlib.h>
#include <cryptopp/filters.h>
#include <cryptopp/hex.h>
#include <cryptopp/sha.h>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
std::string GetSha1(std::string msg) {
std::string digest;
CryptoPP::SHA1 hash;
hash.Update(reinterpret_cast<unsigned char*>(&msg.at(0)), msg.size());
digest.resize(hash.DigestSize());
hash.Final(reinterpret_cast<unsigned char*>(&digest.at(0)));
return digest;
}
std::string HexEncoding(std::string bytes) {
std::string encoded;
// Everything newed is destroyed when the StringSource is destroyed
CryptoPP::StringSource ss(
bytes, true, new CryptoPP::HexEncoder(new CryptoPP::StringSink(encoded)));
return encoded;
}
} // namespace framework
} // namespace paddle
...@@ -114,6 +114,11 @@ interpreter::CostInfo InterpreterCore::DryRun( ...@@ -114,6 +114,11 @@ interpreter::CostInfo InterpreterCore::DryRun(
// until the second step run. // until the second step run.
async_work_queue_ = GetWorkQueue(); async_work_queue_ = GetWorkQueue();
// lazy initialization of gc, do not create gc is the program only run once
if (!gc_) {
gc_ = CreateInterpreterCoreGarbageCollector(place_, vec_instruction_);
}
ExecuteInstructionList(vec_instruction_); ExecuteInstructionList(vec_instruction_);
platform::DeviceContextPool::Instance().Get(place_)->Wait(); platform::DeviceContextPool::Instance().Get(place_)->Wait();
} }
...@@ -144,6 +149,12 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -144,6 +149,12 @@ paddle::framework::FetchList InterpreterCore::Run(
// create work_queue, so the async_work_queue_ is created // create work_queue, so the async_work_queue_ is created
// until the second step run. // until the second step run.
async_work_queue_ = GetWorkQueue(); async_work_queue_ = GetWorkQueue();
// lazy initialization of gc, do not create gc is the program only run once
if (!gc_) {
gc_ = CreateInterpreterCoreGarbageCollector(place_, vec_instruction_);
}
ExecuteInstructionList(vec_instruction_); ExecuteInstructionList(vec_instruction_);
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
platform::DeviceContextPool::Instance().Get(place_)->Wait(); platform::DeviceContextPool::Instance().Get(place_)->Wait();
...@@ -193,6 +204,11 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -193,6 +204,11 @@ paddle::framework::FetchList InterpreterCore::Run(
// until the second step run. // until the second step run.
async_work_queue_ = GetWorkQueue(); async_work_queue_ = GetWorkQueue();
// lazy initialization of gc, do not create gc is the program only run once
if (!gc_) {
gc_ = CreateInterpreterCoreGarbageCollector(place_, vec_instruction_);
}
ExecuteInstructionList(vec_instruction_); ExecuteInstructionList(vec_instruction_);
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
platform::DeviceContextPool::Instance().Get(place_)->Wait(); platform::DeviceContextPool::Instance().Get(place_)->Wait();
...@@ -495,7 +511,7 @@ void InterpreterCore::Convert( ...@@ -495,7 +511,7 @@ void InterpreterCore::Convert(
} }
BuildSkipShareLoDInfo(); BuildSkipShareLoDInfo();
gc_ = CreateInterpreterCoreGarbageCollector(place_, vec_instruction_);
bool inplaced = false; bool inplaced = false;
for (auto inst : vec_instruction_) { for (auto inst : vec_instruction_) {
if (inst.OpBase()->Type() == "share_buffer" || if (inst.OpBase()->Type() == "share_buffer" ||
......
...@@ -901,7 +901,14 @@ void OpDesc::Flush() { ...@@ -901,7 +901,14 @@ void OpDesc::Flush() {
} }
this->desc_.mutable_attrs()->Clear(); this->desc_.mutable_attrs()->Clear();
for (auto &attr : attrs_) { std::vector<std::pair<std::string, Attribute>> sorted_attrs{attrs_.begin(),
attrs_.end()};
std::sort(
sorted_attrs.begin(),
sorted_attrs.end(),
[](std::pair<std::string, Attribute> a,
std::pair<std::string, Attribute> b) { return a.first < b.first; });
for (auto &attr : sorted_attrs) {
auto *attr_desc = desc_.add_attrs(); auto *attr_desc = desc_.add_attrs();
attr_desc->set_name(attr.first); attr_desc->set_name(attr.first);
attr_desc->set_type( attr_desc->set_type(
......
...@@ -17,6 +17,9 @@ limitations under the License. */ ...@@ -17,6 +17,9 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/version.h" #include "paddle/fluid/framework/version.h"
#ifdef PADDLE_WITH_CRYPTO
#include "paddle/fluid/framework/io/crypto/sha.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -249,6 +252,20 @@ void ProgramDesc::SetFetchHolderName(const std::string &fetch_holder_name) { ...@@ -249,6 +252,20 @@ void ProgramDesc::SetFetchHolderName(const std::string &fetch_holder_name) {
fetch_holder->SetPersistable(true); fetch_holder->SetPersistable(true);
} }
std::string ProgramDesc::CachedHashString() {
std::string serialize_str;
if (cached_hash_str_.size() == 0 || NeedUpdate()) {
Flush();
desc_.SerializePartialToString(&serialize_str);
#ifdef PADDLE_WITH_CRYPTO
cached_hash_str_ = HexEncoding(GetSha1(serialize_str));
#else
cached_hash_str_ = serialize_str;
#endif
}
return cached_hash_str_;
}
bool ProgramDesc::NeedUpdate() const { bool ProgramDesc::NeedUpdate() const {
bool need = false; bool need = false;
for (auto &block : blocks_) { for (auto &block : blocks_) {
......
...@@ -85,6 +85,8 @@ class ProgramDesc { ...@@ -85,6 +85,8 @@ class ProgramDesc {
// This function is used to change or unify the fetch_holder variables' name. // This function is used to change or unify the fetch_holder variables' name.
void SetFetchHolderName(const std::string &fetch_holder_name); void SetFetchHolderName(const std::string &fetch_holder_name);
std::string CachedHashString();
bool NeedUpdate() const; bool NeedUpdate() const;
private: private:
...@@ -93,6 +95,8 @@ class ProgramDesc { ...@@ -93,6 +95,8 @@ class ProgramDesc {
proto::ProgramDesc desc_; proto::ProgramDesc desc_;
std::vector<std::unique_ptr<BlockDesc>> blocks_; std::vector<std::unique_ptr<BlockDesc>> blocks_;
std::string cached_hash_str_;
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -102,8 +102,14 @@ void BindProgramDesc(pybind11::module *m) { ...@@ -102,8 +102,14 @@ void BindProgramDesc(pybind11::module *m) {
pybind11::arg("version") = pd::kCurProgramVersion) pybind11::arg("version") = pd::kCurProgramVersion)
.def("_version", .def("_version",
[](pd::ProgramDesc &self) -> int64_t { return self.Version(); }) [](pd::ProgramDesc &self) -> int64_t { return self.Version(); })
.def("get_op_deps", [](const framework::ProgramDesc &program) { .def("get_op_deps",
return framework::ir::GetOpDependencies(program); [](const framework::ProgramDesc &program) {
return framework::ir::GetOpDependencies(program);
})
.def("need_update", &pd::ProgramDesc::NeedUpdate)
.def("cached_hash_str", [](pd::ProgramDesc &self) {
return self.CachedHashString();
// return pybind11::bytes(self.CachedHashString());
}); });
} }
......
...@@ -426,8 +426,13 @@ def _prepare_fleet_executor(): ...@@ -426,8 +426,13 @@ def _prepare_fleet_executor():
return fleet_exe return fleet_exe
def _get_strong_program_cache_key_for_new_exe(program, feed, fetch_list):
return program.desc.cached_hash_str() + _get_program_cache_key(
feed, fetch_list)
def _get_strong_program_cache_key(program, feed, fetch_list): def _get_strong_program_cache_key(program, feed, fetch_list):
# NOTE(xiongkun) id(proram) may be duplicate. So add addition var_name as cache key. # TODO(zhiqiu): use hash_str to generate cache key as above
def _get_varname_from_block(block): def _get_varname_from_block(block):
block_str = [] block_str = []
for var_name in list(block.vars.keys()): for var_name in list(block.vars.keys()):
...@@ -1455,8 +1460,8 @@ class Executor(object): ...@@ -1455,8 +1460,8 @@ class Executor(object):
% (type(feed))) % (type(feed)))
feed = self._update_feed(program, feed) feed = self._update_feed(program, feed)
key = _get_strong_program_cache_key(inner_program, feed, key = _get_strong_program_cache_key_for_new_exe(
fetch_list) inner_program, feed, fetch_list)
# a little bit tricy here, use inner_program before _add_feed_fetch_ops to get key # a little bit tricy here, use inner_program before _add_feed_fetch_ops to get key
# while use program to geet _StandaloneExecutor # while use program to geet _StandaloneExecutor
......
...@@ -241,5 +241,64 @@ class TestProgramProto(unittest.TestCase): ...@@ -241,5 +241,64 @@ class TestProgramProto(unittest.TestCase):
self.assertTrue(a == b) # not affected self.assertTrue(a == b) # not affected
class TestProgramHash(unittest.TestCase):
def build_program(self):
main_program = paddle.static.Program()
startuo_program = paddle.static.Program()
with paddle.utils.unique_name.guard():
with paddle.static.program_guard(main_program, startuo_program):
x = paddle.static.data(name='x', shape=[3, 2, 1])
out = paddle.static.nn.fc(x=x, size=1, num_flatten_dims=2)
return main_program
def test_program_need_update(self):
program = self.build_program()
self.assertTrue(program.desc.need_update())
program.desc.flush()
self.assertFalse(program.desc.need_update())
def test_program_hash_equal(self):
programs = []
for i in range(2):
programs.append(self.build_program())
program1, program2 = programs[0], programs[1]
# why not write as below?
# since the callstack attribute are not equal
#program1 = self.build_program()
#program2 = self.build_program()
self.assertTrue(program1.desc.need_update())
self.assertTrue(program2.desc.need_update())
# two program with same content
self.assertFalse(id(program1) == id(program2))
# print(program1, program2)
self.assertTrue(
program1.desc.cached_hash_str() == program2.desc.cached_hash_str())
self.assertFalse(program1.desc.need_update())
self.assertFalse(program2.desc.need_update())
def test_program_clone(self):
program = self.build_program()
program_clone = program.clone()
self.assertFalse(id(program) == id(program_clone))
self.assertTrue(program.desc.cached_hash_str() ==
program_clone.desc.cached_hash_str())
def test_program_update(self):
program = self.build_program()
hash1 = program.desc.cached_hash_str()
id1 = id(program)
# change mul's attr
program.current_block().ops[0]._set_attr('use_mkldnn', True)
program.current_block().ops[0]._set_attr('scale_x', 2.0)
hash2 = program.desc.cached_hash_str()
id2 = id(program)
self.assertTrue(id1 == id2)
self.assertFalse(hash1 == hash2)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册