提交 b0b72230 编写于 作者: M MRXLT

support encryption model

上级 97883259
......@@ -98,7 +98,10 @@ SET_PROPERTY(TARGET paddle_fluid PROPERTY IMPORTED_LOCATION ${PADDLE_INSTALL_DIR
ADD_LIBRARY(xxhash STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET xxhash PROPERTY IMPORTED_LOCATION ${PADDLE_INSTALL_DIR}/third_party/install/xxhash/lib/libxxhash.a)
ADD_LIBRARY(cryptopp STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET cryptopp PROPERTY IMPORTED_LOCATION ${PADDLE_INSTALL_DIR}/third_party/install/cryptopp/lib/libcryptopp.a)
LIST(APPEND external_project_dependencies paddle)
LIST(APPEND paddle_depend_libs
xxhash)
xxhash cryptopp)
......@@ -13,7 +13,6 @@
// limitations under the License.
#pragma once
#include <pthread.h>
#include <fstream>
#include <map>
......@@ -535,6 +534,16 @@ class FluidCpuAnalysisDirWithSigmoidCore : public FluidCpuWithSigmoidCore {
#if 1
class FluidCpuAnalysisEncryptCore : public FluidFamilyCore {
public:
void ReadBinaryFile(const std::string& filename, std::string* contents) {
std::ifstream fin(filename, std::ios::in | std::ios::binary);
fin.seekg(0, std::ios::end);
contents->clear();
contents->resize(fin.tellg());
fin.seekg(0, std::ios::beg);
fin.read(&(contents->at(0)), contents->size());
fin.close();
}
int create(const predictor::InferEngineCreationParams& params) {
std::string data_path = params.get_path();
if (access(data_path.c_str(), F_OK) == -1) {
......@@ -542,53 +551,31 @@ class FluidCpuAnalysisEncryptCore : public FluidFamilyCore {
<< data_path;
return -1;
}
std::ifstream model_file(data_path + "encry_model",
std::ios::in | std::ios::binary);
std::string model_string;
if (model_file.is_open()) {
std::istreambuf_iterator<char> begin(model_file), end;
model_string = std::string(begin, end);
model_file.close();
}
std::ifstream params_file(data_path + "encry_params",
std::ios::in | std::ios::binary);
std::string params_string;
if (params_file.is_open()) {
std::istreambuf_iterator<char> begin(params_file), end;
params_string = std::string(begin, end);
params_file.close();
}
std::ifstream key_file(data_path + "key", std::ios::in | std::ios::binary);
std::string key_string;
if (key_file.is_open()) {
std::istreambuf_iterator<char> begin(key_file), end;
key_string = std::string(begin, end);
key_file.close();
}
#if 1
std::string model_buffer, params_buffer, key_buffer;
ReadBinaryFile(data_path + "encry_model", &model_buffer);
ReadBinaryFile(data_path + "encry_params", &params_buffer);
ReadBinaryFile(data_path + "key", &key_buffer);
VLOG(2) << "prepare for encryption model";
auto cipher = paddle::framework::CipherFactory::CreateCipher("");
std::string real_model_string = cipher->Decrypt(model_string, key_string);
std::string real_params_string = cipher->Decrypt(params_string, key_string);
#else
std::string real_model_string;
std::string real_params_string;
#endif
const char* real_model_buffer = real_model_string.c_str();
const char* real_params_buffer = real_params_string.c_str();
std::string real_model_buffer = cipher->Decrypt(model_buffer, key_buffer);
std::string real_params_buffer = cipher->Decrypt(params_buffer, key_buffer);
paddle::AnalysisConfig analysis_config;
analysis_config.SetModelBuffer(real_model_buffer,
real_model_string.size(),
real_params_buffer,
real_model_string.size());
analysis_config.SetModelBuffer(&real_model_buffer[0],
real_model_buffer.size(),
&real_params_buffer[0],
real_params_buffer.size());
analysis_config.DisableGpu();
analysis_config.SetCpuMathLibraryNumThreads(1);
if (params.enable_memory_optimization()) {
analysis_config.EnableMemoryOptim();
}
analysis_config.SwitchSpecifyInputNames(true);
AutoLock lock(GlobalPaddleCreateMutex::instance());
VLOG(2) << "decrypt model file sucess";
_core =
paddle::CreatePaddlePredictor<paddle::AnalysisConfig>(analysis_config);
if (NULL == _core.get()) {
......
......@@ -222,7 +222,7 @@ class Server(object):
if self.encryption_model:
engine.type = "FLUID_CPU_ANALYSIS_ENCRYPT"
else:
engine.type = "FLUID_CPU_ANALYSIS_ENCRYPT"
engine.type = "FLUID_CPU_ANALYSIS_DIR"
elif device == "gpu":
if self.encryption_model:
engine.type = "FLUID_GPU_ANALYSIS_ENCRYPT"
......
......@@ -56,7 +56,7 @@ def parse_args(): # pylint: disable=doc-string-missing
type=int,
default=512 * 1024 * 1024,
help="Limit sizes of messages")
parse.add_argument(
parser.add_argument(
"--use_encryption_model",
default=False,
action="store_true",
......@@ -136,18 +136,14 @@ def start_serving():
class MainService(BaseHTTPRequestHandler):
def setup(self):
BaseHTTPRequestHandler.setup(self)
self.p_flag = False
def start(self):
print(self.p_flag)
if not self.p_flag:
from multiprocessing import Pool
pool = Pool(3)
pool.apply_async(start_serving)
self.p_status = 1
self.p_flag = True
global p_flag
print(p_flag)
if not p_flag:
from multiprocessing import Process
p = Process(target=start_serving)
p.start()
p_flag = True
else:
pass
return True
......@@ -166,6 +162,7 @@ class MainService(BaseHTTPRequestHandler):
if __name__ == "__main__":
p_flag = False
args = parse_args()
server = HTTPServer(('localhost', 8080), MainService)
print('Starting server, use <Ctrl-C> to stop')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册