提交 97883259 编写于 作者: M MRXLT

support encryption model

上级 145a26f7
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <map> #include <map>
#include <string> #include <string>
#include <vector> #include <vector>
#include "cipher.h" // NOLINT
#include "cipher_utils.h" // NOLINT #include "cipher_utils.h" // NOLINT
#include "core/configure/include/configure_parser.h" #include "core/configure/include/configure_parser.h"
#include "core/configure/inferencer_configure.pb.h" #include "core/configure/inferencer_configure.pb.h"
...@@ -531,8 +532,8 @@ class FluidCpuAnalysisDirWithSigmoidCore : public FluidCpuWithSigmoidCore { ...@@ -531,8 +532,8 @@ class FluidCpuAnalysisDirWithSigmoidCore : public FluidCpuWithSigmoidCore {
return 0; return 0;
} }
}; };
#if 1
class FluidCpuAnalysisEncryCore : public FluidFamilyCore { class FluidCpuAnalysisEncryptCore : public FluidFamilyCore {
public: public:
int create(const predictor::InferEngineCreationParams& params) { int create(const predictor::InferEngineCreationParams& params) {
std::string data_path = params.get_path(); std::string data_path = params.get_path();
...@@ -564,11 +565,14 @@ class FluidCpuAnalysisEncryCore : public FluidFamilyCore { ...@@ -564,11 +565,14 @@ class FluidCpuAnalysisEncryCore : public FluidFamilyCore {
key_string = std::string(begin, end); key_string = std::string(begin, end);
key_file.close(); key_file.close();
} }
#if 1
auto cipher = paddle::CipherFactory::CreateCipher(); auto cipher = paddle::framework::CipherFactory::CreateCipher("");
std::string real_model_string = cipher->Decrypt(model_string, key_string); std::string real_model_string = cipher->Decrypt(model_string, key_string);
std::string real_params_string = cipher->Decrypt(params_string, key_string); std::string real_params_string = cipher->Decrypt(params_string, key_string);
#else
std::string real_model_string;
std::string real_params_string;
#endif
const char* real_model_buffer = real_model_string.c_str(); const char* real_model_buffer = real_model_string.c_str();
const char* real_params_buffer = real_params_string.c_str(); const char* real_params_buffer = real_params_string.c_str();
...@@ -595,6 +599,7 @@ class FluidCpuAnalysisEncryCore : public FluidFamilyCore { ...@@ -595,6 +599,7 @@ class FluidCpuAnalysisEncryCore : public FluidFamilyCore {
return 0; return 0;
} }
}; };
#endif
} // namespace fluid_cpu } // namespace fluid_cpu
} // namespace paddle_serving } // namespace paddle_serving
} // namespace baidu } // namespace baidu
...@@ -52,6 +52,13 @@ REGIST_FACTORY_OBJECT_IMPL_WITH_NAME( ...@@ -52,6 +52,13 @@ REGIST_FACTORY_OBJECT_IMPL_WITH_NAME(
::baidu::paddle_serving::predictor::InferEngine, ::baidu::paddle_serving::predictor::InferEngine,
"FLUID_CPU_NATIVE_DIR_SIGMOID"); "FLUID_CPU_NATIVE_DIR_SIGMOID");
#if 1
REGIST_FACTORY_OBJECT_IMPL_WITH_NAME(
::baidu::paddle_serving::predictor::FluidInferEngine<
FluidCpuAnalysisEncryptCore>,
::baidu::paddle_serving::predictor::InferEngine,
"FLUID_CPU_ANALYSIS_ENCRYPT");
#endif
} // namespace fluid_cpu } // namespace fluid_cpu
} // namespace paddle_serving } // namespace paddle_serving
} // namespace baidu } // namespace baidu
...@@ -156,6 +156,7 @@ class Server(object): ...@@ -156,6 +156,7 @@ class Server(object):
self.cur_path = os.getcwd() self.cur_path = os.getcwd()
self.use_local_bin = False self.use_local_bin = False
self.mkl_flag = False self.mkl_flag = False
self.encryption_model = False
self.model_config_paths = None # for multi-model in a workflow self.model_config_paths = None # for multi-model in a workflow
def set_max_concurrency(self, concurrency): def set_max_concurrency(self, concurrency):
...@@ -190,6 +191,9 @@ class Server(object): ...@@ -190,6 +191,9 @@ class Server(object):
def set_ir_optimize(self, flag=False): def set_ir_optimize(self, flag=False):
self.ir_optimization = flag self.ir_optimization = flag
def use_encryption_model(self, flag=False):
self.encryption_model = flag
def check_local_bin(self): def check_local_bin(self):
if "SERVING_BIN" in os.environ: if "SERVING_BIN" in os.environ:
self.use_local_bin = True self.use_local_bin = True
...@@ -215,9 +219,15 @@ class Server(object): ...@@ -215,9 +219,15 @@ class Server(object):
engine.force_update_static_cache = False engine.force_update_static_cache = False
if device == "cpu": if device == "cpu":
engine.type = "FLUID_CPU_ANALYSIS_DIR" if self.encryption_model:
engine.type = "FLUID_CPU_ANALYSIS_ENCRYPT"
else:
engine.type = "FLUID_CPU_ANALYSIS_ENCRYPT"
elif device == "gpu": elif device == "gpu":
engine.type = "FLUID_GPU_ANALYSIS_DIR" if self.encryption_model:
engine.type = "FLUID_GPU_ANALYSIS_ENCRYPT"
else:
engine.type = "FLUID_GPU_ANALYSIS_DIR"
self.model_toolkit_conf.engines.extend([engine]) self.model_toolkit_conf.engines.extend([engine])
......
...@@ -56,6 +56,11 @@ def parse_args(): # pylint: disable=doc-string-missing ...@@ -56,6 +56,11 @@ def parse_args(): # pylint: disable=doc-string-missing
type=int, type=int,
default=512 * 1024 * 1024, default=512 * 1024 * 1024,
help="Limit sizes of messages") help="Limit sizes of messages")
parse.add_argument(
"--use_encryption_model",
default=False,
action="store_true",
help="Use encryption model")
return parser.parse_args() return parser.parse_args()
...@@ -70,6 +75,7 @@ def start_standard_model(): # pylint: disable=doc-string-missing ...@@ -70,6 +75,7 @@ def start_standard_model(): # pylint: disable=doc-string-missing
ir_optim = args.ir_optim ir_optim = args.ir_optim
max_body_size = args.max_body_size max_body_size = args.max_body_size
use_mkl = args.use_mkl use_mkl = args.use_mkl
use_encryption_model = args.use_encryption_model
if model == "": if model == "":
print("You must specify your serving model") print("You must specify your serving model")
...@@ -94,6 +100,7 @@ def start_standard_model(): # pylint: disable=doc-string-missing ...@@ -94,6 +100,7 @@ def start_standard_model(): # pylint: disable=doc-string-missing
server.use_mkl(use_mkl) server.use_mkl(use_mkl)
server.set_max_body_size(max_body_size) server.set_max_body_size(max_body_size)
server.set_port(port) server.set_port(port)
server.use_encryption_model(use_encryption_model)
server.load_model_config(model) server.load_model_config(model)
server.prepare_server(workdir=workdir, port=port, device=device) server.prepare_server(workdir=workdir, port=port, device=device)
......
...@@ -93,7 +93,6 @@ class WebService(object): ...@@ -93,7 +93,6 @@ class WebService(object):
return result return result
def run_rpc_service(self): def run_rpc_service(self):
import socket
localIP = socket.gethostbyname(socket.gethostname()) localIP = socket.gethostbyname(socket.gethostname())
print("web service address:") print("web service address:")
print("http://{}:{}/{}/prediction".format(localIP, self.port, print("http://{}:{}/{}/prediction".format(localIP, self.port,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册