提交 b0b72230 编写于 作者: M MRXLT

support encryption model

上级 97883259
...@@ -98,7 +98,10 @@ SET_PROPERTY(TARGET paddle_fluid PROPERTY IMPORTED_LOCATION ${PADDLE_INSTALL_DIR ...@@ -98,7 +98,10 @@ SET_PROPERTY(TARGET paddle_fluid PROPERTY IMPORTED_LOCATION ${PADDLE_INSTALL_DIR
ADD_LIBRARY(xxhash STATIC IMPORTED GLOBAL) ADD_LIBRARY(xxhash STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET xxhash PROPERTY IMPORTED_LOCATION ${PADDLE_INSTALL_DIR}/third_party/install/xxhash/lib/libxxhash.a) SET_PROPERTY(TARGET xxhash PROPERTY IMPORTED_LOCATION ${PADDLE_INSTALL_DIR}/third_party/install/xxhash/lib/libxxhash.a)
ADD_LIBRARY(cryptopp STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET cryptopp PROPERTY IMPORTED_LOCATION ${PADDLE_INSTALL_DIR}/third_party/install/cryptopp/lib/libcryptopp.a)
LIST(APPEND external_project_dependencies paddle) LIST(APPEND external_project_dependencies paddle)
LIST(APPEND paddle_depend_libs LIST(APPEND paddle_depend_libs
xxhash) xxhash cryptopp)
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <pthread.h> #include <pthread.h>
#include <fstream> #include <fstream>
#include <map> #include <map>
...@@ -535,6 +534,16 @@ class FluidCpuAnalysisDirWithSigmoidCore : public FluidCpuWithSigmoidCore { ...@@ -535,6 +534,16 @@ class FluidCpuAnalysisDirWithSigmoidCore : public FluidCpuWithSigmoidCore {
#if 1 #if 1
class FluidCpuAnalysisEncryptCore : public FluidFamilyCore { class FluidCpuAnalysisEncryptCore : public FluidFamilyCore {
public: public:
void ReadBinaryFile(const std::string& filename, std::string* contents) {
std::ifstream fin(filename, std::ios::in | std::ios::binary);
fin.seekg(0, std::ios::end);
contents->clear();
contents->resize(fin.tellg());
fin.seekg(0, std::ios::beg);
fin.read(&(contents->at(0)), contents->size());
fin.close();
}
int create(const predictor::InferEngineCreationParams& params) { int create(const predictor::InferEngineCreationParams& params) {
std::string data_path = params.get_path(); std::string data_path = params.get_path();
if (access(data_path.c_str(), F_OK) == -1) { if (access(data_path.c_str(), F_OK) == -1) {
...@@ -542,53 +551,31 @@ class FluidCpuAnalysisEncryptCore : public FluidFamilyCore { ...@@ -542,53 +551,31 @@ class FluidCpuAnalysisEncryptCore : public FluidFamilyCore {
<< data_path; << data_path;
return -1; return -1;
} }
std::ifstream model_file(data_path + "encry_model",
std::ios::in | std::ios::binary); std::string model_buffer, params_buffer, key_buffer;
std::string model_string; ReadBinaryFile(data_path + "encry_model", &model_buffer);
if (model_file.is_open()) { ReadBinaryFile(data_path + "encry_params", &params_buffer);
std::istreambuf_iterator<char> begin(model_file), end; ReadBinaryFile(data_path + "key", &key_buffer);
model_string = std::string(begin, end);
model_file.close(); VLOG(2) << "prepare for encryption model";
}
std::ifstream params_file(data_path + "encry_params",
std::ios::in | std::ios::binary);
std::string params_string;
if (params_file.is_open()) {
std::istreambuf_iterator<char> begin(params_file), end;
params_string = std::string(begin, end);
params_file.close();
}
std::ifstream key_file(data_path + "key", std::ios::in | std::ios::binary);
std::string key_string;
if (key_file.is_open()) {
std::istreambuf_iterator<char> begin(key_file), end;
key_string = std::string(begin, end);
key_file.close();
}
#if 1
auto cipher = paddle::framework::CipherFactory::CreateCipher(""); auto cipher = paddle::framework::CipherFactory::CreateCipher("");
std::string real_model_string = cipher->Decrypt(model_string, key_string); std::string real_model_buffer = cipher->Decrypt(model_buffer, key_buffer);
std::string real_params_string = cipher->Decrypt(params_string, key_string); std::string real_params_buffer = cipher->Decrypt(params_buffer, key_buffer);
#else
std::string real_model_string;
std::string real_params_string;
#endif
const char* real_model_buffer = real_model_string.c_str();
const char* real_params_buffer = real_params_string.c_str();
paddle::AnalysisConfig analysis_config; paddle::AnalysisConfig analysis_config;
analysis_config.SetModelBuffer(real_model_buffer, analysis_config.SetModelBuffer(&real_model_buffer[0],
real_model_string.size(), real_model_buffer.size(),
real_params_buffer, &real_params_buffer[0],
real_model_string.size()); real_params_buffer.size());
analysis_config.DisableGpu(); analysis_config.DisableGpu();
analysis_config.SetCpuMathLibraryNumThreads(1); analysis_config.SetCpuMathLibraryNumThreads(1);
if (params.enable_memory_optimization()) { if (params.enable_memory_optimization()) {
analysis_config.EnableMemoryOptim(); analysis_config.EnableMemoryOptim();
} }
analysis_config.SwitchSpecifyInputNames(true); analysis_config.SwitchSpecifyInputNames(true);
AutoLock lock(GlobalPaddleCreateMutex::instance()); AutoLock lock(GlobalPaddleCreateMutex::instance());
VLOG(2) << "decrypt model file sucess";
_core = _core =
paddle::CreatePaddlePredictor<paddle::AnalysisConfig>(analysis_config); paddle::CreatePaddlePredictor<paddle::AnalysisConfig>(analysis_config);
if (NULL == _core.get()) { if (NULL == _core.get()) {
......
...@@ -222,7 +222,7 @@ class Server(object): ...@@ -222,7 +222,7 @@ class Server(object):
if self.encryption_model: if self.encryption_model:
engine.type = "FLUID_CPU_ANALYSIS_ENCRYPT" engine.type = "FLUID_CPU_ANALYSIS_ENCRYPT"
else: else:
engine.type = "FLUID_CPU_ANALYSIS_ENCRYPT" engine.type = "FLUID_CPU_ANALYSIS_DIR"
elif device == "gpu": elif device == "gpu":
if self.encryption_model: if self.encryption_model:
engine.type = "FLUID_GPU_ANALYSIS_ENCRYPT" engine.type = "FLUID_GPU_ANALYSIS_ENCRYPT"
......
...@@ -56,7 +56,7 @@ def parse_args(): # pylint: disable=doc-string-missing ...@@ -56,7 +56,7 @@ def parse_args(): # pylint: disable=doc-string-missing
type=int, type=int,
default=512 * 1024 * 1024, default=512 * 1024 * 1024,
help="Limit sizes of messages") help="Limit sizes of messages")
parse.add_argument( parser.add_argument(
"--use_encryption_model", "--use_encryption_model",
default=False, default=False,
action="store_true", action="store_true",
...@@ -136,18 +136,14 @@ def start_serving(): ...@@ -136,18 +136,14 @@ def start_serving():
class MainService(BaseHTTPRequestHandler): class MainService(BaseHTTPRequestHandler):
def setup(self):
BaseHTTPRequestHandler.setup(self)
self.p_flag = False
def start(self): def start(self):
print(self.p_flag) global p_flag
if not self.p_flag: print(p_flag)
from multiprocessing import Pool if not p_flag:
pool = Pool(3) from multiprocessing import Process
pool.apply_async(start_serving) p = Process(target=start_serving)
self.p_status = 1 p.start()
self.p_flag = True p_flag = True
else: else:
pass pass
return True return True
...@@ -166,6 +162,7 @@ class MainService(BaseHTTPRequestHandler): ...@@ -166,6 +162,7 @@ class MainService(BaseHTTPRequestHandler):
if __name__ == "__main__": if __name__ == "__main__":
p_flag = False
args = parse_args() args = parse_args()
server = HTTPServer(('localhost', 8080), MainService) server = HTTPServer(('localhost', 8080), MainService)
print('Starting server, use <Ctrl-C> to stop') print('Starting server, use <Ctrl-C> to stop')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册