提交 82707adf 编写于 作者: H HexToString

Merge branch 'ocr' of github.com:HexToString/Serving into merge_branch

......@@ -18,16 +18,16 @@ set(PADDLE_SERVING_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
set(PADDLE_SERVING_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
SET(PADDLE_SERVING_INSTALL_DIR ${CMAKE_BINARY_DIR}/output)
SET(CMAKE_INSTALL_RPATH "\$ORIGIN" "${CMAKE_INSTALL_RPATH}")
include(system)
SET(CMAKE_BUILD_TYPE "Debug")
SET(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -Wall -g2 -ggdb")
SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O3 -Wall")
project(paddle-serving CXX C)
message(STATUS "CXX compiler: ${CMAKE_CXX_COMPILER}, version: "
"${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION}")
message(STATUS "C compiler: ${CMAKE_C_COMPILER}, version: "
"${CMAKE_C_COMPILER_ID} ${CMAKE_C_COMPILER_VERSION}")
find_package(Git REQUIRED)
find_package(Threads REQUIRED)
find_package(CUDA QUIET)
......@@ -46,19 +46,33 @@ set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING
set(THIRD_PARTY_BUILD_TYPE Release)
option(WITH_AVX "Compile Paddle Serving with AVX intrinsics" OFF)
option(WITH_MKL "Compile Paddle Serving with MKL support." OFF)
option(WITH_GPU "Compile Paddle Serving with NVIDIA GPU" OFF)
option(WITH_LITE "Compile Paddle Serving with Paddle Lite Engine" OFF)
option(WITH_XPU "Compile Paddle Serving with Baidu Kunlun" OFF)
option(WITH_PYTHON "Compile Paddle Serving with Python" ON)
option(CLIENT "Compile Paddle Serving Client" OFF)
option(SERVER "Compile Paddle Serving Server" OFF)
option(APP "Compile Paddle Serving App package" OFF)
option(WITH_ELASTIC_CTR "Compile ELASITC-CTR solution" OFF)
option(PACK "Compile for whl" OFF)
option(WITH_TRT "Compile Paddle Serving with TRT" OFF)
option(PADDLE_ON_INFERENCE "Compile for encryption" ON)
option(WITH_AVX "Compile Paddle Serving with AVX intrinsics" OFF)
option(WITH_MKL "Compile Paddle Serving with MKL support." OFF)
option(WITH_GPU "Compile Paddle Serving with NVIDIA GPU" OFF)
option(WITH_LITE "Compile Paddle Serving with Paddle Lite Engine" OFF)
option(WITH_XPU "Compile Paddle Serving with Baidu Kunlun" OFF)
option(WITH_PYTHON "Compile Paddle Serving with Python" ON)
option(CLIENT "Compile Paddle Serving Client" OFF)
option(SERVER "Compile Paddle Serving Server" OFF)
option(APP "Compile Paddle Serving App package" OFF)
option(WITH_ELASTIC_CTR "Compile ELASITC-CTR solution" OFF)
option(PACK "Compile for whl" OFF)
option(WITH_TRT "Compile Paddle Serving with TRT" OFF)
option(PADDLE_ON_INFERENCE "Compile for encryption" ON)
option(WITH_OPENCV "Compile Paddle Serving with OPENCV" OFF)
if (WITH_OPENCV)
SET(OPENCV_DIR "" CACHE PATH "Location of libraries")
if(NOT DEFINED OPENCV_DIR)
message(FATAL_ERROR "please set OPENCV_DIR with -DOPENCV_DIR=/path/opencv")
endif()
if (WIN32)
find_package(OpenCV REQUIRED PATHS ${OPENCV_DIR}/build/ NO_DEFAULT_PATH)
else ()
find_package(OpenCV REQUIRED PATHS ${OPENCV_DIR}/share/OpenCV NO_DEFAULT_PATH)
endif ()
include_directories(${OpenCV_INCLUDE_DIRS})
endif()
if (PADDLE_ON_INFERENCE)
add_definitions(-DPADDLE_ON_INFERENCE)
......
......@@ -26,7 +26,7 @@ ExternalProject_Add(
extern_zlib
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/madler/zlib.git"
GIT_TAG "v1.2.8"
GIT_TAG "v1.2.9"
PREFIX ${ZLIB_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
......@@ -54,7 +54,11 @@ ELSE(WIN32)
SET(ZLIB_LIBRARIES "${ZLIB_INSTALL_DIR}/lib/libz.a" CACHE FILEPATH "zlib library." FORCE)
ENDIF(WIN32)
ADD_LIBRARY(zlib STATIC IMPORTED GLOBAL)
IF(WITH_OPENCV)
ELSE()
ADD_LIBRARY(zlib STATIC IMPORTED GLOBAL)
ENDIF()
SET_PROPERTY(TARGET zlib PROPERTY IMPORTED_LOCATION ${ZLIB_LIBRARIES})
ADD_DEPENDENCIES(zlib extern_zlib)
......
......@@ -59,7 +59,7 @@ message SimpleResponse { required int32 err_code = 1; }
message GetClientConfigRequest {}
message GetClientConfigResponse { required string client_config_str = 1; }
message GetClientConfigResponse { repeated string client_config_str_list = 1; }
service MultiLangGeneralModelService {
rpc Inference(InferenceRequest) returns (InferenceResponse) {}
......
......@@ -55,10 +55,10 @@ message ModelToolkitConf { repeated EngineDesc engines = 1; };
// reource conf
message ResourceConf {
required string model_toolkit_path = 1;
required string model_toolkit_file = 2;
optional string general_model_path = 3;
optional string general_model_file = 4;
repeated string model_toolkit_path = 1;
repeated string model_toolkit_file = 2;
repeated string general_model_path = 3;
repeated string general_model_file = 4;
optional string cube_config_path = 5;
optional string cube_config_file = 6;
optional int32 cube_quant_bits = 7; // set 0 if no quant.
......
......@@ -207,7 +207,7 @@ class PredictorClient {
void init_gflags(std::vector<std::string> argv);
int init(const std::string& client_conf);
int init(const std::vector<std::string> &client_conf);
void set_predictor_conf(const std::string& conf_path,
const std::string& conf_file);
......@@ -227,6 +227,10 @@ class PredictorClient {
const std::vector<std::string>& int_feed_name,
const std::vector<std::vector<int>>& int_shape,
const std::vector<std::vector<int>>& int_lod_slot_batch,
const std::vector<std::vector<std::string>>& string_feed_batch,
const std::vector<std::string>& string_feed_name,
const std::vector<std::vector<int>>& string_shape,
const std::vector<std::vector<int>>& string_lod_slot_batch,
const std::vector<std::string>& fetch_name,
PredictorRes& predict_res_batch, // NOLINT
const int& pid,
......
......@@ -28,7 +28,7 @@ using baidu::paddle_serving::predictor::general_model::Response;
using baidu::paddle_serving::predictor::general_model::Tensor;
using baidu::paddle_serving::predictor::general_model::FeedInst;
using baidu::paddle_serving::predictor::general_model::FetchInst;
enum ProtoDataType { P_INT64, P_FLOAT32, P_INT32, P_STRING };
std::once_flag gflags_init_flag;
namespace py = pybind11;
......@@ -56,22 +56,20 @@ void PredictorClient::init_gflags(std::vector<std::string> argv) {
});
}
int PredictorClient::init(const std::string &conf_file) {
int PredictorClient::init(const std::vector<std::string> &conf_file) {
try {
GeneralModelConfig model_config;
if (configure::read_proto_conf(conf_file.c_str(), &model_config) != 0) {
if (configure::read_proto_conf(conf_file[0].c_str(), &model_config) != 0) {
LOG(ERROR) << "Failed to load general model config"
<< ", file path: " << conf_file;
<< ", file path: " << conf_file[0];
return -1;
}
_feed_name_to_idx.clear();
_fetch_name_to_idx.clear();
_shape.clear();
int feed_var_num = model_config.feed_var_size();
int fetch_var_num = model_config.fetch_var_size();
VLOG(2) << "feed var num: " << feed_var_num
<< "fetch_var_num: " << fetch_var_num;
VLOG(2) << "feed var num: " << feed_var_num;
for (int i = 0; i < feed_var_num; ++i) {
_feed_name_to_idx[model_config.feed_var(i).alias_name()] = i;
VLOG(2) << "feed alias name: " << model_config.feed_var(i).alias_name()
......@@ -90,6 +88,16 @@ int PredictorClient::init(const std::string &conf_file) {
_shape.push_back(tmp_feed_shape);
}
if (conf_file.size()>1) {
model_config.Clear();
if (configure::read_proto_conf(conf_file[conf_file.size()-1].c_str(), &model_config) != 0) {
LOG(ERROR) << "Failed to load general model config"
<< ", file path: " << conf_file[conf_file.size()-1];
return -1;
}
}
int fetch_var_num = model_config.fetch_var_size();
VLOG(2) << "fetch_var_num: " << fetch_var_num;
for (int i = 0; i < fetch_var_num; ++i) {
_fetch_name_to_idx[model_config.fetch_var(i).alias_name()] = i;
VLOG(2) << "fetch [" << i << "]"
......@@ -146,11 +154,16 @@ int PredictorClient::numpy_predict(
const std::vector<std::string> &int_feed_name,
const std::vector<std::vector<int>> &int_shape,
const std::vector<std::vector<int>> &int_lod_slot_batch,
const std::vector<std::vector<std::string>>& string_feed_batch,
const std::vector<std::string>& string_feed_name,
const std::vector<std::vector<int>>& string_shape,
const std::vector<std::vector<int>>& string_lod_slot_batch,
const std::vector<std::string> &fetch_name,
PredictorRes &predict_res_batch,
const int &pid,
const uint64_t log_id) {
int batch_size = std::max(float_feed_batch.size(), int_feed_batch.size());
int batch_size = std::max( float_feed_batch.size(), int_feed_batch.size() );
batch_size = batch_size>string_feed_batch.size()? batch_size : string_feed_batch.size();
VLOG(2) << "batch size: " << batch_size;
predict_res_batch.clear();
Timer timeline;
......@@ -165,6 +178,7 @@ int PredictorClient::numpy_predict(
VLOG(2) << "fetch general model predictor done.";
VLOG(2) << "float feed name size: " << float_feed_name.size();
VLOG(2) << "int feed name size: " << int_feed_name.size();
VLOG(2) << "string feed name size: " << string_feed_name.size();
VLOG(2) << "max body size : " << brpc::fLU64::FLAGS_max_body_size;
Request req;
req.set_log_id(log_id);
......@@ -178,6 +192,7 @@ int PredictorClient::numpy_predict(
FeedInst *inst = req.add_insts();
std::vector<py::array_t<float>> float_feed = float_feed_batch[bi];
std::vector<py::array_t<int64_t>> int_feed = int_feed_batch[bi];
std::vector<std::string> string_feed = string_feed_batch[bi];
for (auto &name : float_feed_name) {
tensor_vec.push_back(inst->add_tensor_array());
}
......@@ -186,12 +201,13 @@ int PredictorClient::numpy_predict(
tensor_vec.push_back(inst->add_tensor_array());
}
VLOG(2) << "batch [" << bi << "] int_feed_name and float_feed_name "
<< "prepared";
for (auto &name : string_feed_name) {
tensor_vec.push_back(inst->add_tensor_array());
}
VLOG(2) << "batch [" << bi << "] " << "prepared";
int vec_idx = 0;
VLOG(2) << "tensor_vec size " << tensor_vec.size() << " float shape "
<< float_shape.size();
for (auto &name : float_feed_name) {
int idx = _feed_name_to_idx[name];
Tensor *tensor = tensor_vec[idx];
......@@ -203,7 +219,7 @@ int PredictorClient::numpy_predict(
for (uint32_t j = 0; j < float_lod_slot_batch[vec_idx].size(); ++j) {
tensor->add_lod(float_lod_slot_batch[vec_idx][j]);
}
tensor->set_elem_type(1);
tensor->set_elem_type(P_FLOAT32);
const int float_shape_size = float_shape[vec_idx].size();
switch (float_shape_size) {
case 4: {
......@@ -249,7 +265,7 @@ int PredictorClient::numpy_predict(
}
vec_idx++;
}
VLOG(2) << "batch [" << bi << "] "
<< "float feed value prepared";
......@@ -266,7 +282,7 @@ int PredictorClient::numpy_predict(
}
tensor->set_elem_type(_type[idx]);
if (_type[idx] == 0) {
if (_type[idx] == P_INT64) {
VLOG(2) << "prepare int feed " << name << " shape size "
<< int_shape[vec_idx].size();
} else {
......@@ -282,7 +298,7 @@ int PredictorClient::numpy_predict(
for (ssize_t j = 0; j < int_array.shape(1); j++) {
for (ssize_t k = 0; k < int_array.shape(2); k++) {
for (ssize_t l = 0; k < int_array.shape(3); l++) {
if (_type[idx] == 0) {
if (_type[idx] == P_INT64) {
tensor->add_int64_data(int_array(i, j, k, l));
} else {
tensor->add_int_data(int_array(i, j, k, l));
......@@ -298,7 +314,7 @@ int PredictorClient::numpy_predict(
for (ssize_t i = 0; i < int_array.shape(0); i++) {
for (ssize_t j = 0; j < int_array.shape(1); j++) {
for (ssize_t k = 0; k < int_array.shape(2); k++) {
if (_type[idx] == 0) {
if (_type[idx] == P_INT64) {
tensor->add_int64_data(int_array(i, j, k));
} else {
tensor->add_int_data(int_array(i, j, k));
......@@ -312,7 +328,7 @@ int PredictorClient::numpy_predict(
auto int_array = int_feed[vec_idx].unchecked<2>();
for (ssize_t i = 0; i < int_array.shape(0); i++) {
for (ssize_t j = 0; j < int_array.shape(1); j++) {
if (_type[idx] == 0) {
if (_type[idx] == P_INT64) {
tensor->add_int64_data(int_array(i, j));
} else {
tensor->add_int_data(int_array(i, j));
......@@ -324,7 +340,7 @@ int PredictorClient::numpy_predict(
case 1: {
auto int_array = int_feed[vec_idx].unchecked<1>();
for (ssize_t i = 0; i < int_array.shape(0); i++) {
if (_type[idx] == 0) {
if (_type[idx] == P_INT64) {
tensor->add_int64_data(int_array(i));
} else {
tensor->add_int_data(int_array(i));
......@@ -338,6 +354,38 @@ int PredictorClient::numpy_predict(
VLOG(2) << "batch [" << bi << "] "
<< "int feed value prepared";
vec_idx = 0;
for (auto &name : string_feed_name) {
int idx = _feed_name_to_idx[name];
Tensor *tensor = tensor_vec[idx];
for (uint32_t j = 0; j < string_shape[vec_idx].size(); ++j) {
tensor->add_shape(string_shape[vec_idx][j]);
}
for (uint32_t j = 0; j < string_lod_slot_batch[vec_idx].size(); ++j) {
tensor->add_lod(string_lod_slot_batch[vec_idx][j]);
}
tensor->set_elem_type(P_STRING);
const int string_shape_size = string_shape[vec_idx].size();
//string_shape[vec_idx] = [1];cause numpy has no datatype of string.
//we pass string via vector<vector<string> >.
if( string_shape_size!= 1 ){
LOG(ERROR) << "string_shape_size should be 1-D, but received is : " << string_shape_size;
return -1;
}
switch (string_shape_size) {
case 1: {
tensor->add_data(string_feed[vec_idx]);
break;
}
}
vec_idx++;
}
VLOG(2) << "batch [" << bi << "] "
<< "string feed value prepared";
}
int64_t preprocess_end = timeline.TimeStampUS();
......@@ -397,19 +445,19 @@ int PredictorClient::numpy_predict(
for (auto &name : fetch_name) {
// int idx = _fetch_name_to_idx[name];
if (_fetch_name_to_type[name] == 0) {
if (_fetch_name_to_type[name] == P_INT64) {
VLOG(2) << "ferch var " << name << "type int64";
int size = output.insts(0).tensor_array(idx).int64_data_size();
model._int64_value_map[name] = std::vector<int64_t>(
output.insts(0).tensor_array(idx).int64_data().begin(),
output.insts(0).tensor_array(idx).int64_data().begin() + size);
} else if (_fetch_name_to_type[name] == 1) {
} else if (_fetch_name_to_type[name] == P_FLOAT32) {
VLOG(2) << "fetch var " << name << "type float";
int size = output.insts(0).tensor_array(idx).float_data_size();
model._float_value_map[name] = std::vector<float>(
output.insts(0).tensor_array(idx).float_data().begin(),
output.insts(0).tensor_array(idx).float_data().begin() + size);
} else if (_fetch_name_to_type[name] == 2) {
} else if (_fetch_name_to_type[name] == P_INT32) {
VLOG(2) << "fetch var " << name << "type int32";
int size = output.insts(0).tensor_array(idx).int_data_size();
model._int32_value_map[name] = std::vector<int32_t>(
......
......@@ -78,7 +78,7 @@ PYBIND11_MODULE(serving_client, m) {
self.init_gflags(argv);
})
.def("init",
[](PredictorClient &self, const std::string &conf) {
[](PredictorClient &self, const std::vector<std::string> &conf) {
return self.init(conf);
})
.def("set_predictor_conf",
......@@ -107,6 +107,10 @@ PYBIND11_MODULE(serving_client, m) {
const std::vector<std::string> &int_feed_name,
const std::vector<std::vector<int>> &int_shape,
const std::vector<std::vector<int>> &int_lod_slot_batch,
const std::vector<std::vector<std::string>>& string_feed_batch,
const std::vector<std::string>& string_feed_name,
const std::vector<std::vector<int>>& string_shape,
const std::vector<std::vector<int>>& string_lod_slot_batch,
const std::vector<std::string> &fetch_name,
PredictorRes &predict_res_batch,
const int &pid,
......@@ -119,6 +123,10 @@ PYBIND11_MODULE(serving_client, m) {
int_feed_name,
int_shape,
int_lod_slot_batch,
string_feed_batch,
string_feed_name,
string_shape,
string_lod_slot_batch,
fetch_name,
predict_res_batch,
pid,
......
include_directories(SYSTEM ${CMAKE_CURRENT_LIST_DIR}/../../)
include(op/CMakeLists.txt)
include(proto/CMakeLists.txt)
add_executable(serving ${serving_srcs})
add_dependencies(serving pdcodegen paddle_inference_engine pdserving paddle_inference cube-api utils)
......@@ -20,6 +21,9 @@ include_directories(${CUDNN_ROOT}/include/)
target_link_libraries(serving -Wl,--whole-archive paddle_inference_engine
-Wl,--no-whole-archive)
if(WITH_OPENCV)
target_link_libraries(serving ${OpenCV_LIBS})
endif()
target_link_libraries(serving paddle_inference ${paddle_depend_libs})
target_link_libraries(serving brpc)
target_link_libraries(serving protobuf)
......@@ -27,6 +31,7 @@ target_link_libraries(serving pdserving)
target_link_libraries(serving cube-api)
target_link_libraries(serving utils)
if(WITH_GPU)
target_link_libraries(serving ${CUDA_LIBRARIES})
endif()
......
FILE(GLOB op_srcs ${CMAKE_CURRENT_LIST_DIR}/*.cpp ${CMAKE_CURRENT_LIST_DIR}/../../predictor/tools/quant.cpp)
if(WITH_OPENCV)
FILE(GLOB ocrtools_srcs ${CMAKE_CURRENT_LIST_DIR}/../../predictor/tools/ocrtools/*.cpp)
LIST(APPEND op_srcs ${ocrtools_srcs})
else()
set (EXCLUDE_DIR "general_detection_op.cpp")
foreach (TMP_PATH ${op_srcs})
string (FIND ${TMP_PATH} ${EXCLUDE_DIR} EXCLUDE_DIR_FOUND)
if (NOT ${EXCLUDE_DIR_FOUND} EQUAL -1)
list (REMOVE_ITEM op_srcs ${TMP_PATH})
break()
endif ()
endforeach(TMP_PATH)
endif()
LIST(APPEND serving_srcs ${op_srcs})
......@@ -117,8 +117,9 @@ int GeneralDistKVQuantInferOp::inference() {
std::unordered_map<int, int> in_out_map;
baidu::paddle_serving::predictor::Resource &resource =
baidu::paddle_serving::predictor::Resource::instance();
//TODO:Temporary addition, specific details to be studied by HexToString
std::shared_ptr<PaddleGeneralModelConfig> model_config =
resource.get_general_model_config();
resource.get_general_model_config()[0];
int cube_quant_bits = resource.get_cube_quant_bits();
size_t EMBEDDING_SIZE = 0;
if (cube_quant_bits == 0) {
......
......@@ -44,13 +44,57 @@ int GeneralInferOp::inference() {
<< pre_node_names.size();
return -1;
}
if (InferManager::instance().infer(engine_name().c_str())) {
const std::string pre_name = pre_node_names[0];
const GeneralBlob *input_blob = get_depend_argument<GeneralBlob>(pre_name);
if (!input_blob) {
LOG(ERROR) << "input_blob is nullptr,error";
return -1;
}
uint64_t log_id = input_blob->GetLogId();
VLOG(2) << "(logid=" << log_id << ") Get precedent op name: " << pre_name;
GeneralBlob *output_blob = mutable_data<GeneralBlob>();
if (!output_blob) {
LOG(ERROR) << "output_blob is nullptr,error";
return -1;
}
output_blob->SetLogId(log_id);
if (!input_blob) {
LOG(ERROR) << "(logid=" << log_id
<< ") Failed mutable depended argument, op:" << pre_name;
return -1;
}
const TensorVector *in = &input_blob->tensor_vector;
TensorVector *out = &output_blob->tensor_vector;
int batch_size = input_blob->_batch_size;
VLOG(2) << "(logid=" << log_id << ") input batch size: " << batch_size;
output_blob->_batch_size = batch_size;
VLOG(2) << "(logid=" << log_id << ") infer batch size: " << batch_size;
Timer timeline;
int64_t start = timeline.TimeStampUS();
timeline.Start();
if (InferManager::instance().infer(
engine_name().c_str(), in, out, batch_size)) {
LOG(ERROR) << "(logid=" << log_id
<< ") Failed do infer in fluid model: " << engine_name().c_str();
return -1;
}
int64_t end = timeline.TimeStampUS();
CopyBlobInfo(input_blob, output_blob);
AddBlobInfo(output_blob, start);
AddBlobInfo(output_blob, end);
return 0;
}
DEFINE_OP(GeneralInferOp);
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
} // namespace baidu
\ No newline at end of file
......@@ -20,7 +20,6 @@
#include "core/general-server/op/general_infer_helper.h"
#include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h"
#include "core/predictor/framework/resource.h"
#include "core/util/include/timer.h"
namespace baidu {
......@@ -33,8 +32,7 @@ using baidu::paddle_serving::predictor::general_model::Tensor;
using baidu::paddle_serving::predictor::general_model::Request;
using baidu::paddle_serving::predictor::general_model::FeedInst;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
using baidu::paddle_serving::predictor::InferManager;
enum ProtoDataType { P_INT64, P_FLOAT32, P_INT32, P_STRING };
int conf_check(const Request *req,
const std::shared_ptr<PaddleGeneralModelConfig> &model_config) {
int var_num = req->insts(0).tensor_array_size();
......@@ -72,91 +70,181 @@ int conf_check(const Request *req,
}
int GeneralReaderOp::inference() {
// reade request from client
// TODO: only support one engine here
std::string engine_name = "general_infer_0";
// read request from client
const Request *req = dynamic_cast<const Request *>(get_request_message());
uint64_t log_id = req->log_id();
int input_var_num = 0;
std::vector<int64_t> elem_type;
std::vector<int64_t> elem_size;
std::vector<int64_t> capacity;
std::vector<int64_t> databuf_size;
GeneralBlob *res = mutable_data<GeneralBlob>();
TensorVector *out = &(res->tensor_vector);
res->SetLogId(log_id);
if (!res) {
LOG(ERROR) << "(logid=" << log_id
<< ") Failed get op tls reader object output";
}
Timer timeline;
int64_t start = timeline.TimeStampUS();
int var_num = req->insts(0).tensor_array_size();
VLOG(2) << "(logid=" << log_id << ") var num: " << var_num
<< ") start to call load general model_conf op";
baidu::paddle_serving::predictor::Resource &resource =
baidu::paddle_serving::predictor::Resource::instance();
VLOG(2) << "(logid=" << log_id << ") get resource pointer done.";
//get the first InferOP's model_config as ReaderOp's model_config by default.
std::shared_ptr<PaddleGeneralModelConfig> model_config =
resource.get_general_model_config();
resource.get_general_model_config().front();
// TODO(guru4elephant): how to do conditional check?
/*
int ret = conf_check(req, model_config);
if (ret != 0) {
LOG(ERROR) << "model conf of server:";
resource.print_general_model_config(model_config);
return 0;
}
*/
// package tensor
elem_type.resize(var_num);
elem_size.resize(var_num);
capacity.resize(var_num);
databuf_size.resize(var_num);
// prepare basic information for input
// specify the memory needed for output tensor_vector
// fill the data into output general_blob
int data_len = 0;
for (int i = 0; i < var_num; ++i) {
std::string tensor_name = model_config->_feed_name[i];
VLOG(2) << "(logid=" << log_id << ") get tensor name: " << tensor_name;
auto lod_tensor = InferManager::instance().GetInputHandle(
engine_name.c_str(), tensor_name.c_str());
std::vector<std::vector<size_t>> lod;
std::vector<int> shape;
// get lod info here
paddle::PaddleTensor lod_tensor;
const Tensor &tensor = req->insts(0).tensor_array(i);
data_len = 0;
elem_type[i] = req->insts(0).tensor_array(i).elem_type();
VLOG(2) << "var[" << i << "] has elem type: " << elem_type[i];
if (elem_type[i] == P_INT64) { // int64
elem_size[i] = sizeof(int64_t);
lod_tensor.dtype = paddle::PaddleDType::INT64;
data_len = tensor.int64_data_size();
} else if (elem_type[i] == P_FLOAT32) {
elem_size[i] = sizeof(float);
lod_tensor.dtype = paddle::PaddleDType::FLOAT32;
data_len = tensor.float_data_size();
} else if (elem_type[i] == P_INT32) {
elem_size[i] = sizeof(int32_t);
lod_tensor.dtype = paddle::PaddleDType::INT32;
data_len = tensor.int_data_size();
} else if (elem_type[i] == P_STRING) {
//use paddle::PaddleDType::UINT8 as for String.
elem_size[i] = sizeof(uint8_t);
lod_tensor.dtype = paddle::PaddleDType::UINT8;
//this is for vector<String>, cause the databuf_size != vector<String>.size()*sizeof(char);
for (int idx = 0; idx < tensor.data_size(); idx++) {
data_len += tensor.data()[idx].length();
}
}
// implement lod tensor here
// only support 1-D lod
// TODO:support 2-D lod
if (req->insts(0).tensor_array(i).lod_size() > 0) {
lod.resize(1);
VLOG(2) << "(logid=" << log_id << ") var[" << i << "] is lod_tensor";
lod_tensor.lod.resize(1);
for (int k = 0; k < req->insts(0).tensor_array(i).lod_size(); ++k) {
lod[0].push_back(req->insts(0).tensor_array(i).lod(k));
}
capacity[i] = 1;
for (int k = 0; k < req->insts(0).tensor_array(i).shape_size(); ++k) {
int dim = req->insts(0).tensor_array(i).shape(k);
VLOG(2) << "(logid=" << log_id << ") shape for var[" << i
<< "]: " << dim;
capacity[i] *= dim;
shape.push_back(dim);
lod_tensor.lod[0].push_back(req->insts(0).tensor_array(i).lod(k));
}
VLOG(2) << "(logid=" << log_id << ") var[" << i
<< "] is tensor, capacity: " << capacity[i];
} else {
capacity[i] = 1;
for (int k = 0; k < req->insts(0).tensor_array(i).shape_size(); ++k) {
int dim = req->insts(0).tensor_array(i).shape(k);
VLOG(2) << "(logid=" << log_id << ") shape for var[" << i
<< "]: " << dim;
capacity[i] *= dim;
shape.push_back(dim);
}
VLOG(2) << "(logid=" << log_id << ") var[" << i
<< "] is tensor, capacity: " << capacity[i];
}
lod_tensor->SetLoD(lod);
lod_tensor->Reshape(shape);
// insert data here
if (req->insts(0).tensor_array(i).elem_type() == 0) {
// TODO: Copy twice here, can optimize
for (int k = 0; k < req->insts(0).tensor_array(i).shape_size(); ++k) {
int dim = req->insts(0).tensor_array(i).shape(k);
VLOG(2) << "(logid=" << log_id << ") shape for var[" << i
<< "]: " << dim;
lod_tensor.shape.push_back(dim);
}
lod_tensor.name = model_config->_feed_name[i];
out->push_back(lod_tensor);
VLOG(2) << "(logid=" << log_id << ") tensor size for var[" << i
<< "]: " << data_len;
databuf_size[i] = data_len * elem_size[i];
out->at(i).data.Resize(data_len * elem_size[i]);
VLOG(2) << "(logid=" << log_id << ") var[" << i
<< "] is lod_tensor and len=" << out->at(i).lod[0].back();
if (elem_type[i] == P_INT64) {
int64_t *dst_ptr = static_cast<int64_t *>(out->at(i).data.data());
VLOG(2) << "(logid=" << log_id << ") first element data in var[" << i
<< "] is " << req->insts(0).tensor_array(i).int64_data(0);
if (!dst_ptr) {
LOG(ERROR) << "dst_ptr is nullptr";
return -1;
}
memcpy(dst_ptr, req->insts(0).tensor_array(i).int64_data().data(),databuf_size[i]);
/*
int elem_num = req->insts(0).tensor_array(i).int64_data_size();
std::vector<int64_t> data(elem_num);
int64_t *dst_ptr = data.data();
for (int k = 0; k < elem_num; ++k) {
dst_ptr[k] = req->insts(0).tensor_array(i).int64_data(k);
}
lod_tensor->CopyFromCpu(dst_ptr);
} else if (req->insts(0).tensor_array(i).elem_type() == 1) {
int elem_num = req->insts(0).tensor_array(i).float_data_size();
std::vector<float> data(elem_num);
float *dst_ptr = data.data();
*/
} else if (elem_type[i] == P_FLOAT32) {
float *dst_ptr = static_cast<float *>(out->at(i).data.data());
VLOG(2) << "(logid=" << log_id << ") first element data in var[" << i
<< "] is " << req->insts(0).tensor_array(i).float_data(0);
if (!dst_ptr) {
LOG(ERROR) << "dst_ptr is nullptr";
return -1;
}
memcpy(dst_ptr, req->insts(0).tensor_array(i).float_data().data(),databuf_size[i]);
/*int elem_num = req->insts(0).tensor_array(i).float_data_size();
for (int k = 0; k < elem_num; ++k) {
dst_ptr[k] = req->insts(0).tensor_array(i).float_data(k);
}*/
} else if (elem_type[i] == P_INT32) {
int32_t *dst_ptr = static_cast<int32_t *>(out->at(i).data.data());
VLOG(2) << "(logid=" << log_id << ") first element data in var[" << i
<< "] is " << req->insts(0).tensor_array(i).int_data(0);
if (!dst_ptr) {
LOG(ERROR) << "dst_ptr is nullptr";
return -1;
}
lod_tensor->CopyFromCpu(dst_ptr);
} else if (req->insts(0).tensor_array(i).elem_type() == 2) {
memcpy(dst_ptr, req->insts(0).tensor_array(i).int_data().data(),databuf_size[i]);
/*
int elem_num = req->insts(0).tensor_array(i).int_data_size();
std::vector<int32_t> data(elem_num);
int32_t *dst_ptr = data.data();
for (int k = 0; k < elem_num; ++k) {
dst_ptr[k] = req->insts(0).tensor_array(i).int_data(k);
}
lod_tensor->CopyFromCpu(dst_ptr);
*/
} else if (elem_type[i] == P_STRING) {
std::string *dst_ptr = static_cast<std::string *>(out->at(i).data.data());
VLOG(2) << "(logid=" << log_id << ") first element data in var[" << i
<< "] is " << req->insts(0).tensor_array(i).data(0);
if (!dst_ptr) {
LOG(ERROR) << "dst_ptr is nullptr";
return -1;
}
int elem_num = req->insts(0).tensor_array(i).data_size();
for (int k = 0; k < elem_num; ++k) {
dst_ptr[k] = req->insts(0).tensor_array(i).data(k);
}
}
}
VLOG(2) << "(logid=" << log_id << ") output size: " << out->size();
timeline.Pause();
int64_t end = timeline.TimeStampUS();
res->p_size = 0;
res->_batch_size = 1;
AddBlobInfo(res, start);
AddBlobInfo(res, end);
VLOG(2) << "(logid=" << log_id << ") read data from client success";
return 0;
}
DEFINE_OP(GeneralReaderOp);
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
} // namespace baidu
\ No newline at end of file
......@@ -40,59 +40,163 @@ using baidu::paddle_serving::predictor::InferManager;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
int GeneralResponseOp::inference() {
const std::vector<std::string> pre_node_names = pre_names();
VLOG(2) << "pre node names size: " << pre_node_names.size();
const GeneralBlob *input_blob = nullptr;
int var_idx = 0;
int cap = 1;
uint64_t log_id =
get_depend_argument<GeneralBlob>(pre_node_names[0])->GetLogId();
const Request *req = dynamic_cast<const Request *>(get_request_message());
// response inst with only fetch_var_names
Response *res = mutable_data<Response>();
Timer timeline;
// double response_time = 0.0;
// timeline.Start();
int64_t start = timeline.TimeStampUS();
VLOG(2) << "(logid=" << log_id
<< ") start to call load general model_conf op";
baidu::paddle_serving::predictor::Resource &resource =
baidu::paddle_serving::predictor::Resource::instance();
VLOG(2) << "(logid=" << log_id << ") get resource pointer done.";
//get the last InferOP's model_config as ResponseOp's model_config by default.
std::shared_ptr<PaddleGeneralModelConfig> model_config =
resource.get_general_model_config();
std::vector<int> capacity(req->fetch_var_names_size(), 1);
std::string engine_name = "general_infer_0";
ModelOutput *output = res->add_outputs();
FetchInst *fetch_inst = output->add_insts();
FetchInst *fetch_p = output->mutable_insts(0);
std::vector<std::string> outs =
InferManager::instance().GetOutputNames(engine_name.c_str());
resource.get_general_model_config().back();
VLOG(2) << "(logid=" << log_id
<< ") max body size : " << brpc::fLU64::FLAGS_max_body_size;
std::vector<int> fetch_index;
fetch_index.resize(req->fetch_var_names_size());
for (int i = 0; i < req->fetch_var_names_size(); ++i) {
Tensor *tensor = fetch_inst->add_tensor_array();
std::string tensor_name = outs[i];
auto lod_tensor = InferManager::instance().GetOutputHandle(
engine_name.c_str(), tensor_name.c_str());
std::vector<int> shape = lod_tensor->shape();
for (int k = 0; k < shape.size(); ++k) {
capacity[i] *= shape[k];
tensor->add_shape(shape[k]);
fetch_index[i] =
model_config->_fetch_alias_name_to_index[req->fetch_var_names(i)];
}
for (uint32_t pi = 0; pi < pre_node_names.size(); ++pi) {
const std::string &pre_name = pre_node_names[pi];
VLOG(2) << "(logid=" << log_id << ") pre names[" << pi << "]: " << pre_name
<< " (" << pre_node_names.size() << ")";
input_blob = get_depend_argument<GeneralBlob>(pre_name);
// fprintf(stderr, "input(%s) blob address %x\n", pre_names.c_str(),
// input_blob);
if (!input_blob) {
LOG(ERROR) << "(logid=" << log_id
<< ") Failed mutable depended argument, op: " << pre_name;
return -1;
}
const TensorVector *in = &input_blob->tensor_vector;
ModelOutput *output = res->add_outputs();
// To get the order of model return values
output->set_engine_name(pre_name);
FetchInst *fetch_inst = output->add_insts();
for (auto &idx : fetch_index) {
Tensor *tensor = fetch_inst->add_tensor_array();
//tensor->set_elem_type(1);
if (model_config->_is_lod_fetch[idx]) {
VLOG(2) << "(logid=" << log_id << ") out[" << idx << "] "
<< model_config->_fetch_name[idx] << " is lod_tensor";
for (int k = 0; k < in->at(idx).shape.size(); ++k) {
VLOG(2) << "(logid=" << log_id << ") shape[" << k
<< "]: " << in->at(idx).shape[k];
tensor->add_shape(in->at(idx).shape[k]);
}
} else {
VLOG(2) << "(logid=" << log_id << ") out[" << idx << "] "
<< model_config->_fetch_name[idx] << " is tensor";
for (int k = 0; k < in->at(idx).shape.size(); ++k) {
VLOG(2) << "(logid=" << log_id << ") shape[" << k
<< "]: " << in->at(idx).shape[k];
tensor->add_shape(in->at(idx).shape[k]);
}
}
}
auto dtype = lod_tensor->type();
if (dtype == paddle::PaddleDType::INT64) {
std::vector<int64_t> datas(capacity[i]);
int64_t *data_ptr = datas.data();
lod_tensor->CopyToCpu(data_ptr);
google::protobuf::RepeatedField<int64_t> tmp_data(data_ptr,
data_ptr + capacity[i]);
tensor->mutable_int64_data()->Swap(&tmp_data);
} else if (dtype == paddle::PaddleDType::FLOAT32) {
std::vector<float> datas(capacity[i]);
float *data_ptr = datas.data();
lod_tensor->CopyToCpu(data_ptr);
google::protobuf::RepeatedField<float> tmp_data(data_ptr,
data_ptr + capacity[i]);
tensor->mutable_float_data()->Swap(&tmp_data);
} else if (dtype == paddle::PaddleDType::INT32) {
std::vector<int32_t> datas(capacity[i]);
int32_t *data_ptr = datas.data();
lod_tensor->CopyToCpu(data_ptr);
google::protobuf::RepeatedField<int32_t> tmp_data(data_ptr,
data_ptr + capacity[i]);
tensor->mutable_int_data()->Swap(&tmp_data);
var_idx = 0;
for (auto &idx : fetch_index) {
cap = 1;
for (int j = 0; j < in->at(idx).shape.size(); ++j) {
cap *= in->at(idx).shape[j];
}
FetchInst *fetch_p = output->mutable_insts(0);
auto dtype = in->at(idx).dtype;
if (dtype == paddle::PaddleDType::INT64) {
VLOG(2) << "(logid=" << log_id << ") Prepare int64 var ["
<< model_config->_fetch_name[idx] << "].";
int64_t *data_ptr = static_cast<int64_t *>(in->at(idx).data.data());
// from
// https://stackoverflow.com/questions/15499641/copy-a-stdvector-to-a-repeated-field-from-protobuf-with-memcpy
// `Swap` method is faster than `{}` method.
google::protobuf::RepeatedField<int64_t> tmp_data(data_ptr,
data_ptr + cap);
fetch_p->mutable_tensor_array(var_idx)->mutable_int64_data()->Swap(
&tmp_data);
} else if (dtype == paddle::PaddleDType::FLOAT32) {
VLOG(2) << "(logid=" << log_id << ") Prepare float var ["
<< model_config->_fetch_name[idx] << "].";
float *data_ptr = static_cast<float *>(in->at(idx).data.data());
google::protobuf::RepeatedField<float> tmp_data(data_ptr,
data_ptr + cap);
fetch_p->mutable_tensor_array(var_idx)->mutable_float_data()->Swap(
&tmp_data);
} else if (dtype == paddle::PaddleDType::INT32) {
VLOG(2) << "(logid=" << log_id << ")Prepare int32 var ["
<< model_config->_fetch_name[idx] << "].";
int32_t *data_ptr = static_cast<int32_t *>(in->at(idx).data.data());
google::protobuf::RepeatedField<int32_t> tmp_data(data_ptr,
data_ptr + cap);
fetch_p->mutable_tensor_array(var_idx)->mutable_int_data()->Swap(
&tmp_data);
}
if (model_config->_is_lod_fetch[idx]) {
if (in->at(idx).lod.size() > 0) {
for (int j = 0; j < in->at(idx).lod[0].size(); ++j) {
fetch_p->mutable_tensor_array(var_idx)->add_lod(
in->at(idx).lod[0][j]);
}
}
}
VLOG(2) << "(logid=" << log_id << ") fetch var ["
<< model_config->_fetch_name[idx] << "] ready";
var_idx++;
}
std::vector<std::vector<size_t>> lod = lod_tensor->lod();
if (lod.size() > 0) {
for (int j = 0; j < lod[0].size(); ++j) {
tensor->add_lod(lod[0][j]);
}
if (req->profile_server()) {
int64_t end = timeline.TimeStampUS();
// TODO(barriery): multi-model profile_time.
// At present, only the response_op is multi-input, so here we get
// the profile_time by hard coding. It needs to be replaced with
// a more elegant way.
for (uint32_t pi = 0; pi < pre_node_names.size(); ++pi) {
input_blob = get_depend_argument<GeneralBlob>(pre_node_names[pi]);
VLOG(2) << "(logid=" << log_id
<< ") p size for input blob: " << input_blob->p_size;
int profile_time_idx = -1;
if (pi == 0) {
profile_time_idx = 0;
} else {
profile_time_idx = input_blob->p_size - 2;
}
for (; profile_time_idx < input_blob->p_size; ++profile_time_idx) {
res->add_profile_time(input_blob->time_stamp[profile_time_idx]);
}
}
// TODO(guru4elephant): find more elegant way to do this
res->add_profile_time(start);
res->add_profile_time(end);
}
return 0;
}
......@@ -101,4 +205,4 @@ DEFINE_OP(GeneralResponseOp);
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
} // namespace baidu
\ No newline at end of file
......@@ -73,7 +73,7 @@ int GeneralTextReaderOp::inference() {
VLOG(2) << "(logid=" << log_id << ") get resource pointer done.";
std::shared_ptr<PaddleGeneralModelConfig> model_config =
resource.get_general_model_config();
resource.get_general_model_config()[0];
VLOG(2) << "(logid=" << log_id << ") print general model config done.";
......
......@@ -58,7 +58,7 @@ int GeneralTextResponseOp::inference() {
VLOG(2) << "(logid=" << log_id << ") get resource pointer done.";
std::shared_ptr<PaddleGeneralModelConfig> model_config =
resource.get_general_model_config();
resource.get_general_model_config().back();
std::vector<int> fetch_index;
fetch_index.resize(req->fetch_var_names_size());
......
此差异已折叠。
......@@ -21,7 +21,7 @@ namespace baidu {
namespace paddle_serving {
namespace predictor {
enum DataType { FLOAT32, INT64 };
enum DataType { FLOAT32, INT64, INT32 };
class DataBuf {
public:
......@@ -80,8 +80,10 @@ struct Tensor {
size_t ele_byte() const {
if (type == INT64) {
return sizeof(int64_t);
} else {
} else if (type == FLOAT32) {
return sizeof(float);
} else {
return sizeof(int32_t);
}
}
......
......@@ -42,8 +42,8 @@ DynamicResource::~DynamicResource() {}
int DynamicResource::initialize() { return 0; }
std::shared_ptr<PaddleGeneralModelConfig> Resource::get_general_model_config() {
return _config;
std::vector<std::shared_ptr<PaddleGeneralModelConfig> > Resource::get_general_model_config() {
return _configs;
}
void Resource::print_general_model_config(
......@@ -149,30 +149,25 @@ int Resource::initialize(const std::string& path, const std::string& file) {
#endif
if (FLAGS_enable_model_toolkit) {
int err = 0;
std::string model_toolkit_path = resource_conf.model_toolkit_path();
if (err != 0) {
LOG(ERROR) << "read model_toolkit_path failed, path[" << path
<< "], file[" << file << "]";
return -1;
}
std::string model_toolkit_file = resource_conf.model_toolkit_file();
if (err != 0) {
LOG(ERROR) << "read model_toolkit_file failed, path[" << path
<< "], file[" << file << "]";
return -1;
}
if (InferManager::instance().proc_initialize(
model_toolkit_path.c_str(), model_toolkit_file.c_str()) != 0) {
LOG(ERROR) << "failed proc initialize modeltoolkit, config: "
<< model_toolkit_path << "/" << model_toolkit_file;
return -1;
}
size_t model_toolkit_num = resource_conf.model_toolkit_path_size();
for (size_t mi = 0; mi < model_toolkit_num; ++mi) {
std::string model_toolkit_path = resource_conf.model_toolkit_path(mi);
std::string model_toolkit_file = resource_conf.model_toolkit_file(mi);
if (InferManager::instance().proc_initialize(
model_toolkit_path.c_str(), model_toolkit_file.c_str()) != 0) {
LOG(ERROR) << "failed proc initialize modeltoolkit, config: "
<< model_toolkit_path << "/" << model_toolkit_file;
return -1;
}
if (KVManager::instance().proc_initialize(
model_toolkit_path.c_str(), model_toolkit_file.c_str()) != 0) {
LOG(ERROR) << "Failed proc initialize kvmanager, config: "
<< model_toolkit_path << "/" << model_toolkit_file;
if (KVManager::instance().proc_initialize(
model_toolkit_path.c_str(), model_toolkit_file.c_str()) != 0) {
LOG(ERROR) << "Failed proc initialize kvmanager, config: "
<< model_toolkit_path << "/" << model_toolkit_file;
}
}
}
......@@ -231,80 +226,79 @@ int Resource::general_model_initialize(const std::string& path,
LOG(ERROR) << "Failed initialize resource from: " << path << "/" << file;
return -1;
}
int err = 0;
std::string general_model_path = resource_conf.general_model_path();
std::string general_model_file = resource_conf.general_model_file();
if (err != 0) {
LOG(ERROR) << "read general_model_path failed, path[" << path << "], file["
<< file << "]";
return -1;
}
size_t general_model_num = resource_conf.general_model_path_size();
for (size_t gi = 0; gi < general_model_num; ++gi) {
GeneralModelConfig model_config;
if (configure::read_proto_conf(general_model_path.c_str(),
general_model_file.c_str(),
&model_config) != 0) {
LOG(ERROR) << "Failed initialize model config from: " << general_model_path
<< "/" << general_model_file;
return -1;
}
_config.reset(new PaddleGeneralModelConfig());
int feed_var_num = model_config.feed_var_size();
VLOG(2) << "load general model config";
VLOG(2) << "feed var num: " << feed_var_num;
_config->_feed_name.resize(feed_var_num);
_config->_feed_alias_name.resize(feed_var_num);
_config->_feed_type.resize(feed_var_num);
_config->_is_lod_feed.resize(feed_var_num);
_config->_capacity.resize(feed_var_num);
_config->_feed_shape.resize(feed_var_num);
for (int i = 0; i < feed_var_num; ++i) {
_config->_feed_name[i] = model_config.feed_var(i).name();
_config->_feed_alias_name[i] = model_config.feed_var(i).alias_name();
VLOG(2) << "feed var[" << i << "]: " << _config->_feed_name[i];
VLOG(2) << "feed var[" << i << "]: " << _config->_feed_alias_name[i];
_config->_feed_type[i] = model_config.feed_var(i).feed_type();
VLOG(2) << "feed type[" << i << "]: " << _config->_feed_type[i];
if (model_config.feed_var(i).is_lod_tensor()) {
VLOG(2) << "var[" << i << "] is lod tensor";
_config->_feed_shape[i] = {-1};
_config->_is_lod_feed[i] = true;
} else {
VLOG(2) << "var[" << i << "] is tensor";
_config->_capacity[i] = 1;
_config->_is_lod_feed[i] = false;
for (int j = 0; j < model_config.feed_var(i).shape_size(); ++j) {
int32_t dim = model_config.feed_var(i).shape(j);
VLOG(2) << "var[" << i << "].shape[" << i << "]: " << dim;
_config->_feed_shape[i].push_back(dim);
_config->_capacity[i] *= dim;
std::string general_model_path = resource_conf.general_model_path(gi);
std::string general_model_file = resource_conf.general_model_file(gi);
GeneralModelConfig model_config;
if (configure::read_proto_conf(general_model_path.c_str(),
general_model_file.c_str(),
&model_config) != 0) {
LOG(ERROR) << "Failed initialize model config from: " << general_model_path
<< "/" << general_model_file;
return -1;
}
auto _config = std::make_shared<PaddleGeneralModelConfig>();
int feed_var_num = model_config.feed_var_size();
VLOG(2) << "load general model config";
VLOG(2) << "feed var num: " << feed_var_num;
_config->_feed_name.resize(feed_var_num);
_config->_feed_alias_name.resize(feed_var_num);
_config->_feed_type.resize(feed_var_num);
_config->_is_lod_feed.resize(feed_var_num);
_config->_capacity.resize(feed_var_num);
_config->_feed_shape.resize(feed_var_num);
for (int i = 0; i < feed_var_num; ++i) {
_config->_feed_name[i] = model_config.feed_var(i).name();
_config->_feed_alias_name[i] = model_config.feed_var(i).alias_name();
VLOG(2) << "feed var[" << i << "]: " << _config->_feed_name[i];
VLOG(2) << "feed var[" << i << "]: " << _config->_feed_alias_name[i];
_config->_feed_type[i] = model_config.feed_var(i).feed_type();
VLOG(2) << "feed type[" << i << "]: " << _config->_feed_type[i];
if (model_config.feed_var(i).is_lod_tensor()) {
VLOG(2) << "var[" << i << "] is lod tensor";
_config->_feed_shape[i] = {-1};
_config->_is_lod_feed[i] = true;
} else {
VLOG(2) << "var[" << i << "] is tensor";
_config->_capacity[i] = 1;
_config->_is_lod_feed[i] = false;
for (int j = 0; j < model_config.feed_var(i).shape_size(); ++j) {
int32_t dim = model_config.feed_var(i).shape(j);
VLOG(2) << "var[" << i << "].shape[" << i << "]: " << dim;
_config->_feed_shape[i].push_back(dim);
_config->_capacity[i] *= dim;
}
}
}
}
int fetch_var_num = model_config.fetch_var_size();
_config->_is_lod_fetch.resize(fetch_var_num);
_config->_fetch_name.resize(fetch_var_num);
_config->_fetch_alias_name.resize(fetch_var_num);
_config->_fetch_shape.resize(fetch_var_num);
for (int i = 0; i < fetch_var_num; ++i) {
_config->_fetch_name[i] = model_config.fetch_var(i).name();
_config->_fetch_alias_name[i] = model_config.fetch_var(i).alias_name();
_config->_fetch_name_to_index[_config->_fetch_name[i]] = i;
_config->_fetch_alias_name_to_index[_config->_fetch_alias_name[i]] = i;
if (model_config.fetch_var(i).is_lod_tensor()) {
VLOG(2) << "fetch var[" << i << "] is lod tensor";
_config->_fetch_shape[i] = {-1};
_config->_is_lod_fetch[i] = true;
} else {
_config->_is_lod_fetch[i] = false;
for (int j = 0; j < model_config.fetch_var(i).shape_size(); ++j) {
int dim = model_config.fetch_var(i).shape(j);
_config->_fetch_shape[i].push_back(dim);
int fetch_var_num = model_config.fetch_var_size();
_config->_is_lod_fetch.resize(fetch_var_num);
_config->_fetch_name.resize(fetch_var_num);
_config->_fetch_alias_name.resize(fetch_var_num);
_config->_fetch_shape.resize(fetch_var_num);
for (int i = 0; i < fetch_var_num; ++i) {
_config->_fetch_name[i] = model_config.fetch_var(i).name();
_config->_fetch_alias_name[i] = model_config.fetch_var(i).alias_name();
_config->_fetch_name_to_index[_config->_fetch_name[i]] = i;
_config->_fetch_alias_name_to_index[_config->_fetch_alias_name[i]] = i;
if (model_config.fetch_var(i).is_lod_tensor()) {
VLOG(2) << "fetch var[" << i << "] is lod tensor";
_config->_fetch_shape[i] = {-1};
_config->_is_lod_fetch[i] = true;
} else {
_config->_is_lod_fetch[i] = false;
for (int j = 0; j < model_config.fetch_var(i).shape_size(); ++j) {
int dim = model_config.fetch_var(i).shape(j);
_config->_fetch_shape[i].push_back(dim);
}
}
}
_configs.push_back(std::move(_config));
}
return 0;
}
......
......@@ -94,7 +94,7 @@ class Resource {
int finalize();
std::shared_ptr<PaddleGeneralModelConfig> get_general_model_config();
std::vector<std::shared_ptr<PaddleGeneralModelConfig> > get_general_model_config();
void print_general_model_config(
const std::shared_ptr<PaddleGeneralModelConfig>& config);
......@@ -107,7 +107,7 @@ class Resource {
private:
int thread_finalize() { return 0; }
std::shared_ptr<PaddleGeneralModelConfig> _config;
std::vector<std::shared_ptr<PaddleGeneralModelConfig> > _configs;
std::string cube_config_fullpath;
int cube_quant_bits; // 0 if no empty
......
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
......@@ -126,7 +126,7 @@ int main(int argc, char** argv) {
return 0;
}
google::ParseCommandLineFlags(&argc, &argv, true);
//google::ParseCommandLineFlags(&argc, &argv, true);
g_change_server_port();
......@@ -202,6 +202,7 @@ int main(int argc, char** argv) {
}
VLOG(2) << "Succ call pthread worker start function";
//this is not used by any code segment,which can be cancelled.
if (Resource::instance().general_model_initialize(FLAGS_resource_path,
FLAGS_resource_file) != 0) {
LOG(ERROR) << "Failed to initialize general model conf: "
......
此差异已折叠。
/*******************************************************************************
* *
* Author : Angus Johnson *
* Version : 6.4.2 *
* Date : 27 February 2017 *
* Website : http://www.angusj.com *
* Copyright : Angus Johnson 2010-2017 *
* *
* License: *
* Use, modification & distribution is subject to Boost Software License Ver 1. *
* http://www.boost.org/LICENSE_1_0.txt *
* *
* Attributions: *
* The code in this library is an extension of Bala Vatti's clipping algorithm: *
* "A generic solution to polygon clipping" *
* Communications of the ACM, Vol 35, Issue 7 (July 1992) pp 56-63. *
* http://portal.acm.org/citation.cfm?id=129906 *
* *
* Computer graphics and geometric modeling: implementation and algorithms *
* By Max K. Agoston *
* Springer; 1 edition (January 4, 2005) *
* http://books.google.com/books?q=vatti+clipping+agoston *
* *
* See also: *
* "Polygon Offsetting by Computing Winding Numbers" *
* Paper no. DETC2005-85513 pp. 565-575 *
* ASME 2005 International Design Engineering Technical Conferences *
* and Computers and Information in Engineering Conference (IDETC/CIE2005) *
* September 24-28, 2005 , Long Beach, California, USA *
* http://www.me.berkeley.edu/~mcmains/pubs/DAC05OffsetPolygon.pdf *
* *
*******************************************************************************/
#ifndef clipper_hpp
#define clipper_hpp
#define CLIPPER_VERSION "6.4.2"
// use_int32: When enabled 32bit ints are used instead of 64bit ints. This
// improve performance but coordinate values are limited to the range +/- 46340
//#define use_int32
// use_xyz: adds a Z member to IntPoint. Adds a minor cost to perfomance.
//#define use_xyz
// use_lines: Enables line clipping. Adds a very minor cost to performance.
#define use_lines
// use_deprecated: Enables temporary support for the obsolete functions
//#define use_deprecated
#include <cstdlib>
#include <cstring>
#include <functional>
#include <list>
#include <ostream>
#include <queue>
#include <set>
#include <stdexcept>
#include <vector>
namespace ClipperLib {
enum ClipType { ctIntersection, ctUnion, ctDifference, ctXor };
enum PolyType { ptSubject, ptClip };
// By far the most widely used winding rules for polygon filling are
// EvenOdd & NonZero (GDI, GDI+, XLib, OpenGL, Cairo, AGG, Quartz, SVG, Gr32)
// Others rules include Positive, Negative and ABS_GTR_EQ_TWO (only in OpenGL)
// see http://glprogramming.com/red/chapter11.html
enum PolyFillType { pftEvenOdd, pftNonZero, pftPositive, pftNegative };
#ifdef use_int32
typedef int cInt;
static cInt const loRange = 0x7FFF;
static cInt const hiRange = 0x7FFF;
#else
typedef signed long long cInt;
static cInt const loRange = 0x3FFFFFFF;
static cInt const hiRange = 0x3FFFFFFFFFFFFFFFLL;
typedef signed long long long64; // used by Int128 class
typedef unsigned long long ulong64;
#endif
struct IntPoint {
cInt X;
cInt Y;
#ifdef use_xyz
cInt Z;
IntPoint(cInt x = 0, cInt y = 0, cInt z = 0) : X(x), Y(y), Z(z){};
#else
IntPoint(cInt x = 0, cInt y = 0) : X(x), Y(y){};
#endif
friend inline bool operator==(const IntPoint &a, const IntPoint &b) {
return a.X == b.X && a.Y == b.Y;
}
friend inline bool operator!=(const IntPoint &a, const IntPoint &b) {
return a.X != b.X || a.Y != b.Y;
}
};
//------------------------------------------------------------------------------
typedef std::vector<IntPoint> Path;
typedef std::vector<Path> Paths;
inline Path &operator<<(Path &poly, const IntPoint &p) {
poly.push_back(p);
return poly;
}
inline Paths &operator<<(Paths &polys, const Path &p) {
polys.push_back(p);
return polys;
}
std::ostream &operator<<(std::ostream &s, const IntPoint &p);
std::ostream &operator<<(std::ostream &s, const Path &p);
std::ostream &operator<<(std::ostream &s, const Paths &p);
struct DoublePoint {
double X;
double Y;
DoublePoint(double x = 0, double y = 0) : X(x), Y(y) {}
DoublePoint(IntPoint ip) : X((double)ip.X), Y((double)ip.Y) {}
};
//------------------------------------------------------------------------------
#ifdef use_xyz
typedef void (*ZFillCallback)(IntPoint &e1bot, IntPoint &e1top, IntPoint &e2bot,
IntPoint &e2top, IntPoint &pt);
#endif
enum InitOptions {
ioReverseSolution = 1,
ioStrictlySimple = 2,
ioPreserveCollinear = 4
};
enum JoinType { jtSquare, jtRound, jtMiter };
enum EndType {
etClosedPolygon,
etClosedLine,
etOpenButt,
etOpenSquare,
etOpenRound
};
class PolyNode;
typedef std::vector<PolyNode *> PolyNodes;
class PolyNode {
public:
PolyNode();
virtual ~PolyNode(){};
Path Contour;
PolyNodes Childs;
PolyNode *Parent;
PolyNode *GetNext() const;
bool IsHole() const;
bool IsOpen() const;
int ChildCount() const;
private:
// PolyNode& operator =(PolyNode& other);
unsigned Index; // node index in Parent.Childs
bool m_IsOpen;
JoinType m_jointype;
EndType m_endtype;
PolyNode *GetNextSiblingUp() const;
void AddChild(PolyNode &child);
friend class Clipper; // to access Index
friend class ClipperOffset;
};
class PolyTree : public PolyNode {
public:
~PolyTree() { Clear(); };
PolyNode *GetFirst() const;
void Clear();
int Total() const;
private:
// PolyTree& operator =(PolyTree& other);
PolyNodes AllNodes;
friend class Clipper; // to access AllNodes
};
bool Orientation(const Path &poly);
double Area(const Path &poly);
int PointInPolygon(const IntPoint &pt, const Path &path);
void SimplifyPolygon(const Path &in_poly, Paths &out_polys,
PolyFillType fillType = pftEvenOdd);
void SimplifyPolygons(const Paths &in_polys, Paths &out_polys,
PolyFillType fillType = pftEvenOdd);
void SimplifyPolygons(Paths &polys, PolyFillType fillType = pftEvenOdd);
void CleanPolygon(const Path &in_poly, Path &out_poly, double distance = 1.415);
void CleanPolygon(Path &poly, double distance = 1.415);
void CleanPolygons(const Paths &in_polys, Paths &out_polys,
double distance = 1.415);
void CleanPolygons(Paths &polys, double distance = 1.415);
void MinkowskiSum(const Path &pattern, const Path &path, Paths &solution,
bool pathIsClosed);
void MinkowskiSum(const Path &pattern, const Paths &paths, Paths &solution,
bool pathIsClosed);
void MinkowskiDiff(const Path &poly1, const Path &poly2, Paths &solution);
void PolyTreeToPaths(const PolyTree &polytree, Paths &paths);
void ClosedPathsFromPolyTree(const PolyTree &polytree, Paths &paths);
void OpenPathsFromPolyTree(PolyTree &polytree, Paths &paths);
void ReversePath(Path &p);
void ReversePaths(Paths &p);
struct IntRect {
cInt left;
cInt top;
cInt right;
cInt bottom;
};
// enums that are used internally ...
enum EdgeSide { esLeft = 1, esRight = 2 };
// forward declarations (for stuff used internally) ...
struct TEdge;
struct IntersectNode;
struct LocalMinimum;
struct OutPt;
struct OutRec;
struct Join;
typedef std::vector<OutRec *> PolyOutList;
typedef std::vector<TEdge *> EdgeList;
typedef std::vector<Join *> JoinList;
typedef std::vector<IntersectNode *> IntersectList;
//------------------------------------------------------------------------------
// ClipperBase is the ancestor to the Clipper class. It should not be
// instantiated directly. This class simply abstracts the conversion of sets of
// polygon coordinates into edge objects that are stored in a LocalMinima list.
class ClipperBase {
public:
ClipperBase();
virtual ~ClipperBase();
virtual bool AddPath(const Path &pg, PolyType PolyTyp, bool Closed);
bool AddPaths(const Paths &ppg, PolyType PolyTyp, bool Closed);
virtual void Clear();
IntRect GetBounds();
bool PreserveCollinear() { return m_PreserveCollinear; };
void PreserveCollinear(bool value) { m_PreserveCollinear = value; };
protected:
void DisposeLocalMinimaList();
TEdge *AddBoundsToLML(TEdge *e, bool IsClosed);
virtual void Reset();
TEdge *ProcessBound(TEdge *E, bool IsClockwise);
void InsertScanbeam(const cInt Y);
bool PopScanbeam(cInt &Y);
bool LocalMinimaPending();
bool PopLocalMinima(cInt Y, const LocalMinimum *&locMin);
OutRec *CreateOutRec();
void DisposeAllOutRecs();
void DisposeOutRec(PolyOutList::size_type index);
void SwapPositionsInAEL(TEdge *edge1, TEdge *edge2);
void DeleteFromAEL(TEdge *e);
void UpdateEdgeIntoAEL(TEdge *&e);
typedef std::vector<LocalMinimum> MinimaList;
MinimaList::iterator m_CurrentLM;
MinimaList m_MinimaList;
bool m_UseFullRange;
EdgeList m_edges;
bool m_PreserveCollinear;
bool m_HasOpenPaths;
PolyOutList m_PolyOuts;
TEdge *m_ActiveEdges;
typedef std::priority_queue<cInt> ScanbeamList;
ScanbeamList m_Scanbeam;
};
//------------------------------------------------------------------------------
class Clipper : public virtual ClipperBase {
public:
Clipper(int initOptions = 0);
bool Execute(ClipType clipType, Paths &solution,
PolyFillType fillType = pftEvenOdd);
bool Execute(ClipType clipType, Paths &solution, PolyFillType subjFillType,
PolyFillType clipFillType);
bool Execute(ClipType clipType, PolyTree &polytree,
PolyFillType fillType = pftEvenOdd);
bool Execute(ClipType clipType, PolyTree &polytree, PolyFillType subjFillType,
PolyFillType clipFillType);
bool ReverseSolution() { return m_ReverseOutput; };
void ReverseSolution(bool value) { m_ReverseOutput = value; };
bool StrictlySimple() { return m_StrictSimple; };
void StrictlySimple(bool value) { m_StrictSimple = value; };
// set the callback function for z value filling on intersections (otherwise Z
// is 0)
#ifdef use_xyz
void ZFillFunction(ZFillCallback zFillFunc);
#endif
protected:
virtual bool ExecuteInternal();
private:
JoinList m_Joins;
JoinList m_GhostJoins;
IntersectList m_IntersectList;
ClipType m_ClipType;
typedef std::list<cInt> MaximaList;
MaximaList m_Maxima;
TEdge *m_SortedEdges;
bool m_ExecuteLocked;
PolyFillType m_ClipFillType;
PolyFillType m_SubjFillType;
bool m_ReverseOutput;
bool m_UsingPolyTree;
bool m_StrictSimple;
#ifdef use_xyz
ZFillCallback m_ZFill; // custom callback
#endif
void SetWindingCount(TEdge &edge);
bool IsEvenOddFillType(const TEdge &edge) const;
bool IsEvenOddAltFillType(const TEdge &edge) const;
void InsertLocalMinimaIntoAEL(const cInt botY);
void InsertEdgeIntoAEL(TEdge *edge, TEdge *startEdge);
void AddEdgeToSEL(TEdge *edge);
bool PopEdgeFromSEL(TEdge *&edge);
void CopyAELToSEL();
void DeleteFromSEL(TEdge *e);
void SwapPositionsInSEL(TEdge *edge1, TEdge *edge2);
bool IsContributing(const TEdge &edge) const;
bool IsTopHorz(const cInt XPos);
void DoMaxima(TEdge *e);
void ProcessHorizontals();
void ProcessHorizontal(TEdge *horzEdge);
void AddLocalMaxPoly(TEdge *e1, TEdge *e2, const IntPoint &pt);
OutPt *AddLocalMinPoly(TEdge *e1, TEdge *e2, const IntPoint &pt);
OutRec *GetOutRec(int idx);
void AppendPolygon(TEdge *e1, TEdge *e2);
void IntersectEdges(TEdge *e1, TEdge *e2, IntPoint &pt);
OutPt *AddOutPt(TEdge *e, const IntPoint &pt);
OutPt *GetLastOutPt(TEdge *e);
bool ProcessIntersections(const cInt topY);
void BuildIntersectList(const cInt topY);
void ProcessIntersectList();
void ProcessEdgesAtTopOfScanbeam(const cInt topY);
void BuildResult(Paths &polys);
void BuildResult2(PolyTree &polytree);
void SetHoleState(TEdge *e, OutRec *outrec);
void DisposeIntersectNodes();
bool FixupIntersectionOrder();
void FixupOutPolygon(OutRec &outrec);
void FixupOutPolyline(OutRec &outrec);
bool IsHole(TEdge *e);
bool FindOwnerFromSplitRecs(OutRec &outRec, OutRec *&currOrfl);
void FixHoleLinkage(OutRec &outrec);
void AddJoin(OutPt *op1, OutPt *op2, const IntPoint offPt);
void ClearJoins();
void ClearGhostJoins();
void AddGhostJoin(OutPt *op, const IntPoint offPt);
bool JoinPoints(Join *j, OutRec *outRec1, OutRec *outRec2);
void JoinCommonEdges();
void DoSimplePolygons();
void FixupFirstLefts1(OutRec *OldOutRec, OutRec *NewOutRec);
void FixupFirstLefts2(OutRec *InnerOutRec, OutRec *OuterOutRec);
void FixupFirstLefts3(OutRec *OldOutRec, OutRec *NewOutRec);
#ifdef use_xyz
void SetZ(IntPoint &pt, TEdge &e1, TEdge &e2);
#endif
};
//------------------------------------------------------------------------------
class ClipperOffset {
public:
ClipperOffset(double miterLimit = 2.0, double roundPrecision = 0.25);
~ClipperOffset();
void AddPath(const Path &path, JoinType joinType, EndType endType);
void AddPaths(const Paths &paths, JoinType joinType, EndType endType);
void Execute(Paths &solution, double delta);
void Execute(PolyTree &solution, double delta);
void Clear();
double MiterLimit;
double ArcTolerance;
private:
Paths m_destPolys;
Path m_srcPoly;
Path m_destPoly;
std::vector<DoublePoint> m_normals;
double m_delta, m_sinA, m_sin, m_cos;
double m_miterLim, m_StepsPerRad;
IntPoint m_lowest;
PolyNode m_polyNodes;
void FixOrientations();
void DoOffset(double delta);
void OffsetPoint(int j, int &k, JoinType jointype);
void DoSquare(int j, int k);
void DoMiter(int j, int k, double r);
void DoRound(int j, int k);
};
//------------------------------------------------------------------------------
class clipperException : public std::exception {
public:
clipperException(const char *description) : m_descr(description) {}
virtual ~clipperException() throw() {}
virtual const char *what() const throw() { return m_descr.c_str(); }
private:
std::string m_descr;
};
//------------------------------------------------------------------------------
} // ClipperLib namespace
#endif // clipper_hpp
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "postprocess_op.h"
namespace PaddleOCR {
void PostProcessor::GetContourArea(const std::vector<std::vector<float>> &box,
float unclip_ratio, float &distance) {
int pts_num = 4;
float area = 0.0f;
float dist = 0.0f;
for (int i = 0; i < pts_num; i++) {
area += box[i][0] * box[(i + 1) % pts_num][1] -
box[i][1] * box[(i + 1) % pts_num][0];
dist += sqrtf((box[i][0] - box[(i + 1) % pts_num][0]) *
(box[i][0] - box[(i + 1) % pts_num][0]) +
(box[i][1] - box[(i + 1) % pts_num][1]) *
(box[i][1] - box[(i + 1) % pts_num][1]));
}
area = fabs(float(area / 2.0));
distance = area * unclip_ratio / dist;
}
cv::RotatedRect PostProcessor::UnClip(std::vector<std::vector<float>> box,
const float &unclip_ratio) {
float distance = 1.0;
GetContourArea(box, unclip_ratio, distance);
ClipperLib::ClipperOffset offset;
ClipperLib::Path p;
p << ClipperLib::IntPoint(int(box[0][0]), int(box[0][1]))
<< ClipperLib::IntPoint(int(box[1][0]), int(box[1][1]))
<< ClipperLib::IntPoint(int(box[2][0]), int(box[2][1]))
<< ClipperLib::IntPoint(int(box[3][0]), int(box[3][1]));
offset.AddPath(p, ClipperLib::jtRound, ClipperLib::etClosedPolygon);
ClipperLib::Paths soln;
offset.Execute(soln, distance);
std::vector<cv::Point2f> points;
for (int j = 0; j < soln.size(); j++) {
for (int i = 0; i < soln[soln.size() - 1].size(); i++) {
points.emplace_back(soln[j][i].X, soln[j][i].Y);
}
}
cv::RotatedRect res;
if (points.size() <= 0) {
res = cv::RotatedRect(cv::Point2f(0, 0), cv::Size2f(1, 1), 0);
} else {
res = cv::minAreaRect(points);
}
return res;
}
float **PostProcessor::Mat2Vec(cv::Mat mat) {
auto **array = new float *[mat.rows];
for (int i = 0; i < mat.rows; ++i)
array[i] = new float[mat.cols];
for (int i = 0; i < mat.rows; ++i) {
for (int j = 0; j < mat.cols; ++j) {
array[i][j] = mat.at<float>(i, j);
}
}
return array;
}
std::vector<std::vector<int>>
PostProcessor::OrderPointsClockwise(std::vector<std::vector<int>> pts) {
std::vector<std::vector<int>> box = pts;
std::sort(box.begin(), box.end(), XsortInt);
std::vector<std::vector<int>> leftmost = {box[0], box[1]};
std::vector<std::vector<int>> rightmost = {box[2], box[3]};
if (leftmost[0][1] > leftmost[1][1])
std::swap(leftmost[0], leftmost[1]);
if (rightmost[0][1] > rightmost[1][1])
std::swap(rightmost[0], rightmost[1]);
std::vector<std::vector<int>> rect = {leftmost[0], rightmost[0], rightmost[1],
leftmost[1]};
return rect;
}
std::vector<std::vector<float>> PostProcessor::Mat2Vector(cv::Mat mat) {
std::vector<std::vector<float>> img_vec;
std::vector<float> tmp;
for (int i = 0; i < mat.rows; ++i) {
tmp.clear();
for (int j = 0; j < mat.cols; ++j) {
tmp.push_back(mat.at<float>(i, j));
}
img_vec.push_back(tmp);
}
return img_vec;
}
bool PostProcessor::XsortFp32(std::vector<float> a, std::vector<float> b) {
if (a[0] != b[0])
return a[0] < b[0];
return false;
}
bool PostProcessor::XsortInt(std::vector<int> a, std::vector<int> b) {
if (a[0] != b[0])
return a[0] < b[0];
return false;
}
std::vector<std::vector<float>> PostProcessor::GetMiniBoxes(cv::RotatedRect box,
float &ssid) {
ssid = std::max(box.size.width, box.size.height);
cv::Mat points;
cv::boxPoints(box, points);
auto array = Mat2Vector(points);
std::sort(array.begin(), array.end(), XsortFp32);
std::vector<float> idx1 = array[0], idx2 = array[1], idx3 = array[2],
idx4 = array[3];
if (array[3][1] <= array[2][1]) {
idx2 = array[3];
idx3 = array[2];
} else {
idx2 = array[2];
idx3 = array[3];
}
if (array[1][1] <= array[0][1]) {
idx1 = array[1];
idx4 = array[0];
} else {
idx1 = array[0];
idx4 = array[1];
}
array[0] = idx1;
array[1] = idx2;
array[2] = idx3;
array[3] = idx4;
return array;
}
float PostProcessor::BoxScoreFast(std::vector<std::vector<float>> box_array,
cv::Mat pred) {
auto array = box_array;
int width = pred.cols;
int height = pred.rows;
float box_x[4] = {array[0][0], array[1][0], array[2][0], array[3][0]};
float box_y[4] = {array[0][1], array[1][1], array[2][1], array[3][1]};
int xmin = clamp(int(std::floor(*(std::min_element(box_x, box_x + 4)))), 0,
width - 1);
int xmax = clamp(int(std::ceil(*(std::max_element(box_x, box_x + 4)))), 0,
width - 1);
int ymin = clamp(int(std::floor(*(std::min_element(box_y, box_y + 4)))), 0,
height - 1);
int ymax = clamp(int(std::ceil(*(std::max_element(box_y, box_y + 4)))), 0,
height - 1);
cv::Mat mask;
mask = cv::Mat::zeros(ymax - ymin + 1, xmax - xmin + 1, CV_8UC1);
cv::Point root_point[4];
root_point[0] = cv::Point(int(array[0][0]) - xmin, int(array[0][1]) - ymin);
root_point[1] = cv::Point(int(array[1][0]) - xmin, int(array[1][1]) - ymin);
root_point[2] = cv::Point(int(array[2][0]) - xmin, int(array[2][1]) - ymin);
root_point[3] = cv::Point(int(array[3][0]) - xmin, int(array[3][1]) - ymin);
const cv::Point *ppt[1] = {root_point};
int npt[] = {4};
cv::fillPoly(mask, ppt, npt, 1, cv::Scalar(1));
cv::Mat croppedImg;
pred(cv::Rect(xmin, ymin, xmax - xmin + 1, ymax - ymin + 1))
.copyTo(croppedImg);
auto score = cv::mean(croppedImg, mask)[0];
return score;
}
std::vector<std::vector<std::vector<int>>>
PostProcessor::BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap,
const float &box_thresh,
const float &det_db_unclip_ratio) {
const int min_size = 3;
const int max_candidates = 1000;
int width = bitmap.cols;
int height = bitmap.rows;
std::vector<std::vector<cv::Point>> contours;
std::vector<cv::Vec4i> hierarchy;
cv::findContours(bitmap, contours, hierarchy, cv::RETR_LIST,
cv::CHAIN_APPROX_SIMPLE);
int num_contours =
contours.size() >= max_candidates ? max_candidates : contours.size();
std::vector<std::vector<std::vector<int>>> boxes;
for (int _i = 0; _i < num_contours; _i++) {
if (contours[_i].size() <= 2) {
continue;
}
float ssid;
cv::RotatedRect box = cv::minAreaRect(contours[_i]);
auto array = GetMiniBoxes(box, ssid);
auto box_for_unclip = array;
// end get_mini_box
if (ssid < min_size) {
continue;
}
float score;
score = BoxScoreFast(array, pred);
if (score < box_thresh)
continue;
// start for unclip
cv::RotatedRect points = UnClip(box_for_unclip, det_db_unclip_ratio);
if (points.size.height < 1.001 && points.size.width < 1.001) {
continue;
}
// end for unclip
cv::RotatedRect clipbox = points;
auto cliparray = GetMiniBoxes(clipbox, ssid);
if (ssid < min_size + 2)
continue;
int dest_width = pred.cols;
int dest_height = pred.rows;
std::vector<std::vector<int>> intcliparray;
for (int num_pt = 0; num_pt < 4; num_pt++) {
std::vector<int> a{int(clampf(roundf(cliparray[num_pt][0] / float(width) *
float(dest_width)),
0, float(dest_width))),
int(clampf(roundf(cliparray[num_pt][1] /
float(height) * float(dest_height)),
0, float(dest_height)))};
intcliparray.push_back(a);
}
boxes.push_back(intcliparray);
} // end for
return boxes;
}
std::vector<std::vector<std::vector<int>>>
PostProcessor::FilterTagDetRes(std::vector<std::vector<std::vector<int>>> boxes,
float ratio_h, float ratio_w, cv::Mat srcimg) {
int oriimg_h = srcimg.rows;
int oriimg_w = srcimg.cols;
std::vector<std::vector<std::vector<int>>> root_points;
for (int n = 0; n < boxes.size(); n++) {
boxes[n] = OrderPointsClockwise(boxes[n]);
for (int m = 0; m < boxes[0].size(); m++) {
boxes[n][m][0] /= ratio_w;
boxes[n][m][1] /= ratio_h;
boxes[n][m][0] = int(_min(_max(boxes[n][m][0], 0), oriimg_w - 1));
boxes[n][m][1] = int(_min(_max(boxes[n][m][1], 0), oriimg_h - 1));
}
}
for (int n = 0; n < boxes.size(); n++) {
int rect_width, rect_height;
rect_width = int(sqrt(pow(boxes[n][0][0] - boxes[n][1][0], 2) +
pow(boxes[n][0][1] - boxes[n][1][1], 2)));
rect_height = int(sqrt(pow(boxes[n][0][0] - boxes[n][3][0], 2) +
pow(boxes[n][0][1] - boxes[n][3][1], 2)));
if (rect_width <= 4 || rect_height <= 4)
continue;
root_points.push_back(boxes[n]);
}
return root_points;
}
} // namespace PaddleOCR
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
#include "clipper.h"
#include "utility.h"
using namespace std;
namespace PaddleOCR {
class PostProcessor {
public:
void GetContourArea(const std::vector<std::vector<float>> &box,
float unclip_ratio, float &distance);
cv::RotatedRect UnClip(std::vector<std::vector<float>> box,
const float &unclip_ratio);
float **Mat2Vec(cv::Mat mat);
std::vector<std::vector<int>>
OrderPointsClockwise(std::vector<std::vector<int>> pts);
std::vector<std::vector<float>> GetMiniBoxes(cv::RotatedRect box,
float &ssid);
float BoxScoreFast(std::vector<std::vector<float>> box_array, cv::Mat pred);
std::vector<std::vector<std::vector<int>>>
BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap,
const float &box_thresh, const float &det_db_unclip_ratio);
std::vector<std::vector<std::vector<int>>>
FilterTagDetRes(std::vector<std::vector<std::vector<int>>> boxes,
float ratio_h, float ratio_w, cv::Mat srcimg);
private:
static bool XsortInt(std::vector<int> a, std::vector<int> b);
static bool XsortFp32(std::vector<float> a, std::vector<float> b);
std::vector<std::vector<float>> Mat2Vector(cv::Mat mat);
inline int _max(int a, int b) { return a >= b ? a : b; }
inline int _min(int a, int b) { return a >= b ? b : a; }
template <class T> inline T clamp(T x, T min, T max) {
if (x > max)
return max;
if (x < min)
return min;
return x;
}
inline float clampf(float x, float min, float max) {
if (x > max)
return max;
if (x < min)
return min;
return x;
}
};
} // namespace PaddleOCR
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
//#include "paddle_api.h"
//#include "paddle_inference_api.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
#include "preprocess_op.h"
namespace PaddleOCR {
void Permute::Run(const cv::Mat *im, float *data) {
int rh = im->rows;
int rw = im->cols;
int rc = im->channels();
for (int i = 0; i < rc; ++i) {
cv::extractChannel(*im, cv::Mat(rh, rw, CV_32FC1, data + i * rh * rw), i);
}
}
void Normalize::Run(cv::Mat *im, const std::vector<float> &mean,
const std::vector<float> &scale, const bool is_scale) {
double e = 1.0;
if (is_scale) {
e /= 255.0;
}
(*im).convertTo(*im, CV_32FC3, e);
for (int h = 0; h < im->rows; h++) {
for (int w = 0; w < im->cols; w++) {
im->at<cv::Vec3f>(h, w)[0] =
(im->at<cv::Vec3f>(h, w)[0] - mean[0]) * scale[0];
im->at<cv::Vec3f>(h, w)[1] =
(im->at<cv::Vec3f>(h, w)[1] - mean[1]) * scale[1];
im->at<cv::Vec3f>(h, w)[2] =
(im->at<cv::Vec3f>(h, w)[2] - mean[2]) * scale[2];
}
}
}
void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img,
int max_size_len, float &ratio_h, float &ratio_w,
bool use_tensorrt) {
int w = img.cols;
int h = img.rows;
float ratio = 1.f;
int max_wh = w >= h ? w : h;
if (max_wh > max_size_len) {
if (h > w) {
ratio = float(max_size_len) / float(h);
} else {
ratio = float(max_size_len) / float(w);
}
}
int resize_h = int(float(h) * ratio);
int resize_w = int(float(w) * ratio);
if (resize_h % 32 == 0)
resize_h = resize_h;
else if (resize_h / 32 < 1 + 1e-5)
resize_h = 32;
else
resize_h = (resize_h / 32) * 32;
if (resize_w % 32 == 0)
resize_w = resize_w;
else if (resize_w / 32 < 1 + 1e-5)
resize_w = 32;
else
resize_w = (resize_w / 32) * 32;
if (!use_tensorrt) {
cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
ratio_h = float(resize_h) / float(h);
ratio_w = float(resize_w) / float(w);
} else {
cv::resize(img, resize_img, cv::Size(640, 640));
ratio_h = float(640) / float(h);
ratio_w = float(640) / float(w);
}
}
void CrnnResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio,
bool use_tensorrt,
const std::vector<int> &rec_image_shape) {
int imgC, imgH, imgW;
imgC = rec_image_shape[0];
imgH = rec_image_shape[1];
imgW = rec_image_shape[2];
imgW = int(32 * wh_ratio);
float ratio = float(img.cols) / float(img.rows);
int resize_w, resize_h;
if (ceilf(imgH * ratio) > imgW)
resize_w = imgW;
else
resize_w = int(ceilf(imgH * ratio));
if (!use_tensorrt) {
cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
cv::INTER_LINEAR);
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0,
int(imgW - resize_img.cols), cv::BORDER_CONSTANT,
{127, 127, 127});
} else {
int k = int(img.cols * 32 / img.rows);
if (k >= 100) {
cv::resize(img, resize_img, cv::Size(100, 32), 0.f, 0.f,
cv::INTER_LINEAR);
} else {
cv::resize(img, resize_img, cv::Size(k, 32), 0.f, 0.f, cv::INTER_LINEAR);
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, int(100 - k),
cv::BORDER_CONSTANT, {127, 127, 127});
}
}
}
void ClsResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
bool use_tensorrt,
const std::vector<int> &rec_image_shape) {
int imgC, imgH, imgW;
imgC = rec_image_shape[0];
imgH = rec_image_shape[1];
imgW = rec_image_shape[2];
float ratio = float(img.cols) / float(img.rows);
int resize_w, resize_h;
if (ceilf(imgH * ratio) > imgW)
resize_w = imgW;
else
resize_w = int(ceilf(imgH * ratio));
if (!use_tensorrt) {
cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
cv::INTER_LINEAR);
if (resize_w < imgW) {
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, imgW - resize_w,
cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0));
}
} else {
cv::resize(img, resize_img, cv::Size(100, 32), 0.f, 0.f, cv::INTER_LINEAR);
}
}
} // namespace PaddleOCR
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
using namespace std;
//using namespace paddle;
namespace PaddleOCR {
class Normalize {
public:
virtual void Run(cv::Mat *im, const std::vector<float> &mean,
const std::vector<float> &scale, const bool is_scale = true);
};
// RGB -> CHW
class Permute {
public:
virtual void Run(const cv::Mat *im, float *data);
};
class ResizeImgType0 {
public:
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len,
float &ratio_h, float &ratio_w, bool use_tensorrt);
};
class CrnnResizeImg {
public:
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio,
bool use_tensorrt = false,
const std::vector<int> &rec_image_shape = {3, 32, 320});
};
class ClsResizeImg {
public:
virtual void Run(const cv::Mat &img, cv::Mat &resize_img,
bool use_tensorrt = false,
const std::vector<int> &rec_image_shape = {3, 48, 192});
};
} // namespace PaddleOCR
\ No newline at end of file
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <ostream>
#include <vector>
#include "utility.h"
namespace PaddleOCR {
std::vector<std::string> Utility::ReadDict(const std::string &path) {
std::ifstream in(path);
std::string line;
std::vector<std::string> m_vec;
if (in) {
while (getline(in, line)) {
m_vec.push_back(line);
}
} else {
std::cout << "no such label file: " << path << ", exit the program..."
<< std::endl;
exit(1);
}
return m_vec;
}
void Utility::VisualizeBboxes(
const cv::Mat &srcimg,
const std::vector<std::vector<std::vector<int>>> &boxes) {
cv::Mat img_vis;
srcimg.copyTo(img_vis);
for (int n = 0; n < boxes.size(); n++) {
cv::Point rook_points[4];
for (int m = 0; m < boxes[n].size(); m++) {
rook_points[m] = cv::Point(int(boxes[n][m][0]), int(boxes[n][m][1]));
}
const cv::Point *ppt[1] = {rook_points};
int npt[] = {4};
cv::polylines(img_vis, ppt, npt, 1, 1, CV_RGB(0, 255, 0), 2, 8, 0);
}
cv::imwrite("./ocr_vis.png", img_vis);
std::cout << "The detection visualized image saved in ./ocr_vis.png"
<< std::endl;
}
} // namespace PaddleOCR
\ No newline at end of file
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <stdlib.h>
#include <vector>
#include <algorithm>
#include <cstring>
#include <fstream>
#include <numeric>
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
namespace PaddleOCR {
class Utility {
public:
static std::vector<std::string> ReadDict(const std::string &path);
static void
VisualizeBboxes(const cv::Mat &srcimg,
const std::vector<std::vector<std::vector<int>>> &boxes);
template <class ForwardIterator>
inline static size_t argmax(ForwardIterator first, ForwardIterator last) {
return std::distance(first, std::max_element(first, last));
}
};
} // namespace PaddleOCR
\ No newline at end of file
......@@ -117,6 +117,22 @@ make -j10
you can execute `make install` to put targets under directory `./output`, you need to add`-DCMAKE_INSTALL_PREFIX=./output`to specify output path to cmake command shown above.
### Compile C++ Server under the condition of WITH_OPENCV=ON
First of all , opencv library should be installed, if not, please refer to the `Compile and install opencv` section later in this article.
In the compile command, add `DOPENCV_DIR=${OPENCV_DIR}` and `DWITH_OPENCV=ON`,for example:
``` shell
OPENCV_DIR=your_opencv_dir #`your_opencv_dir` is the installation path of OpenCV library。
mkdir server-build-cpu && cd server-build-cpu
cmake -DPYTHON_INCLUDE_DIR=$PYTHON_INCLUDE_DIR/ \
-DPYTHON_LIBRARIES=$PYTHON_LIBRARIES \
-DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \
-DOPENCV_DIR=${OPENCV_DIR} \
-DWITH_OPENCV=ON
-DSERVER=ON ..
make -j10
```
### Integrated GPU version paddle inference library
Compared with CPU environment, GPU environment needs to refer to the following table,
......@@ -209,6 +225,7 @@ Please use the example under `python/examples` to verify.
| WITH_AVX | Compile Paddle Serving with AVX intrinsics | OFF |
| WITH_MKL | Compile Paddle Serving with MKL support | OFF |
| WITH_GPU | Compile Paddle Serving with NVIDIA GPU | OFF |
| WITH_OPENCV | Compile Paddle Serving with OPENCV | OFF |
| CUDNN_LIBRARY | Define CuDNN library and header path | |
| CUDA_TOOLKIT_ROOT_DIR | Define CUDA PATH | |
| TENSORRT_ROOT | Define TensorRT PATH | |
......@@ -247,3 +264,62 @@ The following is the base library version matching relationship used by the Padd
### How to make the compiler detect the CuDNN library
Download the corresponding CUDNN version from NVIDIA developer official website and decompressing it, add `-DCUDNN_ROOT` to cmake command, to specify the path of CUDNN.
## Compile and install opencv
* First of all, you need to download the source code compiled package in the Linux environment from the opencv official website. Taking opencv3.4.7 as an example, the download command is as follows.
```
wget https://github.com/opencv/opencv/archive/3.4.7.tar.gz
tar -xf 3.4.7.tar.gz
```
Finally, you can see the folder of `opencv-3.4.7/` in the current directory.
* Compile opencv, the opencv source path (`root_path`) and installation path (`install_path`) should be set by yourself. Enter the opencv source code path and compile it in the following way.
```shell
root_path=your_opencv_root_path
install_path=${root_path}/opencv3
rm -rf build
mkdir build
cd build
cmake .. \
-DCMAKE_INSTALL_PREFIX=${install_path} \
-DCMAKE_BUILD_TYPE=Release \
-DBUILD_SHARED_LIBS=OFF \
-DWITH_IPP=OFF \
-DBUILD_IPP_IW=OFF \
-DWITH_LAPACK=OFF \
-DWITH_EIGEN=OFF \
-DCMAKE_INSTALL_LIBDIR=lib64 \
-DWITH_ZLIB=ON \
-DBUILD_ZLIB=ON \
-DWITH_JPEG=ON \
-DBUILD_JPEG=ON \
-DWITH_PNG=ON \
-DBUILD_PNG=ON \
-DWITH_TIFF=ON \
-DBUILD_TIFF=ON
make -j
make install
```
Among them, `root_path` is the downloaded opencv source code path, and `install_path` is the installation path of opencv. After `make install` is completed, the opencv header file and library file will be generated in this folder for later OCR source code compilation.
The final file structure under the opencv installation path is as follows.
```
opencv3/
|-- bin
|-- include
|-- lib
|-- lib64
|-- share
```
\ No newline at end of file
......@@ -116,6 +116,22 @@ make -j10
可以执行`make install`把目标产出放在`./output`目录下,cmake阶段需添加`-DCMAKE_INSTALL_PREFIX=./output`选项来指定存放路径。
### 开启WITH_OPENCV选项编译C++ Server
编译Serving C++ Server部分,开启WITH_OPENCV选项时,需要安装安装openCV库,若没有可参考本文档后面的说明编译安装openCV库。
在编译命令中,加入`DOPENCV_DIR=${OPENCV_DIR}``DWITH_OPENCV=ON`选项,例如:
``` shell
OPENCV_DIR=your_opencv_dir #`your_opencv_dir`为opencv库的安装路径。
mkdir server-build-cpu && cd server-build-cpu
cmake -DPYTHON_INCLUDE_DIR=$PYTHON_INCLUDE_DIR/ \
-DPYTHON_LIBRARIES=$PYTHON_LIBRARIES \
-DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \
-DOPENCV_DIR=${OPENCV_DIR} \
-DWITH_OPENCV=ON
-DSERVER=ON ..
make -j10
```
### 集成GPU版本Paddle Inference Library
相比CPU环境,GPU环境需要参考以下表格,
......@@ -153,6 +169,7 @@ make -j10
**注意:** 编译成功后,需要设置`SERVING_BIN`路径,详见后面的[注意事项](https://github.com/PaddlePaddle/Serving/blob/develop/doc/COMPILE_CN.md#注意事项)
## 编译Client部分
``` shell
......@@ -174,7 +191,7 @@ make -j10
mkdir app-build && cd app-build
cmake -DPYTHON_INCLUDE_DIR=$PYTHON_INCLUDE_DIR \
-DPYTHON_LIBRARIES=$PYTHON_LIBRARIES \
-DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \
-DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \
-DAPP=ON ..
make
```
......@@ -211,6 +228,7 @@ make
| WITH_MKL | Compile Paddle Serving with MKL support | OFF |
| WITH_GPU | Compile Paddle Serving with NVIDIA GPU | OFF |
| WITH_TRT | Compile Paddle Serving with TensorRT | OFF |
| WITH_OPENCV | Compile Paddle Serving with OPENCV | OFF |
| CUDNN_LIBRARY | Define CuDNN library and header path | |
| CUDA_TOOLKIT_ROOT_DIR | Define CUDA PATH | |
| TENSORRT_ROOT | Define TensorRT PATH | |
......@@ -248,3 +266,60 @@ Paddle Serving通过PaddlePaddle预测库支持在GPU上做预测。WITH_GPU选
### 如何让Paddle Serving编译系统探测到CuDNN库
从NVIDIA developer官网下载对应版本CuDNN并在本地解压后,在cmake编译命令中增加`-DCUDNN_LIBRARY`参数,指定CuDNN库所在路径。
## 编译安装opencv库
* 首先需要从opencv官网上下载在Linux环境下源码编译的包,以opencv3.4.7为例,下载命令如下。
```
wget https://github.com/opencv/opencv/archive/3.4.7.tar.gz
tar -xf 3.4.7.tar.gz
```
最终可以在当前目录下看到`opencv-3.4.7/`的文件夹。
* 编译opencv,设置opencv源码路径(`root_path`)以及安装路径(`install_path`)。进入opencv源码路径下,按照下面的方式进行编译。
```shell
root_path=your_opencv_root_path
install_path=${root_path}/opencv3
rm -rf build
mkdir build
cd build
cmake .. \
-DCMAKE_INSTALL_PREFIX=${install_path} \
-DCMAKE_BUILD_TYPE=Release \
-DBUILD_SHARED_LIBS=OFF \
-DWITH_IPP=OFF \
-DBUILD_IPP_IW=OFF \
-DWITH_LAPACK=OFF \
-DWITH_EIGEN=OFF \
-DCMAKE_INSTALL_LIBDIR=lib64 \
-DWITH_ZLIB=ON \
-DBUILD_ZLIB=ON \
-DWITH_JPEG=ON \
-DBUILD_JPEG=ON \
-DWITH_PNG=ON \
-DBUILD_PNG=ON \
-DWITH_TIFF=ON \
-DBUILD_TIFF=ON
make -j
make install
```
其中`root_path`为下载的opencv源码路径,`install_path`为opencv的安装路径,`make install`完成之后,会在该文件夹下生成opencv头文件和库文件,用于后面的OCR代码编译。
最终在安装路径下的文件结构如下所示。
```
opencv3/
|-- bin
|-- include
|-- lib
|-- lib64
|-- share
```
\ No newline at end of file
......@@ -28,8 +28,13 @@ test_reader = paddle.batch(
batch_size=1)
for data in test_reader():
new_data = np.zeros((1, 1, 13)).astype("float32")
new_data = np.zeros((1, 13)).astype("float32")
print('testclient.py-----data',data[0][0])
print('testclient.py-----shape',data[0][0].shape)
new_data[0] = data[0][0]
print('testclient.py-----newdata',new_data)
print('testclient.py-----newdata-0',new_data[0])
print('testclient.py-----newdata.shape',new_data.shape)
fetch_map = client.predict(
feed={"x": new_data}, fetch=["price"], batch=True)
print("{} {}".format(fetch_map["price"][0], data[0][1][0]))
......
......@@ -15,6 +15,31 @@ tar -xzvf ocr_det.tar.gz
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/ocr/test_imgs.tar
tar xf test_imgs.tar
```
## C++ OCR Service
### Start Service
Select a startup mode according to CPU / GPU device
After the -- model parameter, the folder path of multiple model files is passed in to start the prediction service of multiple model concatenation.
```
#for cpu user
python -m paddle_serving_server.serve --model ocr_det_model ocr_rec_model --port 9293
#for gpu user
python -m paddle_serving_server_gpu.serve --model ocr_det_model ocr_rec_model --port 9293 --gpu_id 0
```
### Client Prediction
The pre-processing and post-processing is in the C + + server part, the image's Base64 encoded string is passed into the C + + server.
so the value of parameter `feed_var` which is in the file `ocr_det_client/serving_client_conf.prototxt` should be changed.
for this case, `feed_type` should be 3(which means the data type is string),`shape` should be 1.
By passing in multiple client folder paths, the client can be started for multi model prediction.
```
python ocr_c_client_bytes.py ocr_det_client ocr_rec_client
```
## Web Service
......
......@@ -15,6 +15,31 @@ wget --no-check-certificate https://paddle-serving.bj.bcebos.com/ocr/test_imgs.t
tar xf test_imgs.tar
```
## C++ OCR Service服务
### 启动服务
根据CPU/GPU设备选择一种启动方式
通过--model后,指定多个模型文件的文件夹路径来启动多模型串联的预测服务。
```
#for cpu user
python -m paddle_serving_server.serve --model ocr_det_model ocr_rec_model --port 9293
#for gpu user
python -m paddle_serving_server_gpu.serve --model ocr_det_model ocr_rec_model --port 9293 --gpu_id 0
```
### 启动客户端
由于需要在C++Server部分进行前后处理,传入C++Server的仅仅是图片的base64编码的字符串,故第一个模型的Client配置需要修改
`ocr_det_client/serving_client_conf.prototxt``feed_var`字段
对于本示例而言,`feed_type`应修改为3(数据类型为string),`shape`为1.
通过在客户端启动后加入多个client模型的client配置文件夹路径,启动client进行预测。
```
python ocr_c_client_bytes.py ocr_det_client ocr_rec_client
```
## Web Service服务
### 启动服务
......
......@@ -14,7 +14,6 @@
# pylint: disable=doc-string-missing
from . import version
from . import client
from .client import *
......
文件模式从 100644 更改为 100755
......@@ -22,7 +22,8 @@ import os
import json
import base64
import time
from multiprocessing import Pool, Process
from multiprocessing import Process
from .web_service import WebService, port_is_available
from flask import Flask, request
import sys
if sys.version_info.major == 2:
......@@ -41,7 +42,7 @@ def serve_args():
"--device", type=str, default="gpu", help="Type of device")
parser.add_argument("--gpu_ids", type=str, default="", help="gpu ids")
parser.add_argument(
"--model", type=str, default="", help="Model for serving")
"--model", type=str, default="", nargs="+", help="Model for serving")
parser.add_argument(
"--workdir",
type=str,
......@@ -109,18 +110,33 @@ def start_standard_model(serving_port): # pylint: disable=doc-string-missing
if model == "":
print("You must specify your serving model")
exit(-1)
for single_model_config in args.model:
if os.path.isdir(single_model_config):
pass
elif os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.")
import paddle_serving_server as serving
op_maker = serving.OpMaker()
read_op = op_maker.create('general_reader')
general_infer_op = op_maker.create('general_infer')
general_response_op = op_maker.create('general_response')
op_seq_maker = serving.OpSeqMaker()
read_op = op_maker.create('general_reader')
op_seq_maker.add_op(read_op)
op_seq_maker.add_op(general_infer_op)
for idx, single_model in enumerate(model):
infer_op_name = "general_infer"
if len(model) == 2 and idx == 0:
infer_op_name = "general_detection"
else:
infer_op_name = "general_infer"
general_infer_op = op_maker.create(infer_op_name)
op_seq_maker.add_op(general_infer_op)
general_response_op = op_maker.create('general_response')
op_seq_maker.add_op(general_response_op)
server = None
if use_multilang:
server = serving.MultiLangServer()
......@@ -269,8 +285,11 @@ class MainService(BaseHTTPRequestHandler):
return False
else:
key = base64.b64decode(post_data["key"].encode())
with open(args.model + "/key", "wb") as f:
f.write(key)
for single_model_config in args.model:
if os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.")
with open(single_model_config + "/key", "wb") as f:
f.write(key)
return True
def check_key(self, post_data):
......@@ -278,12 +297,17 @@ class MainService(BaseHTTPRequestHandler):
return False
else:
key = base64.b64decode(post_data["key"].encode())
with open(args.model + "/key", "rb") as f:
cur_key = f.read()
return (key == cur_key)
for single_model_config in args.model:
if os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.")
with open(single_model_config + "/key", "rb") as f:
cur_key = f.read()
if key != cur_key:
return False
return True
def start(self, post_data):
post_data = json.loads(post_data)
post_data = json.loads(post_data.decode('utf-8'))
global p_flag
if not p_flag:
if args.use_encryption_model:
......@@ -323,7 +347,14 @@ class MainService(BaseHTTPRequestHandler):
if __name__ == "__main__":
args = serve_args()
args = parse_args()
for single_model_config in args.model:
if os.path.isdir(single_model_config):
pass
elif os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.")
if args.name == "None":
from .web_service import port_is_available
if args.use_encryption_model:
......
......@@ -43,16 +43,17 @@ class Server(object):
def __init__(self):
self.server_handle_ = None
self.infer_service_conf = None
self.model_toolkit_conf = None
self.model_toolkit_conf = []#The quantity is equal to the InferOp quantity,Engine--OP
self.resource_conf = None
self.memory_optimization = False
self.ir_optimization = False
self.model_conf = None
self.workflow_fn = "workflow.prototxt"
self.resource_fn = "resource.prototxt"
self.infer_service_fn = "infer_service.prototxt"
self.model_toolkit_fn = "model_toolkit.prototxt"
self.general_model_config_fn = "general_model.prototxt"
self.model_conf = collections.OrderedDict()# save the serving_server_conf.prototxt content (feed and fetch information) this is a map for multi-model in a workflow
self.workflow_fn = "workflow.prototxt"#only one for one Service,Workflow--Op
self.resource_fn = "resource.prototxt"#only one for one Service,model_toolkit_fn and general_model_config_fn is recorded in this file
self.infer_service_fn = "infer_service.prototxt"#only one for one Service,Service--Workflow
self.model_toolkit_fn = []#["general_infer_0/model_toolkit.prototxt"]The quantity is equal to the InferOp quantity,Engine--OP
self.general_model_config_fn = []#["general_infer_0/general_model.prototxt"]The quantity is equal to the InferOp quantity,Feed and Fetch --OP
self.subdirectory = []#The quantity is equal to the InferOp quantity, and name = node.name = engine.name
self.cube_config_fn = "cube.conf"
self.workdir = ""
self.max_concurrency = 0
......@@ -69,12 +70,12 @@ class Server(object):
self.use_trt = False
self.use_lite = False
self.use_xpu = False
self.model_config_paths = None # for multi-model in a workflow
self.model_config_paths = collections.OrderedDict() # save the serving_server_conf.prototxt path (feed and fetch information) this is a map for multi-model in a workflow
self.product_name = None
self.container_id = None
def get_fetch_list(self):
fetch_names = [var.alias_name for var in self.model_conf.fetch_var]
def get_fetch_list(self,infer_node_idx = -1 ):
fetch_names = [var.alias_name for var in list(self.model_conf.values())[infer_node_idx].fetch_var]
return fetch_names
def set_max_concurrency(self, concurrency):
......@@ -152,6 +153,7 @@ class Server(object):
def _prepare_engine(self, model_config_paths, device, use_encryption_model):
if self.model_toolkit_conf == None:
self.model_toolkit_conf = []
self.model_toolkit_conf = server_sdk.ModelToolkitConf()
for engine_name, model_config_path in model_config_paths.items():
......@@ -177,8 +179,8 @@ class Server(object):
if use_encryption_model:
engine.encrypted_model = True
engine.type = "PADDLE_INFER"
self.model_toolkit_conf.engines.extend([engine])
self.model_toolkit_conf.append(server_sdk.ModelToolkitConf())
self.model_toolkit_conf[-1].engines.extend([engine])
def _prepare_infer_service(self, port):
if self.infer_service_conf == None:
......@@ -192,79 +194,95 @@ class Server(object):
def _prepare_resource(self, workdir, cube_conf):
self.workdir = workdir
if self.resource_conf == None:
with open("{}/{}".format(workdir, self.general_model_config_fn),
"w") as fout:
fout.write(str(self.model_conf))
self.resource_conf = server_sdk.ResourceConf()
for workflow in self.workflow_conf.workflows:
for node in workflow.nodes:
if "dist_kv" in node.name:
self.resource_conf.cube_config_path = workdir
self.resource_conf.cube_config_file = self.cube_config_fn
if cube_conf == None:
raise ValueError(
"Please set the path of cube.conf while use dist_kv op."
)
shutil.copy(cube_conf, workdir)
self.resource_conf.model_toolkit_path = workdir
self.resource_conf.model_toolkit_file = self.model_toolkit_fn
self.resource_conf.general_model_path = workdir
self.resource_conf.general_model_file = self.general_model_config_fn
if self.product_name != None:
self.resource_conf.auth_product_name = self.product_name
if self.container_id != None:
self.resource_conf.auth_container_id = self.container_id
for idx, op_general_model_config_fn in enumerate(self.general_model_config_fn):
with open("{}/{}".format(workdir, op_general_model_config_fn),
"w") as fout:
fout.write(str(list(self.model_conf.values())[idx]))
for workflow in self.workflow_conf.workflows:
for node in workflow.nodes:
if "dist_kv" in node.name:
self.resource_conf.cube_config_path = workdir
self.resource_conf.cube_config_file = self.cube_config_fn
if cube_conf == None:
raise ValueError(
"Please set the path of cube.conf while use dist_kv op."
)
shutil.copy(cube_conf, workdir)
if "quant" in node.name:
self.resource_conf.cube_quant_bits = 8
self.resource_conf.model_toolkit_path.extend([workdir])
self.resource_conf.model_toolkit_file.extend([self.model_toolkit_fn[idx]])
self.resource_conf.general_model_path.extend([workdir])
self.resource_conf.general_model_file.extend([op_general_model_config_fn])
#TODO:figure out the meaning of product_name and container_id.
if self.product_name != None:
self.resource_conf.auth_product_name = self.product_name
if self.container_id != None:
self.resource_conf.auth_container_id = self.container_id
def _write_pb_str(self, filepath, pb_obj):
with open(filepath, "w") as fout:
fout.write(str(pb_obj))
def load_model_config(self, model_config_paths):
def load_model_config(self, model_config_paths_args):
# At present, Serving needs to configure the model path in
# the resource.prototxt file to determine the input and output
# format of the workflow. To ensure that the input and output
# of multiple models are the same.
workflow_oi_config_path = None
if isinstance(model_config_paths, str):
if isinstance(model_config_paths_args, str):
model_config_paths_args = [model_config_paths_args]
for single_model_config in model_config_paths_args:
if os.path.isdir(single_model_config):
pass
elif os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.")
if isinstance(model_config_paths_args, list):
# If there is only one model path, use the default infer_op.
# Because there are several infer_op type, we need to find
# it from workflow_conf.
default_engine_names = [
'general_infer_0', 'general_dist_kv_infer_0',
'general_dist_kv_quant_infer_0'
default_engine_types = [
'GeneralInferOp', 'GeneralDistKVInferOp',
'GeneralDistKVQuantInferOp','GeneralDetectionOp',
]
engine_name = None
# now only support single-workflow.
# TODO:support multi-workflow
model_config_paths_list_idx = 0
for node in self.workflow_conf.workflows[0].nodes:
if node.name in default_engine_names:
engine_name = node.name
break
if engine_name is None:
raise Exception(
"You have set the engine_name of Op. Please use the form {op: model_path} to configure model path"
)
self.model_config_paths = {engine_name: model_config_paths}
workflow_oi_config_path = self.model_config_paths[engine_name]
elif isinstance(model_config_paths, dict):
self.model_config_paths = {}
for node_str, path in model_config_paths.items():
if node.type in default_engine_types:
if node.name is None:
raise Exception(
"You have set the engine_name of Op. Please use the form {op: model_path} to configure model path"
)
f = open("{}/serving_server_conf.prototxt".format(
model_config_paths_args[model_config_paths_list_idx]), 'r')
self.model_conf[node.name] = google.protobuf.text_format.Merge(str(f.read()), m_config.GeneralModelConfig())
self.model_config_paths[node.name] = model_config_paths_args[model_config_paths_list_idx]
self.general_model_config_fn.append(node.name+"/general_model.prototxt")
self.model_toolkit_fn.append(node.name+"/model_toolkit.prototxt")
self.subdirectory.append(node.name)
model_config_paths_list_idx += 1
if model_config_paths_list_idx == len(model_config_paths_args):
break
#Right now, this is not useful.
elif isinstance(model_config_paths_args, dict):
self.model_config_paths = collections.OrderedDict()
for node_str, path in model_config_paths_args.items():
node = server_sdk.DAGNode()
google.protobuf.text_format.Parse(node_str, node)
self.model_config_paths[node.name] = path
print("You have specified multiple model paths, please ensure "
"that the input and output of multiple models are the same.")
workflow_oi_config_path = list(self.model_config_paths.items())[0][
1]
f = open("{}/serving_server_conf.prototxt".format(path), 'r')
self.model_conf[node.name] = google.protobuf.text_format.Merge(
str(f.read()), m_config.GeneralModelConfig())
else:
raise Exception("The type of model_config_paths must be str or "
raise Exception("The type of model_config_paths must be str or list or "
"dict({op: model_path}), not {}.".format(
type(model_config_paths)))
self.model_conf = m_config.GeneralModelConfig()
f = open(
"{}/serving_server_conf.prototxt".format(workflow_oi_config_path),
'r')
self.model_conf = google.protobuf.text_format.Merge(
str(f.read()), self.model_conf)
type(model_config_paths_args)))
# check config here
# print config here
......@@ -371,6 +389,10 @@ class Server(object):
os.system("mkdir -p {}".format(workdir))
os.system("touch {}/fluid_time_file".format(workdir))
for subdir in self.subdirectory:
os.system("mkdir {}/{}".format(workdir, subdir))
os.system("touch {}/{}/fluid_time_file".format(workdir, subdir))
if not self.port_is_available(port):
raise SystemExit("Port {} is already used".format(port))
......@@ -382,14 +404,17 @@ class Server(object):
self.workdir = workdir
infer_service_fn = "{}/{}".format(workdir, self.infer_service_fn)
workflow_fn = "{}/{}".format(workdir, self.workflow_fn)
resource_fn = "{}/{}".format(workdir, self.resource_fn)
model_toolkit_fn = "{}/{}".format(workdir, self.model_toolkit_fn)
self._write_pb_str(infer_service_fn, self.infer_service_conf)
workflow_fn = "{}/{}".format(workdir, self.workflow_fn)
self._write_pb_str(workflow_fn, self.workflow_conf)
resource_fn = "{}/{}".format(workdir, self.resource_fn)
self._write_pb_str(resource_fn, self.resource_conf)
self._write_pb_str(model_toolkit_fn, self.model_toolkit_conf)
for idx,single_model_toolkit_fn in enumerate(self.model_toolkit_fn):
model_toolkit_fn = "{}/{}".format(workdir, single_model_toolkit_fn)
self._write_pb_str(model_toolkit_fn, self.model_toolkit_conf[idx])
def port_is_available(self, port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
......@@ -476,7 +501,6 @@ class Server(object):
os.system(command)
class MultiLangServer(object):
def __init__(self):
self.bserver_ = Server()
......@@ -532,17 +556,51 @@ class MultiLangServer(object):
def set_gpuid(self, gpuid=0):
self.bserver_.set_gpuid(gpuid)
def load_model_config(self, server_config_paths, client_config_path=None):
self.bserver_.load_model_config(server_config_paths)
def load_model_config(self, server_config_dir_paths, client_config_path=None):
if isinstance(server_config_dir_paths, str):
server_config_dir_paths = [server_config_dir_paths]
elif isinstance(server_config_dir_paths, list):
pass
else:
raise Exception("The type of model_config_paths must be str or list"
", not {}.".format(
type(server_config_dir_paths)))
for single_model_config in server_config_dir_paths:
if os.path.isdir(single_model_config):
pass
elif os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.")
self.bserver_.load_model_config(server_config_dir_paths)
if client_config_path is None:
if isinstance(server_config_paths, dict):
#now dict is not useful.
if isinstance(server_config_dir_paths, dict):
self.is_multi_model_ = True
client_config_path = '{}/serving_server_conf.prototxt'.format(
list(server_config_paths.items())[0][1])
client_config_path = []
for server_config_path_items in list(server_config_dir_paths.items()):
client_config_path.append( server_config_path_items[1] )
elif isinstance(server_config_dir_paths, list):
self.is_multi_model_ = False
client_config_path = server_config_dir_paths
else:
client_config_path = '{}/serving_server_conf.prototxt'.format(
server_config_paths)
self.bclient_config_path_ = client_config_path
raise Exception("The type of model_config_paths must be str or list or "
"dict({op: model_path}), not {}.".format(
type(server_config_dir_paths)))
if isinstance(client_config_path, str):
client_config_path = [client_config_path]
elif isinstance(client_config_path, list):
pass
else:# dict is not support right now.
raise Exception("The type of client_config_path must be str or list or "
"dict({op: model_path}), not {}.".format(
type(client_config_path)))
if len(client_config_path) != len(server_config_dir_paths):
raise Warning("The len(client_config_path) is {}, != len(server_config_dir_paths) {}."
.format( len(client_config_path), len(server_config_dir_paths) )
)
self.bclient_config_path_list = client_config_path
def prepare_server(self,
workdir=None,
......@@ -588,7 +646,7 @@ class MultiLangServer(object):
maximum_concurrent_rpcs=self.concurrency_)
multi_lang_general_model_service_pb2_grpc.add_MultiLangGeneralModelServiceServicer_to_server(
MultiLangServerServiceServicer(
self.bclient_config_path_, self.is_multi_model_,
self.bclient_config_path_list, self.is_multi_model_,
["0.0.0.0:{}".format(self.port_list_[0])]), server)
server.add_insecure_port('[::]:{}'.format(self.gport_))
server.start()
......
......@@ -23,8 +23,7 @@ from paddle_serving_server.serve import start_multi_card
import socket
import sys
import numpy as np
import paddle_serving_server as serving
import os
from paddle_serving_server import pipeline
from paddle_serving_server.pipeline import Op
......@@ -64,23 +63,39 @@ class WebService(object):
def run_service(self):
self._server.run_server()
def load_model_config(self, model_config):
print("This API will be deprecated later. Please do not use it")
self.model_config = model_config
import os
def load_model_config(self, server_config_dir_paths, client_config_path=None):
if isinstance(server_config_dir_paths, str):
server_config_dir_paths = [server_config_dir_paths]
elif isinstance(server_config_dir_paths, list):
pass
for single_model_config in server_config_dir_paths:
if os.path.isdir(single_model_config):
pass
elif os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.")
self.server_config_dir_paths = server_config_dir_paths
from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format
if os.path.isdir(model_config):
client_config = "{}/serving_server_conf.prototxt".format(
model_config)
elif os.path.isfile(model_config):
client_config = model_config
file_path_list = []
for single_model_config in self.server_config_dir_paths:
file_path_list.append( "{}/serving_server_conf.prototxt".format(single_model_config) )
model_conf = m_config.GeneralModelConfig()
f = open(client_config, 'r')
f = open(file_path_list[0], 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.feed_vars = {var.name: var for var in model_conf.feed_var}
if len(file_path_list) > 1:
model_conf = m_config.GeneralModelConfig()
f = open(file_path_list[-1], 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.fetch_vars = {var.name: var for var in model_conf.fetch_var}
if client_config_path == None:
self.client_config_path = self.server_config_dir_paths
def set_gpus(self, gpus):
print("This API will be deprecated later. Please do not use it")
......@@ -102,13 +117,21 @@ class WebService(object):
else:
device = "cpu"
op_maker = serving.OpMaker()
read_op = op_maker.create('general_reader')
general_infer_op = op_maker.create('general_infer')
general_response_op = op_maker.create('general_response')
op_seq_maker = OpSeqMaker()
read_op = op_maker.create('general_reader')
op_seq_maker.add_op(read_op)
op_seq_maker.add_op(general_infer_op)
for idx, single_model in enumerate(self.server_config_dir_paths):
infer_op_name = "general_infer"
if len(self.server_config_dir_paths) == 2 and idx == 0:
infer_op_name = "general_detection"
else:
infer_op_name = "general_infer"
general_infer_op = op_maker.create(infer_op_name)
op_seq_maker.add_op(general_infer_op)
general_response_op = op_maker.create('general_response')
op_seq_maker.add_op(general_response_op)
server = Server()
......@@ -123,7 +146,7 @@ class WebService(object):
if use_xpu:
server.set_xpu()
server.load_model_config(self.model_config)
server.load_model_config(self.server_config_dir_paths)#brpc Server support server_config_dir_paths
if gpuid >= 0:
server.set_gpuid(gpuid)
server.prepare_server(workdir=workdir, port=port, device=device)
......@@ -182,8 +205,7 @@ class WebService(object):
def _launch_web_service(self):
gpu_num = len(self.gpus)
self.client = Client()
self.client.load_client_config("{}/serving_server_conf.prototxt".format(
self.model_config))
self.client.load_client_config(self.client_config_path)
endpoints = ""
if gpu_num > 0:
for i in range(gpu_num):
......@@ -264,6 +286,10 @@ class WebService(object):
self.app_instance = app_instance
def _launch_local_predictor(self, gpu):
# actually, LocalPredictor is like a server, but it is WebService Request initiator
# for WebService it is a Client.
# local_predictor only support single-Model DirPath - Type:str
# so the input must be self.server_config_dir_paths[0]
from paddle_serving_app.local_predict import LocalPredictor
self.client = LocalPredictor()
if gpu:
......@@ -271,11 +297,9 @@ class WebService(object):
# default self.gpus = [0].
if len(self.gpus) == 0:
self.gpus.append(0)
self.client.load_model_config(
"{}".format(self.model_config), use_gpu=True, gpu_id=self.gpus[0])
self.client.load_model_config(self.server_config_dir_paths[0], use_gpu=True, gpu_id=self.gpus[0])
else:
self.client.load_model_config(
"{}".format(self.model_config), use_gpu=False)
self.client.load_model_config(self.server_config_dir_paths[0], use_gpu=False)
def run_web_service(self):
print("This API will be deprecated later. Please do not use it")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册