“fa2e9907823c0e4b5b11de0bab9fd484c86526a3”上不存在“...0.10.0/doc/tutorials/sentiment_analysis/index_en.html”
提交 07eadb82 编写于 作者: H HexToString

fix bug and names

上级 9256114c
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <numeric> #include <numeric>
#include <functional>
#include "core/predictor/common/inner_common.h" #include "core/predictor/common/inner_common.h"
#include "core/predictor/framework/bsf.h" #include "core/predictor/framework/bsf.h"
#include "core/predictor/framework/factory.h" #include "core/predictor/framework/factory.h"
...@@ -68,7 +69,9 @@ class InferEngine { ...@@ -68,7 +69,9 @@ class InferEngine {
virtual int thrd_initialize() { return thrd_initialize_impl(); } virtual int thrd_initialize() { return thrd_initialize_impl(); }
virtual int thrd_clear() { return thrd_clear_impl(); } virtual int thrd_clear() { return thrd_clear_impl(); }
virtual int thrd_finalize() { return thrd_finalize_impl(); } virtual int thrd_finalize() { return thrd_finalize_impl(); }
virtual int infer(const void* in, void* out, uint32_t batch_size = -1) { return infer_impl(in, out, batch_size); } virtual int infer(const void* in, void* out, uint32_t batch_size = -1) {
return infer_impl(in, out, batch_size);
}
virtual int reload() = 0; virtual int reload() = 0;
...@@ -208,7 +211,6 @@ class ReloadableInferEngine : public InferEngine { ...@@ -208,7 +211,6 @@ class ReloadableInferEngine : public InferEngine {
} }
uint64_t version() const { return _version; } uint64_t version() const { return _version; }
uint32_t thread_num() const { return _infer_thread_num; } uint32_t thread_num() const { return _infer_thread_num; }
private: private:
...@@ -335,7 +337,7 @@ class DBReloadableInferEngine : public ReloadableInferEngine { ...@@ -335,7 +337,7 @@ class DBReloadableInferEngine : public ReloadableInferEngine {
md->cores[next_idx] = new (std::nothrow) EngineCore; md->cores[next_idx] = new (std::nothrow) EngineCore;
//params.dump(); // params.dump();
if (!md->cores[next_idx] || md->cores[next_idx]->create(conf) != 0) { if (!md->cores[next_idx] || md->cores[next_idx]->create(conf) != 0) {
LOG(ERROR) << "Failed create model, path: " << conf.model_dir(); LOG(ERROR) << "Failed create model, path: " << conf.model_dir();
return -1; return -1;
...@@ -491,71 +493,86 @@ class CloneDBReloadableInferEngine ...@@ -491,71 +493,86 @@ class CloneDBReloadableInferEngine
_pd; // 进程级EngineCore,多个线程级EngineCore共用该对象的模型数据 _pd; // 进程级EngineCore,多个线程级EngineCore共用该对象的模型数据
}; };
template <typename PaddleInferenceCore> template <typename EngineCore>
#ifdef WITH_TRT #ifdef WITH_TRT
class FluidInferEngine : public DBReloadableInferEngine<PaddleInferenceCore> { class FluidInferEngine : public DBReloadableInferEngine<EngineCore> {
#else #else
class FluidInferEngine : public CloneDBReloadableInferEngine<PaddleInferenceCore> { class FluidInferEngine : public CloneDBReloadableInferEngine<EngineCore> {
#endif #endif
public: // NOLINT public: // NOLINT
FluidInferEngine() {} FluidInferEngine() {}
~FluidInferEngine() {} ~FluidInferEngine() {}
typedef std::vector<paddle::PaddleTensor> TensorVector; typedef std::vector<paddle::PaddleTensor> TensorVector;
int infer_impl(const void* in, void* out, uint32_t batch_size = -1) { int infer_impl(const void* in, void* out, uint32_t batch_size = -1) {
//First of all, get the real core acording to the template parameter 'PaddleInferenceCore'. // First of all, get the real core acording to the
PaddleInferenceCore* core =DBReloadableInferEngine<PaddleInferenceCore>::get_core(); // Template parameter <EngineCore>.
EngineCore* core = DBReloadableInferEngine<EngineCore>::get_core();
if (!core || !core->get()) { if (!core || !core->get()) {
LOG(ERROR) << "Failed get fluid core in infer_impl()"; LOG(ERROR) << "Failed get fluid core in infer_impl()";
return -1; return -1;
} }
//We use the for loop to process the input data. // We use the for loop to process the input data.
//Inside each for loop, use the in[i]->name as inputName and call 'core->GetInputHandle(inputName)' to get the pointer of InputData. // Inside each for loop, use the in[i]->name as inputName and call
//Set the lod and shape information of InputData first. then copy data from cpu to the core. // 'core->GetInputHandle(inputName)' to get the pointer of InputData.
const TensorVector* tensorVector_in_pointer = reinterpret_cast<const TensorVector*>(in); // Set the lod and shape information of InputData first.
// Then copy data from cpu to the core.
const TensorVector* tensorVector_in_pointer =
reinterpret_cast<const TensorVector*>(in);
for (int i=0; i < tensorVector_in_pointer->size(); ++i) { for (int i=0; i < tensorVector_in_pointer->size(); ++i) {
auto lod_tensor_in = core->GetInputHandle((*tensorVector_in_pointer)[i].name); auto lod_tensor_in =
core->GetInputHandle((*tensorVector_in_pointer)[i].name);
lod_tensor_in->SetLoD((*tensorVector_in_pointer)[i].lod); lod_tensor_in->SetLoD((*tensorVector_in_pointer)[i].lod);
lod_tensor_in->Reshape((*tensorVector_in_pointer)[i].shape); lod_tensor_in->Reshape((*tensorVector_in_pointer)[i].shape);
void* origin_data = (*tensorVector_in_pointer)[i].data.data(); void* origin_data = (*tensorVector_in_pointer)[i].data.data();
//Because the core needs to determine the size of memory space according to the data type passed in. // Because the core needs to determine the size of memory space
//The pointer type of data must be one of float *,int64_t*,int32_t* instead void*. // according to the data type passed in.
// The pointer type of data must be one of
// float *,int64_t*,int32_t* instead void*.
if ((*tensorVector_in_pointer)[i].dtype == paddle::PaddleDType::FLOAT32) { if ((*tensorVector_in_pointer)[i].dtype == paddle::PaddleDType::FLOAT32) {
float* data = static_cast<float*>(origin_data); float* data = static_cast<float*>(origin_data);
lod_tensor_in->CopyFromCpu(data); lod_tensor_in->CopyFromCpu(data);
}else if ((*tensorVector_in_pointer)[i].dtype == paddle::PaddleDType::INT64) { } else if ((*tensorVector_in_pointer)[i].dtype ==
paddle::PaddleDType::INT64) {
int64_t* data = static_cast<int64_t*>(origin_data); int64_t* data = static_cast<int64_t*>(origin_data);
lod_tensor_in->CopyFromCpu(data); lod_tensor_in->CopyFromCpu(data);
}else if ((*tensorVector_in_pointer)[i].dtype == paddle::PaddleDType::INT32) { } else if ((*tensorVector_in_pointer)[i].dtype ==
paddle::PaddleDType::INT32) {
int32_t* data = static_cast<int32_t*>(origin_data); int32_t* data = static_cast<int32_t*>(origin_data);
lod_tensor_in->CopyFromCpu(data); lod_tensor_in->CopyFromCpu(data);
} }
} }
//After the input data is passed in, call 'core->Run()' perform the prediction process. // After the input data is passed in,
// call 'core->Run()' perform the prediction process.
if (!core->Run()) { if (!core->Run()) {
LOG(ERROR) << "Failed run fluid family core"; LOG(ERROR) << "Failed run fluid family core";
return -1; return -1;
} }
// In order to get the results,
//In order to get the results, first, call the 'core->GetOutputNames()' to get the name of output(which is a dict like {OutputName:pointer of OutputValue}). // first, call the 'core->GetOutputNames()' to get the name of output
//Then, use for-loop to get OutputValue by calling 'core->GetOutputHandle'. // (which is a dict like {OutputName:pointer of OutputValue}).
// Then, use for-loop to get OutputValue by calling 'core->GetOutputHandle'.
std::vector<std::string> outnames = core->GetOutputNames(); std::vector<std::string> outnames = core->GetOutputNames();
std::vector<int> output_shape; std::vector<int> output_shape;
int out_num =0; int out_num = 0;
int dataType =0; int dataType = 0;
void* databuf_data = NULL; void* databuf_data = NULL;
char* databuf_char = NULL; char* databuf_char = NULL;
size_t databuf_size = 0; size_t databuf_size = 0;
TensorVector* tensorVector_out_pointer = reinterpret_cast<TensorVector*>(out); TensorVector* tensorVector_out_pointer =
reinterpret_cast<TensorVector*>(out);
if (!tensorVector_out_pointer) { if (!tensorVector_out_pointer) {
LOG(ERROR) << "tensorVector_out_pointer is nullptr,error"; LOG(ERROR) << "tensorVector_out_pointer is nullptr,error";
return -1; return -1;
} }
//Get the type and shape information of OutputData first. then copy data to cpu from the core. // Get the type and shape information of OutputData first.
//The pointer type of data_out must be one of float *,int64_t*,int32_t* instead void*. // then copy data to cpu from the core.
// The pointer type of data_out must be one of
// float *,int64_t*,int32_t* instead void*.
for (int i=0; i < outnames.size(); ++i) { for (int i=0; i < outnames.size(); ++i) {
auto lod_tensor_out = core->GetOutputHandle(outnames[i]); auto lod_tensor_out = core->GetOutputHandle(outnames[i]);
output_shape = lod_tensor_out->shape(); output_shape = lod_tensor_out->shape();
out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<int>()); out_num = std::accumulate(
output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
dataType = lod_tensor_out->type(); dataType = lod_tensor_out->type();
if (dataType == paddle::PaddleDType::FLOAT32) { if (dataType == paddle::PaddleDType::FLOAT32) {
databuf_size = out_num*sizeof(float); databuf_size = out_num*sizeof(float);
...@@ -567,7 +584,7 @@ class FluidInferEngine : public CloneDBReloadableInferEngine<PaddleInferenceCore ...@@ -567,7 +584,7 @@ class FluidInferEngine : public CloneDBReloadableInferEngine<PaddleInferenceCore
float* data_out = reinterpret_cast<float*>(databuf_data); float* data_out = reinterpret_cast<float*>(databuf_data);
lod_tensor_out->CopyToCpu(data_out); lod_tensor_out->CopyToCpu(data_out);
databuf_char = reinterpret_cast<char*>(data_out); databuf_char = reinterpret_cast<char*>(data_out);
}else if (dataType == paddle::PaddleDType::INT64) { } else if (dataType == paddle::PaddleDType::INT64) {
databuf_size = out_num*sizeof(int64_t); databuf_size = out_num*sizeof(int64_t);
databuf_data = MempoolWrapper::instance().malloc(databuf_size); databuf_data = MempoolWrapper::instance().malloc(databuf_size);
if (!databuf_data) { if (!databuf_data) {
...@@ -577,7 +594,7 @@ class FluidInferEngine : public CloneDBReloadableInferEngine<PaddleInferenceCore ...@@ -577,7 +594,7 @@ class FluidInferEngine : public CloneDBReloadableInferEngine<PaddleInferenceCore
int64_t* data_out = reinterpret_cast<int64_t*>(databuf_data); int64_t* data_out = reinterpret_cast<int64_t*>(databuf_data);
lod_tensor_out->CopyToCpu(data_out); lod_tensor_out->CopyToCpu(data_out);
databuf_char = reinterpret_cast<char*>(data_out); databuf_char = reinterpret_cast<char*>(data_out);
}else if (dataType == paddle::PaddleDType::INT32) { } else if (dataType == paddle::PaddleDType::INT32) {
databuf_size = out_num*sizeof(int32_t); databuf_size = out_num*sizeof(int32_t);
databuf_data = MempoolWrapper::instance().malloc(databuf_size); databuf_data = MempoolWrapper::instance().malloc(databuf_size);
if (!databuf_data) { if (!databuf_data) {
...@@ -588,9 +605,11 @@ class FluidInferEngine : public CloneDBReloadableInferEngine<PaddleInferenceCore ...@@ -588,9 +605,11 @@ class FluidInferEngine : public CloneDBReloadableInferEngine<PaddleInferenceCore
lod_tensor_out->CopyToCpu(data_out); lod_tensor_out->CopyToCpu(data_out);
databuf_char = reinterpret_cast<char*>(data_out); databuf_char = reinterpret_cast<char*>(data_out);
} }
//Because task scheduling requires OPs to use 'Channel'(which is a data structure) to transfer data between OPs. // Because task scheduling requires OPs to use 'Channel'
//We need to copy the processed data to the 'Channel' for the next OP. // (which is a data structure) to transfer data between OPs.
//In this function, it means we should copy the 'databuf_char' to the pointer 'void* out'.(which is also called ‘tensorVector_out_pointer’) // We need to copy the processed data to the 'Channel' for the next OP.
// In this function, it means we should copy the 'databuf_char' to
// 'void* out'.(which is also called ‘tensorVector_out_pointer’)
paddle::PaddleTensor tensor_out; paddle::PaddleTensor tensor_out;
tensor_out.name = outnames[i]; tensor_out.name = outnames[i];
tensor_out.dtype = paddle::PaddleDType(dataType); tensor_out.dtype = paddle::PaddleDType(dataType);
...@@ -611,8 +630,6 @@ class FluidInferEngine : public CloneDBReloadableInferEngine<PaddleInferenceCore ...@@ -611,8 +630,6 @@ class FluidInferEngine : public CloneDBReloadableInferEngine<PaddleInferenceCore
int task_infer_impl(const BatchTensor& in, BatchTensor& out) { // NOLINT int task_infer_impl(const BatchTensor& in, BatchTensor& out) { // NOLINT
return infer_impl(&in, &out); return infer_impl(&in, &out);
} }
}; };
typedef FactoryPool<InferEngine> StaticInferFactory; typedef FactoryPool<InferEngine> StaticInferFactory;
...@@ -797,7 +814,9 @@ class VersionedInferEngine : public InferEngine { ...@@ -797,7 +814,9 @@ class VersionedInferEngine : public InferEngine {
int thrd_finalize_impl() { return -1; } int thrd_finalize_impl() { return -1; }
int thrd_clear_impl() { return -1; } int thrd_clear_impl() { return -1; }
int proc_finalize_impl() { return -1; } int proc_finalize_impl() { return -1; }
int infer_impl(const void* in, void* out, uint32_t batch_size = -1) { return -1; } int infer_impl(const void* in, void* out, uint32_t batch_size = -1) {
return -1;
}
int task_infer_impl(const BatchTensor& in, BatchTensor& out) { // NOLINT int task_infer_impl(const BatchTensor& in, BatchTensor& out) { // NOLINT
return -1; return -1;
} // NOLINT } // NOLINT
......
...@@ -42,9 +42,9 @@ static const int max_batch = 32; ...@@ -42,9 +42,9 @@ static const int max_batch = 32;
static const int min_subgraph_size = 3; static const int min_subgraph_size = 3;
// Engine Base // Engine Base
class PaddleEngineBase { class EngineCore {
public: public:
virtual ~PaddleEngineBase() {} virtual ~EngineCore() {}
virtual std::vector<std::string> GetInputNames() { virtual std::vector<std::string> GetInputNames() {
return _predictor->GetInputNames(); return _predictor->GetInputNames();
} }
...@@ -92,7 +92,7 @@ class PaddleEngineBase { ...@@ -92,7 +92,7 @@ class PaddleEngineBase {
}; };
// Paddle Inference Engine // Paddle Inference Engine
class PaddleInferenceEngine : public PaddleEngineBase { class PaddleInferenceEngine : public EngineCore {
public: public:
int create(const configure::EngineDesc& engine_conf) { int create(const configure::EngineDesc& engine_conf) {
std::string model_path = engine_conf.model_dir(); std::string model_path = engine_conf.model_dir();
......
...@@ -93,7 +93,7 @@ def serve_args(): ...@@ -93,7 +93,7 @@ def serve_args():
def start_standard_model(serving_port): # pylint: disable=doc-string-missing def start_standard_model(serving_port): # pylint: disable=doc-string-missing
args = parse_args() args = serve_args()
thread_num = args.thread thread_num = args.thread
model = args.model model = args.model
port = serving_port port = serving_port
...@@ -135,7 +135,6 @@ def start_standard_model(serving_port): # pylint: disable=doc-string-missing ...@@ -135,7 +135,6 @@ def start_standard_model(serving_port): # pylint: disable=doc-string-missing
general_response_op = op_maker.create('general_response') general_response_op = op_maker.create('general_response')
op_seq_maker.add_op(general_response_op) op_seq_maker.add_op(general_response_op)
server = None server = None
if use_multilang: if use_multilang:
server = serving.MultiLangServer() server = serving.MultiLangServer()
...@@ -297,7 +296,8 @@ class MainService(BaseHTTPRequestHandler): ...@@ -297,7 +296,8 @@ class MainService(BaseHTTPRequestHandler):
key = base64.b64decode(post_data["key"].encode()) key = base64.b64decode(post_data["key"].encode())
for single_model_config in args.model: for single_model_config in args.model:
if os.path.isfile(single_model_config): if os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.") raise ValueError(
"The input of --model should be a dir not file.")
with open(single_model_config + "/key", "wb") as f: with open(single_model_config + "/key", "wb") as f:
f.write(key) f.write(key)
return True return True
...@@ -309,7 +309,8 @@ class MainService(BaseHTTPRequestHandler): ...@@ -309,7 +309,8 @@ class MainService(BaseHTTPRequestHandler):
key = base64.b64decode(post_data["key"].encode()) key = base64.b64decode(post_data["key"].encode())
for single_model_config in args.model: for single_model_config in args.model:
if os.path.isfile(single_model_config): if os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.") raise ValueError(
"The input of --model should be a dir not file.")
with open(single_model_config + "/key", "rb") as f: with open(single_model_config + "/key", "rb") as f:
cur_key = f.read() cur_key = f.read()
if key != cur_key: if key != cur_key:
...@@ -394,7 +395,8 @@ if __name__ == "__main__": ...@@ -394,7 +395,8 @@ if __name__ == "__main__":
device=args.device, device=args.device,
use_lite=args.use_lite, use_lite=args.use_lite,
use_xpu=args.use_xpu, use_xpu=args.use_xpu,
ir_optim=args.ir_optim) ir_optim=args.ir_optim,
thread_num=args.thread)
web_service.run_rpc_service() web_service.run_rpc_service()
app_instance = Flask(__name__) app_instance = Flask(__name__)
......
...@@ -27,6 +27,7 @@ import os ...@@ -27,6 +27,7 @@ import os
from paddle_serving_server import pipeline from paddle_serving_server import pipeline
from paddle_serving_server.pipeline import Op from paddle_serving_server.pipeline import Op
def port_is_available(port): def port_is_available(port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.settimeout(2) sock.settimeout(2)
...@@ -36,6 +37,7 @@ def port_is_available(port): ...@@ -36,6 +37,7 @@ def port_is_available(port):
else: else:
return False return False
class WebService(object): class WebService(object):
def __init__(self, name="default_service"): def __init__(self, name="default_service"):
self.name = name self.name = name
...@@ -63,7 +65,9 @@ class WebService(object): ...@@ -63,7 +65,9 @@ class WebService(object):
def run_service(self): def run_service(self):
self._server.run_server() self._server.run_server()
def load_model_config(self, server_config_dir_paths, client_config_path=None): def load_model_config(self,
server_config_dir_paths,
client_config_path=None):
if isinstance(server_config_dir_paths, str): if isinstance(server_config_dir_paths, str):
server_config_dir_paths = [server_config_dir_paths] server_config_dir_paths = [server_config_dir_paths]
elif isinstance(server_config_dir_paths, list): elif isinstance(server_config_dir_paths, list):
...@@ -73,13 +77,15 @@ class WebService(object): ...@@ -73,13 +77,15 @@ class WebService(object):
if os.path.isdir(single_model_config): if os.path.isdir(single_model_config):
pass pass
elif os.path.isfile(single_model_config): elif os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.") raise ValueError(
"The input of --model should be a dir not file.")
self.server_config_dir_paths = server_config_dir_paths self.server_config_dir_paths = server_config_dir_paths
from .proto import general_model_config_pb2 as m_config from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format import google.protobuf.text_format
file_path_list = [] file_path_list = []
for single_model_config in self.server_config_dir_paths: for single_model_config in self.server_config_dir_paths:
file_path_list.append( "{}/serving_server_conf.prototxt".format(single_model_config) ) file_path_list.append("{}/serving_server_conf.prototxt".format(
single_model_config))
model_conf = m_config.GeneralModelConfig() model_conf = m_config.GeneralModelConfig()
f = open(file_path_list[0], 'r') f = open(file_path_list[0], 'r')
...@@ -146,7 +152,8 @@ class WebService(object): ...@@ -146,7 +152,8 @@ class WebService(object):
if use_xpu: if use_xpu:
server.set_xpu() server.set_xpu()
server.load_model_config(self.server_config_dir_paths)#brpc Server support server_config_dir_paths server.load_model_config(self.server_config_dir_paths
) #brpc Server support server_config_dir_paths
if gpuid >= 0: if gpuid >= 0:
server.set_gpuid(gpuid) server.set_gpuid(gpuid)
server.prepare_server(workdir=workdir, port=port, device=device) server.prepare_server(workdir=workdir, port=port, device=device)
...@@ -163,10 +170,12 @@ class WebService(object): ...@@ -163,10 +170,12 @@ class WebService(object):
use_xpu=False, use_xpu=False,
ir_optim=False, ir_optim=False,
gpuid=0, gpuid=0,
thread_num=2,
mem_optim=True): mem_optim=True):
print("This API will be deprecated later. Please do not use it") print("This API will be deprecated later. Please do not use it")
self.workdir = workdir self.workdir = workdir
self.port = port self.port = port
self.thread_num = thread_num
self.device = device self.device = device
self.gpuid = gpuid self.gpuid = gpuid
self.port_list = [] self.port_list = []
...@@ -184,7 +193,7 @@ class WebService(object): ...@@ -184,7 +193,7 @@ class WebService(object):
self.workdir, self.workdir,
self.port_list[0], self.port_list[0],
-1, -1,
thread_num=2, thread_num=self.thread_num,
mem_optim=mem_optim, mem_optim=mem_optim,
use_lite=use_lite, use_lite=use_lite,
use_xpu=use_xpu, use_xpu=use_xpu,
...@@ -196,7 +205,7 @@ class WebService(object): ...@@ -196,7 +205,7 @@ class WebService(object):
"{}_{}".format(self.workdir, i), "{}_{}".format(self.workdir, i),
self.port_list[i], self.port_list[i],
gpuid, gpuid,
thread_num=2, thread_num=self.thread_num,
mem_optim=mem_optim, mem_optim=mem_optim,
use_lite=use_lite, use_lite=use_lite,
use_xpu=use_xpu, use_xpu=use_xpu,
...@@ -297,9 +306,13 @@ class WebService(object): ...@@ -297,9 +306,13 @@ class WebService(object):
# default self.gpus = [0]. # default self.gpus = [0].
if len(self.gpus) == 0: if len(self.gpus) == 0:
self.gpus.append(0) self.gpus.append(0)
self.client.load_model_config(self.server_config_dir_paths[0], use_gpu=True, gpu_id=self.gpus[0]) self.client.load_model_config(
self.server_config_dir_paths[0],
use_gpu=True,
gpu_id=self.gpus[0])
else: else:
self.client.load_model_config(self.server_config_dir_paths[0], use_gpu=False) self.client.load_model_config(
self.server_config_dir_paths[0], use_gpu=False)
def run_web_service(self): def run_web_service(self):
print("This API will be deprecated later. Please do not use it") print("This API will be deprecated later. Please do not use it")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册