提交 53413a48 编写于 作者: H HexToString

delete ununsed code

上级 7efb4a3b
......@@ -27,16 +27,7 @@ message(STATUS "CXX compiler: ${CMAKE_CXX_COMPILER}, version: "
"${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION}")
message(STATUS "C compiler: ${CMAKE_C_COMPILER}, version: "
"${CMAKE_C_COMPILER_ID} ${CMAKE_C_COMPILER_VERSION}")
SET(OPENCV_DIR "" CACHE PATH "Location of libraries")
if(NOT DEFINED OPENCV_DIR)
message(FATAL_ERROR "please set OPENCV_DIR with -DOPENCV_DIR=/path/opencv")
endif()
if (WIN32)
find_package(OpenCV REQUIRED PATHS ${OPENCV_DIR}/build/ NO_DEFAULT_PATH)
else ()
find_package(OpenCV REQUIRED PATHS ${OPENCV_DIR}/share/OpenCV NO_DEFAULT_PATH)
endif ()
include_directories(${OpenCV_INCLUDE_DIRS})
find_package(Git REQUIRED)
find_package(Threads REQUIRED)
find_package(CUDA QUIET)
......@@ -55,19 +46,33 @@ set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING
set(THIRD_PARTY_BUILD_TYPE Release)
option(WITH_AVX "Compile Paddle Serving with AVX intrinsics" OFF)
option(WITH_MKL "Compile Paddle Serving with MKL support." OFF)
option(WITH_GPU "Compile Paddle Serving with NVIDIA GPU" OFF)
option(WITH_LITE "Compile Paddle Serving with Paddle Lite Engine" OFF)
option(WITH_XPU "Compile Paddle Serving with Baidu Kunlun" OFF)
option(WITH_PYTHON "Compile Paddle Serving with Python" ON)
option(CLIENT "Compile Paddle Serving Client" OFF)
option(SERVER "Compile Paddle Serving Server" OFF)
option(APP "Compile Paddle Serving App package" OFF)
option(WITH_ELASTIC_CTR "Compile ELASITC-CTR solution" OFF)
option(PACK "Compile for whl" OFF)
option(WITH_TRT "Compile Paddle Serving with TRT" OFF)
option(PADDLE_ON_INFERENCE "Compile for encryption" ON)
option(WITH_AVX "Compile Paddle Serving with AVX intrinsics" OFF)
option(WITH_MKL "Compile Paddle Serving with MKL support." OFF)
option(WITH_GPU "Compile Paddle Serving with NVIDIA GPU" OFF)
option(WITH_LITE "Compile Paddle Serving with Paddle Lite Engine" OFF)
option(WITH_XPU "Compile Paddle Serving with Baidu Kunlun" OFF)
option(WITH_PYTHON "Compile Paddle Serving with Python" ON)
option(CLIENT "Compile Paddle Serving Client" OFF)
option(SERVER "Compile Paddle Serving Server" OFF)
option(APP "Compile Paddle Serving App package" OFF)
option(WITH_ELASTIC_CTR "Compile ELASITC-CTR solution" OFF)
option(PACK "Compile for whl" OFF)
option(WITH_TRT "Compile Paddle Serving with TRT" OFF)
option(PADDLE_ON_INFERENCE "Compile for encryption" ON)
option(WITH_OPENCV "Compile Paddle Serving with OPENCV" OFF)
if (WITH_OPENCV)
SET(OPENCV_DIR "" CACHE PATH "Location of libraries")
if(NOT DEFINED OPENCV_DIR)
message(FATAL_ERROR "please set OPENCV_DIR with -DOPENCV_DIR=/path/opencv")
endif()
if (WIN32)
find_package(OpenCV REQUIRED PATHS ${OPENCV_DIR}/build/ NO_DEFAULT_PATH)
else ()
find_package(OpenCV REQUIRED PATHS ${OPENCV_DIR}/share/OpenCV NO_DEFAULT_PATH)
endif ()
include_directories(${OpenCV_INCLUDE_DIRS})
endif()
if (PADDLE_ON_INFERENCE)
add_definitions(-DPADDLE_ON_INFERENCE)
......
......@@ -54,7 +54,11 @@ ELSE(WIN32)
SET(ZLIB_LIBRARIES "${ZLIB_INSTALL_DIR}/lib/libz.a" CACHE FILEPATH "zlib library." FORCE)
ENDIF(WIN32)
#ADD_LIBRARY(zlib STATIC IMPORTED GLOBAL)
IF(WITH_OPENCV)
ELSE()
ADD_LIBRARY(zlib STATIC IMPORTED GLOBAL)
ENDIF()
SET_PROPERTY(TARGET zlib PROPERTY IMPORTED_LOCATION ${ZLIB_LIBRARIES})
ADD_DEPENDENCIES(zlib extern_zlib)
......
......@@ -28,7 +28,9 @@ endif()
target_link_libraries(serving -Wl,--whole-archive fluid_cpu_engine
-Wl,--no-whole-archive)
target_link_libraries(serving ${OpenCV_LIBS})
if(WITH_OPENCV)
target_link_libraries(serving ${OpenCV_LIBS})
endif()
target_link_libraries(serving paddle_fluid ${paddle_depend_libs})
target_link_libraries(serving brpc)
target_link_libraries(serving protobuf)
......
FILE(GLOB op_srcs ${CMAKE_CURRENT_LIST_DIR}/*.cpp ${CMAKE_CURRENT_LIST_DIR}/../../predictor/tools/quant.cpp ${CMAKE_CURRENT_LIST_DIR}/../../predictor/tools/ocrtools/*.cpp)
FILE(GLOB op_srcs ${CMAKE_CURRENT_LIST_DIR}/*.cpp ${CMAKE_CURRENT_LIST_DIR}/../../predictor/tools/quant.cpp)
if(WITH_OPENCV)
FILE(GLOB ocrtools_srcs ${CMAKE_CURRENT_LIST_DIR}/../../predictor/tools/ocrtools/*.cpp)
LIST(APPEND op_srcs ${ocrtools_srcs})
else()
set (EXCLUDE_DIR "general_detection_op.cpp")
foreach (TMP_PATH ${op_srcs})
string (FIND ${TMP_PATH} ${EXCLUDE_DIR} EXCLUDE_DIR_FOUND)
if (NOT ${EXCLUDE_DIR_FOUND} EQUAL -1)
list (REMOVE_ITEM op_srcs ${TMP_PATH})
break()
endif ()
endforeach(TMP_PATH)
endif()
LIST(APPEND serving_srcs ${op_srcs})
......@@ -37,7 +37,6 @@ using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
int GeneralInferOp::inference() {
VLOG(2) << "Going to run inference";
std::cout<<"I am GeneralInferOp"<<std::endl;
const std::vector<std::string> pre_node_names = pre_names();
if (pre_node_names.size() != 1) {
LOG(ERROR) << "This op(" << op_name()
......@@ -88,7 +87,6 @@ int GeneralInferOp::inference() {
<< ") Failed do infer in fluid model: " << engine_name().c_str();
return -1;
}
std::cout<<"I am GeneralInferOp finish"<<std::endl;
int64_t end = timeline.TimeStampUS();
CopyBlobInfo(input_blob, output_blob);
AddBlobInfo(output_blob, start);
......
......@@ -70,7 +70,7 @@ int conf_check(const Request *req,
}
int GeneralReaderOp::inference() {
// reade request from client
// read request from client
const Request *req = dynamic_cast<const Request *>(get_request_message());
uint64_t log_id = req->log_id();
int input_var_num = 0;
......@@ -100,7 +100,7 @@ int GeneralReaderOp::inference() {
VLOG(2) << "(logid=" << log_id << ") get resource pointer done.";
//get the first InferOP's model_config as ReaderOp's model_config by default.
std::shared_ptr<PaddleGeneralModelConfig> model_config =
resource.get_general_model_config()[0];
resource.get_general_model_config().front();
// TODO(guru4elephant): how to do conditional check?
/*
......@@ -183,10 +183,13 @@ int GeneralReaderOp::inference() {
LOG(ERROR) << "dst_ptr is nullptr";
return -1;
}
memcpy(dst_ptr, req->insts(0).tensor_array(i).int64_data().data(),databuf_size[i]);
/*
int elem_num = req->insts(0).tensor_array(i).int64_data_size();
for (int k = 0; k < elem_num; ++k) {
dst_ptr[k] = req->insts(0).tensor_array(i).int64_data(k);
}
*/
} else if (elem_type[i] == P_FLOAT32) {
float *dst_ptr = static_cast<float *>(out->at(i).data.data());
VLOG(2) << "(logid=" << log_id << ") first element data in var[" << i
......@@ -195,11 +198,11 @@ int GeneralReaderOp::inference() {
LOG(ERROR) << "dst_ptr is nullptr";
return -1;
}
//memcpy(dst_ptr,req->insts(0).tensor_array(i).float_data(),databuf_size[i]);
int elem_num = req->insts(0).tensor_array(i).float_data_size();
memcpy(dst_ptr, req->insts(0).tensor_array(i).float_data().data(),databuf_size[i]);
/*int elem_num = req->insts(0).tensor_array(i).float_data_size();
for (int k = 0; k < elem_num; ++k) {
dst_ptr[k] = req->insts(0).tensor_array(i).float_data(k);
}
}*/
} else if (elem_type[i] == P_INT32) {
int32_t *dst_ptr = static_cast<int32_t *>(out->at(i).data.data());
VLOG(2) << "(logid=" << log_id << ") first element data in var[" << i
......@@ -208,10 +211,13 @@ int GeneralReaderOp::inference() {
LOG(ERROR) << "dst_ptr is nullptr";
return -1;
}
memcpy(dst_ptr, req->insts(0).tensor_array(i).int_data().data(),databuf_size[i]);
/*
int elem_num = req->insts(0).tensor_array(i).int_data_size();
for (int k = 0; k < elem_num; ++k) {
dst_ptr[k] = req->insts(0).tensor_array(i).int_data(k);
}
*/
} else if (elem_type[i] == P_STRING) {
std::string *dst_ptr = static_cast<std::string *>(out->at(i).data.data());
VLOG(2) << "(logid=" << log_id << ") first element data in var[" << i
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "core/general-server/op/general_ysl_op.h"
#include <algorithm>
#include <iostream>
#include <memory>
#include <sstream>
#include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h"
#include "core/predictor/framework/resource.h"
#include "core/util/include/timer.h"
/*
#include "opencv2/imgcodecs/legacy/constants_c.h"
#include "opencv2/imgproc/types_c.h"
*/
namespace baidu {
namespace paddle_serving {
namespace serving {
using baidu::paddle_serving::Timer;
using baidu::paddle_serving::predictor::MempoolWrapper;
using baidu::paddle_serving::predictor::general_model::Tensor;
using baidu::paddle_serving::predictor::general_model::Response;
using baidu::paddle_serving::predictor::general_model::Request;
using baidu::paddle_serving::predictor::general_model::FetchInst;
using baidu::paddle_serving::predictor::InferManager;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
int GeneralYSLOp::inference() {
VLOG(2) << "Going to run inference";
std::cout<< "i am GeneralYSLOp"<<std::endl;
const std::vector<std::string> pre_node_names = pre_names();
if (pre_node_names.size() != 1) {
LOG(ERROR) << "This op(" << op_name()
<< ") can only have one predecessor op, but received "
<< pre_node_names.size();
return -1;
}
const std::string pre_name = pre_node_names[0];
const GeneralBlob *input_blob = get_depend_argument<GeneralBlob>(pre_name);
if (!input_blob) {
LOG(ERROR) << "input_blob is nullptr,error";
return -1;
}
uint64_t log_id = input_blob->GetLogId();
VLOG(2) << "(logid=" << log_id << ") Get precedent op name: " << pre_name;
GeneralBlob *output_blob = mutable_data<GeneralBlob>();
if (!output_blob) {
LOG(ERROR) << "output_blob is nullptr,error";
return -1;
}
output_blob->SetLogId(log_id);
if (!input_blob) {
LOG(ERROR) << "(logid=" << log_id
<< ") Failed mutable depended argument, op:" << pre_name;
return -1;
}
//for input_data = string(ocr-base64,TensorVector.size == 1)
const TensorVector *in = &input_blob->tensor_vector;
TensorVector* out = &output_blob->tensor_vector;
int batch_size = input_blob->_batch_size;
VLOG(2) << "(logid=" << log_id << ") input batch size: " << batch_size;
output_blob->_batch_size = batch_size;
VLOG(2) << "(logid=" << log_id << ") infer batch size: " << batch_size;
std::vector<int> input_shape;
int in_num =0;
void* databuf_data = NULL;
char* databuf_char = NULL;
size_t databuf_size = 0;
std::string* input_ptr = static_cast<std::string*>(in->at(0).data.data());
std::string base64str = input_ptr[0];
float ratio_h{};
float ratio_w{};
cv::Mat img = Base2Mat(base64str);
cv::Mat srcimg;
cv::Mat resize_img;
cv::Mat resize_img_rec;
cv::Mat crop_img;
img.copyTo(srcimg);
this->resize_op_.Run(img, resize_img, this->max_side_len_, ratio_h, ratio_w,
this->use_tensorrt_);
this->normalize_op_.Run(&resize_img, this->mean_det, this->scale_det,
this->is_scale_);
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
this->permute_op_.Run(&resize_img, input.data());
TensorVector* real_in = new TensorVector();
if (!real_in) {
LOG(ERROR) << "real_in is nullptr,error";
return -1;
}
for (int i = 0; i < in->size(); ++i) {
input_shape = {1, 3, resize_img.rows, resize_img.cols};
std::cout<< "i am thomas young and i want to know the out info name : "<<",shapesize:" <<input_shape.size()<<std::endl;
in_num = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies<int>());
databuf_size = in_num*sizeof(float);
databuf_data = MempoolWrapper::instance().malloc(databuf_size);
if (!databuf_data) {
LOG(ERROR) << "Malloc failed, size: " << databuf_size;
return -1;
}
memcpy(databuf_data,input.data(),databuf_size);
std::cout<< "the out num: "<<in_num<<std::endl;
databuf_char = reinterpret_cast<char*>(databuf_data);
paddle::PaddleBuf paddleBuf(databuf_char, databuf_size);
paddle::PaddleTensor tensor_in;
tensor_in.name = in->at(i).name;
tensor_in.dtype = paddle::PaddleDType::FLOAT32;
tensor_in.shape = {1, 3, resize_img.rows, resize_img.cols};
tensor_in.lod = in->at(i).lod;
tensor_in.data = paddleBuf;
real_in->push_back(tensor_in);
}
Timer timeline;
int64_t start = timeline.TimeStampUS();
timeline.Start();
if (InferManager::instance().infer(
engine_name().c_str(), real_in, out, batch_size)) {
LOG(ERROR) << "(logid=" << log_id
<< ") Failed do infer in fluid model: " << engine_name().c_str();
return -1;
}
std::cout<< "success after infer "<<std::endl;
std::vector<int> output_shape;
int out_num =0;
void* databuf_data_out = NULL;
char* databuf_char_out = NULL;
size_t databuf_size_out = 0;
//this is special add for PaddleOCR postprecess
int infer_outnum = out->size();
for (int k = 0;k <infer_outnum; ++k) {
int n2 = out->at(k).shape[2];
int n3 = out->at(k).shape[3];
int n = n2 * n3;
float* out_data = static_cast<float*>(out->at(k).data.data());
std::vector<float> pred(n, 0.0);
std::vector<unsigned char> cbuf(n, ' ');
for (int i = 0; i < n; i++) {
pred[i] = float(out_data[i]);
cbuf[i] = (unsigned char)((out_data[i]) * 255);
}
cv::Mat cbuf_map(n2, n3, CV_8UC1, (unsigned char *)cbuf.data());
cv::Mat pred_map(n2, n3, CV_32F, (float *)pred.data());
const double threshold = this->det_db_thresh_ * 255;
const double maxvalue = 255;
cv::Mat bit_map;
cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY);
cv::Mat dilation_map;
cv::Mat dila_ele = cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2));
cv::dilate(bit_map, dilation_map, dila_ele);
boxes = post_processor_.BoxesFromBitmap(pred_map, dilation_map,
this->det_db_box_thresh_,
this->det_db_unclip_ratio_);
boxes = post_processor_.FilterTagDetRes(boxes, ratio_h, ratio_w, srcimg);
for (int i = boxes.size() - 1; i >= 0; i--) {
crop_img = GetRotateCropImage(img, boxes[i]);
float wh_ratio = float(crop_img.cols) / float(crop_img.rows);
this->resize_op_rec.Run(crop_img, resize_img_rec, wh_ratio, this->use_tensorrt_);
this->normalize_op_.Run(&resize_img_rec, this->mean_rec, this->scale_rec,
this->is_scale_);
std::vector<float> output_rec(1 * 3 * resize_img_rec.rows * resize_img_rec.cols, 0.0f);
this->permute_op_.Run(&resize_img_rec, output_rec.data());
// Inference.
output_shape = {1, 3, resize_img_rec.rows, resize_img_rec.cols};
std::cout<< "i am thomas young and i want to know the out info name : "<<",shapesize:" <<output_shape.size()<<"shape :";
out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
databuf_size_out = out_num*sizeof(float);
databuf_data_out = MempoolWrapper::instance().malloc(databuf_size_out);
if (!databuf_data_out) {
LOG(ERROR) << "Malloc failed, size: " << databuf_size_out;
return -1;
}
memcpy(databuf_data_out,output_rec.data(),databuf_size_out);
std::cout<< "the out num: "<<out_num<<" value = "<<" ,"<<std::endl;
databuf_char_out = reinterpret_cast<char*>(databuf_data_out);
paddle::PaddleBuf paddleBuf(databuf_char_out, databuf_size_out);
paddle::PaddleTensor tensor_out;
tensor_out.name = "image";
tensor_out.dtype = paddle::PaddleDType::FLOAT32;
tensor_out.shape = {1, 3, resize_img_rec.rows, resize_img_rec.cols};
//tensor_in.lod = in->at(i).lod;
tensor_out.data = paddleBuf;
out->push_back(tensor_out);
}
}
out->erase(out->begin(),out->begin()+infer_outnum);
std::cout<< "success after out process "<<std::endl;
/*this is special add for two fit a line InferOPTest
int var_num = in->size();
out->clear();
for (int k =0; k<var_num; ++k){
out->push_back(in->at(k));
}
*/
/*
for (int k = 0;k <out->size(); ++k) {
out->at(k).data.Resize(13 * sizeof(float));
out->at(k).shape[1] = 13;
out->at(k).name = "x";
float *dst_ptr = static_cast<float *>(out->at(k).data.data());
for(int l =0; l<13; ++l){dst_ptr[l] = (0.1+l);}
}*/
int64_t end = timeline.TimeStampUS();
CopyBlobInfo(input_blob, output_blob);
AddBlobInfo(output_blob, start);
AddBlobInfo(output_blob, end);
return 0;
}
cv::Mat GeneralYSLOp::Base2Mat(std::string &base64_data)
{
cv::Mat img;
std::string s_mat;
s_mat = base64Decode(base64_data.data(), base64_data.size());
std::vector<char> base64_img(s_mat.begin(), s_mat.end());
img = cv::imdecode(base64_img, cv::IMREAD_COLOR);//CV_LOAD_IMAGE_COLOR
return img;
}
std::string GeneralYSLOp::base64Decode(const char* Data, int DataByte)
{
//解码表
const char DecodeTable[] =
{
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
62, // '+'
0, 0, 0,
63, // '/'
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9'
0, 0, 0, 0, 0, 0, 0,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z'
0, 0, 0, 0, 0, 0,
26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z'
};
//返回值
std::string strDecode;
int nValue;
int i = 0;
while (i < DataByte)
{
if (*Data != '\r' && *Data != '\n')
{
nValue = DecodeTable[*Data++] << 18;
nValue += DecodeTable[*Data++] << 12;
strDecode += (nValue & 0x00FF0000) >> 16;
if (*Data != '=')
{
nValue += DecodeTable[*Data++] << 6;
strDecode += (nValue & 0x0000FF00) >> 8;
if (*Data != '=')
{
nValue += DecodeTable[*Data++];
strDecode += nValue & 0x000000FF;
}
}
i += 4;
}
else// 回车换行,跳过
{
Data++;
i++;
}
}
return strDecode;
}
cv::Mat GeneralYSLOp::GetRotateCropImage(const cv::Mat &srcimage,
std::vector<std::vector<int>> box) {
cv::Mat image;
srcimage.copyTo(image);
std::vector<std::vector<int>> points = box;
int x_collect[4] = {box[0][0], box[1][0], box[2][0], box[3][0]};
int y_collect[4] = {box[0][1], box[1][1], box[2][1], box[3][1]};
int left = int(*std::min_element(x_collect, x_collect + 4));
int right = int(*std::max_element(x_collect, x_collect + 4));
int top = int(*std::min_element(y_collect, y_collect + 4));
int bottom = int(*std::max_element(y_collect, y_collect + 4));
cv::Mat img_crop;
image(cv::Rect(left, top, right - left, bottom - top)).copyTo(img_crop);
for (int i = 0; i < points.size(); i++) {
points[i][0] -= left;
points[i][1] -= top;
}
int img_crop_width = int(sqrt(pow(points[0][0] - points[1][0], 2) +
pow(points[0][1] - points[1][1], 2)));
int img_crop_height = int(sqrt(pow(points[0][0] - points[3][0], 2) +
pow(points[0][1] - points[3][1], 2)));
cv::Point2f pts_std[4];
pts_std[0] = cv::Point2f(0., 0.);
pts_std[1] = cv::Point2f(img_crop_width, 0.);
pts_std[2] = cv::Point2f(img_crop_width, img_crop_height);
pts_std[3] = cv::Point2f(0.f, img_crop_height);
cv::Point2f pointsf[4];
pointsf[0] = cv::Point2f(points[0][0], points[0][1]);
pointsf[1] = cv::Point2f(points[1][0], points[1][1]);
pointsf[2] = cv::Point2f(points[2][0], points[2][1]);
pointsf[3] = cv::Point2f(points[3][0], points[3][1]);
cv::Mat M = cv::getPerspectiveTransform(pointsf, pts_std);
cv::Mat dst_img;
cv::warpPerspective(img_crop, dst_img, M,
cv::Size(img_crop_width, img_crop_height),
cv::BORDER_REPLICATE);
if (float(dst_img.rows) >= float(dst_img.cols) * 1.5) {
cv::Mat srcCopy = cv::Mat(dst_img.rows, dst_img.cols, dst_img.depth());
cv::transpose(dst_img, srcCopy);
cv::flip(srcCopy, srcCopy, 0);
return srcCopy;
} else {
return dst_img;
}
}
DEFINE_OP(GeneralYSLOp);
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
\ No newline at end of file
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include <numeric>
#include "core/general-server/general_model_service.pb.h"
#include "core/general-server/op/general_infer_helper.h"
#include "core/predictor/tools/ocrtools/postprocess_op.h"
#include "core/predictor/tools/ocrtools/preprocess_op.h"
#include "paddle_inference_api.h" // NOLINT
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
namespace baidu {
namespace paddle_serving {
namespace serving {
class GeneralYSLOp
: public baidu::paddle_serving::predictor::OpWithChannel<GeneralBlob> {
public:
typedef std::vector<paddle::PaddleTensor> TensorVector;
DECLARE_OP(GeneralYSLOp);
int inference();
private:
//config info
bool use_gpu_ = false;
int gpu_id_ = 0;
int gpu_mem_ = 4000;
int cpu_math_library_num_threads_ = 4;
bool use_mkldnn_ = false;
// pre-process
PaddleOCR::ResizeImgType0 resize_op_;
PaddleOCR::Normalize normalize_op_;
PaddleOCR::Permute permute_op_;
PaddleOCR::CrnnResizeImg resize_op_rec;
bool use_tensorrt_ = false;
bool use_fp16_ = false;
// post-process
PaddleOCR::PostProcessor post_processor_;
//det config info
int max_side_len_ = 960;
double det_db_thresh_ = 0.3;
double det_db_box_thresh_ = 0.5;
double det_db_unclip_ratio_ = 2.0;
std::vector<float> mean_det = {0.485f, 0.456f, 0.406f};
std::vector<float> scale_det = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f};
bool is_scale_ = true;
//rec config info
std::vector<std::string> label_list_;
std::vector<float> mean_rec = {0.5f, 0.5f, 0.5f};
std::vector<float> scale_rec = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
cv::Mat GetRotateCropImage(const cv::Mat &srcimage,
std::vector<std::vector<int>> box);
cv::Mat Base2Mat(std::string &base64_data);
std::string base64Decode(const char* Data, int DataByte);
std::vector<std::vector<std::vector<int>>> boxes;
};
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
......@@ -608,10 +608,6 @@ class FluidInferEngine : public CloneDBReloadableInferEngine<FluidFamilyCore> {
for(int i =0; i< tensorVector_in_pointer->size();++i){
auto lod_tensor_in = core->GetInputHandle((*tensorVector_in_pointer)[i].name);
lod_tensor_in->SetLoD((*tensorVector_in_pointer)[i].lod);
std::cout<< "i am thomas young and i want to know the in info name : "<<(*tensorVector_in_pointer)[i].name
<<",shapesize:" <<(*tensorVector_in_pointer)[i].shape.size()<<"shape :";;
for (auto l = 0; l != (*tensorVector_in_pointer)[i].shape.size(); ++l) std::cout << (*tensorVector_in_pointer)[i].shape[l] << " ,";
std::cout<< std::endl;
lod_tensor_in->Reshape((*tensorVector_in_pointer)[i].shape);
void* origin_data = (*tensorVector_in_pointer)[i].data.data();
//Because the core needs to determine the size of memory space according to the data type passed in.
......@@ -652,10 +648,6 @@ class FluidInferEngine : public CloneDBReloadableInferEngine<FluidFamilyCore> {
for (int i = 0; i < outnames.size(); ++i){
auto lod_tensor_out = core->GetOutputHandle(outnames[i]);
output_shape = lod_tensor_out->shape();
std::cout<< "i am thomas young and i want to know the out info name : "<<outnames[i]
<<",shapesize:" <<output_shape.size()<<"shape :";
for (auto l = 0; l != output_shape.size(); ++l) std::cout << output_shape[l] << " ,";
std::cout<< std::endl;
out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
dataType = lod_tensor_out->type();
if (dataType == paddle::PaddleDType::FLOAT32) {
......@@ -667,7 +659,6 @@ class FluidInferEngine : public CloneDBReloadableInferEngine<FluidFamilyCore> {
}
float* data_out = reinterpret_cast<float*>(databuf_data);
lod_tensor_out->CopyToCpu(data_out);
std::cout<< "the out num: "<<out_num<<" value = "<< data_out[0] <<" ,"<<std::endl;
databuf_char = reinterpret_cast<char*>(data_out);
}else if (dataType == paddle::PaddleDType::INT64) {
databuf_size = out_num*sizeof(int64_t);
......
......@@ -49,7 +49,7 @@ class OpMaker(object):
"general_dist_kv_infer": "GeneralDistKVInferOp",
"general_dist_kv_quant_infer": "GeneralDistKVQuantInferOp",
"general_copy": "GeneralCopyOp",
"general_YSL":"GeneralYSLOp",
"general_detection":"GeneralDetectionOp",
}
self.node_name_suffix_ = collections.defaultdict(int)
......@@ -307,7 +307,7 @@ class Server(object):
# it from workflow_conf.
default_engine_types = [
'GeneralInferOp', 'GeneralDistKVInferOp',
'GeneralDistKVQuantInferOp','GeneralYSLOp',
'GeneralDistKVQuantInferOp','GeneralDetectionOp',
]
model_config_paths_list_idx = 0
for node in self.workflow_conf.workflows[0].nodes:
......
......@@ -114,7 +114,7 @@ def start_standard_model(serving_port): # pylint: disable=doc-string-missing
for idx, single_model in enumerate(model):
infer_op_name = "general_infer"
if len(model) == 2 and idx == 0:
infer_op_name = "general_YSL"
infer_op_name = "general_detection"
else:
infer_op_name = "general_infer"
general_infer_op = op_maker.create(infer_op_name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册