未验证 提交 1715631f 编写于 作者: T TeslaZhao 提交者: GitHub

Merge pull request #11 from PaddlePaddle/develop

Pull lastest codes
......@@ -114,7 +114,7 @@ ADD_LIBRARY(openblas STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET openblas PROPERTY IMPORTED_LOCATION ${PADDLE_INSTALL_DIR}/third_party/install/openblas/lib/libopenblas.a)
ADD_LIBRARY(paddle_fluid SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET paddle_fluid PROPERTY IMPORTED_LOCATION ${PADDLE_INSTALL_DIR}/lib/libpaddle_fluid.a)
SET_PROPERTY(TARGET paddle_fluid PROPERTY IMPORTED_LOCATION ${PADDLE_INSTALL_DIR}/lib/libpaddle_fluid.so)
if (WITH_TRT)
ADD_LIBRARY(nvinfer SHARED IMPORTED GLOBAL)
......@@ -127,17 +127,12 @@ endif()
ADD_LIBRARY(xxhash STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET xxhash PROPERTY IMPORTED_LOCATION ${PADDLE_INSTALL_DIR}/third_party/install/xxhash/lib/libxxhash.a)
ADD_LIBRARY(cryptopp STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET cryptopp PROPERTY IMPORTED_LOCATION ${PADDLE_INSTALL_DIR}/third_party/install/cryptopp/lib/libcryptopp.a)
LIST(APPEND external_project_dependencies paddle)
LIST(APPEND paddle_depend_libs
xxhash cryptopp)
xxhash)
if(WITH_TRT)
LIST(APPEND paddle_depend_libs
nvinfer nvinfer_plugin)
endif()
......@@ -17,11 +17,11 @@
#include <fstream>
#include <iostream>
#include <memory>
#include <thread> //NOLINT
#include <thread>
#include "core/predictor/framework.pb.h"
#include "quant.h" // NOLINT
#include "seq_file.h" // NOLINT
#include "quant.h"
#include "seq_file.h"
inline uint64_t time_diff(const struct timeval &start_time,
const struct timeval &end_time) {
......@@ -113,15 +113,13 @@ int dump_parameter(const char *input_file, const char *output_file) {
// std::cout << "key_len " << key_len << " value_len " << value_buf_len
// << std::endl;
memcpy(value_buf, tensor_buf + offset, value_buf_len);
seq_file_writer.write(
std::to_string(i).c_str(), sizeof(i), value_buf, value_buf_len);
seq_file_writer.write((char *)&i, sizeof(i), value_buf, value_buf_len);
offset += value_buf_len;
}
return 0;
}
float *read_embedding_table(const char *file1,
std::vector<int64_t> &dims) { // NOLINT
float *read_embedding_table(const char *file1, std::vector<int64_t> &dims) {
std::ifstream is(file1);
// Step 1: is read version, os write version
uint32_t version;
......@@ -244,7 +242,7 @@ int compress_parameter_parallel(const char *file1,
float x = *(emb_table + k * emb_size + e);
int val = round((x - xmin) / scale);
val = std::max(0, val);
val = std::min(static_cast<int>(pow2bits) - 1, val);
val = std::min((int)pow2bits - 1, val);
*(tensor_temp + 2 * sizeof(float) + e) = val;
}
result[k] = tensor_temp;
......@@ -264,8 +262,7 @@ int compress_parameter_parallel(const char *file1,
}
SeqFileWriter seq_file_writer(file2);
for (int64_t i = 0; i < dict_size; i++) {
seq_file_writer.write(
std::to_string(i).c_str(), sizeof(i), result[i], per_line_size);
seq_file_writer.write((char *)&i, sizeof(i), result[i], per_line_size);
}
return 0;
}
......
......@@ -100,14 +100,21 @@ make -j10
you can execute `make install` to put targets under directory `./output`, you need to add`-DCMAKE_INSTALL_PREFIX=./output`to specify output path to cmake command shown above.
### Integrated GPU version paddle inference library
### CUDA_PATH is the cuda install path,use the command(whereis cuda) to check,it should be /usr/local/cuda.
### CUDNN_LIBRARY && CUDA_CUDART_LIBRARY is the lib path, it should be /usr/local/cuda/lib64/
``` shell
export CUDA_PATH='/usr/local'
export CUDNN_LIBRARY='/usr/local/cuda/lib64/'
export CUDA_CUDART_LIBRARY="/usr/local/cuda/lib64/"
mkdir server-build-gpu && cd server-build-gpu
cmake -DPYTHON_INCLUDE_DIR=$PYTHONROOT/include/python2.7/ \
-DPYTHON_LIBRARIES=$PYTHONROOT/lib/libpython2.7.so \
-DPYTHON_EXECUTABLE=$PYTHONROOT/bin/python \
-DCUDA_TOOLKIT_ROOT_DIR=${CUDA_PATH} \
-DCUDNN_LIBRARY=${CUDNN_LIBRARY} \
-DCUDNN_LIBRARY=${CUDNN_LIBRARY} \
-DCUDA_CUDART_LIBRARY=${CUDA_CUDART_LIBRARY} \
-DSERVER=ON \
-DWITH_GPU=ON ..
make -j10
......@@ -116,6 +123,10 @@ make -j10
### Integrated TRT version paddle inference library
```
export CUDA_PATH='/usr/local'
export CUDNN_LIBRARY='/usr/local/cuda/lib64/'
export CUDA_CUDART_LIBRARY="/usr/local/cuda/lib64/"
mkdir server-build-trt && cd server-build-trt
cmake -DPYTHON_INCLUDE_DIR=$PYTHONROOT/include/python2.7/ \
-DPYTHON_LIBRARIES=$PYTHONROOT/lib/libpython2.7.so \
......@@ -123,6 +134,7 @@ cmake -DPYTHON_INCLUDE_DIR=$PYTHONROOT/include/python2.7/ \
-DTENSORRT_ROOT=${TENSORRT_LIBRARY_PATH} \
-DCUDA_TOOLKIT_ROOT_DIR=${CUDA_PATH} \
-DCUDNN_LIBRARY=${CUDNN_LIBRARY} \
-DCUDA_CUDART_LIBRARY=${CUDA_CUDART_LIBRARY} \
-DSERVER=ON \
-DWITH_GPU=ON \
-DWITH_TRT=ON ..
......@@ -166,12 +178,14 @@ make
## Install wheel package
Regardless of the client, server or App part, after compiling, install the whl package in `python/dist/` in the temporary directory(`server-build-cpu`, `server-build-gpu`, `client-build`,`app-build`) of the compilation process.
for example:cd server-build-cpu/python/dist && pip install -U xxxxx.whl
## Note
When running the python server, it will check the `SERVING_BIN` environment variable. If you want to use your own compiled binary file, set the environment variable to the path of the corresponding binary file, usually`export SERVING_BIN=${BUILD_DIR}/core/general-server/serving`.
BUILD_DIR is the absolute path of server build CPU or server build GPU。
for example: cd server-build-cpu && export SERVING_BIN=${PWD}/core/general-server/serving
......
......@@ -97,14 +97,20 @@ make -j10
可以执行`make install`把目标产出放在`./output`目录下,cmake阶段需添加`-DCMAKE_INSTALL_PREFIX=./output`选项来指定存放路径。
### 集成GPU版本Paddle Inference Library
### CUDA_PATH是cuda的安装路径,可以使用命令行whereis cuda命令确认你的cuda安装路径,通常应该是/usr/local/cuda
### CUDNN_LIBRARY CUDA_CUDART_LIBRARY 是cuda库文件的路径,通常应该是/usr/local/cuda/lib64/
``` shell
export CUDA_PATH='/usr/local'
export CUDNN_LIBRARY='/usr/local/cuda/lib64/'
export CUDA_CUDART_LIBRARY="/usr/local/cuda/lib64/"
mkdir server-build-gpu && cd server-build-gpu
cmake -DPYTHON_INCLUDE_DIR=$PYTHONROOT/include/python2.7/ \
-DPYTHON_LIBRARIES=$PYTHONROOT/lib/libpython2.7.so \
-DPYTHON_EXECUTABLE=$PYTHONROOT/bin/python \
-DCUDA_TOOLKIT_ROOT_DIR=${CUDA_PATH} \
-DCUDNN_LIBRARY=${CUDNN_LIBRARY} \
-DCUDA_CUDART_LIBRARY=${CUDA_CUDART_LIBRARY} \
-DSERVER=ON \
-DWITH_GPU=ON ..
make -j10
......@@ -113,6 +119,10 @@ make -j10
### 集成TensorRT版本Paddle Inference Library
```
export CUDA_PATH='/usr/local'
export CUDNN_LIBRARY='/usr/local/cuda/lib64/'
export CUDA_CUDART_LIBRARY="/usr/local/cuda/lib64/"
mkdir server-build-trt && cd server-build-trt
cmake -DPYTHON_INCLUDE_DIR=$PYTHONROOT/include/python2.7/ \
-DPYTHON_LIBRARIES=$PYTHONROOT/lib/libpython2.7.so \
......@@ -120,6 +130,7 @@ cmake -DPYTHON_INCLUDE_DIR=$PYTHONROOT/include/python2.7/ \
-DTENSORRT_ROOT=${TENSORRT_LIBRARY_PATH} \
-DCUDA_TOOLKIT_ROOT_DIR=${CUDA_PATH} \
-DCUDNN_LIBRARY=${CUDNN_LIBRARY} \
-DCUDA_CUDART_LIBRARY=${CUDA_CUDART_LIBRARY} \
-DSERVER=ON \
-DWITH_GPU=ON \
-DWITH_TRT=ON ..
......@@ -162,12 +173,16 @@ make
## 安装wheel包
无论是Client端,Server端还是App部分,编译完成后,安装编译过程临时目录(`server-build-cpu``server-build-gpu``client-build``app-build`)下的`python/dist/` 中的whl包即可。
例如:cd server-build-cpu/python/dist && pip install -U xxxxx.whl
## 注意事项
运行python端Server时,会检查`SERVING_BIN`环境变量,如果想使用自己编译的二进制文件,请将设置该环境变量为对应二进制文件的路径,通常是`export SERVING_BIN=${BUILD_DIR}/core/general-server/serving`
其中BUILD_DIR为server-build-cpu或server-build-gpu的绝对路径。
可以cd server-build-cpu路径下,执行export SERVING_BIN=${PWD}/core/general-server/serving
......
......@@ -28,6 +28,7 @@ You can get images in two ways:
## Image description
Runtime images cannot be used for compilation.
If you want to customize your Serving based on source code, use the version with the suffix - devel.
| Description | OS | TAG | Dockerfile |
| :----------------------------------------------------------: | :-----: | :--------------------------: | :----------------------------------------------------------: |
......
......@@ -28,6 +28,7 @@
## 镜像说明
运行时镜像不能用于开发编译。
若需要基于源代码二次开发编译,请使用后缀为-devel的版本。
| 镜像说明 | 操作系统 | TAG | Dockerfile |
| -------------------------------------------------- | -------- | ---------------------------- | ------------------------------------------------------------ |
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include <pthread.h>
#include <fstream>
#include <map>
......@@ -28,6 +29,7 @@ namespace paddle_serving {
namespace fluid_cpu {
using configure::SigmoidConf;
class AutoLock {
public:
explicit AutoLock(pthread_mutex_t& mutex) : _mut(mutex) {
......@@ -528,60 +530,7 @@ class FluidCpuAnalysisDirWithSigmoidCore : public FluidCpuWithSigmoidCore {
return 0;
}
};
class FluidCpuAnalysisEncryptCore : public FluidFamilyCore {
public:
void ReadBinaryFile(const std::string& filename, std::string* contents) {
std::ifstream fin(filename, std::ios::in | std::ios::binary);
fin.seekg(0, std::ios::end);
contents->clear();
contents->resize(fin.tellg());
fin.seekg(0, std::ios::beg);
fin.read(&(contents->at(0)), contents->size());
fin.close();
}
int create(const predictor::InferEngineCreationParams& params) {
std::string data_path = params.get_path();
if (access(data_path.c_str(), F_OK) == -1) {
LOG(ERROR) << "create paddle predictor failed, path note exits: "
<< data_path;
return -1;
}
std::string model_buffer, params_buffer, key_buffer;
ReadBinaryFile(data_path + "encrypt_model", &model_buffer);
ReadBinaryFile(data_path + "encrypt_params", &params_buffer);
ReadBinaryFile(data_path + "key", &key_buffer);
VLOG(2) << "prepare for encryption model";
auto cipher = paddle::MakeCipher("");
std::string real_model_buffer = cipher->Decrypt(model_buffer, key_buffer);
std::string real_params_buffer = cipher->Decrypt(params_buffer, key_buffer);
paddle::AnalysisConfig analysis_config;
analysis_config.SetModelBuffer(&real_model_buffer[0],
real_model_buffer.size(),
&real_params_buffer[0],
real_params_buffer.size());
analysis_config.DisableGpu();
analysis_config.SetCpuMathLibraryNumThreads(1);
if (params.enable_memory_optimization()) {
analysis_config.EnableMemoryOptim();
}
analysis_config.SwitchSpecifyInputNames(true);
AutoLock lock(GlobalPaddleCreateMutex::instance());
VLOG(2) << "decrypt model file sucess";
_core =
paddle::CreatePaddlePredictor<paddle::AnalysisConfig>(analysis_config);
if (NULL == _core.get()) {
LOG(ERROR) << "create paddle predictor failed, path: " << data_path;
return -1;
}
VLOG(2) << "create paddle predictor sucess, path: " << data_path;
return 0;
}
};
} // namespace fluid_cpu
} // namespace paddle_serving
} // namespace baidu
......@@ -52,13 +52,6 @@ REGIST_FACTORY_OBJECT_IMPL_WITH_NAME(
::baidu::paddle_serving::predictor::InferEngine,
"FLUID_CPU_NATIVE_DIR_SIGMOID");
#if 1
REGIST_FACTORY_OBJECT_IMPL_WITH_NAME(
::baidu::paddle_serving::predictor::FluidInferEngine<
FluidCpuAnalysisEncryptCore>,
::baidu::paddle_serving::predictor::InferEngine,
"FLUID_CPU_ANALYSIS_ENCRYPT");
#endif
} // namespace fluid_cpu
} // namespace paddle_serving
} // namespace baidu
......@@ -25,6 +25,7 @@
#include "core/configure/inferencer_configure.pb.h"
#include "core/predictor/framework/infer.h"
#include "paddle_inference_api.h" // NOLINT
DECLARE_int32(gpuid);
namespace baidu {
......@@ -590,60 +591,6 @@ class FluidGpuAnalysisDirWithSigmoidCore : public FluidGpuWithSigmoidCore {
}
};
class FluidGpuAnalysisEncryptCore : public FluidFamilyCore {
public:
void ReadBinaryFile(const std::string& filename, std::string* contents) {
std::ifstream fin(filename, std::ios::in | std::ios::binary);
fin.seekg(0, std::ios::end);
contents->clear();
contents->resize(fin.tellg());
fin.seekg(0, std::ios::beg);
fin.read(&(contents->at(0)), contents->size());
fin.close();
}
int create(const predictor::InferEngineCreationParams& params) {
std::string data_path = params.get_path();
if (access(data_path.c_str(), F_OK) == -1) {
LOG(ERROR) << "create paddle predictor failed, path note exits: "
<< data_path;
return -1;
}
std::string model_buffer, params_buffer, key_buffer;
ReadBinaryFile(data_path + "encrypt_model", &model_buffer);
ReadBinaryFile(data_path + "encrypt_params", &params_buffer);
ReadBinaryFile(data_path + "key", &key_buffer);
VLOG(2) << "prepare for encryption model";
auto cipher = paddle::MakeCipher("");
std::string real_model_buffer = cipher->Decrypt(model_buffer, key_buffer);
std::string real_params_buffer = cipher->Decrypt(params_buffer, key_buffer);
paddle::AnalysisConfig analysis_config;
analysis_config.SetModelBuffer(&real_model_buffer[0],
real_model_buffer.size(),
&real_params_buffer[0],
real_params_buffer.size());
analysis_config.EnableUseGpu(100, FLAGS_gpuid);
analysis_config.SetCpuMathLibraryNumThreads(1);
if (params.enable_memory_optimization()) {
analysis_config.EnableMemoryOptim();
}
analysis_config.SwitchSpecifyInputNames(true);
AutoLock lock(GlobalPaddleCreateMutex::instance());
VLOG(2) << "decrypt model file sucess";
_core =
paddle::CreatePaddlePredictor<paddle::AnalysisConfig>(analysis_config);
if (NULL == _core.get()) {
LOG(ERROR) << "create paddle predictor failed, path: " << data_path;
return -1;
}
VLOG(2) << "create paddle predictor sucess, path: " << data_path;
return 0;
}
};
} // namespace fluid_gpu
} // namespace paddle_serving
} // namespace baidu
......@@ -54,12 +54,6 @@ REGIST_FACTORY_OBJECT_IMPL_WITH_NAME(
::baidu::paddle_serving::predictor::InferEngine,
"FLUID_GPU_NATIVE_DIR_SIGMOID");
REGIST_FACTORY_OBJECT_IMPL_WITH_NAME(
::baidu::paddle_serving::predictor::FluidInferEngine<
FluidGpuAnalysisEncryptCore>,
::baidu::paddle_serving::predictor::InferEngine,
"FLUID_GPU_ANALYSIS_ENCRPT")
} // namespace fluid_gpu
} // namespace paddle_serving
} // namespace baidu
# Encryption Model Prediction
([简体中文](README_CN.md)|English)
## Get Origin Model
The example uses the model file of the fit_a_line example as a origin model
```
sh get_data.sh
```
## Encrypt Model
```
python encrypt.py
```
The key is stored in the `key` file, and the encrypted model file and server-side configuration file are stored in the `encrypt_server` directory.
client-side configuration file are stored in the `encrypt_client` directory.
## Start Encryption Service
CPU Service
```
python -m paddle_serving_server.serve --model encrypt_server/ --port 9300 --use_encryption_model
```
GPU Service
```
python -m paddle_serving_server_gpu.serve --model encrypt_server/ --port 9300 --use_encryption_model --gpu_ids 0
```
## Prediction
```
python test_client.py uci_housing_client/serving_client_conf.prototxt
```
# 加密模型预测
(简体中文|[English](README.md))
## 获取明文模型
示例中使用fit_a_line示例的模型文件作为明文模型
```
sh get_data.sh
```
## 模型加密
```
python encrypt.py
```
密钥保存在`key`文件中,加密模型文件以及server端配置文件保存在`encrypt_server`目录下,client端配置文件保存在`encrypt_client`目录下。
## 启动加密预测服务
CPU预测服务
```
python -m paddle_serving_server.serve --model encrypt_server/ --port 9300 --use_encryption_model
```
GPU预测服务
```
python -m paddle_serving_server_gpu.serve --model encrypt_server/ --port 9300 --use_encryption_model --gpu_ids 0
```
## 预测
```
python test_client.py uci_housing_client/serving_client_conf.prototxt
```
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client.io import inference_model_to_serving
def serving_encryption():
inference_model_to_serving(
dirname="./uci_housing_model",
serving_server="encrypt_server",
serving_client="encrypt_client",
encryption=True)
if __name__ == "__main__":
serving_encryption()
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/uci_housing_example/encrypt.tar.gz
tar -xzf encrypt.tar.gz
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
from paddle_serving_client import Client
import sys
client = Client()
client.load_client_config(sys.argv[1])
client.use_key("./key")
client.connect(["127.0.0.1:9300"], encryption=True)
import paddle
test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.uci_housing.test(), buf_size=500),
batch_size=1)
for data in test_reader():
fetch_map = client.predict(feed={"x": data[0][0]}, fetch=["price"])
print("{} {}".format(fetch_map["price"][0], data[0][1][0]))
......@@ -13,19 +13,16 @@
# limitations under the License.
# pylint: disable=doc-string-missing
import paddle_serving_client
import os
from .proto import sdk_configure_pb2 as sdk
from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format
import numpy as np
import time
import sys
import requests
import json
import base64
import numpy as np
import paddle_serving_client
import google.protobuf.text_format
import grpc
from .proto import sdk_configure_pb2 as sdk
from .proto import general_model_config_pb2 as m_config
from .proto import multi_lang_general_model_service_pb2
sys.path.append(
os.path.join(os.path.abspath(os.path.dirname(__file__)), 'proto'))
......@@ -164,7 +161,6 @@ class Client(object):
self.fetch_names_to_idx_ = {}
self.lod_tensor_set = set()
self.feed_tensor_len = {}
self.key = None
for i, var in enumerate(model_conf.feed_var):
self.feed_names_to_idx_[var.alias_name] = i
......@@ -197,28 +193,7 @@ class Client(object):
else:
self.rpc_timeout_ms = rpc_timeout
def use_key(self, key_filename):
with open(key_filename, "r") as f:
self.key = f.read()
def get_serving_port(self, endpoints):
if self.key is not None:
req = json.dumps({"key": base64.b64encode(self.key)})
else:
req = json.dumps({})
r = requests.post("http://" + endpoints[0], req)
result = r.json()
print(result)
if "endpoint_list" not in result:
raise ValueError("server not ready")
else:
endpoints = [
endpoints[0].split(":")[0] + ":" +
str(result["endpoint_list"][0])
]
return endpoints
def connect(self, endpoints=None, encryption=False):
def connect(self, endpoints=None):
# check whether current endpoint is available
# init from client config
# create predictor here
......@@ -228,8 +203,6 @@ class Client(object):
"You must set the endpoints parameter or use add_variant function to create a variant."
)
else:
if encryption:
endpoints = self.get_serving_port(endpoints)
if self.predictor_sdk_ is None:
self.add_variant('default_tag_{}'.format(id(self)), endpoints,
100)
......
......@@ -21,9 +21,6 @@ from paddle.fluid.framework import Program
from paddle.fluid import CPUPlace
from paddle.fluid.io import save_inference_model
import paddle.fluid as fluid
from paddle.fluid.core import CipherUtils
from paddle.fluid.core import CipherFactory
from paddle.fluid.core import Cipher
from ..proto import general_model_config_pb2 as model_conf
import os
......@@ -32,10 +29,7 @@ def save_model(server_model_folder,
client_config_folder,
feed_var_dict,
fetch_var_dict,
main_program=None,
encryption=False,
key_len=128,
encrypt_conf=None):
main_program=None):
executor = Executor(place=CPUPlace())
feed_var_names = [feed_var_dict[x].name for x in feed_var_dict]
......@@ -44,29 +38,14 @@ def save_model(server_model_folder,
for key in sorted(fetch_var_dict.keys()):
target_vars.append(fetch_var_dict[key])
target_var_names.append(key)
if not encryption:
save_inference_model(
server_model_folder,
feed_var_names,
target_vars,
executor,
main_program=main_program)
else:
if encrypt_conf == None:
aes_cipher = CipherFactory.create_cipher()
else:
#todo: more encryption algorithms
pass
key = CipherUtils.gen_key_to_file(128, "key")
params = fluid.io.save_persistables(
executor=executor, dirname=None, main_program=main_program)
model = main_program.desc.serialize_to_string()
if not os.path.exists(server_model_folder):
os.makedirs(server_model_folder)
os.chdir(server_model_folder)
aes_cipher.encrypt_to_file(params, key, "encrypt_params")
aes_cipher.encrypt_to_file(model, key, "encrypt_model")
os.chdir("..")
save_inference_model(
server_model_folder,
feed_var_names,
target_vars,
executor,
main_program=main_program)
config = model_conf.GeneralModelConfig()
#int64 = 0; float32 = 1; int32 = 2;
......@@ -137,10 +116,7 @@ def inference_model_to_serving(dirname,
serving_server="serving_server",
serving_client="serving_client",
model_filename=None,
params_filename=None,
encryption=False,
key_len=128,
encrypt_conf=None):
params_filename=None):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
inference_program, feed_target_names, fetch_targets = \
......@@ -151,7 +127,7 @@ def inference_model_to_serving(dirname,
}
fetch_dict = {x.name: x for x in fetch_targets}
save_model(serving_server, serving_client, feed_dict, fetch_dict,
inference_program, encryption, key_len, encrypt_conf)
inference_program)
feed_names = feed_dict.keys()
fetch_names = fetch_dict.keys()
return feed_names, fetch_names
......@@ -157,7 +157,6 @@ class Server(object):
self.cur_path = os.getcwd()
self.use_local_bin = False
self.mkl_flag = False
self.encryption_model = False
self.product_name = None
self.container_id = None
self.model_config_paths = None # for multi-model in a workflow
......@@ -198,9 +197,6 @@ class Server(object):
def set_ir_optimize(self, flag=False):
self.ir_optimization = flag
def use_encryption_model(self, flag=False):
self.encryption_model = flag
def set_product_name(self, product_name=None):
if product_name == None:
raise ValueError("product_name can't be None.")
......@@ -236,15 +232,9 @@ class Server(object):
engine.force_update_static_cache = False
if device == "cpu":
if self.encryption_model:
engine.type = "FLUID_CPU_ANALYSIS_ENCRYPT"
else:
engine.type = "FLUID_CPU_ANALYSIS_DIR"
engine.type = "FLUID_CPU_ANALYSIS_DIR"
elif device == "gpu":
if self.encryption_model:
engine.type = "FLUID_GPU_ANALYSIS_ENCRYPT"
else:
engine.type = "FLUID_GPU_ANALYSIS_DIR"
engine.type = "FLUID_GPU_ANALYSIS_DIR"
self.model_toolkit_conf.engines.extend([engine])
......
......@@ -18,14 +18,8 @@ Usage:
python -m paddle_serving_server.serve --model ./serving_server_model --port 9292
"""
import argparse
import sys
import json
import base64
import time
from multiprocessing import Process
from web_service import WebService, port_is_available
from .web_service import WebService
from flask import Flask, request
from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer
def parse_args(): # pylint: disable=doc-string-missing
......@@ -59,11 +53,6 @@ def parse_args(): # pylint: disable=doc-string-missing
type=int,
default=512 * 1024 * 1024,
help="Limit sizes of messages")
parser.add_argument(
"--use_encryption_model",
default=False,
action="store_true",
help="Use encryption model")
parser.add_argument(
"--use_multilang",
default=False,
......@@ -82,18 +71,17 @@ def parse_args(): # pylint: disable=doc-string-missing
return parser.parse_args()
def start_standard_model(serving_port): # pylint: disable=doc-string-missing
def start_standard_model(): # pylint: disable=doc-string-missing
args = parse_args()
thread_num = args.thread
model = args.model
port = serving_port
port = args.port
workdir = args.workdir
device = args.device
mem_optim = args.mem_optim_off is False
ir_optim = args.ir_optim
max_body_size = args.max_body_size
use_mkl = args.use_mkl
use_encryption_model = args.use_encryption_model
use_multilang = args.use_multilang
if model == "":
......@@ -123,7 +111,6 @@ def start_standard_model(serving_port): # pylint: disable=doc-string-missing
server.use_mkl(use_mkl)
server.set_max_body_size(max_body_size)
server.set_port(port)
server.use_encryption_model(use_encryption_model)
if args.product_name != None:
server.set_product_name(args.product_name)
if args.container_id != None:
......@@ -134,88 +121,11 @@ def start_standard_model(serving_port): # pylint: disable=doc-string-missing
server.run_server()
class MainService(BaseHTTPRequestHandler):
def get_available_port(self):
default_port = 12000
for i in range(1000):
if port_is_available(default_port + i):
return default_port + i
def start_serving(self):
start_standard_model(serving_port)
def get_key(self, post_data):
if "key" not in post_data:
return False
else:
key = base64.b64decode(post_data["key"])
with open(args.model + "/key", "w") as f:
f.write(key)
return True
def check_key(self, post_data):
if "key" not in post_data:
return False
else:
key = base64.b64decode(post_data["key"])
with open(args.model + "/key", "r") as f:
cur_key = f.read()
return (key == cur_key)
def start(self, post_data):
post_data = json.loads(post_data)
global p_flag
if not p_flag:
if args.use_encryption_model:
print("waiting key for model")
if not self.get_key(post_data):
print("not found key in request")
return False
global serving_port
global p
serving_port = self.get_available_port()
p = Process(target=self.start_serving)
p.start()
time.sleep(3)
if p.is_alive():
p_flag = True
else:
return False
else:
if p.is_alive():
if not self.check_key(post_data):
return False
else:
return False
return True
def do_POST(self):
content_length = int(self.headers['Content-Length'])
post_data = self.rfile.read(content_length)
if self.start(post_data):
response = {"endpoint_list": [serving_port]}
else:
response = {"message": "start serving failed"}
self.send_response(200)
self.send_header('Content-type', 'application/json')
self.end_headers()
self.wfile.write(json.dumps(response))
if __name__ == "__main__":
args = parse_args()
if args.name == "None":
if args.use_encryption_model:
p_flag = False
p = None
serving_port = 0
server = HTTPServer(('localhost', int(args.port)), MainService)
print(
'Starting encryption server, waiting for key from client, use <Ctrl-C> to stop'
)
server.serve_forever()
else:
start_standard_model(args.port)
start_standard_model()
else:
service = WebService(name=args.name)
service.load_model_config(args.model)
......
......@@ -25,16 +25,6 @@ from paddle_serving_server import pipeline
from paddle_serving_server.pipeline import Op
def port_is_available(port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
result = sock.connect_ex(('0.0.0.0', port))
if result != 0:
return True
else:
return False
class WebService(object):
def __init__(self, name="default_service"):
self.name = name
......@@ -120,7 +110,7 @@ class WebService(object):
self.mem_optim = mem_optim
self.ir_optim = ir_optim
for i in range(1000):
if port_is_available(default_port + i):
if self.port_is_available(default_port + i):
self.port_list.append(default_port + i)
break
......
......@@ -68,11 +68,6 @@ def serve_args():
type=int,
default=512 * 1024 * 1024,
help="Limit sizes of messages")
parser.add_argument(
"--use_encryption_model",
default=False,
action="store_true",
help="Use encryption model")
parser.add_argument(
"--use_multilang",
default=False,
......@@ -282,8 +277,7 @@ class Server(object):
def set_trt(self):
self.use_trt = True
def _prepare_engine(self, model_config_paths, device, use_encryption_model):
def _prepare_engine(self, model_config_paths, device):
if self.model_toolkit_conf == None:
self.model_toolkit_conf = server_sdk.ModelToolkitConf()
......@@ -305,15 +299,9 @@ class Server(object):
engine.use_trt = self.use_trt
if device == "cpu":
if use_encryption_model:
engine.type = "FLUID_CPU_ANALYSIS_ENCRPT"
else:
engine.type = "FLUID_CPU_ANALYSIS_DIR"
engine.type = "FLUID_CPU_ANALYSIS_DIR"
elif device == "gpu":
if use_encryption_model:
engine.type = "FLUID_GPU_ANALYSIS_ENCRPT"
else:
engine.type = "FLUID_GPU_ANALYSIS_DIR"
engine.type = "FLUID_GPU_ANALYSIS_DIR"
self.model_toolkit_conf.engines.extend([engine])
......@@ -470,7 +458,6 @@ class Server(object):
workdir=None,
port=9292,
device="cpu",
use_encryption_model=False,
cube_conf=None):
if workdir == None:
workdir = "./tmp"
......@@ -484,8 +471,7 @@ class Server(object):
self.set_port(port)
self._prepare_resource(workdir, cube_conf)
self._prepare_engine(self.model_config_paths, device,
use_encryption_model)
self._prepare_engine(self.model_config_paths, device)
self._prepare_infer_service(port)
self.workdir = workdir
......
......@@ -19,21 +19,19 @@ Usage:
"""
import argparse
import os
import json
import base64
from multiprocessing import Pool, Process
from paddle_serving_server_gpu import serve_args
from flask import Flask, request
from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer
def start_gpu_card_model(index, gpuid, port, args): # pylint: disable=doc-string-missing
def start_gpu_card_model(index, gpuid, args): # pylint: disable=doc-string-missing
gpuid = int(gpuid)
device = "gpu"
port = args.port
if gpuid == -1:
device = "cpu"
elif gpuid >= 0:
port = port + index
port = args.port + index
thread_num = args.thread
model = args.model
mem_optim = args.mem_optim_off is False
......@@ -75,20 +73,14 @@ def start_gpu_card_model(index, gpuid, port, args): # pylint: disable=doc-strin
server.set_container_id(args.container_id)
server.load_model_config(model)
server.prepare_server(
workdir=workdir,
port=port,
device=device,
use_encryption_model=args.use_encryption_model)
server.prepare_server(workdir=workdir, port=port, device=device)
if gpuid >= 0:
server.set_gpuid(gpuid)
server.run_server()
def start_multi_card(args, serving_port=None): # pylint: disable=doc-string-missing
def start_multi_card(args): # pylint: disable=doc-string-missing
gpus = ""
if serving_port == None:
serving_port = args.port
if args.gpu_ids == "":
gpus = []
else:
......@@ -105,16 +97,14 @@ def start_multi_card(args, serving_port=None): # pylint: disable=doc-string-mis
env_gpus = []
if len(gpus) <= 0:
print("gpu_ids not set, going to run cpu service.")
start_gpu_card_model(-1, -1, serving_port, args)
start_gpu_card_model(-1, -1, args)
else:
gpu_processes = []
for i, gpu_id in enumerate(gpus):
p = Process(
target=start_gpu_card_model,
args=(
target=start_gpu_card_model, args=(
i,
gpu_id,
serving_port,
args, ))
gpu_processes.append(p)
for p in gpu_processes:
......@@ -123,89 +113,10 @@ def start_multi_card(args, serving_port=None): # pylint: disable=doc-string-mis
p.join()
class MainService(BaseHTTPRequestHandler):
def get_available_port(self):
default_port = 12000
for i in range(1000):
if port_is_available(default_port + i):
return default_port + i
def start_serving(self):
start_multi_card(args, serving_port)
def get_key(self, post_data):
if "key" not in post_data:
return False
else:
key = base64.b64decode(post_data["key"])
with open(args.model + "/key", "w") as f:
f.write(key)
return True
def check_key(self, post_data):
if "key" not in post_data:
return False
else:
key = base64.b64decode(post_data["key"])
with open(args.model + "/key", "r") as f:
cur_key = f.read()
return (key == cur_key)
def start(self, post_data):
post_data = json.loads(post_data)
global p_flag
if not p_flag:
if args.use_encryption_model:
print("waiting key for model")
if not self.get_key(post_data):
print("not found key in request")
return False
global serving_port
global p
serving_port = self.get_available_port()
p = Process(target=self.start_serving)
p.start()
time.sleep(3)
if p.is_alive():
p_flag = True
else:
return False
else:
if p.is_alive():
if not self.check_key(post_data):
return False
else:
return False
return True
def do_POST(self):
content_length = int(self.headers['Content-Length'])
post_data = self.rfile.read(content_length)
if self.start(post_data):
response = {"endpoint_list": [serving_port]}
else:
response = {"message": "start serving failed"}
self.send_response(200)
self.send_header('Content-type', 'application/json')
self.end_headers()
self.wfile.write(json.dumps(response))
if __name__ == "__main__":
args = serve_args()
if args.name == "None":
from .web_service import port_is_available
if args.use_encryption_model:
p_flag = False
p = None
serving_port = 0
server = HTTPServer(('localhost', int(args.port)), MainService)
print(
'Starting encryption server, waiting for key from client, use <Ctrl-C> to stop'
)
server.serve_forever()
else:
start_multi_card(args)
start_multi_card(args)
else:
from .web_service import WebService
web_service = WebService(name=args.name)
......
......@@ -28,16 +28,6 @@ from paddle_serving_server_gpu import pipeline
from paddle_serving_server_gpu.pipeline import Op
def port_is_available(port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
result = sock.connect_ex(('0.0.0.0', port))
if result != 0:
return True
else:
return False
class WebService(object):
def __init__(self, name="default_service"):
self.name = name
......@@ -146,7 +136,7 @@ class WebService(object):
self.port_list = []
default_port = 12000
for i in range(1000):
if port_is_available(default_port + i):
if self.port_is_available(default_port + i):
self.port_list.append(default_port + i)
if len(self.port_list) > len(self.gpus):
break
......
......@@ -39,8 +39,6 @@ RUN yum -y install wget && \
make clean && \
echo 'export PATH=/usr/local/python3.6/bin:$PATH' >> /root/.bashrc && \
echo 'export LD_LIBRARY_PATH=/usr/local/python3.6/lib:$LD_LIBRARY_PATH' >> /root/.bashrc && \
pip install requests && \
pip3 install requests && \
source /root/.bashrc && \
cd .. && rm -rf Python-3.6.8* && \
wget https://github.com/protocolbuffers/protobuf/releases/download/v3.11.2/protobuf-all-3.11.2.tar.gz && \
......
......@@ -49,8 +49,6 @@ RUN yum -y install wget && \
cd .. && rm -rf protobuf-* && \
yum -y install epel-release && yum -y install patchelf libXext libSM libXrender && \
yum clean all && \
pip install requests && \
pip3 install requests && \
localedef -c -i en_US -f UTF-8 en_US.UTF-8 && \
echo "export LANG=en_US.utf8" >> /root/.bashrc && \
echo "export LANGUAGE=en_US.utf8" >> /root/.bashrc
......@@ -23,8 +23,7 @@ RUN wget https://dl.google.com/go/go1.14.linux-amd64.tar.gz >/dev/null \
RUN yum -y install python-devel sqlite-devel >/dev/null \
&& curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py >/dev/null \
&& python get-pip.py >/dev/null \
&& rm get-pip.py \
&& pip install requests
&& rm get-pip.py
RUN wget http://nixos.org/releases/patchelf/patchelf-0.10/patchelf-0.10.tar.bz2 \
&& yum -y install bzip2 >/dev/null \
......@@ -35,9 +34,6 @@ RUN wget http://nixos.org/releases/patchelf/patchelf-0.10/patchelf-0.10.tar.bz2
&& cd .. \
&& rm -rf patchelf-0.10*
RUN yum install -y python3 python3-devel \
&& pip3 install requests
RUN wget https://github.com/protocolbuffers/protobuf/releases/download/v3.11.2/protobuf-all-3.11.2.tar.gz && \
tar zxf protobuf-all-3.11.2.tar.gz && \
cd protobuf-3.11.2 && \
......@@ -45,6 +41,8 @@ RUN wget https://github.com/protocolbuffers/protobuf/releases/download/v3.11.2/p
make clean && \
cd .. && rm -rf protobuf-*
RUN yum install -y python3 python3-devel
RUN yum -y update >/dev/null \
&& yum -y install dnf >/dev/null \
&& yum -y install dnf-plugins-core >/dev/null \
......
......@@ -30,13 +30,11 @@ RUN wget https://dl.google.com/go/go1.14.linux-amd64.tar.gz >/dev/null \
RUN yum -y install python-devel sqlite-devel \
&& curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py >/dev/null \
&& python get-pip.py >/dev/null \
&& rm get-pip.py \
&& pip install requests
&& rm get-pip.py
RUN yum install -y python3 python3-devel \
&& yum -y install epel-release && yum -y install patchelf libXext libSM libXrender\
&& yum clean all \
&& pip3 install requests
&& yum clean all
RUN localedef -c -i en_US -f UTF-8 en_US.UTF-8 \
&& echo "export LANG=en_US.utf8" >> /root/.bashrc \
......
......@@ -29,13 +29,11 @@ RUN wget https://dl.google.com/go/go1.14.linux-amd64.tar.gz >/dev/null \
RUN yum -y install python-devel sqlite-devel \
&& curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py >/dev/null \
&& python get-pip.py >/dev/null \
&& rm get-pip.py \
&& pip install requests
&& rm get-pip.py
RUN yum install -y python3 python3-devel \
&& yum -y install epel-release && yum -y install patchelf libXext libSM libXrender\
&& yum clean all \
&& pip3 install requests
&& yum clean all
RUN localedef -c -i en_US -f UTF-8 en_US.UTF-8 \
&& echo "export LANG=en_US.utf8" >> /root/.bashrc \
......
......@@ -19,13 +19,11 @@ RUN wget https://dl.google.com/go/go1.14.linux-amd64.tar.gz >/dev/null \
RUN yum -y install python-devel sqlite-devel \
&& curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py >/dev/null \
&& python get-pip.py >/dev/null \
&& rm get-pip.py \
&& pip install requests
&& rm get-pip.py
RUN yum install -y python3 python3-devel \
&& yum -y install epel-release && yum -y install patchelf libXext libSM libXrender\
&& yum clean all \
&& pip3 install requests
&& yum clean all
RUN localedef -c -i en_US -f UTF-8 en_US.UTF-8 \
&& echo "export LANG=en_US.utf8" >> /root/.bashrc \
......
......@@ -514,40 +514,6 @@ function python_test_lac() {
cd ..
}
function python_test_encryption(){
#pwd: /Serving/python/examples
cd encryption
sh get_data.sh
local TYPE=$1
export SERVING_BIN=${SERIVNG_WORKDIR}/build-server-${TYPE}/core/general-server/serving
case $TYPE in
CPU)
#check_cmd "python encrypt.py"
#sleep 5
check_cmd "python -m paddle_serving_server.serve --model encrypt_server/ --port 9300 --use_encryption_model > /dev/null &"
sleep 5
check_cmd "python test_client.py encrypt_client/serving_client_conf.prototxt"
kill_server_process
;;
GPU)
#check_cmd "python encrypt.py"
#sleep 5
check_cmd "python -m paddle_serving_server_gpu.serve --model encrypt_server/ --port 9300 --use_encryption_model --gpu_ids 0"
sleep 5
check_cmd "python test_client.py encrypt_client/serving_client_conf.prototxt"
kill_servere_process
;;
*)
echo "error type"
exit 1
;;
esac
echo "encryption $TYPE test finished as expected"
setproxy
unset SERVING_BIN
cd ..
}
function java_run_test() {
# pwd: /Serving
local TYPE=$1
......@@ -563,7 +529,7 @@ function java_run_test() {
cd examples # pwd: /Serving/java/examples
mvn compile > /dev/null
mvn install > /dev/null
# fit_a_line (general, asyn_predict, batch_predict)
cd ../../python/examples/grpc_impl_example/fit_a_line # pwd: /Serving/python/examples/grpc_impl_example/fit_a_line
sh get_data.sh
......@@ -820,7 +786,7 @@ function python_test_pipeline(){
python -m paddle_serving_server.serve --model imdb_cnn_model --port 9292 --workdir test9292 &> cnn.log &
python -m paddle_serving_server.serve --model imdb_bow_model --port 9393 --workdir test9393 &> bow.log &
sleep 5
# test: thread servicer & thread op
cat << EOF > config.yml
rpc_port: 18080
......@@ -994,7 +960,6 @@ function python_run_test() {
python_test_lac $TYPE # pwd: /Serving/python/examples
python_test_multi_process $TYPE # pwd: /Serving/python/examples
python_test_multi_fetch $TYPE # pwd: /Serving/python/examples
python_test_encryption $TYPE # pwd: /Serving/python/examples
python_test_yolov4 $TYPE # pwd: /Serving/python/examples
python_test_grpc_impl $TYPE # pwd: /Serving/python/examples
python_test_resnet50 $TYPE # pwd: /Serving/python/examples
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册