From 0649918f3aa9527fdb4f49ebaa750db84fc78abd Mon Sep 17 00:00:00 2001 From: ShiningZhang Date: Wed, 5 Jan 2022 14:24:40 +0800 Subject: [PATCH] add request cache --- core/pdcodegen/src/pdcodegen.cpp | 72 +++++-- core/predictor/common/constant.cpp | 5 +- core/predictor/common/constant.h | 1 + core/predictor/common/inner_common.h | 1 + core/predictor/framework/infer.h | 1 + core/predictor/framework/request_cache.cpp | 236 +++++++++++++++++++++ core/predictor/framework/request_cache.h | 98 +++++++++ python/paddle_serving_server/serve.py | 3 + python/paddle_serving_server/server.py | 15 ++ 9 files changed, 412 insertions(+), 20 deletions(-) create mode 100644 core/predictor/framework/request_cache.cpp create mode 100644 core/predictor/framework/request_cache.h diff --git a/core/pdcodegen/src/pdcodegen.cpp b/core/pdcodegen/src/pdcodegen.cpp index 1ad3fe65..be343070 100644 --- a/core/pdcodegen/src/pdcodegen.cpp +++ b/core/pdcodegen/src/pdcodegen.cpp @@ -301,15 +301,33 @@ class PdsCodeGenerator : public CodeGenerator { inference_body += "\"\]\";\n"; inference_body += " LOG(INFO) << \"(logid=\" << log_id << \") "; inference_body += "service_name=\[\" << \"$name$\" << \"\]\";\n"; // NOLINT - inference_body += " int err_code = svr->inference(request, response, log_id);\n"; - inference_body += " if (err_code != 0) {\n"; - inference_body += " LOG(WARNING)\n"; - inference_body += " << \"(logid=\" << log_id << \") Failed call "; - inference_body += "inferservice[$name$], name[$service$]\"\n"; - inference_body += " << \", error_code: \" << err_code;\n"; - inference_body += " cntl->SetFailed(err_code, \"InferService inference "; - inference_body += "failed!\");\n"; - inference_body += " }\n"; + if (service_name == "GeneralModelService") { + inference_body += "uint64_t key = 0;"; + inference_body += "int err_code = 0;"; + inference_body += "if (RequestCache::GetSingleton()->Get(*request, response, &key) != 0) {"; + inference_body += " err_code = svr->inference(request, response, log_id);"; + inference_body += " if (err_code != 0) {"; + inference_body += " LOG(WARNING)"; + inference_body += " << \"(logid=\" << log_id << \") Failed call inferservice[GeneralModelService], name[GeneralModelService]\""; + inference_body += " << \", error_code: \" << err_code;"; + inference_body += " cntl->SetFailed(err_code, \"InferService inference failed!\");"; + inference_body += " } else {"; + inference_body += " RequestCache::GetSingleton()->Put(*request, *response, &key);"; + inference_body += " }"; + inference_body += "} else {"; + inference_body += " LOG(INFO) << \"(logid=\" << log_id << \") Get from cache\";"; + inference_body += "}"; + } else { + inference_body += " int err_code = svr->inference(request, response, log_id);\n"; + inference_body += " if (err_code != 0) {\n"; + inference_body += " LOG(WARNING)\n"; + inference_body += " << \"(logid=\" << log_id << \") Failed call "; + inference_body += "inferservice[$name$], name[$service$]\"\n"; + inference_body += " << \", error_code: \" << err_code;\n"; + inference_body += " cntl->SetFailed(err_code, \"InferService inference "; + inference_body += "failed!\");\n"; + inference_body += " }\n"; + } inference_body += " gettimeofday(&tv, NULL);\n"; inference_body += " long end = tv.tv_sec * 1000000 + tv.tv_usec;\n"; if (service_name == "GeneralModelService") { @@ -1085,15 +1103,33 @@ class PdsCodeGenerator : public CodeGenerator { inference_body += "\"\]\";\n"; inference_body += " LOG(INFO) << \"(logid=\" << log_id << \") "; inference_body += "service_name=\[\" << \"$name$\" << \"\]\";\n"; // NOLINT - inference_body += " int err_code = svr->inference(request, response, log_id);\n"; - inference_body += " if (err_code != 0) {\n"; - inference_body += " LOG(WARNING)\n"; - inference_body += " << \"(logid=\" << log_id << \") Failed call "; - inference_body += "inferservice[$name$], name[$service$]\"\n"; - inference_body += " << \", error_code: \" << err_code;\n"; - inference_body += " cntl->SetFailed(err_code, \"InferService inference "; - inference_body += "failed!\");\n"; - inference_body += " }\n"; + if (service_name == "GeneralModelService") { + inference_body += "uint64_t key = 0;"; + inference_body += "int err_code = 0;"; + inference_body += "if (RequestCache::GetSingleton()->Get(*request, response, &key) != 0) {"; + inference_body += " err_code = svr->inference(request, response, log_id);"; + inference_body += " if (err_code != 0) {"; + inference_body += " LOG(WARNING)"; + inference_body += " << \"(logid=\" << log_id << \") Failed call inferservice[GeneralModelService], name[GeneralModelService]\""; + inference_body += " << \", error_code: \" << err_code;"; + inference_body += " cntl->SetFailed(err_code, \"InferService inference failed!\");"; + inference_body += " } else {"; + inference_body += " RequestCache::GetSingleton()->Put(*request, *response, &key);"; + inference_body += " }"; + inference_body += "} else {"; + inference_body += " LOG(INFO) << \"(logid=\" << log_id << \") Get from cache\";"; + inference_body += "}"; + } else { + inference_body += " int err_code = svr->inference(request, response, log_id);\n"; + inference_body += " if (err_code != 0) {\n"; + inference_body += " LOG(WARNING)\n"; + inference_body += " << \"(logid=\" << log_id << \") Failed call "; + inference_body += "inferservice[$name$], name[$service$]\"\n"; + inference_body += " << \", error_code: \" << err_code;\n"; + inference_body += " cntl->SetFailed(err_code, \"InferService inference "; + inference_body += "failed!\");\n"; + inference_body += " }\n"; + } inference_body += " gettimeofday(&tv, NULL);\n"; inference_body += " long end = tv.tv_sec * 1000000 + tv.tv_usec;\n"; if (service_name == "GeneralModelService") { diff --git a/core/predictor/common/constant.cpp b/core/predictor/common/constant.cpp index 8e7044a9..b0acb886 100644 --- a/core/predictor/common/constant.cpp +++ b/core/predictor/common/constant.cpp @@ -44,8 +44,9 @@ DEFINE_bool(enable_cube, false, "enable cube"); DEFINE_string(general_model_path, "./conf", ""); DEFINE_string(general_model_file, "general_model.prototxt", ""); DEFINE_bool(enable_general_model, true, "enable general model"); -DEFINE_bool(enable_prometheus, true, "enable prometheus"); -DEFINE_int32(prometheus_port, 18010, ""); +DEFINE_bool(enable_prometheus, false, "enable prometheus"); +DEFINE_int32(prometheus_port, 19393, ""); +DEFINE_int64(request_cache_size, 0, "request cache size"); const char* START_OP_NAME = "startup_op"; } // namespace predictor diff --git a/core/predictor/common/constant.h b/core/predictor/common/constant.h index b74f6955..e0727ce4 100644 --- a/core/predictor/common/constant.h +++ b/core/predictor/common/constant.h @@ -45,6 +45,7 @@ DECLARE_bool(enable_cube); DECLARE_bool(enable_general_model); DECLARE_bool(enable_prometheus); DECLARE_int32(prometheus_port); +DECLARE_int64(request_cache_size); // STATIC Variables extern const char* START_OP_NAME; diff --git a/core/predictor/common/inner_common.h b/core/predictor/common/inner_common.h index 703f14a5..9a7627ae 100644 --- a/core/predictor/common/inner_common.h +++ b/core/predictor/common/inner_common.h @@ -61,6 +61,7 @@ #include "core/predictor/common/utils.h" #include "core/predictor/framework/prometheus_metric.h" +#include "core/predictor/framework/request_cache.h" #ifdef BCLOUD namespace brpc = baidu::rpc; diff --git a/core/predictor/framework/infer.h b/core/predictor/framework/infer.h index 00518145..5c5ef873 100644 --- a/core/predictor/framework/infer.h +++ b/core/predictor/framework/infer.h @@ -236,6 +236,7 @@ class DBReloadableInferEngine : public ReloadableInferEngine { } LOG(WARNING) << "Succ load engine, path: " << conf.model_dir(); + RequestCache::GetSingleton()->Clear(); return 0; } diff --git a/core/predictor/framework/request_cache.cpp b/core/predictor/framework/request_cache.cpp new file mode 100644 index 00000000..8ac9b7e4 --- /dev/null +++ b/core/predictor/framework/request_cache.cpp @@ -0,0 +1,236 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +#include "core/predictor/framework/request_cache.h" + +#include "core/predictor/common/inner_common.h" + +#include "core/sdk-cpp/general_model_service.pb.h" + +namespace baidu { +namespace paddle_serving { +namespace predictor { + +using baidu::paddle_serving::predictor::general_model::Request; +using baidu::paddle_serving::predictor::general_model::Response; + +RequestCache::RequestCache(const int64_t size) + : cache_size_(size), used_size_(0) { + bstop_ = false; + thread_ptr_ = std::unique_ptr( + new std::thread([this]() { this->ThreadLoop(); })); +} + +RequestCache::~RequestCache() { + bstop_ = true; + condition_.notify_all(); + thread_ptr_->join(); +} + +RequestCache* RequestCache::GetSingleton() { + static RequestCache cache(FLAGS_request_cache_size); + return &cache; +} + +int RequestCache::Hash(const Request& req, uint64_t* key) { + uint64_t log_id = req.log_id(); + bool profile_server = req.profile_server(); + Request* r = const_cast(&req); + r->clear_log_id(); + r->clear_profile_server(); + std::string buf = req.SerializeAsString(); + *key = std::hash{}(buf); + r->set_log_id(log_id); + r->set_profile_server(profile_server); + return 0; +} + +int RequestCache::Get(const Request& req, Response* res, uint64_t* key) { + if (!Enabled()) { + return -1; + } + uint64_t local_key = 0; + Hash(req, &local_key); + if (key != nullptr) { + *key = local_key; + } + std::lock_guard lk(cache_mtx_); + auto iter = map_.find(local_key); + if (iter == map_.end()) { + LOG(INFO) << "key not found in cache"; + return -1; + } + auto entry = iter->second; + BuildResponse(entry, res); + UpdateLru(local_key); + + return 0; +} + +int RequestCache::Put(const Request& req, const Response& res, uint64_t* key) { + if (!Enabled()) { + return -1; + } + uint64_t local_key = 0; + if (key != nullptr && *key != 0) { + local_key = *key; + } else { + Hash(req, &local_key); + } + if (key != nullptr) { + *key = local_key; + } + + AddTask(local_key, res); + return 0; +} + +int RequestCache::PutImpl(const Response& res, uint64_t key) { + std::lock_guard lk(cache_mtx_); + auto iter = map_.find(key); + if (iter != map_.end()) { + LOG(WARNING) << "key[" << key << "] already exists in cache"; + return -1; + } + + CacheEntry entry; + if (BuildCacheEntry(res, &entry) != 0) { + LOG(WARNING) << "key[" << key << "] build cache entry failed"; + return -1; + } + map_.insert({key, entry}); + UpdateLru(key); + + return 0; +} + +int RequestCache::BuildResponse(const CacheEntry& entry, + predictor::general_model::Response* res) { + if (res == nullptr) { + return -1; + } + res->ParseFromString(entry.buf_); + res->clear_profile_time(); + return 0; +} + +int RequestCache::BuildCacheEntry(const Response& res, CacheEntry* entry) { + if (entry == nullptr) { + return -1; + } + std::lock_guard lk(cache_mtx_); + int size = res.ByteSize(); + if (size >= cache_size_) { + LOG(INFO) << "res size[" << size << "] larger than cache_size[" + << cache_size_ << "]"; + return -1; + } + while (size > GetFreeCacheSize()) { + if (RemoveOne() != 0) { + LOG(ERROR) << "RemoveOne failed so can not build entry"; + return -1; + } + } + entry->buf_ = res.SerializeAsString(); + used_size_ += size; + return 0; +} + +void RequestCache::UpdateLru(uint64_t key) { + std::lock_guard lk(cache_mtx_); + auto lru_iter = std::find(lru_.begin(), lru_.end(), key); + if (lru_iter != lru_.end()) { + lru_.erase(lru_iter); + } + lru_.push_front(key); +} + +bool RequestCache::Enabled() { return cache_size_ > 0; } + +int64_t RequestCache::GetFreeCacheSize() { return cache_size_ - used_size_; } + +int RequestCache::RemoveOne() { + std::lock_guard lk(cache_mtx_); + uint64_t lru_key = lru_.back(); + VLOG(1) << "Remove key[" << lru_key << "] from cache"; + auto iter = map_.find(lru_key); + if (iter == map_.end()) { + LOG(ERROR) << "Remove key[" << lru_key << "] not find in cache"; + return -1; + } + auto entry = iter->second; + used_size_ -= entry.buf_.size(); + map_.erase(iter); + lru_.pop_back(); + + return 0; +} + +void RequestCache::ThreadLoop() { + std::queue>> exec_task_queue; + for (;;) { + { + std::unique_lock lock(queue_mutex_); + condition_.wait( + lock, [this]() { return this->bstop_ || this->task_queue_.size(); }); + + if (!task_queue_.size()) { + if (bstop_) { + return; + } + continue; + } + swap(exec_task_queue, task_queue_); + } + while (!exec_task_queue.empty()) { + auto [key, res_ptr] = exec_task_queue.front(); + exec_task_queue.pop(); + PutImpl(*res_ptr, key); + } + } +} + +int RequestCache::AddTask(uint64_t key, const Response& res) { + std::unique_lock lock(queue_mutex_); + std::shared_ptr res_ptr = std::make_shared(res); + task_queue_.push(std::make_pair(key, res_ptr)); + condition_.notify_one(); + return 0; +} + +bool RequestCache::Empty() { + std::lock_guard lk(cache_mtx_); + return lru_.empty(); +} + +int RequestCache::Clear() { + { + std::unique_lock lock(queue_mutex_); + std::queue>> empty; + swap(empty, task_queue_); + } + int count = 0; + { + std::lock_guard lk(cache_mtx_); + count = lru_.size(); + lru_.clear(); + map_.clear(); + } + LOG(INFO) << "Clear " << count << " key!"; + return 0; +} + +} // namespace predictor +} // namespace paddle_serving +} // namespace baidu \ No newline at end of file diff --git a/core/predictor/framework/request_cache.h b/core/predictor/framework/request_cache.h new file mode 100644 index 00000000..014775ec --- /dev/null +++ b/core/predictor/framework/request_cache.h @@ -0,0 +1,98 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace baidu { +namespace paddle_serving { +namespace predictor { + +namespace general_model { +class Request; +class Response; +} // namespace general_model + +struct CacheEntry { + explicit CacheEntry() {} + std::string buf_; +}; + +class RequestCache { + public: + explicit RequestCache(const int64_t size); + ~RequestCache(); + + static RequestCache* GetSingleton(); + + int Hash(const predictor::general_model::Request& req, uint64_t* key); + + int Get(const predictor::general_model::Request& req, + predictor::general_model::Response* res, + uint64_t* key = nullptr); + + int Put(const predictor::general_model::Request& req, + const predictor::general_model::Response& res, + uint64_t* key = nullptr); + + void ThreadLoop(); + + bool Empty(); + + int Clear(); + + private: + int BuildResponse(const CacheEntry& entry, + predictor::general_model::Response* res); + + int BuildCacheEntry(const predictor::general_model::Response& res, + CacheEntry* entry); + + void UpdateLru(uint64_t key); + + bool Enabled(); + + int64_t GetFreeCacheSize(); + + int RemoveOne(); + + int AddTask(uint64_t key, const predictor::general_model::Response& res); + + int PutImpl(const predictor::general_model::Response& res, uint64_t key); + + uint64_t cache_size_; + uint64_t used_size_; + std::unordered_map map_; + std::list lru_; + std::recursive_mutex cache_mtx_; + std::atomic bstop_{false}; + std::condition_variable condition_; + std::mutex queue_mutex_; + std::queue< + std::pair>> + task_queue_; + std::unique_ptr thread_ptr_; +}; + +} // namespace predictor +} // namespace paddle_serving +} // namespace baidu \ No newline at end of file diff --git a/python/paddle_serving_server/serve.py b/python/paddle_serving_server/serve.py index 6e8cb283..1340933d 100755 --- a/python/paddle_serving_server/serve.py +++ b/python/paddle_serving_server/serve.py @@ -208,6 +208,8 @@ def serve_args(): "--enable_prometheus", default=False, action="store_true", help="Use Prometheus") parser.add_argument( "--prometheus_port", type=int, default=19393, help="Port of the Prometheus") + parser.add_argument( + "--request_cache_size", type=int, default=0, help="Port of the Prometheus") return parser.parse_args() @@ -291,6 +293,7 @@ def start_gpu_card_model(gpu_mode, port, args): # pylint: disable=doc-string-mi server.set_max_body_size(max_body_size) server.set_enable_prometheus(args.enable_prometheus) server.set_prometheus_port(args.prometheus_port) + server.set_request_cache_size(args.request_cache_size) if args.use_trt and device == "gpu": server.set_trt() diff --git a/python/paddle_serving_server/server.py b/python/paddle_serving_server/server.py index f1d0b631..e369c57d 100755 --- a/python/paddle_serving_server/server.py +++ b/python/paddle_serving_server/server.py @@ -100,6 +100,7 @@ class Server(object): ] self.enable_prometheus = False self.prometheus_port = 19393 + self.request_cache_size = 0 def get_fetch_list(self, infer_node_idx=-1): fetch_names = [ @@ -207,6 +208,9 @@ class Server(object): def set_prometheus_port(self, prometheus_port): self.prometheus_port = prometheus_port + def set_request_cache_size(self, request_cache_size): + self.request_cache_size = request_cache_size + def _prepare_engine(self, model_config_paths, device, use_encryption_model): self.device = device if self.model_toolkit_conf == None: @@ -615,6 +619,17 @@ class Server(object): self.max_body_size, self.enable_prometheus, self.prometheus_port) + if self.enable_prometheus: + command = command + \ + "-enable_prometheus={} " \ + "-prometheus_port {} ".format( + self.enable_prometheus, + self.prometheus_port) + if self.request_cache_size > 0: + command = command + \ + "-request_cache_size {} ".format( + self.request_cache_size + ) print("Going to Run Comand") print(command) -- GitLab