提交 0649918f 编写于 作者: S ShiningZhang

add request cache

上级 b3daf1ba
......@@ -301,15 +301,33 @@ class PdsCodeGenerator : public CodeGenerator {
inference_body += "\"\]\";\n";
inference_body += " LOG(INFO) << \"(logid=\" << log_id << \") ";
inference_body += "service_name=\[\" << \"$name$\" << \"\]\";\n"; // NOLINT
inference_body += " int err_code = svr->inference(request, response, log_id);\n";
inference_body += " if (err_code != 0) {\n";
inference_body += " LOG(WARNING)\n";
inference_body += " << \"(logid=\" << log_id << \") Failed call ";
inference_body += "inferservice[$name$], name[$service$]\"\n";
inference_body += " << \", error_code: \" << err_code;\n";
inference_body += " cntl->SetFailed(err_code, \"InferService inference ";
inference_body += "failed!\");\n";
inference_body += " }\n";
if (service_name == "GeneralModelService") {
inference_body += "uint64_t key = 0;";
inference_body += "int err_code = 0;";
inference_body += "if (RequestCache::GetSingleton()->Get(*request, response, &key) != 0) {";
inference_body += " err_code = svr->inference(request, response, log_id);";
inference_body += " if (err_code != 0) {";
inference_body += " LOG(WARNING)";
inference_body += " << \"(logid=\" << log_id << \") Failed call inferservice[GeneralModelService], name[GeneralModelService]\"";
inference_body += " << \", error_code: \" << err_code;";
inference_body += " cntl->SetFailed(err_code, \"InferService inference failed!\");";
inference_body += " } else {";
inference_body += " RequestCache::GetSingleton()->Put(*request, *response, &key);";
inference_body += " }";
inference_body += "} else {";
inference_body += " LOG(INFO) << \"(logid=\" << log_id << \") Get from cache\";";
inference_body += "}";
} else {
inference_body += " int err_code = svr->inference(request, response, log_id);\n";
inference_body += " if (err_code != 0) {\n";
inference_body += " LOG(WARNING)\n";
inference_body += " << \"(logid=\" << log_id << \") Failed call ";
inference_body += "inferservice[$name$], name[$service$]\"\n";
inference_body += " << \", error_code: \" << err_code;\n";
inference_body += " cntl->SetFailed(err_code, \"InferService inference ";
inference_body += "failed!\");\n";
inference_body += " }\n";
}
inference_body += " gettimeofday(&tv, NULL);\n";
inference_body += " long end = tv.tv_sec * 1000000 + tv.tv_usec;\n";
if (service_name == "GeneralModelService") {
......@@ -1085,15 +1103,33 @@ class PdsCodeGenerator : public CodeGenerator {
inference_body += "\"\]\";\n";
inference_body += " LOG(INFO) << \"(logid=\" << log_id << \") ";
inference_body += "service_name=\[\" << \"$name$\" << \"\]\";\n"; // NOLINT
inference_body += " int err_code = svr->inference(request, response, log_id);\n";
inference_body += " if (err_code != 0) {\n";
inference_body += " LOG(WARNING)\n";
inference_body += " << \"(logid=\" << log_id << \") Failed call ";
inference_body += "inferservice[$name$], name[$service$]\"\n";
inference_body += " << \", error_code: \" << err_code;\n";
inference_body += " cntl->SetFailed(err_code, \"InferService inference ";
inference_body += "failed!\");\n";
inference_body += " }\n";
if (service_name == "GeneralModelService") {
inference_body += "uint64_t key = 0;";
inference_body += "int err_code = 0;";
inference_body += "if (RequestCache::GetSingleton()->Get(*request, response, &key) != 0) {";
inference_body += " err_code = svr->inference(request, response, log_id);";
inference_body += " if (err_code != 0) {";
inference_body += " LOG(WARNING)";
inference_body += " << \"(logid=\" << log_id << \") Failed call inferservice[GeneralModelService], name[GeneralModelService]\"";
inference_body += " << \", error_code: \" << err_code;";
inference_body += " cntl->SetFailed(err_code, \"InferService inference failed!\");";
inference_body += " } else {";
inference_body += " RequestCache::GetSingleton()->Put(*request, *response, &key);";
inference_body += " }";
inference_body += "} else {";
inference_body += " LOG(INFO) << \"(logid=\" << log_id << \") Get from cache\";";
inference_body += "}";
} else {
inference_body += " int err_code = svr->inference(request, response, log_id);\n";
inference_body += " if (err_code != 0) {\n";
inference_body += " LOG(WARNING)\n";
inference_body += " << \"(logid=\" << log_id << \") Failed call ";
inference_body += "inferservice[$name$], name[$service$]\"\n";
inference_body += " << \", error_code: \" << err_code;\n";
inference_body += " cntl->SetFailed(err_code, \"InferService inference ";
inference_body += "failed!\");\n";
inference_body += " }\n";
}
inference_body += " gettimeofday(&tv, NULL);\n";
inference_body += " long end = tv.tv_sec * 1000000 + tv.tv_usec;\n";
if (service_name == "GeneralModelService") {
......
......@@ -44,8 +44,9 @@ DEFINE_bool(enable_cube, false, "enable cube");
DEFINE_string(general_model_path, "./conf", "");
DEFINE_string(general_model_file, "general_model.prototxt", "");
DEFINE_bool(enable_general_model, true, "enable general model");
DEFINE_bool(enable_prometheus, true, "enable prometheus");
DEFINE_int32(prometheus_port, 18010, "");
DEFINE_bool(enable_prometheus, false, "enable prometheus");
DEFINE_int32(prometheus_port, 19393, "");
DEFINE_int64(request_cache_size, 0, "request cache size");
const char* START_OP_NAME = "startup_op";
} // namespace predictor
......
......@@ -45,6 +45,7 @@ DECLARE_bool(enable_cube);
DECLARE_bool(enable_general_model);
DECLARE_bool(enable_prometheus);
DECLARE_int32(prometheus_port);
DECLARE_int64(request_cache_size);
// STATIC Variables
extern const char* START_OP_NAME;
......
......@@ -61,6 +61,7 @@
#include "core/predictor/common/utils.h"
#include "core/predictor/framework/prometheus_metric.h"
#include "core/predictor/framework/request_cache.h"
#ifdef BCLOUD
namespace brpc = baidu::rpc;
......
......@@ -236,6 +236,7 @@ class DBReloadableInferEngine : public ReloadableInferEngine {
}
LOG(WARNING) << "Succ load engine, path: " << conf.model_dir();
RequestCache::GetSingleton()->Clear();
return 0;
}
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "core/predictor/framework/request_cache.h"
#include "core/predictor/common/inner_common.h"
#include "core/sdk-cpp/general_model_service.pb.h"
namespace baidu {
namespace paddle_serving {
namespace predictor {
using baidu::paddle_serving::predictor::general_model::Request;
using baidu::paddle_serving::predictor::general_model::Response;
RequestCache::RequestCache(const int64_t size)
: cache_size_(size), used_size_(0) {
bstop_ = false;
thread_ptr_ = std::unique_ptr<std::thread>(
new std::thread([this]() { this->ThreadLoop(); }));
}
RequestCache::~RequestCache() {
bstop_ = true;
condition_.notify_all();
thread_ptr_->join();
}
RequestCache* RequestCache::GetSingleton() {
static RequestCache cache(FLAGS_request_cache_size);
return &cache;
}
int RequestCache::Hash(const Request& req, uint64_t* key) {
uint64_t log_id = req.log_id();
bool profile_server = req.profile_server();
Request* r = const_cast<Request*>(&req);
r->clear_log_id();
r->clear_profile_server();
std::string buf = req.SerializeAsString();
*key = std::hash<std::string>{}(buf);
r->set_log_id(log_id);
r->set_profile_server(profile_server);
return 0;
}
int RequestCache::Get(const Request& req, Response* res, uint64_t* key) {
if (!Enabled()) {
return -1;
}
uint64_t local_key = 0;
Hash(req, &local_key);
if (key != nullptr) {
*key = local_key;
}
std::lock_guard<std::recursive_mutex> lk(cache_mtx_);
auto iter = map_.find(local_key);
if (iter == map_.end()) {
LOG(INFO) << "key not found in cache";
return -1;
}
auto entry = iter->second;
BuildResponse(entry, res);
UpdateLru(local_key);
return 0;
}
int RequestCache::Put(const Request& req, const Response& res, uint64_t* key) {
if (!Enabled()) {
return -1;
}
uint64_t local_key = 0;
if (key != nullptr && *key != 0) {
local_key = *key;
} else {
Hash(req, &local_key);
}
if (key != nullptr) {
*key = local_key;
}
AddTask(local_key, res);
return 0;
}
int RequestCache::PutImpl(const Response& res, uint64_t key) {
std::lock_guard<std::recursive_mutex> lk(cache_mtx_);
auto iter = map_.find(key);
if (iter != map_.end()) {
LOG(WARNING) << "key[" << key << "] already exists in cache";
return -1;
}
CacheEntry entry;
if (BuildCacheEntry(res, &entry) != 0) {
LOG(WARNING) << "key[" << key << "] build cache entry failed";
return -1;
}
map_.insert({key, entry});
UpdateLru(key);
return 0;
}
int RequestCache::BuildResponse(const CacheEntry& entry,
predictor::general_model::Response* res) {
if (res == nullptr) {
return -1;
}
res->ParseFromString(entry.buf_);
res->clear_profile_time();
return 0;
}
int RequestCache::BuildCacheEntry(const Response& res, CacheEntry* entry) {
if (entry == nullptr) {
return -1;
}
std::lock_guard<std::recursive_mutex> lk(cache_mtx_);
int size = res.ByteSize();
if (size >= cache_size_) {
LOG(INFO) << "res size[" << size << "] larger than cache_size["
<< cache_size_ << "]";
return -1;
}
while (size > GetFreeCacheSize()) {
if (RemoveOne() != 0) {
LOG(ERROR) << "RemoveOne failed so can not build entry";
return -1;
}
}
entry->buf_ = res.SerializeAsString();
used_size_ += size;
return 0;
}
void RequestCache::UpdateLru(uint64_t key) {
std::lock_guard<std::recursive_mutex> lk(cache_mtx_);
auto lru_iter = std::find(lru_.begin(), lru_.end(), key);
if (lru_iter != lru_.end()) {
lru_.erase(lru_iter);
}
lru_.push_front(key);
}
bool RequestCache::Enabled() { return cache_size_ > 0; }
int64_t RequestCache::GetFreeCacheSize() { return cache_size_ - used_size_; }
int RequestCache::RemoveOne() {
std::lock_guard<std::recursive_mutex> lk(cache_mtx_);
uint64_t lru_key = lru_.back();
VLOG(1) << "Remove key[" << lru_key << "] from cache";
auto iter = map_.find(lru_key);
if (iter == map_.end()) {
LOG(ERROR) << "Remove key[" << lru_key << "] not find in cache";
return -1;
}
auto entry = iter->second;
used_size_ -= entry.buf_.size();
map_.erase(iter);
lru_.pop_back();
return 0;
}
void RequestCache::ThreadLoop() {
std::queue<std::pair<uint64_t, std::shared_ptr<Response>>> exec_task_queue;
for (;;) {
{
std::unique_lock<std::mutex> lock(queue_mutex_);
condition_.wait(
lock, [this]() { return this->bstop_ || this->task_queue_.size(); });
if (!task_queue_.size()) {
if (bstop_) {
return;
}
continue;
}
swap(exec_task_queue, task_queue_);
}
while (!exec_task_queue.empty()) {
auto [key, res_ptr] = exec_task_queue.front();
exec_task_queue.pop();
PutImpl(*res_ptr, key);
}
}
}
int RequestCache::AddTask(uint64_t key, const Response& res) {
std::unique_lock<std::mutex> lock(queue_mutex_);
std::shared_ptr<Response> res_ptr = std::make_shared<Response>(res);
task_queue_.push(std::make_pair(key, res_ptr));
condition_.notify_one();
return 0;
}
bool RequestCache::Empty() {
std::lock_guard<std::recursive_mutex> lk(cache_mtx_);
return lru_.empty();
}
int RequestCache::Clear() {
{
std::unique_lock<std::mutex> lock(queue_mutex_);
std::queue<std::pair<uint64_t, std::shared_ptr<Response>>> empty;
swap(empty, task_queue_);
}
int count = 0;
{
std::lock_guard<std::recursive_mutex> lk(cache_mtx_);
count = lru_.size();
lru_.clear();
map_.clear();
}
LOG(INFO) << "Clear " << count << " key!";
return 0;
}
} // namespace predictor
} // namespace paddle_serving
} // namespace baidu
\ No newline at end of file
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include <atomic>
#include <condition_variable>
#include <list>
#include <mutex>
#include <queue>
#include <string>
#include <thread>
#include <unordered_map>
namespace baidu {
namespace paddle_serving {
namespace predictor {
namespace general_model {
class Request;
class Response;
} // namespace general_model
struct CacheEntry {
explicit CacheEntry() {}
std::string buf_;
};
class RequestCache {
public:
explicit RequestCache(const int64_t size);
~RequestCache();
static RequestCache* GetSingleton();
int Hash(const predictor::general_model::Request& req, uint64_t* key);
int Get(const predictor::general_model::Request& req,
predictor::general_model::Response* res,
uint64_t* key = nullptr);
int Put(const predictor::general_model::Request& req,
const predictor::general_model::Response& res,
uint64_t* key = nullptr);
void ThreadLoop();
bool Empty();
int Clear();
private:
int BuildResponse(const CacheEntry& entry,
predictor::general_model::Response* res);
int BuildCacheEntry(const predictor::general_model::Response& res,
CacheEntry* entry);
void UpdateLru(uint64_t key);
bool Enabled();
int64_t GetFreeCacheSize();
int RemoveOne();
int AddTask(uint64_t key, const predictor::general_model::Response& res);
int PutImpl(const predictor::general_model::Response& res, uint64_t key);
uint64_t cache_size_;
uint64_t used_size_;
std::unordered_map<uint64_t, CacheEntry> map_;
std::list<uint64_t> lru_;
std::recursive_mutex cache_mtx_;
std::atomic<bool> bstop_{false};
std::condition_variable condition_;
std::mutex queue_mutex_;
std::queue<
std::pair<uint64_t, std::shared_ptr<predictor::general_model::Response>>>
task_queue_;
std::unique_ptr<std::thread> thread_ptr_;
};
} // namespace predictor
} // namespace paddle_serving
} // namespace baidu
\ No newline at end of file
......@@ -208,6 +208,8 @@ def serve_args():
"--enable_prometheus", default=False, action="store_true", help="Use Prometheus")
parser.add_argument(
"--prometheus_port", type=int, default=19393, help="Port of the Prometheus")
parser.add_argument(
"--request_cache_size", type=int, default=0, help="Port of the Prometheus")
return parser.parse_args()
......@@ -291,6 +293,7 @@ def start_gpu_card_model(gpu_mode, port, args): # pylint: disable=doc-string-mi
server.set_max_body_size(max_body_size)
server.set_enable_prometheus(args.enable_prometheus)
server.set_prometheus_port(args.prometheus_port)
server.set_request_cache_size(args.request_cache_size)
if args.use_trt and device == "gpu":
server.set_trt()
......
......@@ -100,6 +100,7 @@ class Server(object):
]
self.enable_prometheus = False
self.prometheus_port = 19393
self.request_cache_size = 0
def get_fetch_list(self, infer_node_idx=-1):
fetch_names = [
......@@ -207,6 +208,9 @@ class Server(object):
def set_prometheus_port(self, prometheus_port):
self.prometheus_port = prometheus_port
def set_request_cache_size(self, request_cache_size):
self.request_cache_size = request_cache_size
def _prepare_engine(self, model_config_paths, device, use_encryption_model):
self.device = device
if self.model_toolkit_conf == None:
......@@ -615,6 +619,17 @@ class Server(object):
self.max_body_size,
self.enable_prometheus,
self.prometheus_port)
if self.enable_prometheus:
command = command + \
"-enable_prometheus={} " \
"-prometheus_port {} ".format(
self.enable_prometheus,
self.prometheus_port)
if self.request_cache_size > 0:
command = command + \
"-request_cache_size {} ".format(
self.request_cache_size
)
print("Going to Run Comand")
print(command)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册