未验证 提交 5699ab64 编写于 作者: J Jiawei Wang 提交者: GitHub

Merge branch 'develop' into cube_062

......@@ -30,7 +30,7 @@ find_package(Threads REQUIRED)
find_package(CUDA QUIET)
include(simd)
# SET(CMAKE_BUILD_TYPE "Debug")
# CMAKE_BUILD_TYPE
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "RelWithDebInfo" CACHE STRING
......
......@@ -175,9 +175,12 @@ python3 -m paddle_serving_server.serve --model uci_housing_model --thread 10 --p
| Argument | Type | Default | Description |
| ---------------------------------------------- | ---- | ------- | ----------------------------------------------------- |
| `thread` | int | `4` | Concurrency of current service |
| `thread` | int | `2` | Number of brpc service thread |
| `op_num` | int[]| `0` | Thread Number for each model in asynchronous mode |
| `op_max_batch` | int[]| `0` | Batch Number for each model in asynchronous mode |
| `gpu_ids` | str[]| `"-1"` | Gpu card id for each model |
| `port` | int | `9292` | Exposed port of current service to users |
| `model` | str | `""` | Path of paddle model directory to be served |
| `model` | str[]| `""` | Path of paddle model directory to be served |
| `mem_optim_off` | - | - | Disable memory / graphic memory optimization |
| `ir_optim` | bool | False | Enable analysis and optimization of calculation graph |
| `use_mkl` (Only for cpu version) | - | - | Run inference with MKL |
......@@ -186,7 +189,24 @@ python3 -m paddle_serving_server.serve --model uci_housing_model --thread 10 --p
| `use_xpu` | - | - | Run PaddleLite inference with Baidu Kunlun XPU |
| `precision` | str | FP32 | Precision Mode, support FP32, FP16, INT8 |
| `use_calib` | bool | False | Only for deployment with TensorRT |
| `gpu_multi_stream` | bool | False | EnableGpuMultiStream to get larger QPS |
#### Description of asynchronous model
Asynchronous mode is suitable for 1. When the number of requests is very large, 2. When multiple models are concatenated and you want to specify the concurrency number of each model.
Asynchronous mode helps to improve the throughput (QPS) of service, but for a single request, the delay will increase slightly.
In asynchronous mode, each model will start n threads of the number you specify, and each thread contains a model instance. In other words, each model is equivalent to a thread pool containing N threads, and the task is taken from the task queue of the thread pool to execute.
In asynchronous mode, each RPC server thread is only responsible for putting the request into the task queue of the model thread pool. After the task is executed, the completed task is removed from the task queue.
In the above table, the number of RPC server threads is specified by --thread, and the default value is 2.
--op_num specifies the number of threads in the thread pool of each model. The default value is 0, indicating that asynchronous mode is not used.
--op_max_batch specifies the number of batches for each model. The default value is 32. It takes effect when --op_num is not 0.
#### When you want a model to use multiple GPU cards.
python3 -m paddle_serving_server.serve --model uci_housing_model --thread 10 --port 9292 --gpu_ids 0,1,2
#### When you want 2 models.
python3 -m paddle_serving_server.serve --model uci_housing_model_1 uci_housing_model_2 --thread 10 --port 9292
#### When you want 2 models, and want each of them use multiple GPU cards.
python3 -m paddle_serving_server.serve --model uci_housing_model_1 uci_housing_model_2 --thread 10 --port 9292 --gpu_ids 0,1 1,2
#### When a service contains two models, and each model needs to specify multiple GPU cards, and needs asynchronous mode, each model specifies different concurrency number.
python3 -m paddle_serving_server.serve --model uci_housing_model_1 uci_housing_model_2 --thread 10 --port 9292 --gpu_ids 0,1 1,2 --op_num 4 8
</center>
```python
......
......@@ -172,19 +172,40 @@ python3 -m paddle_serving_server.serve --model uci_housing_model --thread 10 --p
```
<center>
| Argument | Type | Default | Description |
| ---------------------------------------------- | ---- | ------- | ------------------------------------------------------ |
| `thread` | int | `4` | Concurrency of current service |
| `port` | int | `9292` | Exposed port of current service to users |
| `name` | str | `""` | Service name, can be used to generate HTTP request url |
| `model` | str | `""` | Path of paddle model directory to be served |
| `mem_optim_off` | - | - | Disable memory optimization |
| `ir_optim` | bool | False | Enable analysis and optimization of calculation graph |
| `use_mkl` (Only for cpu version) | - | - | Run inference with MKL |
| `use_trt` (Only for Cuda>=10.1 version) | - | - | Run inference with TensorRT |
| `use_lite` (Only for Intel x86 CPU or ARM CPU) | - | - | Run PaddleLite inference |
| `use_xpu` | - | - | Run PaddleLite inference with Baidu Kunlun XPU |
| `precision` | str | FP32 | Precision Mode, support FP32, FP16, INT8 |
| Argument | Type | Default | Description |
| ---------------------------------------------- | ---- | ------- | ----------------------------------------------------- |
| `thread` | int | `2` | Number of brpc service thread |
| `op_num` | int[]| `0` | Thread Number for each model in asynchronous mode |
| `op_max_batch` | int[]| `32` | Batch Number for each model in asynchronous mode |
| `gpu_ids` | str[]| `"-1"` | Gpu card id for each model |
| `port` | int | `9292` | Exposed port of current service to users |
| `model` | str[]| `""` | Path of paddle model directory to be served |
| `mem_optim_off` | - | - | Disable memory / graphic memory optimization |
| `ir_optim` | bool | False | Enable analysis and optimization of calculation graph |
| `use_mkl` (Only for cpu version) | - | - | Run inference with MKL |
| `use_trt` (Only for trt version) | - | - | Run inference with TensorRT |
| `use_lite` (Only for Intel x86 CPU or ARM CPU) | - | - | Run PaddleLite inference |
| `use_xpu` | - | - | Run PaddleLite inference with Baidu Kunlun XPU |
| `precision` | str | FP32 | Precision Mode, support FP32, FP16, INT8 |
| `use_calib` | bool | False | Only for deployment with TensorRT |
| `gpu_multi_stream` | bool | False | EnableGpuMultiStream to get larger QPS |
#### 异步模型的说明
异步模式适用于1、请求数量非常大的情况,2、多模型串联,想要分别指定每个模型的并发数的情况。
异步模式有助于提高Service服务的吞吐(QPS),但对于单次请求而言,时延会有少量增加。
异步模式中,每个模型会启动您指定个数的N个线程,每个线程中包含一个模型实例,换句话说每个模型相当于包含N个线程的线程池,从线程池的任务队列中取任务来执行。
异步模式中,各个RPC Server的线程只负责将Request请求放入模型线程池的任务队列中,等任务被执行完毕后,再从任务队列中取出已完成的任务。
上表中通过 --thread 10 指定的是RPC Server的线程数量,默认值为2,--op_num 指定的是各个模型的线程池中线程数N,默认值为0,表示不使用异步模式。
--op_max_batch 指定的各个模型的batch数量,默认值为32,该参数只有当--op_num不为0时才生效。
#### 当您的某个模型想使用多张GPU卡部署时.
python3 -m paddle_serving_server.serve --model uci_housing_model --thread 10 --port 9292 --gpu_ids 0,1,2
#### 当您的一个服务包含两个模型部署时.
python3 -m paddle_serving_server.serve --model uci_housing_model_1 uci_housing_model_2 --thread 10 --port 9292
#### 当您的一个服务包含两个模型,且每个模型都需要指定多张GPU卡部署时.
python3 -m paddle_serving_server.serve --model uci_housing_model_1 uci_housing_model_2 --thread 10 --port 9292 --gpu_ids 0,1 1,2
#### 当您的一个服务包含两个模型,且每个模型都需要指定多张GPU卡,且需要异步模式每个模型指定不同的并发数时.
python3 -m paddle_serving_server.serve --model uci_housing_model_1 uci_housing_model_2 --thread 10 --port 9292 --gpu_ids 0,1 1,2 --op_num 4 8
</center>
......
文件模式从 100755 更改为 100644
......@@ -33,9 +33,7 @@ if (WITH_PYTHON)
add_custom_target(general_model_config_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(general_model_config_py_proto general_model_config_py_proto_init)
py_grpc_proto_compile(multi_lang_general_model_service_py_proto SRCS proto/multi_lang_general_model_service.proto)
add_custom_target(multi_lang_general_model_service_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(multi_lang_general_model_service_py_proto multi_lang_general_model_service_py_proto_init)
if (CLIENT)
py_proto_compile(sdk_configure_py_proto SRCS proto/sdk_configure.proto)
......@@ -53,11 +51,7 @@ if (WITH_PYTHON)
COMMENT "Copy generated general_model_config proto file into directory paddle_serving_client/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
add_custom_command(TARGET multi_lang_general_model_service_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client/proto
COMMAND cp -f *.py ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client/proto
COMMENT "Copy generated multi_lang_general_model_service proto file into directory paddle_serving_client/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
endif()
if (APP)
......@@ -84,11 +78,6 @@ if (WITH_PYTHON)
COMMENT "Copy generated general_model_config proto file into directory paddle_serving_server/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
add_custom_command(TARGET multi_lang_general_model_service_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server/proto
COMMAND cp -f *.py ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server/proto
COMMENT "Copy generated multi_lang_general_model_service proto file into directory paddle_serving_server/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
endif()
endif()
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package baidu.paddle_serving.multi_lang;
option java_multiple_files = true;
option java_package = "io.paddle.serving.grpc";
option java_outer_classname = "ServingProto";
message Tensor {
optional bytes data = 1;
repeated int32 int_data = 2;
repeated int64 int64_data = 3;
repeated float float_data = 4;
optional int32 elem_type = 5;
repeated int32 shape = 6;
repeated int32 lod = 7; // only for fetch tensor currently
};
message FeedInst { repeated Tensor tensor_array = 1; };
message FetchInst { repeated Tensor tensor_array = 1; };
message InferenceRequest {
repeated FeedInst insts = 1;
repeated string feed_var_names = 2;
repeated string fetch_var_names = 3;
required bool is_python = 4 [ default = false ];
required uint64 log_id = 5 [ default = 0 ];
};
message InferenceResponse {
repeated ModelOutput outputs = 1;
optional string tag = 2;
required int32 err_code = 3;
};
message ModelOutput {
repeated FetchInst insts = 1;
optional string engine_name = 2;
}
message SetTimeoutRequest { required int32 timeout_ms = 1; }
message SimpleResponse { required int32 err_code = 1; }
message GetClientConfigRequest {}
message GetClientConfigResponse { required string client_config_str = 1; }
service MultiLangGeneralModelService {
rpc Inference(InferenceRequest) returns (InferenceResponse) {}
rpc SetTimeout(SetTimeoutRequest) returns (SimpleResponse) {}
rpc GetClientConfig(GetClientConfigRequest)
returns (GetClientConfigResponse) {}
};
......@@ -21,11 +21,12 @@ message EngineDesc {
required string reloadable_meta = 3;
required string reloadable_type = 4;
required string model_dir = 5;
required int32 runtime_thread_num = 6;
required int32 batch_infer_size = 7;
required int32 enable_batch_align = 8;
optional string version_file = 9;
optional string version_type = 10;
repeated int32 gpu_ids = 6;
required int32 runtime_thread_num = 7;
required int32 batch_infer_size = 8;
required int32 enable_batch_align = 9;
optional string version_file = 10;
optional string version_type = 11;
/*
* Sparse Parameter Service type. Valid types are:
......@@ -38,16 +39,17 @@ message EngineDesc {
LOCAL = 1;
REMOTE = 2;
}
optional SparseParamServiceType sparse_param_service_type = 11;
optional string sparse_param_service_table_name = 12;
optional bool enable_memory_optimization = 13;
optional bool enable_ir_optimization = 14;
optional bool use_trt = 15;
optional bool use_lite = 16;
optional bool use_xpu = 17;
optional bool use_gpu = 18;
optional bool combined_model = 19;
optional bool encrypted_model = 20;
optional SparseParamServiceType sparse_param_service_type = 12;
optional string sparse_param_service_table_name = 13;
optional bool enable_memory_optimization = 14;
optional bool enable_ir_optimization = 15;
optional bool use_trt = 16;
optional bool use_lite = 17;
optional bool use_xpu = 18;
optional bool use_gpu = 19;
optional bool combined_model = 20;
optional bool encrypted_model = 21;
optional bool gpu_multi_stream = 22;
};
// model_toolkit conf
......
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
......@@ -207,7 +207,7 @@ class PredictorClient {
void init_gflags(std::vector<std::string> argv);
int init(const std::vector<std::string> &client_conf);
int init(const std::vector<std::string>& client_conf);
void set_predictor_conf(const std::string& conf_path,
const std::string& conf_file);
......@@ -218,23 +218,22 @@ class PredictorClient {
int destroy_predictor();
int numpy_predict(
const std::vector<std::vector<py::array_t<float>>>& float_feed_batch,
const std::vector<std::string>& float_feed_name,
const std::vector<std::vector<int>>& float_shape,
const std::vector<std::vector<int>>& float_lod_slot_batch,
const std::vector<std::vector<py::array_t<int64_t>>>& int_feed_batch,
const std::vector<std::string>& int_feed_name,
const std::vector<std::vector<int>>& int_shape,
const std::vector<std::vector<int>>& int_lod_slot_batch,
const std::vector<std::vector<std::string>>& string_feed_batch,
const std::vector<std::string>& string_feed_name,
const std::vector<std::vector<int>>& string_shape,
const std::vector<std::vector<int>>& string_lod_slot_batch,
const std::vector<std::string>& fetch_name,
PredictorRes& predict_res_batch, // NOLINT
const int& pid,
const uint64_t log_id);
int numpy_predict(const std::vector<py::array_t<float>>& float_feed,
const std::vector<std::string>& float_feed_name,
const std::vector<std::vector<int>>& float_shape,
const std::vector<std::vector<int>>& float_lod_slot_batch,
const std::vector<py::array_t<int64_t>>& int_feed,
const std::vector<std::string>& int_feed_name,
const std::vector<std::vector<int>>& int_shape,
const std::vector<std::vector<int>>& int_lod_slot_batch,
const std::vector<std::string>& string_feed,
const std::vector<std::string>& string_feed_name,
const std::vector<std::vector<int>>& string_shape,
const std::vector<std::vector<int>>& string_lod_slot_batch,
const std::vector<std::string>& fetch_name,
PredictorRes& predict_res_batch, // NOLINT
const int& pid,
const uint64_t log_id);
private:
PredictorApi _api;
......@@ -243,6 +242,7 @@ class PredictorClient {
std::string _predictor_path;
std::string _conf_file;
std::map<std::string, int> _feed_name_to_idx;
std::vector<std::string> _feed_name;
std::map<std::string, int> _fetch_name_to_idx;
std::map<std::string, std::string> _fetch_name_to_var_name;
std::map<std::string, int> _fetch_name_to_type;
......
......@@ -25,8 +25,6 @@ using baidu::paddle_serving::Timer;
using baidu::paddle_serving::predictor::general_model::Request;
using baidu::paddle_serving::predictor::general_model::Response;
using baidu::paddle_serving::predictor::general_model::Tensor;
using baidu::paddle_serving::predictor::general_model::FeedInst;
using baidu::paddle_serving::predictor::general_model::FetchInst;
enum ProtoDataType { P_INT64, P_FLOAT32, P_INT32, P_STRING };
std::once_flag gflags_init_flag;
namespace py = pybind11;
......@@ -68,9 +66,13 @@ int PredictorClient::init(const std::vector<std::string> &conf_file) {
_fetch_name_to_idx.clear();
_shape.clear();
int feed_var_num = model_config.feed_var_size();
_feed_name.clear();
VLOG(2) << "feed var num: " << feed_var_num;
for (int i = 0; i < feed_var_num; ++i) {
_feed_name_to_idx[model_config.feed_var(i).alias_name()] = i;
VLOG(2) << "feed [" << i << "]"
<< " name: " << model_config.feed_var(i).name();
_feed_name.push_back(model_config.feed_var(i).name());
VLOG(2) << "feed alias name: " << model_config.feed_var(i).alias_name()
<< " index: " << i;
std::vector<int> tmp_feed_shape;
......@@ -146,15 +148,15 @@ int PredictorClient::create_predictor() {
}
int PredictorClient::numpy_predict(
const std::vector<std::vector<py::array_t<float>>> &float_feed_batch,
const std::vector<py::array_t<float>> &float_feed,
const std::vector<std::string> &float_feed_name,
const std::vector<std::vector<int>> &float_shape,
const std::vector<std::vector<int>> &float_lod_slot_batch,
const std::vector<std::vector<py::array_t<int64_t>>> &int_feed_batch,
const std::vector<py::array_t<int64_t>> &int_feed,
const std::vector<std::string> &int_feed_name,
const std::vector<std::vector<int>> &int_shape,
const std::vector<std::vector<int>> &int_lod_slot_batch,
const std::vector<std::vector<std::string>> &string_feed_batch,
const std::vector<std::string> &string_feed,
const std::vector<std::string> &string_feed_name,
const std::vector<std::vector<int>> &string_shape,
const std::vector<std::vector<int>> &string_lod_slot_batch,
......@@ -162,10 +164,6 @@ int PredictorClient::numpy_predict(
PredictorRes &predict_res_batch,
const int &pid,
const uint64_t log_id) {
int batch_size = std::max(float_feed_batch.size(), int_feed_batch.size());
batch_size = batch_size > string_feed_batch.size() ? batch_size
: string_feed_batch.size();
VLOG(2) << "batch size: " << batch_size;
predict_res_batch.clear();
Timer timeline;
int64_t preprocess_start = timeline.TimeStampUS();
......@@ -188,134 +186,122 @@ int PredictorClient::numpy_predict(
}
int vec_idx = 0;
for (int bi = 0; bi < batch_size; bi++) {
VLOG(2) << "prepare batch " << bi;
std::vector<Tensor *> tensor_vec;
FeedInst *inst = req.add_insts();
std::vector<py::array_t<float>> float_feed = float_feed_batch[bi];
std::vector<py::array_t<int64_t>> int_feed = int_feed_batch[bi];
std::vector<std::string> string_feed = string_feed_batch[bi];
for (auto &name : float_feed_name) {
tensor_vec.push_back(inst->add_tensor_array());
}
for (auto &name : int_feed_name) {
tensor_vec.push_back(inst->add_tensor_array());
}
// batch is already in Tensor.
std::vector<Tensor *> tensor_vec;
for (auto &name : string_feed_name) {
tensor_vec.push_back(inst->add_tensor_array());
}
for (auto &name : float_feed_name) {
tensor_vec.push_back(req.add_tensor());
}
VLOG(2) << "batch [" << bi << "] "
<< "prepared";
for (auto &name : int_feed_name) {
tensor_vec.push_back(req.add_tensor());
}
vec_idx = 0;
for (auto &name : float_feed_name) {
int idx = _feed_name_to_idx[name];
if (idx >= tensor_vec.size()) {
LOG(ERROR) << "idx > tensor_vec.size()";
return -1;
}
int nbytes = float_feed[vec_idx].nbytes();
void *rawdata_ptr = (void *)(float_feed[vec_idx].data(0));
int total_number = float_feed[vec_idx].size();
Tensor *tensor = tensor_vec[idx];
VLOG(2) << "prepare float feed " << name << " shape size "
<< float_shape[vec_idx].size();
for (uint32_t j = 0; j < float_shape[vec_idx].size(); ++j) {
tensor->add_shape(float_shape[vec_idx][j]);
}
for (uint32_t j = 0; j < float_lod_slot_batch[vec_idx].size(); ++j) {
tensor->add_lod(float_lod_slot_batch[vec_idx][j]);
}
tensor->set_elem_type(P_FLOAT32);
for (auto &name : string_feed_name) {
tensor_vec.push_back(req.add_tensor());
}
tensor->mutable_float_data()->Resize(total_number, 0);
memcpy(tensor->mutable_float_data()->mutable_data(), rawdata_ptr, nbytes);
vec_idx++;
vec_idx = 0;
for (auto &name : float_feed_name) {
int idx = _feed_name_to_idx[name];
if (idx >= tensor_vec.size()) {
LOG(ERROR) << "idx > tensor_vec.size()";
return -1;
}
VLOG(2) << "prepare float feed " << name << " idx " << idx;
int nbytes = float_feed[vec_idx].nbytes();
void *rawdata_ptr = (void *)(float_feed[vec_idx].data(0));
int total_number = float_feed[vec_idx].size();
Tensor *tensor = tensor_vec[idx];
VLOG(2) << "prepare float feed " << name << " shape size "
<< float_shape[vec_idx].size();
for (uint32_t j = 0; j < float_shape[vec_idx].size(); ++j) {
tensor->add_shape(float_shape[vec_idx][j]);
}
for (uint32_t j = 0; j < float_lod_slot_batch[vec_idx].size(); ++j) {
tensor->add_lod(float_lod_slot_batch[vec_idx][j]);
}
tensor->set_elem_type(P_FLOAT32);
VLOG(2) << "batch [" << bi << "] "
<< "float feed value prepared";
tensor->set_name(_feed_name[idx]);
tensor->set_alias_name(name);
vec_idx = 0;
for (auto &name : int_feed_name) {
int idx = _feed_name_to_idx[name];
if (idx >= tensor_vec.size()) {
LOG(ERROR) << "idx > tensor_vec.size()";
return -1;
}
Tensor *tensor = tensor_vec[idx];
int nbytes = int_feed[vec_idx].nbytes();
void *rawdata_ptr = (void *)(int_feed[vec_idx].data(0));
int total_number = int_feed[vec_idx].size();
tensor->mutable_float_data()->Resize(total_number, 0);
memcpy(tensor->mutable_float_data()->mutable_data(), rawdata_ptr, nbytes);
vec_idx++;
}
for (uint32_t j = 0; j < int_shape[vec_idx].size(); ++j) {
tensor->add_shape(int_shape[vec_idx][j]);
}
for (uint32_t j = 0; j < int_lod_slot_batch[vec_idx].size(); ++j) {
tensor->add_lod(int_lod_slot_batch[vec_idx][j]);
}
tensor->set_elem_type(_type[idx]);
if (_type[idx] == P_INT64) {
tensor->mutable_int64_data()->Resize(total_number, 0);
memcpy(
tensor->mutable_int64_data()->mutable_data(), rawdata_ptr, nbytes);
} else {
tensor->mutable_int_data()->Resize(total_number, 0);
memcpy(tensor->mutable_int_data()->mutable_data(), rawdata_ptr, nbytes);
}
vec_idx++;
vec_idx = 0;
for (auto &name : int_feed_name) {
int idx = _feed_name_to_idx[name];
if (idx >= tensor_vec.size()) {
LOG(ERROR) << "idx > tensor_vec.size()";
return -1;
}
Tensor *tensor = tensor_vec[idx];
int nbytes = int_feed[vec_idx].nbytes();
void *rawdata_ptr = (void *)(int_feed[vec_idx].data(0));
int total_number = int_feed[vec_idx].size();
VLOG(2) << "batch [" << bi << "] "
<< "int feed value prepared";
for (uint32_t j = 0; j < int_shape[vec_idx].size(); ++j) {
tensor->add_shape(int_shape[vec_idx][j]);
}
for (uint32_t j = 0; j < int_lod_slot_batch[vec_idx].size(); ++j) {
tensor->add_lod(int_lod_slot_batch[vec_idx][j]);
}
tensor->set_elem_type(_type[idx]);
tensor->set_name(_feed_name[idx]);
tensor->set_alias_name(name);
if (_type[idx] == P_INT64) {
tensor->mutable_int64_data()->Resize(total_number, 0);
memcpy(tensor->mutable_int64_data()->mutable_data(), rawdata_ptr, nbytes);
} else {
tensor->mutable_int_data()->Resize(total_number, 0);
memcpy(tensor->mutable_int_data()->mutable_data(), rawdata_ptr, nbytes);
}
vec_idx++;
}
vec_idx = 0;
for (auto &name : string_feed_name) {
int idx = _feed_name_to_idx[name];
if (idx >= tensor_vec.size()) {
LOG(ERROR) << "idx > tensor_vec.size()";
return -1;
}
Tensor *tensor = tensor_vec[idx];
vec_idx = 0;
for (auto &name : string_feed_name) {
int idx = _feed_name_to_idx[name];
if (idx >= tensor_vec.size()) {
LOG(ERROR) << "idx > tensor_vec.size()";
return -1;
}
Tensor *tensor = tensor_vec[idx];
for (uint32_t j = 0; j < string_shape[vec_idx].size(); ++j) {
tensor->add_shape(string_shape[vec_idx][j]);
}
for (uint32_t j = 0; j < string_lod_slot_batch[vec_idx].size(); ++j) {
tensor->add_lod(string_lod_slot_batch[vec_idx][j]);
}
tensor->set_elem_type(P_STRING);
const int string_shape_size = string_shape[vec_idx].size();
// string_shape[vec_idx] = [1];cause numpy has no datatype of string.
// we pass string via vector<vector<string> >.
if (string_shape_size != 1) {
LOG(ERROR) << "string_shape_size should be 1-D, but received is : "
<< string_shape_size;
return -1;
}
switch (string_shape_size) {
case 1: {
tensor->add_data(string_feed[vec_idx]);
break;
}
for (uint32_t j = 0; j < string_shape[vec_idx].size(); ++j) {
tensor->add_shape(string_shape[vec_idx][j]);
}
for (uint32_t j = 0; j < string_lod_slot_batch[vec_idx].size(); ++j) {
tensor->add_lod(string_lod_slot_batch[vec_idx][j]);
}
tensor->set_elem_type(P_STRING);
tensor->set_name(_feed_name[idx]);
tensor->set_alias_name(name);
const int string_shape_size = string_shape[vec_idx].size();
// string_shape[vec_idx] = [1];cause numpy has no datatype of string.
// we pass string via vector<vector<string> >.
if (string_shape_size != 1) {
LOG(ERROR) << "string_shape_size should be 1-D, but received is : "
<< string_shape_size;
return -1;
}
switch (string_shape_size) {
case 1: {
tensor->add_data(string_feed[vec_idx]);
break;
}
vec_idx++;
}
VLOG(2) << "batch [" << bi << "] "
<< "string feed value prepared";
vec_idx++;
}
int64_t preprocess_end = timeline.TimeStampUS();
int64_t client_infer_start = timeline.TimeStampUS();
Response res;
int64_t client_infer_end = 0;
......@@ -347,19 +333,18 @@ int PredictorClient::numpy_predict(
int idx = 0;
for (auto &name : fetch_name) {
// int idx = _fetch_name_to_idx[name];
int shape_size = output.insts(0).tensor_array(idx).shape_size();
int shape_size = output.tensor(idx).shape_size();
VLOG(2) << "fetch var " << name << " index " << idx << " shape size "
<< shape_size;
model._shape_map[name].resize(shape_size);
for (int i = 0; i < shape_size; ++i) {
model._shape_map[name][i] =
output.insts(0).tensor_array(idx).shape(i);
model._shape_map[name][i] = output.tensor(idx).shape(i);
}
int lod_size = output.insts(0).tensor_array(idx).lod_size();
int lod_size = output.tensor(idx).lod_size();
if (lod_size > 0) {
model._lod_map[name].resize(lod_size);
for (int i = 0; i < lod_size; ++i) {
model._lod_map[name][i] = output.insts(0).tensor_array(idx).lod(i);
model._lod_map[name][i] = output.tensor(idx).lod(i);
}
}
idx += 1;
......@@ -371,22 +356,22 @@ int PredictorClient::numpy_predict(
// int idx = _fetch_name_to_idx[name];
if (_fetch_name_to_type[name] == P_INT64) {
VLOG(2) << "ferch var " << name << "type int64";
int size = output.insts(0).tensor_array(idx).int64_data_size();
int size = output.tensor(idx).int64_data_size();
model._int64_value_map[name] = std::vector<int64_t>(
output.insts(0).tensor_array(idx).int64_data().begin(),
output.insts(0).tensor_array(idx).int64_data().begin() + size);
output.tensor(idx).int64_data().begin(),
output.tensor(idx).int64_data().begin() + size);
} else if (_fetch_name_to_type[name] == P_FLOAT32) {
VLOG(2) << "fetch var " << name << "type float";
int size = output.insts(0).tensor_array(idx).float_data_size();
int size = output.tensor(idx).float_data_size();
model._float_value_map[name] = std::vector<float>(
output.insts(0).tensor_array(idx).float_data().begin(),
output.insts(0).tensor_array(idx).float_data().begin() + size);
output.tensor(idx).float_data().begin(),
output.tensor(idx).float_data().begin() + size);
} else if (_fetch_name_to_type[name] == P_INT32) {
VLOG(2) << "fetch var " << name << "type int32";
int size = output.insts(0).tensor_array(idx).int_data_size();
int size = output.tensor(idx).int_data_size();
model._int32_value_map[name] = std::vector<int32_t>(
output.insts(0).tensor_array(idx).int_data().begin(),
output.insts(0).tensor_array(idx).int_data().begin() + size);
output.tensor(idx).int_data().begin(),
output.tensor(idx).int_data().begin() + size);
}
idx += 1;
}
......
......@@ -97,33 +97,31 @@ PYBIND11_MODULE(serving_client, m) {
[](PredictorClient &self) { self.destroy_predictor(); })
.def("numpy_predict",
[](PredictorClient &self,
const std::vector<std::vector<py::array_t<float>>>
&float_feed_batch,
const std::vector<py::array_t<float>> &float_feed,
const std::vector<std::string> &float_feed_name,
const std::vector<std::vector<int>> &float_shape,
const std::vector<std::vector<int>> &float_lod_slot_batch,
const std::vector<std::vector<py::array_t<int64_t>>>
&int_feed_batch,
const std::vector<py::array_t<int64_t>> &int_feed,
const std::vector<std::string> &int_feed_name,
const std::vector<std::vector<int>> &int_shape,
const std::vector<std::vector<int>> &int_lod_slot_batch,
const std::vector<std::vector<std::string>>& string_feed_batch,
const std::vector<std::string>& string_feed_name,
const std::vector<std::vector<int>>& string_shape,
const std::vector<std::vector<int>>& string_lod_slot_batch,
const std::vector<std::string> &string_feed,
const std::vector<std::string> &string_feed_name,
const std::vector<std::vector<int>> &string_shape,
const std::vector<std::vector<int>> &string_lod_slot_batch,
const std::vector<std::string> &fetch_name,
PredictorRes &predict_res_batch,
const int &pid,
const uint64_t log_id) {
return self.numpy_predict(float_feed_batch,
return self.numpy_predict(float_feed,
float_feed_name,
float_shape,
float_lod_slot_batch,
int_feed_batch,
int_feed,
int_feed_name,
int_shape,
int_lod_slot_batch,
string_feed_batch,
string_feed,
string_feed_name,
string_shape,
string_lod_slot_batch,
......
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "core/general-server/op/general_copy_op.h"
#include <algorithm>
#include <iostream>
#include <memory>
#include <sstream>
#include "core/general-server/op/general_infer_helper.h"
#include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h"
#include "core/util/include/timer.h"
namespace baidu {
namespace paddle_serving {
namespace serving {
using baidu::paddle_serving::Timer;
using baidu::paddle_serving::predictor::MempoolWrapper;
using baidu::paddle_serving::predictor::general_model::Tensor;
using baidu::paddle_serving::predictor::general_model::Request;
using baidu::paddle_serving::predictor::general_model::FeedInst;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
int GeneralCopyOp::inference() {
// reade request from client
const std::vector<std::string> pre_node_names = pre_names();
if (pre_node_names.size() != 1) {
LOG(ERROR) << "This op(" << op_name()
<< ") can only have one predecessor op, but received "
<< pre_node_names.size();
return -1;
}
const std::string pre_name = pre_node_names[0];
const GeneralBlob *input_blob = get_depend_argument<GeneralBlob>(pre_name);
uint64_t log_id = input_blob->GetLogId();
VLOG(2) << "(logid=" << log_id << ") precedent name: " << pre_name;
const TensorVector *in = &input_blob->tensor_vector;
VLOG(2) << "(logid=" << log_id << ") input size: " << in->size();
int batch_size = input_blob->GetBatchSize();
int input_var_num = 0;
GeneralBlob *res = mutable_data<GeneralBlob>();
res->SetLogId(log_id);
TensorVector *out = &res->tensor_vector;
VLOG(2) << "(logid=" << log_id << ") input batch size: " << batch_size;
res->SetBatchSize(batch_size);
if (!res) {
LOG(ERROR) << "(logid=" << log_id
<< ") Failed get op tls reader object output";
}
Timer timeline;
int64_t start = timeline.TimeStampUS();
VLOG(2) << "(logid=" << log_id << ") Going to init lod tensor";
for (int i = 0; i < in->size(); ++i) {
paddle::PaddleTensor lod_tensor;
CopyLod(&in->at(i), &lod_tensor);
lod_tensor.dtype = in->at(i).dtype;
lod_tensor.name = in->at(i).name;
VLOG(2) << "(logid=" << log_id << ") lod tensor [" << i
<< "].name = " << lod_tensor.name;
out->push_back(lod_tensor);
}
VLOG(2) << "(logid=" << log_id << ") pack done.";
for (int i = 0; i < out->size(); ++i) {
int64_t *src_ptr = static_cast<int64_t *>(in->at(i).data.data());
out->at(i).data.Resize(out->at(i).lod[0].back() * sizeof(int64_t));
out->at(i).shape = {out->at(i).lod[0].back(), 1};
int64_t *tgt_ptr = static_cast<int64_t *>(out->at(i).data.data());
for (int j = 0; j < out->at(i).lod[0].back(); ++j) {
tgt_ptr[j] = src_ptr[j];
}
}
VLOG(2) << "(logid=" << log_id << ") output done.";
timeline.Pause();
int64_t end = timeline.TimeStampUS();
CopyBlobInfo(input_blob, res);
AddBlobInfo(res, start);
AddBlobInfo(res, end);
VLOG(2) << "(logid=" << log_id << ") read data from client success";
return 0;
}
DEFINE_OP(GeneralCopyOp);
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "core/general-server/general_model_service.pb.h"
#include "core/general-server/op/general_infer_helper.h"
#include "core/predictor/framework/resource.h"
#include "paddle_inference_api.h" // NOLINT
namespace baidu {
namespace paddle_serving {
namespace serving {
class GeneralCopyOp
: public baidu::paddle_serving::predictor::OpWithChannel<GeneralBlob> {
public:
typedef std::vector<paddle::PaddleTensor> TensorVector;
DECLARE_OP(GeneralCopyOp);
int inference();
};
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
......@@ -36,7 +36,6 @@ using baidu::paddle_serving::predictor::MempoolWrapper;
using baidu::paddle_serving::predictor::general_model::Tensor;
using baidu::paddle_serving::predictor::general_model::Response;
using baidu::paddle_serving::predictor::general_model::Request;
using baidu::paddle_serving::predictor::general_model::FetchInst;
using baidu::paddle_serving::predictor::InferManager;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
......
文件模式从 100755 更改为 100644
......@@ -34,7 +34,6 @@ using baidu::paddle_serving::predictor::MempoolWrapper;
using baidu::paddle_serving::predictor::general_model::Tensor;
using baidu::paddle_serving::predictor::general_model::Response;
using baidu::paddle_serving::predictor::general_model::Request;
using baidu::paddle_serving::predictor::general_model::FetchInst;
using baidu::paddle_serving::predictor::InferManager;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
......
......@@ -35,7 +35,6 @@ using baidu::paddle_serving::predictor::MempoolWrapper;
using baidu::paddle_serving::predictor::general_model::Tensor;
using baidu::paddle_serving::predictor::general_model::Response;
using baidu::paddle_serving::predictor::general_model::Request;
using baidu::paddle_serving::predictor::general_model::FetchInst;
using baidu::paddle_serving::predictor::InferManager;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
......@@ -117,9 +116,6 @@ int GeneralDistKVQuantInferOp::inference() {
std::unordered_map<int, int> in_out_map;
baidu::paddle_serving::predictor::Resource &resource =
baidu::paddle_serving::predictor::Resource::instance();
//TODO:Temporary addition, specific details to be studied by HexToString
std::shared_ptr<PaddleGeneralModelConfig> model_config =
resource.get_general_model_config()[0];
int cube_quant_bits = resource.get_cube_quant_bits();
size_t EMBEDDING_SIZE = 0;
if (cube_quant_bits == 0) {
......@@ -146,7 +142,7 @@ int GeneralDistKVQuantInferOp::inference() {
sparse_out[sparse_idx].shape.push_back(
sparse_out[sparse_idx].lod[0].back());
sparse_out[sparse_idx].shape.push_back(EMBEDDING_SIZE);
sparse_out[sparse_idx].name = model_config->_feed_name[i];
sparse_out[sparse_idx].name = in->at(i).name;
sparse_out[sparse_idx].data.Resize(sparse_out[sparse_idx].lod[0].back() *
EMBEDDING_SIZE * sizeof(float));
// END HERE
......
......@@ -31,7 +31,6 @@ using baidu::paddle_serving::predictor::MempoolWrapper;
using baidu::paddle_serving::predictor::general_model::Tensor;
using baidu::paddle_serving::predictor::general_model::Response;
using baidu::paddle_serving::predictor::general_model::Request;
using baidu::paddle_serving::predictor::general_model::FetchInst;
using baidu::paddle_serving::predictor::InferManager;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
......@@ -49,7 +48,7 @@ int GeneralInferOp::inference() {
const GeneralBlob *input_blob = get_depend_argument<GeneralBlob>(pre_name);
if (!input_blob) {
LOG(ERROR) << "input_blob is nullptr,error";
return -1;
return -1;
}
uint64_t log_id = input_blob->GetLogId();
VLOG(2) << "(logid=" << log_id << ") Get precedent op name: " << pre_name;
......@@ -57,7 +56,7 @@ int GeneralInferOp::inference() {
GeneralBlob *output_blob = mutable_data<GeneralBlob>();
if (!output_blob) {
LOG(ERROR) << "output_blob is nullptr,error";
return -1;
return -1;
}
output_blob->SetLogId(log_id);
......
......@@ -30,42 +30,8 @@ using baidu::paddle_serving::Timer;
using baidu::paddle_serving::predictor::MempoolWrapper;
using baidu::paddle_serving::predictor::general_model::Tensor;
using baidu::paddle_serving::predictor::general_model::Request;
using baidu::paddle_serving::predictor::general_model::FeedInst;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
enum ProtoDataType { P_INT64, P_FLOAT32, P_INT32, P_STRING };
int conf_check(const Request *req,
const std::shared_ptr<PaddleGeneralModelConfig> &model_config) {
int var_num = req->insts(0).tensor_array_size();
if (var_num != model_config->_feed_type.size()) {
LOG(ERROR) << "feed var number not match: model config["
<< model_config->_feed_type.size() << "] vs. actual[" << var_num
<< "]";
return -1;
}
VLOG(2) << "fetch var num in reader op: " << req->fetch_var_names_size();
for (int i = 0; i < var_num; ++i) {
const Tensor &tensor = req->insts(0).tensor_array(i);
if (model_config->_feed_type[i] != tensor.elem_type()) {
LOG(ERROR) << "feed type not match.";
return -1;
}
if (model_config->_feed_shape[i].size() == tensor.shape_size()) {
for (int j = 0; j < model_config->_feed_shape[i].size(); ++j) {
tensor.shape(j);
if (model_config->_feed_shape[i][j] != tensor.shape(j)) {
LOG(ERROR) << "feed shape not match.";
return -1;
}
}
} else {
LOG(ERROR) << "feed shape not match.";
return -1;
}
}
return 0;
}
int GeneralReaderOp::inference() {
// read request from client
......@@ -93,7 +59,8 @@ int GeneralReaderOp::inference() {
res->SetLogId(log_id);
Timer timeline;
int64_t start = timeline.TimeStampUS();
int var_num = req->insts(0).tensor_array_size();
// var_num means the number of feed_var.
int var_num = req->tensor_size();
VLOG(2) << "(logid=" << log_id << ") var num: " << var_num
<< ") start to call load general model_conf op";
......@@ -102,19 +69,7 @@ int GeneralReaderOp::inference() {
baidu::paddle_serving::predictor::Resource::instance();
VLOG(2) << "(logid=" << log_id << ") get resource pointer done.";
// get the first InferOP's model_config as ReaderOp's model_config by default.
std::shared_ptr<PaddleGeneralModelConfig> model_config =
resource.get_general_model_config().front();
// TODO(guru4elephant): how to do conditional check?
/*
int ret = conf_check(req, model_config);
if (ret != 0) {
LOG(ERROR) << "model conf of server:";
resource.print_general_model_config(model_config);
return 0;
}
*/
// package tensor
// prepare basic information for input
// specify the memory needed for output tensor_vector
......@@ -125,7 +80,7 @@ int GeneralReaderOp::inference() {
int64_t databuf_size = 0;
for (int i = 0; i < var_num; ++i) {
paddle::PaddleTensor paddleTensor;
const Tensor &tensor = req->insts(0).tensor_array(i);
const Tensor &tensor = req->tensor(i);
data_len = 0;
elem_type = 0;
elem_size = 0;
......@@ -172,13 +127,16 @@ int GeneralReaderOp::inference() {
VLOG(2) << "(logid=" << log_id << ") shape for var[" << i << "]: " << dim;
paddleTensor.shape.push_back(dim);
}
paddleTensor.name = model_config->_feed_name[i];
paddleTensor.name = tensor.name();
out->push_back(paddleTensor);
VLOG(2) << "(logid=" << log_id << ") tensor size for var[" << i
<< "]: " << data_len;
databuf_size = data_len * elem_size;
out->at(i).data.Resize(databuf_size);
void *databuf_char = MempoolWrapper::instance().malloc(databuf_size);
paddle::PaddleBuf paddleBuf(databuf_char, databuf_size);
out->at(i).data = paddleBuf;
// out->at(i).data.Resize(databuf_size);
if (out->at(i).lod.size() > 0) {
VLOG(2) << "(logid=" << log_id << ") var[" << i
<< "] has lod_tensor and len=" << out->at(i).lod[0].back();
......
......@@ -34,7 +34,6 @@ using baidu::paddle_serving::predictor::MempoolWrapper;
using baidu::paddle_serving::predictor::general_model::Tensor;
using baidu::paddle_serving::predictor::general_model::Response;
using baidu::paddle_serving::predictor::general_model::Request;
using baidu::paddle_serving::predictor::general_model::FetchInst;
using baidu::paddle_serving::predictor::general_model::ModelOutput;
using baidu::paddle_serving::predictor::InferManager;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
......@@ -49,7 +48,6 @@ int GeneralResponseOp::inference() {
get_depend_argument<GeneralBlob>(pre_node_names[0])->GetLogId();
const Request *req = dynamic_cast<const Request *>(get_request_message());
// response inst with only fetch_var_names
Response *res = mutable_data<Response>();
Timer timeline;
......@@ -63,7 +61,8 @@ int GeneralResponseOp::inference() {
baidu::paddle_serving::predictor::Resource::instance();
VLOG(2) << "(logid=" << log_id << ") get resource pointer done.";
//get the last InferOP's model_config as ResponseOp's model_config by default.
// get the last InferOP's model_config as ResponseOp's model_config by
// default.
std::shared_ptr<PaddleGeneralModelConfig> model_config =
resource.get_general_model_config().back();
......@@ -71,6 +70,10 @@ int GeneralResponseOp::inference() {
<< ") max body size : " << brpc::fLU64::FLAGS_max_body_size;
std::vector<int> fetch_index;
// this is based on GetOutPutNames() is ordered map.
// and the order of Output is the same as the prototxt FetchVar.
// otherwise, you can only get the Output by the corresponding of
// Name -- Alias_name.
fetch_index.resize(req->fetch_var_names_size());
for (int i = 0; i < req->fetch_var_names_size(); ++i) {
fetch_index[i] =
......@@ -95,40 +98,41 @@ int GeneralResponseOp::inference() {
ModelOutput *output = res->add_outputs();
// To get the order of model return values
output->set_engine_name(pre_name);
FetchInst *fetch_inst = output->add_insts();
var_idx = 0;
// idx is the real index of FetchVar.
// idx is not the index of FetchList.
// fetch_index is the real index in FetchVar of Fetchlist
// for example, FetchVar = {0:A, 1:B, 2:C}
// FetchList = {0:C,1:A}, at this situation.
// fetch_index = [2,0], C`index = 2 and A`index = 0
for (auto &idx : fetch_index) {
Tensor *tensor = fetch_inst->add_tensor_array();
//tensor->set_elem_type(1);
if (model_config->_is_lod_fetch[idx]) {
VLOG(2) << "(logid=" << log_id << ") out[" << idx << "] "
<< model_config->_fetch_name[idx] << " is lod_tensor";
for (int k = 0; k < in->at(idx).shape.size(); ++k) {
VLOG(2) << "(logid=" << log_id << ") shape[" << k
<< "]: " << in->at(idx).shape[k];
tensor->add_shape(in->at(idx).shape[k]);
}
} else {
VLOG(2) << "(logid=" << log_id << ") out[" << idx << "] "
<< model_config->_fetch_name[idx] << " is tensor";
for (int k = 0; k < in->at(idx).shape.size(); ++k) {
VLOG(2) << "(logid=" << log_id << ") shape[" << k
<< "]: " << in->at(idx).shape[k];
tensor->add_shape(in->at(idx).shape[k]);
Tensor *tensor = output->add_tensor();
tensor->set_name(in->at(idx).name);
tensor->set_alias_name(model_config->_fetch_alias_name[idx]);
for (int k = 0; k < in->at(idx).shape.size(); ++k) {
VLOG(2) << "(logid=" << log_id << ") shape[" << k
<< "]: " << in->at(idx).shape[k];
tensor->add_shape(in->at(idx).shape[k]);
}
std::string str_tensor_type = "is tensor";
if (model_config->_is_lod_fetch[idx] && in->at(idx).lod.size() > 0) {
str_tensor_type = "is lod_tensor";
for (int j = 0; j < in->at(idx).lod[0].size(); ++j) {
tensor->add_lod(in->at(idx).lod[0][j]);
}
}
}
VLOG(2) << "(logid=" << log_id << ") out[" << idx << "] "
<< model_config->_fetch_name[idx] << str_tensor_type;
var_idx = 0;
for (auto &idx : fetch_index) {
cap = 1;
for (int j = 0; j < in->at(idx).shape.size(); ++j) {
cap *= in->at(idx).shape[j];
}
FetchInst *fetch_p = output->mutable_insts(0);
auto dtype = in->at(idx).dtype;
if (dtype == paddle::PaddleDType::INT64) {
tensor->set_elem_type(0);
VLOG(2) << "(logid=" << log_id << ") Prepare int64 var ["
<< model_config->_fetch_name[idx] << "].";
int64_t *data_ptr = static_cast<int64_t *>(in->at(idx).data.data());
......@@ -137,35 +141,24 @@ int GeneralResponseOp::inference() {
// `Swap` method is faster than `{}` method.
google::protobuf::RepeatedField<int64_t> tmp_data(data_ptr,
data_ptr + cap);
fetch_p->mutable_tensor_array(var_idx)->mutable_int64_data()->Swap(
&tmp_data);
output->mutable_tensor(var_idx)->mutable_int64_data()->Swap(&tmp_data);
} else if (dtype == paddle::PaddleDType::FLOAT32) {
tensor->set_elem_type(1);
VLOG(2) << "(logid=" << log_id << ") Prepare float var ["
<< model_config->_fetch_name[idx] << "].";
float *data_ptr = static_cast<float *>(in->at(idx).data.data());
google::protobuf::RepeatedField<float> tmp_data(data_ptr,
data_ptr + cap);
fetch_p->mutable_tensor_array(var_idx)->mutable_float_data()->Swap(
&tmp_data);
output->mutable_tensor(var_idx)->mutable_float_data()->Swap(&tmp_data);
} else if (dtype == paddle::PaddleDType::INT32) {
tensor->set_elem_type(2);
VLOG(2) << "(logid=" << log_id << ")Prepare int32 var ["
<< model_config->_fetch_name[idx] << "].";
int32_t *data_ptr = static_cast<int32_t *>(in->at(idx).data.data());
google::protobuf::RepeatedField<int32_t> tmp_data(data_ptr,
data_ptr + cap);
fetch_p->mutable_tensor_array(var_idx)->mutable_int_data()->Swap(
&tmp_data);
}
if (model_config->_is_lod_fetch[idx]) {
if (in->at(idx).lod.size() > 0) {
for (int j = 0; j < in->at(idx).lod[0].size(); ++j) {
fetch_p->mutable_tensor_array(var_idx)->add_lod(
in->at(idx).lod[0][j]);
}
}
output->mutable_tensor(var_idx)->mutable_int_data()->Swap(&tmp_data);
}
VLOG(2) << "(logid=" << log_id << ") fetch var ["
......@@ -205,4 +198,4 @@ DEFINE_OP(GeneralResponseOp);
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
\ No newline at end of file
} // namespace baidu
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "core/general-server/op/general_text_reader_op.h"
#include <algorithm>
#include <iostream>
#include <memory>
#include <sstream>
#include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h"
#include "core/util/include/timer.h"
namespace baidu {
namespace paddle_serving {
namespace serving {
using baidu::paddle_serving::Timer;
using baidu::paddle_serving::predictor::MempoolWrapper;
using baidu::paddle_serving::predictor::general_model::Tensor;
using baidu::paddle_serving::predictor::general_model::Request;
using baidu::paddle_serving::predictor::general_model::FeedInst;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
int GeneralTextReaderOp::inference() {
// reade request from client
const Request *req = dynamic_cast<const Request *>(get_request_message());
uint64_t log_id = req->log_id();
int batch_size = req->insts_size();
int input_var_num = 0;
std::vector<int64_t> elem_type;
std::vector<int64_t> elem_size;
std::vector<int64_t> capacity;
GeneralBlob *res = mutable_data<GeneralBlob>();
if (!res) {
LOG(ERROR) << "(logid=" << log_id
<< ") Failed get op tls reader object output";
}
TensorVector *out = &res->tensor_vector;
res->SetBatchSize(batch_size);
res->SetLogId(log_id);
if (batch_size <= 0) {
LOG(ERROR) << "(logid=" << log_id << ") Batch size < 0";
return -1;
}
Timer timeline;
int64_t start = timeline.TimeStampUS();
int var_num = req->insts(0).tensor_array_size();
VLOG(2) << "(logid=" << log_id << ") var num: " << var_num;
VLOG(2) << "(logid=" << log_id
<< ") start to call load general model_conf op";
baidu::paddle_serving::predictor::Resource &resource =
baidu::paddle_serving::predictor::Resource::instance();
VLOG(2) << "(logid=" << log_id << ") get resource pointer done.";
std::shared_ptr<PaddleGeneralModelConfig> model_config =
resource.get_general_model_config()[0];
VLOG(2) << "(logid=" << log_id << ") print general model config done.";
elem_type.resize(var_num);
elem_size.resize(var_num);
capacity.resize(var_num);
for (int i = 0; i < var_num; ++i) {
paddle::PaddleTensor lod_tensor;
elem_type[i] = req->insts(0).tensor_array(i).elem_type();
VLOG(2) << "(logid=" << log_id << ") var[" << i
<< "] has elem type: " << elem_type[i];
if (elem_type[i] == 0) { // int64
elem_size[i] = sizeof(int64_t);
lod_tensor.dtype = paddle::PaddleDType::INT64;
} else {
elem_size[i] = sizeof(float);
lod_tensor.dtype = paddle::PaddleDType::FLOAT32;
}
if (req->insts(0).tensor_array(i).shape(0) == -1) {
lod_tensor.lod.resize(1);
lod_tensor.lod[0].push_back(0);
VLOG(2) << "(logid=" << log_id << ") var[" << i << "] is lod_tensor";
} else {
lod_tensor.shape.push_back(batch_size);
capacity[i] = 1;
for (int k = 0; k < req->insts(0).tensor_array(i).shape_size(); ++k) {
int dim = req->insts(0).tensor_array(i).shape(k);
VLOG(2) << "(logid=" << log_id << ") shape for var[" << i
<< "]: " << dim;
capacity[i] *= dim;
lod_tensor.shape.push_back(dim);
}
VLOG(2) << "(logid=" << log_id << ") var[" << i
<< "] is tensor, capacity: " << capacity[i];
}
lod_tensor.name = model_config->_feed_name[i];
out->push_back(lod_tensor);
}
for (int i = 0; i < var_num; ++i) {
if (out->at(i).lod.size() == 1) {
for (int j = 0; j < batch_size; ++j) {
const Tensor &tensor = req->insts(j).tensor_array(i);
int data_len = tensor.int_data_size();
int cur_len = out->at(i).lod[0].back();
out->at(i).lod[0].push_back(cur_len + data_len);
}
out->at(i).data.Resize(out->at(i).lod[0].back() * elem_size[i]);
out->at(i).shape = {out->at(i).lod[0].back(), 1};
VLOG(2) << "(logid=" << log_id << ") var[" << i
<< "] is lod_tensor and len=" << out->at(i).lod[0].back();
} else {
out->at(i).data.Resize(batch_size * capacity[i] * elem_size[i]);
VLOG(2) << "(logid=" << log_id << ") var[" << i
<< "] is tensor and capacity=" << batch_size * capacity[i];
}
}
for (int i = 0; i < var_num; ++i) {
if (elem_type[i] == 0) {
int64_t *dst_ptr = static_cast<int64_t *>(out->at(i).data.data());
int offset = 0;
for (int j = 0; j < batch_size; ++j) {
for (int k = 0; k < req->insts(j).tensor_array(i).int_data_size();
++k) {
dst_ptr[offset + k] = req->insts(j).tensor_array(i).int_data(k);
}
if (out->at(i).lod.size() == 1) {
offset = out->at(i).lod[0][j + 1];
} else {
offset += capacity[i];
}
}
} else {
float *dst_ptr = static_cast<float *>(out->at(i).data.data());
int offset = 0;
for (int j = 0; j < batch_size; ++j) {
for (int k = 0; k < req->insts(j).tensor_array(i).int_data_size();
++k) {
dst_ptr[offset + k] = req->insts(j).tensor_array(i).int_data(k);
}
if (out->at(i).lod.size() == 1) {
offset = out->at(i).lod[0][j + 1];
} else {
offset += capacity[i];
}
}
}
}
int64_t end = timeline.TimeStampUS();
res->p_size = 0;
AddBlobInfo(res, start);
AddBlobInfo(res, end);
VLOG(2) << "(logid=" << log_id << ") read data from client success";
return 0;
}
DEFINE_OP(GeneralTextReaderOp);
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "core/general-server/general_model_service.pb.h"
#include "core/general-server/load_general_model_service.pb.h"
#include "core/general-server/op/general_infer_helper.h"
#include "core/predictor/framework/resource.h"
#include "paddle_inference_api.h" // NOLINT
namespace baidu {
namespace paddle_serving {
namespace serving {
class GeneralTextReaderOp
: public baidu::paddle_serving::predictor::OpWithChannel<GeneralBlob> {
public:
typedef std::vector<paddle::PaddleTensor> TensorVector;
DECLARE_OP(GeneralTextReaderOp);
int inference();
};
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "core/general-server/op/general_text_response_op.h"
#include <algorithm>
#include <iostream>
#include <memory>
#include <sstream>
#include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h"
#include "core/predictor/framework/resource.h"
#include "core/util/include/timer.h"
namespace baidu {
namespace paddle_serving {
namespace serving {
using baidu::paddle_serving::Timer;
using baidu::paddle_serving::predictor::MempoolWrapper;
using baidu::paddle_serving::predictor::general_model::Tensor;
using baidu::paddle_serving::predictor::general_model::Response;
using baidu::paddle_serving::predictor::general_model::Request;
using baidu::paddle_serving::predictor::general_model::FetchInst;
using baidu::paddle_serving::predictor::general_model::ModelOutput;
using baidu::paddle_serving::predictor::InferManager;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
int GeneralTextResponseOp::inference() {
VLOG(2) << "Going to run inference";
const std::vector<std::string> pre_node_names = pre_names();
VLOG(2) << "pre node names size: " << pre_node_names.size();
const GeneralBlob *input_blob;
uint64_t log_id =
get_depend_argument<GeneralBlob>(pre_node_names[0])->GetLogId();
const Request *req = dynamic_cast<const Request *>(get_request_message());
// response inst with only fetch_var_names
Response *res = mutable_data<Response>();
Timer timeline;
int64_t start = timeline.TimeStampUS();
VLOG(2) << "(logid=" << log_id
<< ") start to call load general model_conf op";
baidu::paddle_serving::predictor::Resource &resource =
baidu::paddle_serving::predictor::Resource::instance();
VLOG(2) << "(logid=" << log_id << ") get resource pointer done.";
std::shared_ptr<PaddleGeneralModelConfig> model_config =
resource.get_general_model_config().back();
std::vector<int> fetch_index;
fetch_index.resize(req->fetch_var_names_size());
for (int i = 0; i < req->fetch_var_names_size(); ++i) {
fetch_index[i] =
model_config->_fetch_alias_name_to_index[req->fetch_var_names(i)];
}
for (uint32_t pi = 0; pi < pre_node_names.size(); ++pi) {
const std::string &pre_name = pre_node_names[pi];
VLOG(2) << "(logid=" << log_id << ") pre names[" << pi << "]: " << pre_name
<< " (" << pre_node_names.size() << ")";
input_blob = get_depend_argument<GeneralBlob>(pre_name);
if (!input_blob) {
LOG(ERROR) << "(logid=" << log_id
<< ") Failed mutable depended argument, op: " << pre_name;
return -1;
}
const TensorVector *in = &input_blob->tensor_vector;
int batch_size = input_blob->GetBatchSize();
VLOG(2) << "(logid=" << log_id << ") input batch size: " << batch_size;
ModelOutput *output = res->add_outputs();
output->set_engine_name(
pre_name); // To get the order of model return values
for (int i = 0; i < batch_size; ++i) {
FetchInst *fetch_inst = output->add_insts();
for (auto &idx : fetch_index) {
Tensor *tensor = fetch_inst->add_tensor_array();
// currently only response float tensor or lod_tensor
tensor->set_elem_type(1);
if (model_config->_is_lod_fetch[idx]) {
VLOG(2) << "(logid=" << log_id << ") out[" << idx << " is lod_tensor";
tensor->add_shape(-1);
} else {
VLOG(2) << "(logid=" << log_id << ") out[" << idx << "] is tensor";
for (int k = 1; k < in->at(idx).shape.size(); ++k) {
VLOG(2) << "(logid=" << log_id << ") shape[" << k - 1
<< "]: " << in->at(idx).shape[k];
tensor->add_shape(in->at(idx).shape[k]);
}
}
}
}
int var_idx = 0;
for (auto &idx : fetch_index) {
float *data_ptr = static_cast<float *>(in->at(idx).data.data());
int cap = 1;
for (int j = 1; j < in->at(idx).shape.size(); ++j) {
cap *= in->at(idx).shape[j];
}
if (model_config->_is_lod_fetch[idx]) {
for (int j = 0; j < batch_size; ++j) {
for (int k = in->at(idx).lod[0][j]; k < in->at(idx).lod[0][j + 1];
k++) {
output->mutable_insts(j)
->mutable_tensor_array(var_idx)
->add_float_data(data_ptr[k]);
}
}
} else {
for (int j = 0; j < batch_size; ++j) {
for (int k = j * cap; k < (j + 1) * cap; ++k) {
output->mutable_insts(j)
->mutable_tensor_array(var_idx)
->add_float_data(data_ptr[k]);
}
}
}
var_idx++;
}
}
if (req->profile_server()) {
int64_t end = timeline.TimeStampUS();
// TODO(barriery): multi-model profile_time.
// At present, only the response_op is multi-input, so here we get
// the profile_time by hard coding. It needs to be replaced with
// a more elegant way.
for (uint32_t pi = 0; pi < pre_node_names.size(); ++pi) {
input_blob = get_depend_argument<GeneralBlob>(pre_node_names[pi]);
VLOG(2) << "(logid=" << log_id
<< ") p size for input blob: " << input_blob->p_size;
int profile_time_idx = -1;
if (pi == 0) {
profile_time_idx = 0;
} else {
profile_time_idx = input_blob->p_size - 2;
}
for (; profile_time_idx < input_blob->p_size; ++profile_time_idx) {
res->add_profile_time(input_blob->time_stamp[profile_time_idx]);
}
}
// TODO(guru4elephant): find more elegant way to do this
res->add_profile_time(start);
res->add_profile_time(end);
}
return 0;
}
DEFINE_OP(GeneralTextResponseOp);
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "core/general-server/general_model_service.pb.h"
#include "core/general-server/op/general_infer_helper.h"
#include "paddle_inference_api.h" // NOLINT
namespace baidu {
namespace paddle_serving {
namespace serving {
class GeneralTextResponseOp
: public baidu::paddle_serving::predictor::OpWithChannel<
baidu::paddle_serving::predictor::general_model::Response> {
public:
typedef std::vector<paddle::PaddleTensor> TensorVector;
DECLARE_OP(GeneralTextResponseOp);
int inference();
};
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
......@@ -24,17 +24,16 @@ message Tensor {
repeated int32 int_data = 2;
repeated int64 int64_data = 3;
repeated float float_data = 4;
optional int32 elem_type = 5;
repeated int32 shape = 6;
repeated int32 lod = 7; // only for fetch tensor currently
optional int32 elem_type =
5; // 0 means int64, 1 means float32, 2 means int32, 3 means bytes(string)
repeated int32 shape = 6; // shape should include batch
repeated int32 lod = 7; // only for fetch tensor currently
optional string name = 8; // get from the Model prototxt
optional string alias_name = 9; // get from the Model prototxt
};
message FeedInst { repeated Tensor tensor_array = 1; };
message FetchInst { repeated Tensor tensor_array = 1; };
message Request {
repeated FeedInst insts = 1;
repeated Tensor tensor = 1;
repeated string fetch_var_names = 2;
optional bool profile_server = 3 [ default = false ];
required uint64 log_id = 4 [ default = 0 ];
......@@ -46,7 +45,7 @@ message Response {
};
message ModelOutput {
repeated FetchInst insts = 1;
repeated Tensor tensor = 1;
optional string engine_name = 2;
}
......
......@@ -280,6 +280,7 @@ class PdsCodeGenerator : public CodeGenerator {
" baidu::rpc::ClosureGuard done_guard(done);\n"
" baidu::rpc::Controller* cntl = \n"
" static_cast<baidu::rpc::Controller*>(cntl_base);\n"
" cntl->set_response_compress_type(brpc::COMPRESS_TYPE_GZIP);\n"
" uint64_t log_id = request->log_id();\n"
" cntl->set_log_id(log_id);\n"
" ::baidu::paddle_serving::predictor::InferService* svr = \n"
......@@ -322,6 +323,7 @@ class PdsCodeGenerator : public CodeGenerator {
" baidu::rpc::ClosureGuard done_guard(done);\n"
" baidu::rpc::Controller* cntl = \n"
" static_cast<baidu::rpc::Controller*>(cntl_base);\n"
" cntl->set_response_compress_type(brpc::COMPRESS_TYPE_GZIP);\n"
" uint64_t log_id = equest->log_id();\n"
" cntl->set_log_id(log_id);\n"
" ::baidu::paddle_serving::predictor::InferService* svr = \n"
......@@ -1023,6 +1025,7 @@ class PdsCodeGenerator : public CodeGenerator {
" brpc::ClosureGuard done_guard(done);\n"
" brpc::Controller* cntl = \n"
" static_cast<brpc::Controller*>(cntl_base);\n"
" cntl->set_response_compress_type(brpc::COMPRESS_TYPE_GZIP);\n"
" uint64_t log_id = request->log_id();\n"
" cntl->set_log_id(log_id);\n"
" ::baidu::paddle_serving::predictor::InferService* svr = \n"
......@@ -1067,6 +1070,7 @@ class PdsCodeGenerator : public CodeGenerator {
" brpc::ClosureGuard done_guard(done);\n"
" brpc::Controller* cntl = \n"
" static_cast<brpc::Controller*>(cntl_base);\n"
" cntl->set_response_compress_type(brpc::COMPRESS_TYPE_GZIP);\n"
" uint64_t log_id = request->log_id();\n"
" cntl->set_log_id(log_id);\n"
" ::baidu::paddle_serving::predictor::InferService* svr = \n"
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef BCLOUD
#include <base/atomicops.h>
#else
#include <butil/atomicops.h>
#endif
#include <errno.h>
#include <algorithm>
#include <deque>
#include <vector>
#include "core/predictor/common/inner_common.h"
#include "core/predictor/framework/infer_data.h"
#include "core/predictor/framework/memory.h"
#include <boost/function.hpp>
namespace im {
namespace bsf {
template <>
struct Task<baidu::paddle_serving::predictor::Tensor,
baidu::paddle_serving::predictor::Tensor> {
typedef Task<baidu::paddle_serving::predictor::Tensor,
baidu::paddle_serving::predictor::Tensor>
TaskT;
typedef baidu::paddle_serving::predictor::Tensor Tensor;
typedef baidu::paddle_serving::predictor::Tensor InType;
typedef baidu::paddle_serving::predictor::Tensor OutType;
typedef baidu::paddle_serving::predictor::BatchTensor BatchTensor;
typedef baidu::paddle_serving::predictor::BatchTensor InArrayT;
typedef baidu::paddle_serving::predictor::BatchTensor OutArrayT;
struct Segment {
Segment(void* p, size_t b, size_t s) : ptr(p), begin(b), size(s) {}
void* ptr;
size_t begin;
size_t size;
};
int read_fd;
int write_fd;
pid_t owner_tid;
const InArrayT* in;
OutArrayT* out;
size_t rem;
size_t size;
butil::atomic<size_t> index;
const BatchTensor* get(bool is_in) const {
if (is_in) {
return in;
} else {
return out;
}
}
BatchTensor* get(bool is_in) {
if (is_in) {
return const_cast<BatchTensor*>(in);
} else {
return out;
}
}
Task() {
read_fd = -1;
write_fd = -1;
owner_tid = -1;
in = NULL;
out = NULL;
rem = -1;
size = -1;
index.store(0, butil::memory_order_relaxed);
}
};
template <>
class BatchTasks<Task<baidu::paddle_serving::predictor::Tensor,
baidu::paddle_serving::predictor::Tensor>> {
public:
typedef baidu::paddle_serving::predictor::Tensor Tensor;
typedef baidu::paddle_serving::predictor::Tensor InType;
typedef baidu::paddle_serving::predictor::Tensor OutType;
typedef baidu::paddle_serving::predictor::DataBuf DataBuf;
typedef baidu::paddle_serving::predictor::MempoolWrapper MempoolWrapper;
typedef Task<baidu::paddle_serving::predictor::Tensor,
baidu::paddle_serving::predictor::Tensor>
TaskT;
typedef TaskMeta<TaskT> TaskMetaT;
typedef TaskT::InArrayT InArrayT;
typedef TaskT::OutArrayT OutArrayT;
explicit BatchTasks(size_t batch_size, bool batch_align = false)
: _batch_size(batch_size),
_rem_size(batch_size),
_batch_align(batch_align) {
_batch_in.clear();
_batch_out.clear();
_tasks.clear();
}
~BatchTasks() {
_batch_in.clear();
_batch_out.clear();
_tasks.clear();
}
static bool check_valid(const InArrayT& in,
OutArrayT& out, // NOLINT
bool align) { // NOLINT
if (align) {
if (out.count() <= 0 || out.size() <= 0) {
LOG(ERROR) << "Out tensor is empty, when aligned";
return false;
}
if (out.size() != in.size()) {
LOG(ERROR) << "In/Out tensor size not eq: " << out.size()
<< "!=" << in.size();
return false;
}
for (size_t fi = 0, shape0 = 0; fi < out.count(); ++fi) {
if (!out[fi].valid()) {
LOG(ERROR) << "Out[" << fi << "] tensor not valid";
return false;
}
if (out.size() != out[fi].shape0()) {
LOG(ERROR) << "Shape0 not consistency, " << out.size()
<< "!=" << out[fi].shape0() << ", " << fi;
return false;
}
}
}
return true;
}
size_t append_task(TaskT* task) {
size_t add = std::min(task->rem, _rem_size);
if (!_batch_align) {
add = task->rem;
}
TaskMetaT tm(task, task->in->size() - task->rem, add);
_tasks.push_back(tm);
task->rem -= add;
_rem_size -= add;
return _rem_size;
}
void merge_tasks() {
merge_input();
merge_output();
}
void merge_input() {
if (_tasks.size() <= 0 || _tasks[0].task->in->count() <= 0) {
return;
}
if (_tasks.size() == 1 && !_batch_align) {
TaskMetaT& tm = _tasks[0];
_batch_in = *(tm.task->in);
return;
}
merge_tensor(true);
}
void merge_output() {
if (_batch_align) {
if (_tasks.size() <= 0 || _tasks[0].task->out->count() <= 0) {
return;
}
}
if (_tasks.size() <= 0 || _tasks[0].task->out->count() <= 0) {
return;
}
TaskMetaT& tm = _tasks[0];
if (_tasks.size() == 1 && !_batch_align) {
_batch_out = *(tm.task->out);
return;
}
if (tm.task->out->size() <= 0) {
// shape is empty
_batch_out = *(tm.task->out);
return;
}
if ((*tm.task->out)[0].data.data() == 0 ||
(*tm.task->out)[0].data.size() == 0) {
_batch_out = *(tm.task->out);
return;
}
merge_tensor(false);
}
void merge_tensor(bool is_in) {
// accumulate batch size from fetched tasks
size_t batch_size = 0;
for (size_t ti = 0; ti < _tasks.size(); ++ti) {
TaskMetaT& tm = _tasks[ti];
size_t add = tm.end - tm.begin;
batch_size += add;
}
// merge all instanses in each tensor data
size_t tensor_count = _tasks[0].task->get(is_in)->count();
for (size_t fi = 0; fi < tensor_count; ++fi) {
const Tensor& head = (*(_tasks[0].task->get(is_in)))[fi];
Tensor batch_tensor;
batch_tensor.name = head.name;
batch_tensor.type = head.type;
batch_tensor.shape.push_back(batch_size);
size_t ins_ele_count = 1;
for (size_t si = 1; si < head.shape.size(); ++si) {
batch_tensor.shape.push_back(head.shape[si]);
ins_ele_count *= head.shape[si];
}
size_t tensor_ele_count = ins_ele_count * batch_size;
size_t ins_byte = ins_ele_count * head.ele_byte();
size_t tensor_byte = tensor_ele_count * head.ele_byte();
void* data_buf = MempoolWrapper::instance().malloc(tensor_byte);
if (!data_buf) {
LOG(ERROR) << "Malloc failed, size: " << tensor_byte;
return;
}
size_t data_byte = 0;
for (size_t ti = 0; ti < _tasks.size(); ++ti) {
TaskMetaT& tm = _tasks[ti];
size_t acc_byte = ins_byte * (tm.end - tm.begin);
if (data_byte + acc_byte > tensor_byte) {
LOG(ERROR) << "Invalid bytes: " << data_byte << " + " << acc_byte
<< " >= " << tensor_byte;
return;
}
const Tensor& tensor = (*(tm.task->get(is_in)))[fi];
memcpy(
reinterpret_cast<char*>(data_buf) + data_byte,
reinterpret_cast<char*>(tensor.data.data()) + tm.begin * ins_byte,
acc_byte);
data_byte += acc_byte;
}
if (data_byte != tensor_byte) {
LOG(ERROR) << "Invalid tensor byte: " << data_byte
<< " != " << tensor_byte;
return;
}
batch_tensor.data =
DataBuf(reinterpret_cast<char*>(data_buf), tensor_byte);
if (is_in) {
_batch_in.push_back(batch_tensor);
} else {
_batch_out.push_back(batch_tensor);
}
}
LOG(INFO) << "merge input(" << is_in << ") samples: " << batch_size
<< " from " << _tasks.size() << " pvs";
}
void notify_tasks() {
if (_batch_out.size() != _batch_in.size()) {
LOG(ERROR) << "batch size not consistency: " << _batch_out.size()
<< " != " << _batch_in.size();
return;
}
size_t tensor_count = _batch_out.count();
size_t batch_size = _batch_out.size();
for (size_t fi = 0; fi < tensor_count; ++fi) {
const Tensor& tensor = _batch_out[fi];
size_t ins_byte = tensor.ele_byte();
for (size_t si = 1; si < tensor.shape.size(); ++si) {
ins_byte *= tensor.shape[si];
}
for (size_t ti = 0, bi = 0, add = 0; ti < _tasks.size();
++ti, bi += add) {
OutArrayT* dst = _tasks[ti].task->out;
add = _tasks[ti].end - _tasks[ti].begin;
size_t offset_src = ins_byte * bi;
size_t add_byte = add * ins_byte;
if (_batch_align) { // merge all batchs
size_t offset_dst = ins_byte * _tasks[ti].begin;
void* ptr = const_cast<void*>((*dst)[fi].data.data());
memcpy(
reinterpret_cast<char*>(ptr) + offset_dst,
reinterpret_cast<char*>(_batch_out[fi].data.data()) + offset_src,
add_byte);
} else { // overwrite
if (dst->count() <= 0) {
dst->push_back(_batch_out[fi]);
} else {
(*dst)[fi] = _batch_out[fi];
}
(*dst)[fi].shape[0] = add;
(*dst)[fi].data = DataBuf(
reinterpret_cast<char*>(_batch_out[fi].data.data()) + offset_src,
add_byte);
}
}
}
for (size_t ti = 0; ti < _tasks.size(); ++ti) {
TaskT* task = _tasks[ti].task;
size_t begin = _tasks[ti].begin;
size_t end = _tasks[ti].end;
size_t add = end - begin;
size_t index = task->index.fetch_add(add);
if ((index + add) >= task->in->size()) {
char c = 0;
while (write(task->write_fd, &c, 1) != 1 && errno == EINTR) {
}
butil::return_object(task);
}
}
}
const typename TaskT::InArrayT& in() const { return _batch_in; }
typename TaskT::OutArrayT& out() { return _batch_out; }
size_t task_size() { return _tasks.size(); }
private:
std::vector<TaskMetaT> _tasks;
InArrayT _batch_in;
OutArrayT _batch_out;
size_t _batch_size;
size_t _rem_size;
bool _batch_align;
};
} // namespace bsf
} // namespace im
......@@ -24,6 +24,7 @@
#include <boost/bind.hpp>
#include "core/predictor/common/inner_common.h"
#include "core/predictor/framework/memory.h"
namespace im {
namespace bsf {
......@@ -35,7 +36,7 @@ void* TaskExecutor<TaskT>::thread_entry(void* args) {
static_cast<TaskExecutor<TaskT>*>(context->executor);
executor->work(context);
return NULL;
return nullptr;
}
template <typename TaskT>
......@@ -70,7 +71,7 @@ int TaskExecutor<TaskT>::start(uint32_t thread_num, uint32_t init_timeout_sec) {
_thread_contexts.push_back(&contexts[i]);
}
int init_timeout = init_timeout_sec * 1000 * 1000;
size_t init_timeout = init_timeout_sec * 1000 * 1000;
bool has_error = false;
bool has_timeout = true;
......@@ -102,7 +103,7 @@ int TaskExecutor<TaskT>::start(uint32_t thread_num, uint32_t init_timeout_sec) {
}
// 100ms
const int sleep_interval = 100 * 1000;
const size_t sleep_interval = 100 * 1000;
usleep(sleep_interval);
init_timeout -= sleep_interval;
}
......@@ -125,18 +126,21 @@ void TaskExecutor<TaskT>::stop() {
}
template <typename TaskT>
TaskHandler<TaskT> TaskExecutor<TaskT>::schedule(const InArrayT& in,
OutArrayT& out) { // NOLINT
TaskHandler<TaskT> TaskExecutor<TaskT>::schedule(
const void* inVectorT_ptr,
void* outVectorT_ptr) { // NOLINT
TaskT* task = butil::get_object<TaskT>();
if (!task) {
LOG(ERROR) << "Failed get TaskT from object pool";
return TaskHandler<TaskT>::valid_handle();
}
/*
if (!BatchTasks<TaskT>::check_valid(in, out, _batch_align)) {
LOG(ERROR) << "Invalid input & output";
return TaskHandler<TaskT>::valid_handle();
}
*/
int fds[2];
int rc = pipe(fds);
......@@ -150,10 +154,9 @@ TaskHandler<TaskT> TaskExecutor<TaskT>::schedule(const InArrayT& in,
task->write_fd = fds[1];
task->owner_tid = ::syscall(SYS_gettid);
task->in = &in;
task->out = &out;
task->rem = in.size();
task->size = in.size();
task->inVectorT_ptr = (const InVectorT*)inVectorT_ptr;
task->outVectorT_ptr = (OutVectorT*)outVectorT_ptr;
task->rem = task->batch_size();
task->index.store(0, butil::memory_order_relaxed);
AutoMutex lock(_mut);
......@@ -163,8 +166,13 @@ TaskHandler<TaskT> TaskExecutor<TaskT>::schedule(const InArrayT& in,
return TaskHandler<TaskT>(*task);
}
// this function is accessed by multi thread.
// so AutoMutex at first.
// so batch.append_task is thread safe.
// you dont need to add extra lock in append_task()
template <typename TaskT>
bool TaskExecutor<TaskT>::fetch_batch(BatchTasks<TaskT>& batch) { // NOLINT
bool TaskExecutor<TaskT>::move_task_to_batch(
BatchTasks<TaskT>& batch) { // NOLINT
AutoMutex lock(_mut);
while (_task_queue.empty()) {
THREAD_COND_WAIT(&_cond, &_mut);
......@@ -187,8 +195,30 @@ bool TaskExecutor<TaskT>::fetch_batch(BatchTasks<TaskT>& batch) { // NOLINT
return true;
}
// this function is accessed by multi thread.
// move_task_to_batch have add lock inside the function.
// Packaging 1 TaskT as 1 or Several TaskMeta.
// TaskT is from the SingleTon TaskExecutor`s _task_queue
// although TaskMeta is a local variable, but several TaskMeta may points to
// the same TaskT which is get from the SingleTon TaskExecutor`s _task_queue.
// put TaskMeta to the local variable BatchTasks<TaskT> batch.
// batch.merge_tasks() and batch.notify_tasks() has no lock.
// BatchTasks<TaskT> batch itself is a local variable, it`s thread safe.
// If batch.merge_tasks() and batch.notify_tasks() do something to TaskMeta
// you need to pay attention to that.
// Multi-Thread deal with different TaskMeta(cause it`s created as local
// variable)
// But different TaskMeta may points to the same TaskT
// which is get from the SingleTon TaskExecutor`s _task_queue.
template <typename TaskT>
int TaskExecutor<TaskT>::work(ThreadContext<TaskT>* context) {
if (MempoolWrapper::instance().thread_initialize() != 0) {
LOG(ERROR) << "Failed thread initialize mempool";
return -1;
}
if (_thread_init_fn != NULL) {
if (_thread_init_fn(context->user_thread_context) != 0) {
LOG(ERROR) << "execute thread init thunk failed, BSF thread will exit";
......@@ -207,10 +237,15 @@ int TaskExecutor<TaskT>::work(ThreadContext<TaskT>* context) {
}
}
if (MempoolWrapper::instance().thread_clear() != 0) {
LOG(ERROR) << "Failed thread clear mempool";
return -1;
}
BatchTasks<TaskT> batch(_batch_size, _batch_align);
if (fetch_batch(batch)) {
if (move_task_to_batch(batch)) {
batch.merge_tasks();
_fn(batch.in(), batch.out());
_fn(&batch.in(), &batch.out());
batch.notify_tasks();
}
}
......@@ -219,9 +254,10 @@ int TaskExecutor<TaskT>::work(ThreadContext<TaskT>* context) {
}
template <typename InItemT, typename OutItemT>
bool TaskManager<InItemT, OutItemT>::schedule(const InArrayT& in,
OutArrayT& out) { // NOLINT
TaskHandler<TaskT> handler = _executor.schedule(in, out);
bool TaskManager<InItemT, OutItemT>::schedule(const void* in,
void* out) { // NOLINT
TaskHandler<TaskT> handler =
TaskExecutorVector<TaskT>::instance()[_model_index].schedule(in, out);
if (handler.valid()) {
_task_owned = handler;
......
此差异已折叠。
......@@ -56,15 +56,23 @@ int ReloadableInferEngine::proc_initialize(const configure::EngineDesc& conf,
}
// init bsf framework
im::bsf::TaskExecutor<TaskT>::instance()->set_thread_init_fn(
boost::bind(&InferEngine::thrd_initialize_impl, this));
im::bsf::TaskExecutor<TaskT>::instance()->set_thread_reset_fn(
boost::bind(&InferEngine::thrd_clear_impl, this));
im::bsf::TaskExecutor<TaskT>::instance()->set_thread_callback_fn(
boost::bind(&InferEngine::task_infer_impl, this, _1, _2));
im::bsf::TaskExecutor<TaskT>::instance()->set_batch_size(_infer_batch_size);
im::bsf::TaskExecutor<TaskT>::instance()->set_batch_align(_infer_batch_align);
if (im::bsf::TaskExecutor<TaskT>::instance()->start(_infer_thread_num) != 0) {
im::bsf::TaskExecutorVector<TaskT>::instance()[_model_index]
.set_thread_init_fn(
boost::bind(&InferEngine::thrd_initialize_impl, this));
im::bsf::TaskExecutorVector<TaskT>::instance()[_model_index]
.set_thread_init_fn(
boost::bind(&InferEngine::thrd_initialize_impl, this));
im::bsf::TaskExecutorVector<TaskT>::instance()[_model_index]
.set_thread_reset_fn(boost::bind(&InferEngine::thrd_clear_impl, this));
im::bsf::TaskExecutorVector<TaskT>::instance()[_model_index]
.set_thread_callback_fn(
boost::bind(&InferEngine::task_infer_impl, this, _1, _2));
im::bsf::TaskExecutorVector<TaskT>::instance()[_model_index].set_batch_size(
_infer_batch_size);
im::bsf::TaskExecutorVector<TaskT>::instance()[_model_index].set_batch_align(
_infer_batch_align);
if (im::bsf::TaskExecutorVector<TaskT>::instance()[_model_index].start(
_infer_thread_num) != 0) {
LOG(ERROR) << "Failed start bsf executor, threads:" << _infer_thread_num;
return -1;
}
......@@ -75,6 +83,11 @@ int ReloadableInferEngine::proc_initialize(const configure::EngineDesc& conf,
return 0;
}
// Multiple threads will enter this method of the same object
// One Model corresponds to One ReloadableInferEngine object.
// ReloadableInferEngine object is Process object.
// One ReloadableInferEngine object can have several ModelData<EngineCore>
// ModelData<EngineCore> is Thread object.
int ReloadableInferEngine::infer(const void* in,
void* out,
uint32_t batch_size) {
......@@ -82,9 +95,10 @@ int ReloadableInferEngine::infer(const void* in,
return infer_impl(in, out, batch_size);
}
im::bsf::TaskManager<Tensor, Tensor> task_manager;
task_manager.schedule(*(reinterpret_cast<const BatchTensor*>(in)),
*(reinterpret_cast<BatchTensor*>(out)));
im::bsf::TaskManager<paddle::PaddleTensor, paddle::PaddleTensor> task_manager(
_model_index);
task_manager.schedule(in, out);
task_manager.wait();
return 0;
}
......@@ -110,7 +124,7 @@ int ReloadableInferEngine::proc_finalize() {
}
if (_infer_thread_num > 0) {
im::bsf::TaskExecutor<TaskT>::instance()->stop();
im::bsf::TaskExecutorVector<TaskT>::instance()[_model_index].stop();
}
return 0;
}
......@@ -191,6 +205,7 @@ int VersionedInferEngine::proc_initialize(const configure::EngineDesc& conf,
std::string engine_type = conf.type();
InferEngine* engine =
StaticInferFactory::instance().generate_object(engine_type);
engine->set_model_index(_model_index);
if (!engine) {
LOG(ERROR) << "Failed generate engine with type:" << engine_type;
return -1;
......@@ -362,23 +377,30 @@ int VersionedInferEngine::infer_impl(const void* in,
uint32_t batch_size) {
return -1;
}
int VersionedInferEngine::task_infer_impl(const BatchTensor& in,
BatchTensor& out) { // NOLINT
int VersionedInferEngine::task_infer_impl(const void* in,
void* out) { // NOLINT
return -1;
}
int InferManager::proc_initialize(const char* path, const char* file) {
int InferManager::proc_initialize(const char* path,
const char* file,
std::shared_ptr<int> engine_index_ptr) {
ModelToolkitConf model_toolkit_conf;
if (configure::read_proto_conf(path, file, &model_toolkit_conf) != 0) {
LOG(ERROR) << "failed load infer config, path: " << path << "/" << file;
return -1;
}
size_t engine_num = model_toolkit_conf.engines_size();
for (size_t ei = 0; ei < engine_num; ++ei) {
uint32_t engine_num = model_toolkit_conf.engines_size();
im::bsf::TaskExecutorVector<TaskT>::instance().resize(*engine_index_ptr +
engine_num);
for (uint32_t ei = 0; ei < engine_num; ++ei) {
LOG(INFO) << "model_toolkit_conf.engines(" << ei
<< ").name: " << model_toolkit_conf.engines(ei).name();
std::string engine_name = model_toolkit_conf.engines(ei).name();
VersionedInferEngine* engine = new (std::nothrow) VersionedInferEngine();
int temp_engine_index_ptr = *engine_index_ptr;
engine->set_model_index(temp_engine_index_ptr);
*engine_index_ptr = temp_engine_index_ptr + 1;
if (!engine) {
LOG(ERROR) << "Failed generate versioned engine: " << engine_name;
return -1;
......
......@@ -17,6 +17,8 @@
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#include <functional>
#include <memory>
#include <numeric>
#include <string>
#include <utility>
......@@ -25,6 +27,7 @@
#include "core/predictor/framework/bsf.h"
#include "core/predictor/framework/factory.h"
#include "core/predictor/framework/infer_data.h"
#include "core/predictor/framework/memory.h"
#include "paddle_inference_api.h" // NOLINT
namespace baidu {
namespace paddle_serving {
......@@ -71,7 +74,7 @@ class InferEngine {
virtual int infer(const void* in, void* out, uint32_t batch_size = -1) {
return infer_impl(in, out, batch_size);
}
virtual void set_model_index(uint32_t index) { _model_index = index; }
virtual int reload() = 0;
virtual uint64_t version() const = 0;
......@@ -86,12 +89,13 @@ class InferEngine {
virtual int infer_impl(const void* in,
void* out,
uint32_t batch_size = -1) = 0;
virtual int task_infer_impl(const BatchTensor& in,
BatchTensor& out) = 0; // NOLINT
virtual int task_infer_impl(const void* in, void* out) = 0; // NOLINT
protected:
uint32_t _model_index;
// end: framework inner call
};
typedef im::bsf::Task<paddle::PaddleTensor, paddle::PaddleTensor> TaskT;
class ReloadableInferEngine : public InferEngine {
public:
virtual ~ReloadableInferEngine() {}
......@@ -104,7 +108,6 @@ class ReloadableInferEngine : public InferEngine {
};
virtual int load(const configure::EngineDesc& conf) = 0;
typedef im::bsf::Task<Tensor, Tensor> TaskT;
int proc_initialize_impl(const configure::EngineDesc& conf, bool version);
......@@ -179,6 +182,8 @@ struct ModelData {
delete cores[1];
}
void* get() { return cores[current_idx]->get(); }
EngineCore* cores[2];
uint32_t current_idx;
};
......@@ -191,14 +196,20 @@ class DBReloadableInferEngine : public ReloadableInferEngine {
int proc_initialize(const configure::EngineDesc& conf, bool version) {
THREAD_KEY_CREATE(&_skey, NULL);
THREAD_MUTEX_INIT(&_mutex, NULL);
gpu_index = 0;
return ReloadableInferEngine::proc_initialize(conf, version);
}
// 进程初始化会调用load,但由于未执行线程初始化,所以_reload_vec为空,不再继续执行。
// 热加载的话会调用load,由于线程已经初始化,_reload_vec不为空,所以继续执行load_data操作加载数据。
// 线程初始化会执行load_data操作加载数据,然后将engine加入_reload_vec中。
// 每个模型只有一个CloneDBReloadableInferEngine对象。
// 但一个CloneDBReloadableInferEngine对象,可以包含N个EngineCore。
virtual int load(const configure::EngineDesc& conf) {
if (_reload_vec.empty()) {
return 0;
}
gpu_index = 0;
for (uint32_t ti = 0; ti < _reload_vec.size(); ++ti) {
if (load_data(_reload_vec[ti], conf) != 0) {
LOG(ERROR) << "Failed reload engine model: " << ti;
......@@ -210,7 +221,8 @@ class DBReloadableInferEngine : public ReloadableInferEngine {
return 0;
}
int load_data(ModelData<EngineCore>* md, const configure::EngineDesc& conf) {
virtual int load_data(ModelData<EngineCore>* md,
const configure::EngineDesc& conf) {
uint32_t next_idx = (md->current_idx + 1) % 2;
if (md->cores[next_idx]) {
delete md->cores[next_idx];
......@@ -219,28 +231,29 @@ class DBReloadableInferEngine : public ReloadableInferEngine {
md->cores[next_idx] = new (std::nothrow) EngineCore;
// params.dump();
if (!md->cores[next_idx] || md->cores[next_idx]->create(conf) != 0) {
size_t gpu_ids_num = conf.gpu_ids_size();
im::bsf::AutoMutex lock(_mutex);
int gpu_id = -1;
if (gpu_ids_num > 0) {
gpu_id = conf.gpu_ids(gpu_index % gpu_ids_num);
}
if (!md->cores[next_idx] ||
md->cores[next_idx]->create(conf, gpu_id) != 0) {
LOG(ERROR) << "Failed create model, path: " << conf.model_dir();
return -1;
}
gpu_index++;
md->current_idx = next_idx;
return 0;
}
virtual int thrd_initialize_impl() {
// memory pool to be inited in non-serving-threads
if (MempoolWrapper::instance().thread_initialize() != 0) {
LOG(ERROR) << "Failed thread initialize mempool";
return -1;
}
ModelData<EngineCore>* md = new (std::nothrow) ModelData<EngineCore>;
if (!md || load_data(md, _conf) != 0) {
LOG(ERROR) << "Failed create thread data from " << _conf.model_dir();
return -1;
}
LOG(ERROR) << "THREAD_SETSPECIFIC _skey = md";
THREAD_SETSPECIFIC(_skey, md);
im::bsf::AutoMutex lock(_mutex);
_reload_vec.push_back(md);
......@@ -248,11 +261,33 @@ class DBReloadableInferEngine : public ReloadableInferEngine {
}
int thrd_clear_impl() {
// for non-serving-threads
if (MempoolWrapper::instance().thread_clear() != 0) {
LOG(ERROR) << "Failed thread clear mempool";
return -1;
}
// actually, there are 2 kinds of multi-thread.
// 1. brpc thread 2. bsf Task thread
// each request is in 1-single brpc thread.
// IF (bsf Task thread is not used)
// every single brpc thread corresponds to all the DBReloadableInferEngines.
// each request runs all models in 1-single brpc thread.
// every single brpc thread will create or clone N predictor.
// N = the number of Model.
// so if there are 2 models, and --thread 10.
// each brpc thread will create predictor of Model-1 and Model-2.
// there are totally 10 predictors of Model-1 and 10 predictors of Model-2
// cause there are 10 brpc threads.
// IF bsf Task thread is used。
// there will be a ThreadPool called bsf TaskExecutor.
// TaskExecutorVector is the vector of TaskExecutor.
// the number of TaskExecutor equals to the number of Model.
// 1 TaskExecutor corresponding to 1 Model.
// 1 TaskExecutor have N bsf threads.
// 1 bsf thread corresponds to 1 predictor of
// the Model corresponding to the TaskExecutor.
// brpc thread only put the data into the task_queue(which is in
// TaskExecutor)
// EngineCore->infer() is running in bsf Task thread.
// MempoolWrapper::instance() is actually a Thread-Local Mempool.
// so it belongs to a single Thread.
return 0;
}
......@@ -278,6 +313,7 @@ class DBReloadableInferEngine : public ReloadableInferEngine {
THREAD_KEY_T _skey;
THREAD_MUTEX_T _mutex;
std::vector<ModelData<EngineCore>*> _reload_vec;
int gpu_index = 0;
};
// 多个EngineCore共用同一份模型数据
......@@ -287,88 +323,76 @@ class CloneDBReloadableInferEngine
public:
virtual ~CloneDBReloadableInferEngine() {}
virtual int proc_initialize(const configure::EngineDesc& conf, bool version) {
_pd = new (std::nothrow) ModelData<EngineCore>;
if (!_pd) {
LOG(ERROR) << "Failed to allocate for ProcData";
return -1;
}
return DBReloadableInferEngine<EngineCore>::proc_initialize(conf, version);
}
// 进程初始化会调用load,但由于未执行线程初始化,所以_reload_vec为空,不再继续执行。
// 热加载的话会调用load,由于线程已经初始化,_reload_vec不为空,所以继续执行load_data操作加载数据。
// 线程初始化会执行load_data操作加载数据,然后将engine加入_reload_vec中。
// 每个模型只有一个CloneDBReloadableInferEngine对象。
// 但一个CloneDBReloadableInferEngine对象,可以包含N个EngineCore。
virtual int load(const configure::EngineDesc& conf) {
// 加载进程级模型数据
if (!_pd ||
DBReloadableInferEngine<EngineCore>::load_data(_pd, conf) != 0) {
LOG(ERROR) << "Failed to create common model from [" << conf.model_dir()
<< "].";
return -1;
virtual int load_data(ModelData<EngineCore>* md,
const configure::EngineDesc& conf) {
uint32_t next_idx = (md->current_idx + 1) % 2;
if (md->cores[next_idx]) {
delete md->cores[next_idx];
}
LOG(WARNING) << "Succ load common model[" << _pd->cores[_pd->current_idx]
<< "], path[" << conf.model_dir() << "].";
md->cores[next_idx] = new (std::nothrow) EngineCore;
if (DBReloadableInferEngine<EngineCore>::_reload_vec.empty()) {
return 0;
// params.dump();
// gpu_ids_num > 0 is always true.
// if use CPU, gpu_ids = [-1].
// if gpu_ids_num = 0, which means no gpuid is given.
// so we should set gpu_ids_num = 1, and gpu_id = -1.
// so that we can create at least 1 predictor.
size_t gpu_ids_num = conf.gpu_ids_size();
im::bsf::AutoMutex lock(DBReloadableInferEngine<EngineCore>::_mutex);
int gpu_id = -1;
if (gpu_ids_num > 0) {
gpu_id = conf.gpu_ids(DBReloadableInferEngine<EngineCore>::gpu_index %
gpu_ids_num);
} else {
gpu_ids_num = 1;
}
for (uint32_t ti = 0;
ti < DBReloadableInferEngine<EngineCore>::_reload_vec.size();
++ti) {
if (load_data(DBReloadableInferEngine<EngineCore>::_reload_vec[ti],
_pd->cores[_pd->current_idx]) != 0) {
LOG(ERROR) << "Failed reload engine model: " << ti;
// gpu_index will be set to be 0, when load() or proc_initial() is called.
// gpu_index < gpu_ids_num, means there are predictors still not create
// on some GPU card.
// so we need to create the predictor.
// gpu_index >= gpu_ids_num, means each GPU card has already create one.
// so we need to clone the predictor.
if (DBReloadableInferEngine<EngineCore>::gpu_index < gpu_ids_num) {
if (!md->cores[next_idx] ||
md->cores[next_idx]->create(conf, gpu_id) != 0) {
LOG(ERROR) << "Failed create model, path: " << conf.model_dir();
return -1;
}
DBReloadableInferEngine<EngineCore>::gpu_index++;
md->current_idx = next_idx;
if (_cloneTemplate.size() <
DBReloadableInferEngine<EngineCore>::gpu_index) {
_cloneTemplate.push_back(md);
} else {
_cloneTemplate[DBReloadableInferEngine<EngineCore>::gpu_index - 1] = md;
}
} else {
int template_index = DBReloadableInferEngine<EngineCore>::gpu_index %
_cloneTemplate.size();
if (!md->cores[next_idx] ||
md->cores[next_idx]->clone(_cloneTemplate[template_index]->get()) !=
0) {
LOG(ERROR) << "Failed clone model from core";
return -1;
}
DBReloadableInferEngine<EngineCore>::gpu_index++;
md->current_idx = next_idx;
LOG(WARNING) << "core clone model succ, cur_idx[" << md->current_idx
<< "].";
}
LOG(WARNING) << "Succ load clone model, path[" << conf.model_dir() << "]";
return 0;
}
// 加载线程级对象,多个线程级对象共用pd_core的模型数据
int load_data(ModelData<EngineCore>* td, EngineCore* pd_core) {
uint32_t next_idx = (td->current_idx + 1) % 2;
if (td->cores[next_idx]) {
delete td->cores[next_idx];
}
td->cores[next_idx] = new (std::nothrow) EngineCore;
if (!td->cores[next_idx] ||
td->cores[next_idx]->clone(pd_core->get()) != 0) {
LOG(ERROR) << "Failed clone model from pd_core[ " << pd_core << "], idx["
<< next_idx << "]";
return -1;
}
td->current_idx = next_idx;
LOG(WARNING) << "td_core[" << td->cores[td->current_idx]
<< "] clone model from pd_core[" << pd_core
<< "] succ, cur_idx[" << td->current_idx << "].";
return 0;
}
virtual int thrd_initialize_impl() {
// memory pool to be inited in non-serving-threads
if (MempoolWrapper::instance().thread_initialize() != 0) {
LOG(ERROR) << "Failed thread initialize mempool";
return -1;
}
ModelData<EngineCore>* md = new (std::nothrow) ModelData<EngineCore>;
if (!md || load_data(md, _pd->cores[_pd->current_idx]) != 0) {
LOG(ERROR) << "Failed clone thread data, origin_core["
<< _pd->cores[_pd->current_idx] << "].";
return -1;
}
THREAD_SETSPECIFIC(DBReloadableInferEngine<EngineCore>::_skey, md);
im::bsf::AutoMutex lock(DBReloadableInferEngine<EngineCore>::_mutex);
DBReloadableInferEngine<EngineCore>::_reload_vec.push_back(md);
return 0;
}
protected:
ModelData<EngineCore>*
_pd; // 进程级EngineCore,多个线程级EngineCore共用该对象的模型数据
// 模板EngineCore,如果已创建,则多个线程级EngineCore共用该对象的模型数据
std::vector<ModelData<EngineCore>*> _cloneTemplate;
};
template <typename EngineCore>
......@@ -505,8 +529,8 @@ class FluidInferEngine : public CloneDBReloadableInferEngine<EngineCore> {
return 0;
}
int task_infer_impl(const BatchTensor& in, BatchTensor& out) { // NOLINT
return infer_impl(&in, &out);
int task_infer_impl(const void* in, void* out) { // NOLINT
return infer_impl(in, out);
}
};
......@@ -559,7 +583,7 @@ class VersionedInferEngine : public InferEngine {
int infer_impl(const void* in, void* out, uint32_t batch_size = -1);
int task_infer_impl(const BatchTensor& in, BatchTensor& out);
int task_infer_impl(const void* in, void* out);
private:
boost::unordered_map<uint64_t, InferEngine*> _versions;
......@@ -572,7 +596,9 @@ class InferManager {
return ins;
}
int proc_initialize(const char* path, const char* file);
int proc_initialize(const char* path,
const char* file,
std::shared_ptr<int> engine_index_ptr);
int thrd_initialize();
......
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
......@@ -135,12 +135,14 @@ int Resource::initialize(const std::string& path, const std::string& file) {
if (FLAGS_enable_model_toolkit) {
size_t model_toolkit_num = resource_conf.model_toolkit_path_size();
std::shared_ptr<int> engine_index_ptr(new int(0));
for (size_t mi = 0; mi < model_toolkit_num; ++mi) {
std::string model_toolkit_path = resource_conf.model_toolkit_path(mi);
std::string model_toolkit_file = resource_conf.model_toolkit_file(mi);
if (InferManager::instance().proc_initialize(
model_toolkit_path.c_str(), model_toolkit_file.c_str()) != 0) {
if (InferManager::instance().proc_initialize(model_toolkit_path.c_str(),
model_toolkit_file.c_str(),
engine_index_ptr) != 0) {
LOG(ERROR) << "failed proc initialize modeltoolkit, config: "
<< model_toolkit_path << "/" << model_toolkit_file;
return -1;
......
......@@ -16,6 +16,7 @@
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "core/cube/cube-api/include/cube_api.h"
#include "core/predictor/common/inner_common.h"
......
......@@ -91,6 +91,7 @@ int ServerManager::start_and_wait() {
}
}
// rpc multi-thread start from here.
if (_server.Start(FLAGS_port, &_options) != 0) {
LOG(ERROR) << "Failed to start Paddle Inference Server";
return -1;
......
文件模式从 100755 更改为 100644
......@@ -24,7 +24,7 @@ namespace fugue {
namespace memory {
void Region::init() {
_big_mem_capacity = 64 * 1024 * 1024; // 64MB
_big_mem_capacity = 128 * 1024 * 1024; // 128MB
_big_mem_start = new char[_big_mem_capacity];
}
......
......@@ -129,7 +129,7 @@ class FreeList {
to get the class Pointer
for example
T is the member of class Node, T data, 'data' is the name.
T* value is the member(pointer type) class Node
T* value is the member(pointer type) of class Node
so we can get the Node* by calling container_of(value, Node, data)
*/
Node* node = container_of(value, Node, data);
......@@ -261,7 +261,11 @@ struct BlockReference {
// because BlockFreeList is a threal-safe Singleton.
// so we don`t release Block, it is global memory.
// total number is 32*1024
// total number is 256*1024.
// the MAX_BLOCK_COUNT of Region(one thread one Region) is 1024.
// so BlockFreeList allow 256 Region(means 256 thread).
// the memory used by BlockFreeListType is sizeof(void*)*256*1024.
// Block(2MB) memory is created only when get() is called.
class BlockFreeList {
public:
static const int MAX_BLOCK_COUNT = 256 * 1024;
......@@ -341,9 +345,10 @@ class Region {
2 * 1024 *
1024; // 2MB,means when you need less than 2M, get memory from Block.
// 64MB,means when you need less than 64MB, get memory from BigMemory instead
// 128MB,means when you need less than 128MB, get memory from BigMemory
// instead
// of BigNode
static const int BIGNODE_MEM_THRESHOLD = (64 * 1024 * 1024 + 1);
static const int BIGNODE_MEM_THRESHOLD = (128 * 1024 * 1024 + 1);
static const int COUNTER_SIZE =
BIGNODE_MEM_THRESHOLD / BIG_MEM_THRESHOLD + 1; // this is not used
......@@ -374,7 +379,8 @@ class Mempool {
void* malloc(size_t size) {
size = _align(size);
// It does not enter the if statement the first time.
// Because the block has not been used up, it will enter.
// The if statement may enter after the block is created.
// If the block has not been used up, it will enter.
if (size <= _free_size) {
void* p = _free_cursor;
_free_size -= size;
......@@ -392,7 +398,7 @@ class Mempool {
return;
}
// memory in Block,update the pointer.
// memory in _block,update the pointer.
if (_free_cursor - size == static_cast<char*>(p)) {
// for example, you need to release -(8+1)bytes
// you can only release -8bytes,cause -(8+2)byte is used by other.
......@@ -424,9 +430,8 @@ class Mempool {
}
// 可能返回的是单独Region中malloc的内存。
// 也可能是Block,例如new_size=1M, old_data原本的指针头就在1.2M处,old_size
// =
// 0.5M
// 也可能是Block,例如new_size=1M, old_data原本的指针头就在1.2M处
// old_size = 0.5M
// 此时,_free_size = 0.3M,new_size<2M,但是required = 1-0.5 >0.3
// 分配出来的就是Block,但是该Block没有并很完美的利用完全。
void* p = this->malloc_from_region(new_size);
......
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
......@@ -68,13 +68,14 @@ static bvar::PassiveStatus<std::string> s_predictor_revision(
DEFINE_bool(V, false, "print version, bool");
DEFINE_bool(g, false, "user defined gflag path");
DECLARE_string(flagfile);
/*
namespace bthread {
extern pthread_mutex_t g_task_control_mutex;
}
pthread_mutex_t g_worker_start_fn_mutex = PTHREAD_MUTEX_INITIALIZER;
*/
void pthread_worker_start_fn() {
/*
while (pthread_mutex_lock(&g_worker_start_fn_mutex) != 0) {
}
......@@ -83,15 +84,18 @@ void pthread_worker_start_fn() {
if (lock_status == EBUSY || lock_status == EAGAIN) {
pthread_mutex_unlock(&bthread::g_task_control_mutex);
}
*/
Resource::instance().thread_initialize();
// Try to avoid deadlock in bthread
/*
if (lock_status == EBUSY || lock_status == EAGAIN) {
while (pthread_mutex_lock(&bthread::g_task_control_mutex) != 0) {
}
}
pthread_mutex_unlock(&g_worker_start_fn_mutex);
*/
}
static void g_change_server_port() {
......@@ -126,7 +130,7 @@ int main(int argc, char** argv) {
return 0;
}
//google::ParseCommandLineFlags(&argc, &argv, true);
// google::ParseCommandLineFlags(&argc, &argv, true);
g_change_server_port();
......@@ -202,7 +206,7 @@ int main(int argc, char** argv) {
}
VLOG(2) << "Succ call pthread worker start function";
//this is not used by any code segment,which can be cancelled.
// this is not used by any code segment,which can be cancelled.
if (Resource::instance().general_model_initialize(FLAGS_resource_path,
FLAGS_resource_file) != 0) {
LOG(ERROR) << "Failed to initialize general model conf: "
......
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
......@@ -24,17 +24,16 @@ message Tensor {
repeated int32 int_data = 2;
repeated int64 int64_data = 3;
repeated float float_data = 4;
optional int32 elem_type = 5;
repeated int32 shape = 6;
repeated int32 lod = 7; // only for fetch tensor currently
optional int32 elem_type =
5; // 0 means int64, 1 means float32, 2 means int32, 3 means bytes(string)
repeated int32 shape = 6; // shape should include batch
repeated int32 lod = 7; // only for fetch tensor currently
optional string name = 8; // get from the Model prototxt
optional string alias_name = 9; // get from the Model prototxt
};
message FeedInst { repeated Tensor tensor_array = 1; };
message FetchInst { repeated Tensor tensor_array = 1; };
message Request {
repeated FeedInst insts = 1;
repeated Tensor tensor = 1;
repeated string fetch_var_names = 2;
optional bool profile_server = 3 [ default = false ];
required uint64 log_id = 4 [ default = 0 ];
......@@ -46,7 +45,7 @@ message Response {
};
message ModelOutput {
repeated FetchInst insts = 1;
repeated Tensor tensor = 1;
optional string engine_name = 2;
}
......
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
......@@ -242,6 +242,9 @@ InvalidArgumentError: Device id must be less than GPU count, but received id is:
**A:** 支持离线部署,需要把一些相关的[依赖包](https://github.com/PaddlePaddle/Serving/blob/develop/doc/COMPILE.md)提前准备安装好
#### Q: Docker中启动server IP地址 127.0.0.1 与 0.0.0.0 差异
**A:** 您必须将容器的主进程设置为绑定到特殊的 0.0.0.0 “所有接口”地址,否则它将无法从容器外部访问。在Docker中 127.0.0.1 代表“这个容器”,而不是“这台机器”。如果您从容器建立到 127.0.0.1 的出站连接,它将返回到同一个容器;如果您将服务器绑定到 127.0.0.1,接收不到来自外部的连接。
## 预测问题
#### Q: 使用GPU第一次预测时特别慢,如何调整RPC服务的等待时间避免超时?
......@@ -321,6 +324,15 @@ GLOG_v=2 python -m paddle_serving_server.serve --model xxx_conf/ --port 9999
**A:** Logid默认为0(后续应该有自动生成Logid的计划,当前版本0.4.0),Client端通过在predict函数中指定log_id参数传递
#### Q: C++Server出现问题如何调试和定位
**A:** 推荐您使用gdb进行定位和调试,如果您使用docker,在启动容器时候,需要加上docker run --privileged参数,开启特权模式,这样才能在docker容器中使用gdb定位和调试
如果您C++端出现coredump,一般而言会生成一个core文件,若没有,则应开启生成core文件选项,使用ulimit -c unlimited命令。
使用gdb调试core文件的方法为:gdb <可执行文件> <core文件>,进入后输入bt指令,一般即可显示出错在哪一行。
注意:可执行文件路径是C++ bin文件的路径,而不是python命令,一般为类似下面的这种/usr/local/lib/python3.6/site-packages/paddle_serving_server/serving-gpu-102-0.6.2/serving
## 性能优化
# HTTP方式访问Server
Paddle Serving服务端目前提供了支持Http直接访问的功能,本文档显示了详细信息。
## 基本原理
BRPC-Server端支持通过Http的方式被访问,各种语言都有实现Http请求的一些库,所以Java/Python/Go等BRPC支持不太完善的语言,可以通过Http的方式直接访问服务端进行预测。
### Http方式
基本流程和原理:客户端需要将数据按照Proto约定的格式(请参阅[`core/general-server/proto/general_model_service.proto`](../core/general-server/proto/general_model_service.proto))封装在Http请求的请求体中。
BRPC-Server会尝试去JSON字符串中再去反序列化出Proto格式的数据,从而进行后续的处理。
### Http+protobuf方式
各种语言都提供了对ProtoBuf的支持,如果您对此比较熟悉,您也可以先将数据使用ProtoBuf序列化,再将序列化后的数据放入Http请求数据体中,然后指定Content-Type: application/proto,从而使用http/h2+protobuf二进制串访问服务。
**理论上讲,序列化/反序列化的性能从高到底排序为:protobuf > http/h2+protobuf > http**
## 示例
我们将以python/examples/fit_a_line为例,讲解如何通过Http访问Server端。
### 获取模型
```shell
sh get_data.sh
```
## 开启服务端
```shell
python3.6 -m paddle_serving_server.serve --model uci_housing_model --thread 10 --port 9393
```
服务端无须做任何改造,即可支持BRPC和HTTP两种方式。
## 客户端访问
### HttpClient方式发送Http请求(Python/Java)
为了方便用户快速的使用Http方式请求Server端预测服务,我们已经将常用的Http请求的数据体封装、压缩、请求加密等功能封装为一个HttpClient类提供给用户,方便用户使用。
使用HttpClient最简单只需要三步,1、创建一个HttpClient对象。2、加载Client端的prototxt配置文件(本例中为python/examples/fit_a_line/目录下的uci_housing_client/serving_client_conf.prototxt),3、调用Predict函数,通过Http方式请求预测服务。
此外,您可以根据自己的需要配置Server端IP、Port、服务名称(此服务名称需要与[`core/general-server/proto/general_model_service.proto`](../core/general-server/proto/general_model_service.proto)文件中的Service服务名和rpc方法名对应,即`GeneralModelService`字段和`inference`字段),设置Request数据体压缩,设置Response支持压缩传输,模型加密预测(需要配置Server端使用模型加密)、设置响应超时时间等功能。
Python的HttpClient使用示例见[`python/examples/fit_a_line/test_httpclient.py`](../python/examples/fit_a_line/test_httpclient.py),接口详见[`python/paddle_serving_client/httpclient.py`](../python/paddle_serving_client/httpclient.py)
Java的HttpClient使用示例见[`java/examples/src/main/java/PaddleServingClientExample.java`](../java/examples/src/main/java/PaddleServingClientExample.java)接口详见[`java/src/main/java/io/paddle/serving/client/HttpClient.java`](../java/src/main/java/io/paddle/serving/client/HttpClient.java)
如果不能满足您的需求,您也可以在此基础上添加一些功能。
如需支持https或者自定义Response的Status Code等,则需要对C++端brpc-Server进行一定的二次开发,请参考https://github.com/apache/incubator-brpc/blob/master/docs/cn/http_service.md,后续如果需求很大,我们也会将这部分功能加入到Server中,尽情期待。
### curl方式发送Http请求(基本原理)
```shell
curl -XPOST http://0.0.0.0:9393/GeneralModelService/inference -d ' {"tensor":[{"float_data":[0.0137,-0.1136,0.2553,-0.0692,0.0582,-0.0727,-0.1583,-0.0584,0.6283,0.4919,0.1856,0.0795,-0.0332],"elem_type":1,"name":"x","alias_name":"x","shape":[1,13]}],"fetch_var_names":["price"],"log_id":0}'
```
其中`127.0.0.1:9393`为IP和Port,根据您服务端启动的IP和Port自行设定。
`GeneralModelService`字段和`inference`字段分别为Proto文件中的Service服务名和rpc方法名,详见[`core/general-server/proto/general_model_service.proto`](../core/general-server/proto/general_model_service.proto)
-d后面的是请求的数据体,json中一定要包含上述proto中的required字段,否则转化会失败,对应请求会被拒绝。
需要注意的是,数据中的shape字段为模型实际需要的shape信息,包含batch维度在内,可能与proto文件中的shape不一致。
#### message
对应rapidjson Object, 以花括号包围,其中的元素会被递归地解析。
```protobuf
// protobuf
message Foo {
required string field1 = 1;
required int32 field2 = 2;
}
message Bar {
required Foo foo = 1;
optional bool flag = 2;
required string name = 3;
}
// rapidjson
{"foo":{"field1":"hello", "field2":3},"name":"Tom" }
```
#### repeated field
对应rapidjson Array, 以方括号包围,其中的元素会被递归地解析,和message不同,每个元素的类型相同。
```protobuf
// protobuf
repeated int32 numbers = 1;
// rapidjson
{"numbers" : [12, 17, 1, 24] }
```
#### elem_type
表示数据类型,0 means int64, 1 means float32, 2 means int32, 3 means bytes(string)
#### fetch_var_names
表示返回结果中需要的数据名称,请参考模型文件serving_client_conf.prototxt中的`fetch_var`字段下的`alias_name`
### Http压缩
支持gzip压缩,但gzip并不是一个压缩解压速度非常快的方法,当数据量较小时候,使用gzip压缩反而会得不偿失,推荐至少数据大于512字节时才考虑使用gzip压缩。
#### Client请求的数据体压缩
以上面的fit_a_line为例,仍使用上文的请求数据体,但只作为示例演示用法,实际此时使用压缩得不偿失。
```shell
echo ' {"tensor":[{"float_data":[0.0137,-0.1136,0.2553,-0.0692,0.0582,-0.0727,-0.1583,-0.0584,0.6283,0.4919,0.1856,0.0795,-0.0332],"elem_type":1,"shape":[1,13]}],"fetch_var_names":["price"],"log_id":0}' | gzip -c > data.txt.gz
```
```shell
curl --data-binary @data.txt.gz -H'Content-Encoding: gzip' -XPOST http://127.0.0.1:9393/GeneralModelService/inference
```
**注意:当请求数据体压缩时,需要指定请求头中Content-Encoding: gzip**
#### Server端Response压缩
当Http请求头中设置了Accept-encoding: gzip时,Server端会尝试用gzip压缩Response的数据,“尝试“指的是压缩有可能不发生,条件有:
- 请求中没有设置Accept-encoding: gzip。
- body尺寸小于-http_body_compress_threshold指定的字节数,默认是512。gzip并不是一个很快的压缩算法,当body较小时,压缩增加的延时可能比网络传输省下的还多。当包较小时不做压缩可能是个更好的选项。
这时server总是会返回不压缩的结果。
如果使用curl,通常推荐使用--compressed参数来设置Response压缩,--compressed参数会自动地在http请求中设置Accept-encoding: gzip,并在收到压缩后的Response后自动解压,对于用户而言,整个压缩/解压过程就像透明的一样。
```shell
curl --data-binary @data.txt.gz -H'Content-Encoding: gzip' --compressed -XPOST http://127.0.0.1:9393/GeneralModelService/inference
```
若您只是在Http请求头中通过-H'Accept-encoding: gzip'设置了接收压缩的信息,收到的将是压缩后的Response,此时,您需要手动解压。
也就是说,--compressed = -H'Content-Encoding: gzip' + 自动解压,所以推荐您使用--compressed,以下仅作为单独设置请求头+手动解压的原理性示例。
当您想要验证返回值是否真的压缩时,您可以只添加请求头-H'Content-Encoding: gzip',而不解压,可以看到返回信息是压缩后的数据(一般而言是看不懂的压缩码)。
```shell
curl --data-binary @data.txt.gz -H'Content-Encoding: gzip' -H'Accept-encoding: gzip' -XPOST http://127.0.0.1:9393/GeneralModelService/inference | gunzip
```
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
此差异已折叠。
此差异已折叠。
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
此差异已折叠。
此差异已折叠。
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
文件模式从 100755 更改为 100644
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册