提交 7da09539 编写于 作者: W Wang Guibao 提交者: GitHub

GPU inference support (#4)

* GPU inference

Change-Id: I90dc187bf5e523422cecf618ec5fea2cd0e5bf7a

* GPU inference

Change-Id: I7a4b418d42911ee2e74437a9e20d817aaf0fb878

* GPU prediction

Change-Id: I60d8dfdb04326cfdc41686f1e1595a6292e34488

* GPU inference

Change-Id: I43a6c20abc4d4ac4dd5bef59b60e467bc77c4e63

* GPU inference:
1) Fix image classification batch prediction
2) Fix documentation
3) Add image classification press test case

Change-Id: I28f55ced27eea48be75d41d616bd50291d9cc506

* GPU inference: Fix press test case

Change-Id: I308e3347737170c81e213f012b29419202e63dce
上级 b48354cd
......@@ -28,6 +28,7 @@ message(STATUS "C compiler: ${CMAKE_C_COMPILER}, version: "
find_package(Git REQUIRED)
find_package(Threads REQUIRED)
find_package(CUDA QUIET)
include(simd)
......@@ -43,10 +44,10 @@ set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING
set(THIRD_PARTY_BUILD_TYPE Release)
option(WITH_AVX "Compile Paddle Serving with AVX intrinsics" ${AVX_FOUND})
option(WITH_MKL "Compile Paddle Serving with MKL support." ${AVX_FOUND})
option(CLIENT_ONLY "Compile client libraries and demos only"
FALSE)
option(WITH_AVX "Compile Paddle Serving with AVX intrinsics" ${AVX_FOUND})
option(WITH_MKL "Compile Paddle Serving with MKL support." ${AVX_FOUND})
option(WITH_GPU "Compile Paddle Serving with NVIDIA GPU" ${CUDA_FOUND})
option(CLIENT_ONLY "Compile client libraries and demos only" FALSE)
set(WITH_MKLML ${WITH_MKL})
if (NOT DEFINED WITH_MKLDNN)
......@@ -108,5 +109,8 @@ add_subdirectory(demo-client)
if (NOT CLIENT_ONLY)
add_subdirectory(predictor)
add_subdirectory(inferencer-fluid-cpu)
if (WITH_GPU)
add_subdirectory(inferencer-fluid-gpu)
endif()
add_subdirectory(demo-serving)
endif()
if(NOT WITH_GPU)
return()
endif()
set(paddle_known_gpu_archs "30 35 50 52 60 61 70")
set(paddle_known_gpu_archs7 "30 35 50 52")
set(paddle_known_gpu_archs8 "30 35 50 52 60 61")
set(paddle_known_gpu_archs9 "30 35 50 52 60 61 70")
set(paddle_known_gpu_archs10 "30 35 50 52 60 61 70 75")
######################################################################################
# A function for automatic detection of GPUs installed (if autodetection is enabled)
# Usage:
# detect_installed_gpus(out_variable)
function(detect_installed_gpus out_variable)
if(NOT CUDA_gpu_detect_output)
set(cufile ${PROJECT_BINARY_DIR}/detect_cuda_archs.cu)
file(WRITE ${cufile} ""
"#include <cstdio>\n"
"int main() {\n"
" int count = 0;\n"
" if (cudaSuccess != cudaGetDeviceCount(&count)) return -1;\n"
" if (count == 0) return -1;\n"
" for (int device = 0; device < count; ++device) {\n"
" cudaDeviceProp prop;\n"
" if (cudaSuccess == cudaGetDeviceProperties(&prop, device))\n"
" std::printf(\"%d.%d \", prop.major, prop.minor);\n"
" }\n"
" return 0;\n"
"}\n")
execute_process(COMMAND "${CUDA_NVCC_EXECUTABLE}" "-ccbin=${CUDA_HOST_COMPILER}"
"--run" "${cufile}"
WORKING_DIRECTORY "${PROJECT_BINARY_DIR}/CMakeFiles/"
RESULT_VARIABLE nvcc_res OUTPUT_VARIABLE nvcc_out
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
if(nvcc_res EQUAL 0)
# only keep the last line of nvcc_out
STRING(REGEX REPLACE ";" "\\\\;" nvcc_out "${nvcc_out}")
STRING(REGEX REPLACE "\n" ";" nvcc_out "${nvcc_out}")
list(GET nvcc_out -1 nvcc_out)
string(REPLACE "2.1" "2.1(2.0)" nvcc_out "${nvcc_out}")
set(CUDA_gpu_detect_output ${nvcc_out} CACHE INTERNAL "Returned GPU architetures from detect_installed_gpus tool" FORCE)
endif()
endif()
if(NOT CUDA_gpu_detect_output)
message(STATUS "Automatic GPU detection failed. Building for all known architectures.")
set(${out_variable} ${paddle_known_gpu_archs} PARENT_SCOPE)
else()
set(${out_variable} ${CUDA_gpu_detect_output} PARENT_SCOPE)
endif()
endfunction()
########################################################################
# Function for selecting GPU arch flags for nvcc based on CUDA_ARCH_NAME
# Usage:
# select_nvcc_arch_flags(out_variable)
function(select_nvcc_arch_flags out_variable)
# List of arch names
set(archs_names "Kepler" "Maxwell" "Pascal" "Volta" "Turing" "All" "Manual")
set(archs_name_default "All")
list(APPEND archs_names "Auto")
# set CUDA_ARCH_NAME strings (so it will be seen as dropbox in CMake-Gui)
set(CUDA_ARCH_NAME ${archs_name_default} CACHE STRING "Select target NVIDIA GPU achitecture.")
set_property( CACHE CUDA_ARCH_NAME PROPERTY STRINGS "" ${archs_names} )
mark_as_advanced(CUDA_ARCH_NAME)
# verify CUDA_ARCH_NAME value
if(NOT ";${archs_names};" MATCHES ";${CUDA_ARCH_NAME};")
string(REPLACE ";" ", " archs_names "${archs_names}")
message(FATAL_ERROR "Only ${archs_names} architeture names are supported.")
endif()
if(${CUDA_ARCH_NAME} STREQUAL "Manual")
set(CUDA_ARCH_BIN ${paddle_known_gpu_archs} CACHE STRING "Specify 'real' GPU architectures to build binaries for, BIN(PTX) format is supported")
set(CUDA_ARCH_PTX "50" CACHE STRING "Specify 'virtual' PTX architectures to build PTX intermediate code for")
mark_as_advanced(CUDA_ARCH_BIN CUDA_ARCH_PTX)
else()
unset(CUDA_ARCH_BIN CACHE)
unset(CUDA_ARCH_PTX CACHE)
endif()
if(${CUDA_ARCH_NAME} STREQUAL "Kepler")
set(cuda_arch_bin "30 35")
elseif(${CUDA_ARCH_NAME} STREQUAL "Maxwell")
set(cuda_arch_bin "50")
elseif(${CUDA_ARCH_NAME} STREQUAL "Pascal")
set(cuda_arch_bin "60 61")
elseif(${CUDA_ARCH_NAME} STREQUAL "Volta")
set(cuda_arch_bin "70")
elseif(${CUDA_ARCH_NAME} STREQUAL "Turing")
set(cuda_arch_bin "75")
elseif(${CUDA_ARCH_NAME} STREQUAL "All")
set(cuda_arch_bin ${paddle_known_gpu_archs})
elseif(${CUDA_ARCH_NAME} STREQUAL "Auto")
detect_installed_gpus(cuda_arch_bin)
else() # (${CUDA_ARCH_NAME} STREQUAL "Manual")
set(cuda_arch_bin ${CUDA_ARCH_BIN})
endif()
# remove dots and convert to lists
string(REGEX REPLACE "\\." "" cuda_arch_bin "${cuda_arch_bin}")
string(REGEX REPLACE "\\." "" cuda_arch_ptx "${CUDA_ARCH_PTX}")
string(REGEX MATCHALL "[0-9()]+" cuda_arch_bin "${cuda_arch_bin}")
string(REGEX MATCHALL "[0-9]+" cuda_arch_ptx "${cuda_arch_ptx}")
list(REMOVE_DUPLICATES cuda_arch_bin)
list(REMOVE_DUPLICATES cuda_arch_ptx)
set(nvcc_flags "")
set(nvcc_archs_readable "")
# Tell NVCC to add binaries for the specified GPUs
foreach(arch ${cuda_arch_bin})
if(arch MATCHES "([0-9]+)\\(([0-9]+)\\)")
# User explicitly specified PTX for the concrete BIN
list(APPEND nvcc_flags -gencode arch=compute_${CMAKE_MATCH_2},code=sm_${CMAKE_MATCH_1})
list(APPEND nvcc_archs_readable sm_${CMAKE_MATCH_1})
else()
# User didn't explicitly specify PTX for the concrete BIN, we assume PTX=BIN
list(APPEND nvcc_flags -gencode arch=compute_${arch},code=sm_${arch})
list(APPEND nvcc_archs_readable sm_${arch})
endif()
endforeach()
# Tell NVCC to add PTX intermediate code for the specified architectures
foreach(arch ${cuda_arch_ptx})
list(APPEND nvcc_flags -gencode arch=compute_${arch},code=compute_${arch})
list(APPEND nvcc_archs_readable compute_${arch})
endforeach()
string(REPLACE ";" " " nvcc_archs_readable "${nvcc_archs_readable}")
set(${out_variable} ${nvcc_flags} PARENT_SCOPE)
set(${out_variable}_readable ${nvcc_archs_readable} PARENT_SCOPE)
endfunction()
message(STATUS "CUDA detected: " ${CUDA_VERSION})
if (${CUDA_VERSION} LESS 7.0)
set(paddle_known_gpu_archs ${paddle_known_gpu_archs})
add_definitions("-DPADDLE_CUDA_BINVER=\"60\"")
elseif (${CUDA_VERSION} LESS 8.0) # CUDA 7.x
set(paddle_known_gpu_archs ${paddle_known_gpu_archs7})
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__")
add_definitions("-DPADDLE_CUDA_BINVER=\"70\"")
elseif (${CUDA_VERSION} LESS 9.0) # CUDA 8.x
set(paddle_known_gpu_archs ${paddle_known_gpu_archs8})
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__")
# CUDA 8 may complain that sm_20 is no longer supported. Suppress the
# warning for now.
list(APPEND CUDA_NVCC_FLAGS "-Wno-deprecated-gpu-targets")
add_definitions("-DPADDLE_CUDA_BINVER=\"80\"")
elseif (${CUDA_VERSION} LESS 10.0) # CUDA 9.x
set(paddle_known_gpu_archs ${paddle_known_gpu_archs9})
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__")
add_definitions("-DPADDLE_CUDA_BINVER=\"90\"")
elseif (${CUDA_VERSION} LESS 11.0) # CUDA 10.x
set(paddle_known_gpu_archs ${paddle_known_gpu_archs10})
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__")
add_definitions("-DPADDLE_CUDA_BINVER=\"100\"")
endif()
include_directories(${CUDA_INCLUDE_DIRS})
list(APPEND EXTERNAL_LIBS ${CUDA_LIBRARIES} ${CUDA_rt_LIBRARY})
if(NOT WITH_DSO)
# TODO(panyx0718): CUPTI only allows DSO?
list(APPEND EXTERNAL_LIBS ${CUDNN_LIBRARY} ${CUPTI_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_curand_LIBRARY} ${NCCL_LIBRARY})
if(WIN32)
set_property(GLOBAL PROPERTY CUDA_MODULES ${CUDNN_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_curand_LIBRARY})
endif(WIN32)
endif(NOT WITH_DSO)
# setting nvcc arch flags
select_nvcc_arch_flags(NVCC_FLAGS_EXTRA)
list(APPEND CUDA_NVCC_FLAGS ${NVCC_FLAGS_EXTRA})
message(STATUS "Added CUDA NVCC flags for: ${NVCC_FLAGS_EXTRA_readable}")
# Set C++11 support
set(CUDA_PROPAGATE_HOST_FLAGS OFF)
# Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc.
# So, don't set these flags here.
if (NOT WIN32) # windows msvc2015 support c++11 natively.
# -std=c++11 -fPIC not recoginize by msvc, -Xcompiler will be added by cmake.
list(APPEND CUDA_NVCC_FLAGS "-std=c++11")
list(APPEND CUDA_NVCC_FLAGS "-Xcompiler -fPIC")
endif(NOT WIN32)
if(WITH_FAST_MATH)
# Make use of fast math library. https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html
list(APPEND CUDA_NVCC_FLAGS "--use_fast_math")
endif()
# in cuda9, suppress cuda warning on eigen
list(APPEND CUDA_NVCC_FLAGS "-w")
# Set :expt-relaxed-constexpr to suppress Eigen warnings
list(APPEND CUDA_NVCC_FLAGS "--expt-relaxed-constexpr")
if (NOT WIN32)
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_DEBUG})
elseif(CMAKE_BUILD_TYPE STREQUAL "Release")
list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_RELEASE})
elseif(CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo")
list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_RELWITHDEBINFO})
elseif(CMAKE_BUILD_TYPE STREQUAL "MinSizeRel")
# nvcc 9 does not support -Os. Use Release flags instead
list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_RELEASE})
endif()
else(NOT WIN32)
list(APPEND CUDA_NVCC_FLAGS "-Xcompiler \"/wd 4244 /wd 4267 /wd 4819\"")
list(APPEND CUDA_NVCC_FLAGS "--compiler-options;/bigobj")
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
list(APPEND CUDA_NVCC_FLAGS "-g -G")
# match the cl's _ITERATOR_DEBUG_LEVEL
list(APPEND CUDA_NVCC_FLAGS "-D_DEBUG")
elseif(CMAKE_BUILD_TYPE STREQUAL "Release")
list(APPEND CUDA_NVCC_FLAGS "-O3 -DNDEBUG")
else()
message(FATAL "Windows only support Release or Debug build now. Please set visual studio build type to Release/Debug, x64 build.")
endif()
endif(NOT WIN32)
mark_as_advanced(CUDA_BUILD_CUBIN CUDA_BUILD_EMULATION CUDA_VERBOSE_BUILD)
mark_as_advanced(CUDA_SDK_ROOT_DIR CUDA_SEPARABLE_COMPILATION)
......@@ -24,6 +24,8 @@ INCLUDE_DIRECTORIES(${CMAKE_BINARY_DIR}/Paddle/fluid_install_dir)
# Reference https://stackoverflow.com/questions/45414507/pass-a-list-of-prefix-paths-to-externalproject-add-in-cmake-args
set(prefix_path "${THIRD_PARTY_PATH}/install/gflags|${THIRD_PARTY_PATH}/install/leveldb|${THIRD_PARTY_PATH}/install/snappy|${THIRD_PARTY_PATH}/install/gtest|${THIRD_PARTY_PATH}/install/protobuf|${THIRD_PARTY_PATH}/install/zlib|${THIRD_PARTY_PATH}/install/glog")
message( "WITH_GPU = ${WITH_GPU}")
# If minimal .a is need, you can set WITH_DEBUG_SYMBOLS=OFF
ExternalProject_Add(
extern_paddle
......@@ -47,7 +49,7 @@ ExternalProject_Add(
-DWITH_MKL=${WITH_MKL}
-DWITH_AVX=${WITH_AVX}
-DWITH_MKLDNN=OFF
-DWITH_GPU=OFF
-DWITH_GPU=${WITH_GPU}
-DWITH_FLUID_ONLY=ON
-DWITH_TESTING=OFF
-DWITH_DISTRIBUTE=OFF
......
......@@ -20,6 +20,11 @@ target_link_libraries(ximage -Wl,--whole-archive sdk-cpp
-Wl,--no-whole-archive -lpthread -lcrypto -lm -lrt -lssl -ldl
-lz)
add_executable(ximage_press ${CMAKE_CURRENT_LIST_DIR}/src/ximage_press.cpp)
target_link_libraries(ximage_press -Wl,--whole-archive sdk-cpp
-Wl,--no-whole-archive -lpthread -lcrypto -lm -lrt -lssl -ldl
-lz)
add_executable(echo ${CMAKE_CURRENT_LIST_DIR}/src/echo.cpp)
target_link_libraries(echo -Wl,--whole-archive sdk-cpp -Wl,--no-whole-archive
-lpthread -lcrypto -lm -lrt -lssl -ldl
......@@ -51,6 +56,9 @@ target_link_libraries(text_classification_press -Wl,--whole-archive sdk-cpp -Wl,
install(TARGETS ximage
RUNTIME DESTINATION
${PADDLE_SERVING_INSTALL_DIR}/demo/client/image_classification/bin)
install(TARGETS ximage_press
RUNTIME DESTINATION
${PADDLE_SERVING_INSTALL_DIR}/demo/client/image_classification/bin)
install(DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/conf DESTINATION
${PADDLE_SERVING_INSTALL_DIR}/demo/client/image_classification/)
install(DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/data/images DESTINATION
......
ILSVRC2012_val_00000001.JPEG 65
ILSVRC2012_val_00000002.JPEG 970
ILSVRC2012_val_00000003.JPEG 230
ILSVRC2012_val_00000004.JPEG 809
ILSVRC2012_val_00000005.JPEG 516
ILSVRC2012_val_00000006.JPEG 57
ILSVRC2012_val_00000007.JPEG 334
ILSVRC2012_val_00000008.JPEG 415
ILSVRC2012_val_00000009.JPEG 674
ILSVRC2012_val_00000010.JPEG 332
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#include <fstream>
#include <thread> // NOLINT
#include "sdk-cpp/builtin_format.pb.h"
#include "sdk-cpp/image_class.pb.h"
#include "sdk-cpp/include/common.h"
#include "sdk-cpp/include/predictor_sdk.h"
#ifndef BCLOUD
using json2pb::JsonToProtoMessage;
#endif
using baidu::paddle_serving::sdk_cpp::Predictor;
using baidu::paddle_serving::sdk_cpp::PredictorApi;
using baidu::paddle_serving::predictor::format::XImageReqInstance;
using baidu::paddle_serving::predictor::format::DensePrediction;
using baidu::paddle_serving::predictor::image_classification::Request;
using baidu::paddle_serving::predictor::image_classification::Response;
DEFINE_int32(concurrency, 1, "Set the max concurrent number of requests");
DEFINE_int32(requests, 100, "Number of requests to send per thread");
DEFINE_int32(batch_size, 1, "Batch size");
std::atomic<int> g_concurrency(0);
std::vector<std::vector<uint64_t>> g_round_time;
std::vector<char*> g_image_data;
std::vector<size_t> g_image_lengths;
const std::vector<std::string> g_image_paths{
"./data/images/ILSVRC2012_val_00000001.jpeg",
"./data/images/ILSVRC2012_val_00000002.jpeg",
"./data/images/ILSVRC2012_val_00000003.jpeg",
"./data/images/ILSVRC2012_val_00000004.jpeg",
"./data/images/ILSVRC2012_val_00000005.jpeg",
"./data/images/ILSVRC2012_val_00000006.jpeg",
"./data/images/ILSVRC2012_val_00000007.jpeg",
"./data/images/ILSVRC2012_val_00000008.jpeg",
"./data/images/ILSVRC2012_val_00000009.jpeg",
"./data/images/ILSVRC2012_val_00000010.jpeg"};
int prepare_data() {
for (auto x : g_image_paths) {
FILE* fp = fopen(x.c_str(), "rb");
if (!fp) {
LOG(ERROR) << "Failed open image: " << x.c_str();
continue;
}
fseek(fp, 0L, SEEK_END);
size_t isize = ftell(fp);
char* ibuf = new (std::nothrow) char[isize];
if (!ibuf) {
LOG(ERROR) << "Failed malloc image buffer";
fclose(fp);
return -1;
}
fseek(fp, 0, SEEK_SET);
fread(ibuf, sizeof(ibuf[0]), isize, fp);
g_image_data.push_back(ibuf);
g_image_lengths.push_back(isize);
fclose(fp);
}
return 0;
}
int create_req(Request& req) { // NOLINT
for (int i = 0; i < FLAGS_batch_size; ++i) {
XImageReqInstance* ins = req.add_instances();
if (!ins) {
LOG(ERROR) << "Failed create req instance";
return -1;
}
int id = i % g_image_data.size();
ins->set_image_binary(g_image_data[id], g_image_lengths[id]);
ins->set_image_length(g_image_lengths[id]);
}
return 0;
}
void extract_res(const Request& req, const Response& res) {
uint32_t sample_size = res.predictions_size();
std::string err_string;
for (uint32_t si = 0; si < sample_size; ++si) {
DensePrediction json_msg;
std::string json = res.predictions(si).response_json();
butil::IOBuf buf;
buf.clear();
buf.append(json);
butil::IOBufAsZeroCopyInputStream wrapper(buf);
if (!JsonToProtoMessage(&wrapper, &json_msg, &err_string)) {
LOG(ERROR) << "Failed parse json from str:" << json;
return;
}
uint32_t csize = json_msg.categories_size();
if (csize <= 0) {
LOG(ERROR) << "sample-" << si << "has no"
<< "categories props";
continue;
}
float max_prop = json_msg.categories(0);
uint32_t max_idx = 0;
for (uint32_t ci = 1; ci < csize; ++ci) {
if (json_msg.categories(ci) > max_prop) {
max_prop = json_msg.categories(ci);
max_idx = ci;
}
}
LOG(INFO) << "instance " << si << "has class " << max_idx;
} // end for
}
void thread_worker(PredictorApi* api, int thread_id) {
Request req;
Response res;
api->thrd_initialize();
for (int i = 0; i < FLAGS_requests; ++i) {
api->thrd_clear();
Predictor* predictor = api->fetch_predictor("ximage");
if (!predictor) {
LOG(ERROR) << "Failed fetch predictor: ximage";
return;
}
req.Clear();
res.Clear();
if (create_req(req) != 0) {
return;
}
while (g_concurrency.load() >= FLAGS_concurrency) {
}
g_concurrency++;
#if 1
LOG(INFO) << "Current concurrency " << g_concurrency.load();
#endif
timeval start;
timeval end;
gettimeofday(&start, NULL);
if (predictor->inference(&req, &res) != 0) {
LOG(ERROR) << "failed call predictor with req:" << req.ShortDebugString();
return;
}
gettimeofday(&end, NULL);
g_round_time[thread_id].push_back(end.tv_sec * 1000 + end.tv_usec / 1000 -
start.tv_sec * 1000 -
start.tv_usec / 1000);
extract_res(req, res);
res.Clear();
g_concurrency--;
#if 1
LOG(INFO) << "Done. Currenct concurrency " << g_concurrency.load();
#endif
} // for (int i = 0; i < FLAGS_requests; ++i)
api->thrd_finalize();
}
int main(int argc, char** argv) {
google::ParseCommandLineFlags(&argc, &argv, true);
PredictorApi api;
// initialize logger instance
#ifdef BCLOUD
logging::LoggingSettings settings;
settings.logging_dest = logging::LOG_TO_FILE;
std::string filename(argv[0]);
filename = filename.substr(filename.find_last_of('/') + 1);
settings.log_file = (std::string("./log/") + filename + ".log").c_str();
settings.delete_old = logging::DELETE_OLD_LOG_FILE;
logging::InitLogging(settings);
logging::ComlogSinkOptions cso;
cso.process_name = filename;
cso.enable_wf_device = true;
logging::ComlogSink::GetInstance()->Setup(&cso);
#else
struct stat st_buf;
int ret = 0;
if ((ret = stat("./log", &st_buf)) != 0) {
mkdir("./log", 0777);
ret = stat("./log", &st_buf);
if (ret != 0) {
LOG(WARNING) << "Log path ./log not exist, and create fail";
return -1;
}
}
FLAGS_log_dir = "./log";
google::InitGoogleLogging(strdup(argv[0]));
#endif
g_round_time.resize(FLAGS_concurrency);
if (api.create("./conf", "predictors.prototxt") != 0) {
LOG(ERROR) << "Failed create predictors api!";
return -1;
}
if (prepare_data() != 0) {
LOG(ERROR) << "Prepare data fail";
return -1;
}
std::vector<std::thread*> worker_threads;
int i = 0;
for (; i < FLAGS_concurrency; ++i) {
worker_threads.push_back(new std::thread(thread_worker, &api, i));
}
for (i = 0; i < FLAGS_concurrency; ++i) {
worker_threads[i]->join();
delete worker_threads[i];
}
api.destroy();
std::vector<uint64_t> round_times;
for (auto x : g_round_time) {
round_times.insert(round_times.end(), x.begin(), x.end());
}
std::sort(round_times.begin(), round_times.end());
int percent_pos_50 = round_times.size() * 0.5;
int percent_pos_80 = round_times.size() * 0.8;
int percent_pos_90 = round_times.size() * 0.9;
int percent_pos_99 = round_times.size() * 0.99;
int percent_pos_999 = round_times.size() * 0.999;
uint64_t total_ms = 0;
for (auto x : round_times) {
total_ms += x;
}
LOG(INFO) << "Batch size: " << FLAGS_batch_size;
LOG(INFO) << "Total requests: " << round_times.size();
LOG(INFO) << "Max concurrency: " << FLAGS_concurrency;
LOG(INFO) << "Total ms (absolute time): " << total_ms / FLAGS_concurrency;
double qps = 0.0;
if (total_ms != 0) {
qps = (static_cast<double>(FLAGS_concurrency * FLAGS_requests) /
(total_ms / FLAGS_concurrency)) *
1000;
}
LOG(INFO) << "QPS: " << qps << "/s";
LOG(INFO) << "Latency statistics: ";
if (round_times.size() != 0) {
LOG(INFO) << "Average ms: "
<< static_cast<float>(total_ms) / round_times.size();
LOG(INFO) << "50 percent ms: " << round_times[percent_pos_50];
LOG(INFO) << "80 percent ms: " << round_times[percent_pos_80];
LOG(INFO) << "90 percent ms: " << round_times[percent_pos_90];
LOG(INFO) << "99 percent ms: " << round_times[percent_pos_99];
LOG(INFO) << "99.9 percent ms: " << round_times[percent_pos_999];
} else {
LOG(INFO) << "N/A";
}
return 0;
}
/* vim: set expandtab ts=4 sw=4 sts=4 tw=100: */
......@@ -18,9 +18,17 @@ include(proto/CMakeLists.txt)
add_executable(serving ${serving_srcs})
add_dependencies(serving pdcodegen fluid_cpu_engine pdserving paddle_fluid
opencv_imgcodecs)
if (WITH_GPU)
add_dependencies(serving fluid_gpu_engine)
endif()
target_include_directories(serving PUBLIC
${CMAKE_CURRENT_BINARY_DIR}/../predictor
)
if(WITH_GPU)
target_link_libraries(serving ${CUDA_LIBRARIES} -Wl,--whole-archive fluid_gpu_engine
-Wl,--no-whole-archive)
endif()
target_link_libraries(serving opencv_imgcodecs
${opencv_depend_libs} -Wl,--whole-archive fluid_cpu_engine
-Wl,--no-whole-archive pdserving paddle_fluid ${paddle_depend_libs}
......
......@@ -35,7 +35,6 @@ int ClassifyOp::inference() {
}
const TensorVector* in = &reader_out->tensors;
uint32_t sample_size = in->size();
TensorVector* out = butil::get_object<TensorVector>();
if (!out) {
......@@ -43,20 +42,21 @@ int ClassifyOp::inference() {
return -1;
}
if (sample_size <= 0) {
LOG(INFO) << "No samples need to to predicted";
return 0;
if (in->size() != 1) {
LOG(ERROR) << "Samples should have been packed into a single tensor";
return -1;
}
int batch_size = in->at(0).shape[0];
// call paddle fluid model for inferencing
if (InferManager::instance().infer(
IMAGE_CLASSIFICATION_MODEL_NAME, in, out, sample_size)) {
IMAGE_CLASSIFICATION_MODEL_NAME, in, out, batch_size)) {
LOG(ERROR) << "Failed do infer in fluid model: "
<< IMAGE_CLASSIFICATION_MODEL_NAME;
return -1;
}
if (out->size() != sample_size) {
if (out->size() != in->size()) {
LOG(ERROR) << "Output size not eq input size: " << in->size()
<< out->size();
return -1;
......@@ -64,24 +64,35 @@ int ClassifyOp::inference() {
// copy output tensor into response
ClassifyResponse* res = mutable_data<ClassifyResponse>();
const paddle::PaddleTensor& out_tensor = (*out)[0];
#if 0
int out_shape_size = out_tensor.shape.size();
LOG(ERROR) << "out_tensor.shpae";
for (int i = 0; i < out_shape_size; ++i) {
LOG(ERROR) << out_tensor.shape[i] << ":";
}
if (out_shape_size != 2) {
return -1;
}
#endif
int sample_size = out_tensor.shape[0];
#if 0
LOG(ERROR) << "Output sample size " << sample_size;
#endif
for (uint32_t si = 0; si < sample_size; si++) {
const paddle::PaddleTensor& out_tensor = (*out)[si];
DensePrediction* ins = res->add_predictions();
if (!ins) {
LOG(ERROR) << "Failed append new out tensor";
return -1;
}
uint32_t shape_size = out_tensor.shape.size();
if (out_tensor.shape.size() != 2 || out_tensor.shape[0] != 1) {
LOG(ERROR) << "Not valid classification out shape"
<< ", shape size: " << out_tensor.shape.size();
return -1;
}
// assign output data
uint32_t data_size = out_tensor.data.length() / sizeof(float);
float* data = reinterpret_cast<float*>(out_tensor.data.data());
uint32_t data_size = out_tensor.shape[1];
float* data = reinterpret_cast<float*>(out_tensor.data.data() +
si * sizeof(float) * data_size);
for (uint32_t di = 0; di < data_size; ++di) {
ins->add_categories(data[di]);
}
......@@ -95,10 +106,6 @@ int ClassifyOp::inference() {
out->clear();
butil::return_object<TensorVector>(out);
LOG(INFO) << "Response in image classification:"
<< "length:" << res->ByteSize() << ","
<< "data:" << res->ShortDebugString();
return 0;
}
......
......@@ -51,6 +51,26 @@ int ReaderOp::inference() {
resize.height = iresize[0];
resize.width = iresize[1];
paddle::PaddleTensor in_tensor;
in_tensor.name = "tensor";
in_tensor.dtype = paddle::FLOAT32;
// shape assignment
in_tensor.shape.push_back(sample_size); // batch_size
in_tensor.shape.push_back(3);
in_tensor.shape.push_back(resize.width);
in_tensor.shape.push_back(resize.height);
// tls resource assignment
size_t dense_capacity = 3 * resize.width * resize.height;
size_t len = dense_capacity * sizeof(float) * sample_size;
float* data =
reinterpret_cast<float*>(MempoolWrapper::instance().malloc(len));
if (data == NULL) {
LOG(ERROR) << "Failed create temp float array, "
<< "size=" << dense_capacity * sample_size * sizeof(float);
return -1;
}
for (uint32_t si = 0; si < sample_size; si++) {
// parse image object from x-image
const XImageReqInstance& ins = req->instances(si);
......@@ -103,50 +123,31 @@ int ReaderOp::inference() {
const int H = _image_8u_rgb.rows;
const int W = _image_8u_rgb.cols;
const int C = _image_8u_rgb.channels();
size_t dense_capacity = H * W * C;
paddle::PaddleTensor in_tensor;
in_tensor.name = "tensor";
in_tensor.dtype = paddle::FLOAT32;
// shape assignment
in_tensor.shape.push_back(1); // batch_size
// accoreding to training stage, the instance shape should be
// in order of C-W-H.
in_tensor.shape.push_back(C);
in_tensor.shape.push_back(W);
in_tensor.shape.push_back(H);
if (H != resize.height || W != resize.width || C != 3) {
LOG(ERROR) << "Image " << si << " has incompitable size";
return -1;
}
LOG(INFO) << "Succ read one image, C: " << C << ", W: " << W
<< ", H: " << H;
// tls resource assignment
size_t len = dense_capacity * sizeof(float);
float* data =
reinterpret_cast<float*>(MempoolWrapper::instance().malloc(len));
if (data == NULL) {
LOG(ERROR) << "Failed create temp float array, "
<< "size=" << dense_capacity;
return -1;
}
float* data_ptr = data + dense_capacity * si;
for (int h = 0; h < H; h++) {
// p points to a new line
unsigned char* p = _image_8u_rgb.ptr<unsigned char>(h);
for (int w = 0; w < W; w++) {
for (int c = 0; c < C; c++) {
// HWC(row,column,channel) -> CWH
data[W * H * c + W * h + w] = (p[C * w + c] - pmean[c]) * scale[c];
data_ptr[W * H * c + W * h + w] =
(p[C * w + c] - pmean[c]) * scale[c];
}
}
}
paddle::PaddleBuf pbuf(data, len);
in_tensor.data = pbuf;
in->push_back(in_tensor);
}
paddle::PaddleBuf pbuf(data, len);
in_tensor.data = pbuf;
in->push_back(in_tensor);
return 0;
}
......
......@@ -16,7 +16,7 @@
#include <string>
#ifdef BCLOUD
#include "pb_to_json.h"
#include "pb_to_json.h" // NOLINT
#else
#include "json2pb/pb_to_json.h"
#endif
......@@ -70,7 +70,7 @@ int WriteJsonOp::inference() {
}
}
LOG(INFO) << "Succ write json:" << classify_out->ShortDebugString();
LOG(INFO) << "Succ write json";
return 0;
}
......
......@@ -5,3 +5,22 @@
- 如果在inferservice_file里指定了port:xxx,那么就去申请该端口号;
- 否则,如果在gflags.conf里指定了--port:xxx,那就去申请该端口号;
- 否则,使用程序里指定的默认端口号:8010。
## 2. GPU预测中为何请求的响应时间波动会非常大?
PaddleServing依托PaddlePaddle预测库执行预测计算;在GPU设备上,由于同一个进程内目前共用1个GPU stream,进程内的多个请求的预测计算会被严格串行。所以如果有2个请求同时到达某个Serving实例,不管该实例启动时创建了多少个worker线程,都不能起到加速作用,后到的请求会被排队,直到前面请求计算完成。
## 3. 如何充分利用GPU卡的计算能力?
如问题2所说,由于预测库的限制,单个Serving进程只能绑定单张GPU卡,且进程内共用1个GPU stream,所有请求必须串行计算。
为提高GPU卡使用率,目前可以想到的方法是:在单张GPU卡上启动多个Serving进程,每个进程绑定一个GPU stream,多个stream并行计算。这种方法是否能起到加速作用,受限于多个因素,主要有:
1. 单个stream占用GPU算力;假如单个stream已经将GPU算力占用超过50%,那么增加stream很可能会导致2个stream的job分别排队,拖慢各自的响应时间
2. GPU显存:Serving进程需要将模型参数加载到显存中,并且计算时要在GPU显存池分配临时变量;假如单个Serving进程已经用掉超过50%的显存,则增加Serving进程会造成显存不足,导致进程报错退出
为此,可采用如下步骤,进行测试:
1. 加载模型时,在model_toolkit.prototxt中,model type选择FLUID_GPU_ANALYSIS或FLUID_GPU_ANALYSIS_DIR;会对模型进行静态分析,进行一定程度显存优化
2. 在步骤1完成后,启动单个Serving进程,启动参数:`--gpuid=N --bthread_concurrency=4 --bthread_min_concurrency=4`;启动一个client,进行并发度为1的压力测试,batch size从小到大,记下平响;由于算力的限制,当batch size增大到一定程度,应该会出现响应时间明显变大;或虽然没有明显变大,但已经不满足系统需求
3. 再启动1个Serving进程,与步骤2启动时使用相同的参数略有不同: `--gpuid=N --bthread_concurrency=4 --bthread_min_concurrency=4 --port=8011` 其中--port=8011用来让新启动的进程使用一个新的服务端口;然后同时对这2个Serving进程进行压测,继续观察batch size从小到大时平均响应时间的变化,直到取得batch size和响应时间的折中
4. 重复步骤2-3
5. 以2-4步的测试,来决定:单张GPU卡可以由多少个Serving进程共用; 实际部署时,就在一张GPU卡上启动这么多个Serving进程同时提供服务
[Design](DESIGN.md)
[Client Configure](CLIENT_CONFIGURE.md)
[Installation](INSTALL.md)
[How to Configure a Clustered Service](CLUSTERING.md)
[Getting Started](GETTING_STARTED.md)
[Creating a Prediction Service](CREATING.md)
[Design](DESIGN.md)
[Client Configure](CLIENT_CONFIGURE.md)
[FAQ](FAQ.md)
[Server Side Configuration](SERVING_CONFIGURE.md)
[Getting Started](GETTING_STARTED.md)
[How to Configure a Clustered Service](CLUSTERING.md)
[Installation](INSTALL.md)
[Multiple Serving Instances over Single GPU Card](MULTI_SERVING_OVER_SINGLE_GPU_CARD.md)
[Server Side Configuration](SERVING_CONFIGURE.md)
[Benchmarking](BENCHMARKING.md)
[FAQ](FAQ.md)
......@@ -58,10 +58,9 @@ $ make install
# CMake编译选项说明
因Paddle Serving依托于PaddlePaddle项目进行构建,以下编译选项其实是传递给PaddlePaddle的编译选项:
| 编译选项 | 说明 |
|----------|------|
| WITH_AVX | Compile PaddlePaddle with AVX intrinsics |
| WITH_MKL | Compile PaddlePaddle with MKLML library |
| WITH_AVX | For configuring PaddlePaddle. Compile PaddlePaddle with AVX intrinsics |
| WITH_MKL | For configuring PaddlePaddle. Compile PaddlePaddle with MKLML library |
| WITH_GPU | For configuring PaddlePaddle. Compile PaddlePaddle with NVIDIA GPU |
| CLINET_ONLY | Compile client libraries and demos only |
# Multiple Serving Instances over Single GPU Card
Paddle Serving依托PaddlePaddle预测库执行实际的预测计算。由于当前GPU预测库的限制,单个Serving实例只可以绑定1张GPU卡,且进程内所有worker线程共用1个GPU stream。也就是说,不管Serving启动多少个worker线程,所有的请求在GPU是严格串行计算的,起不到加速作用。这会带来一个问题,就是如果模型计算量不大,那么Serving进程实际上不会用满GPU的算力。
为了充分利用GPU卡的算力,考虑在单张卡上启动多个Serving实例,通过多个GPU stream,力争用满GPU的算力。启动命令可以如下所示:
```
bin/serving --gpuid=0 --bthread_concurrency=4 --bthread_min_concurrency=4 --port=8010&
bin/serving --gpuid=0 --bthread_concurrency=4 --bthread_min_concurrency=4 --port=8011&
```
上述2条命令,启动2个Serving实例,分别监听8010端口和8011端口。但他们都绑定同一张卡 (gpuid = 0)。
命令行参数含义:
```
-gpuid=N:用于指定所绑定的GPU卡ID
-bthread_concurrency和bthread_min_concurrency共同限制该进程启动的worker数:由于在GPU预测模式下,增加worker线程数并不能提高并发能力,为了节省部分资源,干脆将他们限制掉;均设为4,是因为这是bthread允许的最小值。
-port xxx:Serving实例监听的端口
```
但是,上述方式究竟是否能在不影响响应时间等其他指标的前提下,起到提高GPU使用率作用,受到多个限制因素的制约,具体的:
1. 单个stream占用GPU算力;假如单个stream已经将GPU算力占用超过50%,那么增加stream很可能会导致2个stream的job分别排队,拖慢各自的响应时间
2. GPU显存:Serving进程需要将模型参数加载到显存中,并且计算时要在GPU显存池分配临时变量;假如单个Serving进程已经用掉超过50%的显存,则增加Serving进程会造成显存不足,导致进程报错退出
为此,可采用如下步骤,进行测试:
1. 加载模型时,在model_toolkit.prototxt中,model type选择FLUID_GPU_ANALYSIS或FLUID_GPU_ANALYSIS_DIR;会对模型进行静态分析,进行一定程度显存优化
2. 在步骤1完成后,启动单个Serving进程,启动参数:`--gpuid=N --bthread_concurrency=4 --bthread_min_concurrency=4`;启动一个client,进行并发度为1的压力测试,batch size从小到大,记下平响;由于算力的限制,当batch size增大到一定程度,应该会出现响应时间明显变大;或虽然没有明显变大,但已经不满足系统需求
3. 再启动1个Serving进程,与步骤2启动时使用相同的参数略有不同: `--gpuid=N --bthread_concurrency=4 --bthread_min_concurrency=4 --port=8011` 其中--port=8011用来让新启动的进程使用一个新的服务端口;然后同时对这2个Serving进程进行压测,继续观察batch size从小到大时平均响应时间的变化,直到取得batch size和响应时间的折中
4. 重复步骤2-3
5. 以2-4步的测试,来决定:单张GPU卡可以由多少个Serving进程共用; 实际部署时,就在一张GPU卡上启动这么多个Serving进程同时提供服务
......@@ -142,6 +142,11 @@ type: 预测引擎的类型。可在inferencer-fluid-cpu/src/fluid_cpu_engine.cp
|FLUID_CPU_ANALYSIS_DIR|使用fluid Analysis API;模型所有参数分开保存为独立的文件,整个模型放到一个目录中|
|FLUID_CPU_NATIVE|使用fluid Native API;模型所有参数保存在一个文件|
|FLUID_CPU_NATIVE_DIR|使用fluid Native API;模型所有参数分开保存为独立的文件,整个模型放到一个目录中|
|FLUID_GPU_ANALYSIS|GPU预测,使用fluid Analysis API;模型所有参数保存在一个文件|
|FLUID_GPU_ANALYSIS_DIR|GPU预测,使用fluid Analysis API;模型所有参数分开保存为独立的文件,整个模型放到一个目录中|
|FLUID_GPU_NATIVE|GPU预测,使用fluid Native API;模型所有参数保存在一个文件|
|FLUID_GPU_NATIVE_DIR|GPU预测,使用fluid Native API;模型所有参数分开保存为独立的文件,整个模型放到一个目录中|
**fluid Analysis API和fluid Native API的区别**
......@@ -182,8 +187,11 @@ enable_batch_align:
|enable_model_toolkit|true|模型管理|
|enable_protocol_list|baidu_std|brpc 通信协议列表|
|log_dir|./log|log dir|
|num_threads|brpc server使用的系统线程数,默认为CPU核数|
|max_concurrency|并发处理的请求数,设为<=0则为不予限制,若大于0则限定brpc server端同时处理的请求数上限|
|num_threads||brpc server使用的系统线程数,默认为CPU核数|
|port|8010|Serving进程接收请求监听端口|
|gpuid|0|GPU预测时指定Serving进程使用的GPU device id。只允许绑定1张GPU卡|
|bthread_concurrency|9|BRPC底层bthread的concurrency。在使用GPU预测引擎时,为了限制并发worker数,可使用此参数|
|bthread_min_concurrency|4|BRPC底层bthread的min concurrency。在使用GPU预测引擎时,为限制并发worker数,可使用此参数。与bthread_concurrency结合使用|
可以通过在serving/conf/gflags.conf覆盖默认值,例如
```
......
......@@ -155,6 +155,8 @@ class FluidCpuNativeCore : public FluidFamilyCore {
native_config.prog_file = data_path + "/__model__";
native_config.use_gpu = false;
native_config.device = 0;
native_config.fraction_of_gpu_memory = 0;
AutoLock lock(GlobalPaddleCreateMutex::instance());
_core = paddle::CreatePaddlePredictor<paddle::NativeConfig,
paddle::PaddleEngineKind::kNative>(
......@@ -209,6 +211,7 @@ class FluidCpuNativeDirCore : public FluidFamilyCore {
native_config.model_dir = data_path;
native_config.use_gpu = false;
native_config.device = 0;
native_config.fraction_of_gpu_memory = 0;
AutoLock lock(GlobalPaddleCreateMutex::instance());
_core = paddle::CreatePaddlePredictor<paddle::NativeConfig,
paddle::PaddleEngineKind::kNative>(
......@@ -458,6 +461,7 @@ class FluidCpuNativeDirWithSigmoidCore : public FluidCpuWithSigmoidCore {
native_config.model_dir = data_path;
native_config.use_gpu = false;
native_config.device = 0;
native_config.fraction_of_gpu_memory = 0;
AutoLock lock(GlobalPaddleCreateMutex::instance());
_core->_fluid_core =
paddle::CreatePaddlePredictor<paddle::NativeConfig,
......
FILE(GLOB fluid_gpu_engine_srcs ${CMAKE_CURRENT_LIST_DIR}/src/*.cpp)
add_library(fluid_gpu_engine ${fluid_gpu_engine_srcs})
target_include_directories(fluid_gpu_engine PUBLIC
${CMAKE_BINARY_DIR}/Paddle/fluid_install_dir/)
add_dependencies(fluid_gpu_engine pdserving extern_paddle configure)
target_link_libraries(fluid_gpu_engine pdserving paddle_fluid -liomp5 -lmklml_intel -lpthread -lcrypto -lm -lrt -lssl -ldl -lz)
install(TARGETS fluid_gpu_engine
ARCHIVE DESTINATION ${PADDLE_SERVING_INSTALL_DIR}/lib
)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <pthread.h>
#include <fstream>
#include <map>
#include <string>
#include <vector>
#include "configure/include/configure_parser.h"
#include "configure/inferencer_configure.pb.h"
#ifdef BCLOUD
#ifdef WITH_GPU
#include "paddle/paddle_inference_api.h"
#else
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#endif
#else
#include "paddle/fluid/inference/paddle_inference_api.h"
#endif
#include "predictor/framework/infer.h"
DECLARE_int32(gpuid);
namespace baidu {
namespace paddle_serving {
namespace fluid_gpu {
using configure::SigmoidConf;
class AutoLock {
public:
explicit AutoLock(pthread_mutex_t& mutex) : _mut(mutex) {
pthread_mutex_lock(&mutex);
}
~AutoLock() { pthread_mutex_unlock(&_mut); }
private:
pthread_mutex_t& _mut;
};
class GlobalPaddleCreateMutex {
public:
pthread_mutex_t& mutex() { return _mut; }
static pthread_mutex_t& instance() {
static GlobalPaddleCreateMutex gmutex;
return gmutex.mutex();
}
private:
GlobalPaddleCreateMutex() { pthread_mutex_init(&_mut, NULL); }
pthread_mutex_t _mut;
};
class GlobalSigmoidCreateMutex {
public:
pthread_mutex_t& mutex() { return _mut; }
static pthread_mutex_t& instance() {
static GlobalSigmoidCreateMutex gmutex;
return gmutex.mutex();
}
private:
GlobalSigmoidCreateMutex() { pthread_mutex_init(&_mut, NULL); }
pthread_mutex_t _mut;
};
// data interface
class FluidFamilyCore {
public:
virtual ~FluidFamilyCore() {}
virtual bool Run(const void* in_data, void* out_data) {
if (!_core->Run(*(std::vector<paddle::PaddleTensor>*)in_data,
(std::vector<paddle::PaddleTensor>*)out_data)) {
LOG(ERROR) << "Failed call Run with paddle predictor";
return false;
}
return true;
}
virtual int create(const std::string& data_path) = 0;
virtual int clone(void* origin_core) {
if (origin_core == NULL) {
LOG(ERROR) << "origin paddle Predictor is null.";
return -1;
}
paddle::PaddlePredictor* p_predictor =
(paddle::PaddlePredictor*)origin_core;
_core = p_predictor->Clone();
if (_core.get() == NULL) {
LOG(ERROR) << "fail to clone paddle predictor: " << origin_core;
return -1;
}
return 0;
}
virtual void* get() { return _core.get(); }
protected:
std::unique_ptr<paddle::PaddlePredictor> _core;
};
// infer interface
class FluidGpuAnalysisCore : public FluidFamilyCore {
public:
int create(const std::string& data_path) {
if (access(data_path.c_str(), F_OK) == -1) {
LOG(ERROR) << "create paddle predictor failed, path not exits: "
<< data_path;
return -1;
}
paddle::AnalysisConfig analysis_config;
analysis_config.SetParamsFile(data_path + "/__params__");
analysis_config.SetProgFile(data_path + "/__model__");
analysis_config.EnableUseGpu(100, FLAGS_gpuid);
analysis_config.SetCpuMathLibraryNumThreads(1);
analysis_config.SwitchSpecifyInputNames(true);
analysis_config.EnableMemoryOptim();
AutoLock lock(GlobalPaddleCreateMutex::instance());
_core =
paddle::CreatePaddlePredictor<paddle::AnalysisConfig>(analysis_config);
if (NULL == _core.get()) {
LOG(ERROR) << "create paddle predictor failed, path: " << data_path;
return -1;
}
LOG(WARNING) << "create paddle predictor sucess, path: " << data_path;
return 0;
}
};
class FluidGpuNativeCore : public FluidFamilyCore {
public:
int create(const std::string& data_path) {
if (access(data_path.c_str(), F_OK) == -1) {
LOG(ERROR) << "create paddle predictor failed, path not exits: "
<< data_path;
return -1;
}
paddle::NativeConfig native_config;
native_config.param_file = data_path + "/__params__";
native_config.prog_file = data_path + "/__model__";
native_config.use_gpu = true;
native_config.fraction_of_gpu_memory = 0.01;
native_config.device = FLAGS_gpuid;
AutoLock lock(GlobalPaddleCreateMutex::instance());
_core = paddle::CreatePaddlePredictor<paddle::NativeConfig,
paddle::PaddleEngineKind::kNative>(
native_config);
if (NULL == _core.get()) {
LOG(ERROR) << "create paddle predictor failed, path: " << data_path;
return -1;
}
LOG(WARNING) << "create paddle predictor sucess, path: " << data_path;
return 0;
}
};
class FluidGpuAnalysisDirCore : public FluidFamilyCore {
public:
int create(const std::string& data_path) {
if (access(data_path.c_str(), F_OK) == -1) {
LOG(ERROR) << "create paddle predictor failed, path not exits: "
<< data_path;
return -1;
}
paddle::AnalysisConfig analysis_config;
analysis_config.SetModel(data_path);
analysis_config.EnableUseGpu(100, FLAGS_gpuid);
analysis_config.SwitchSpecifyInputNames(true);
analysis_config.SetCpuMathLibraryNumThreads(1);
analysis_config.EnableMemoryOptim();
AutoLock lock(GlobalPaddleCreateMutex::instance());
_core =
paddle::CreatePaddlePredictor<paddle::AnalysisConfig>(analysis_config);
if (NULL == _core.get()) {
LOG(ERROR) << "create paddle predictor failed, path: " << data_path;
return -1;
}
LOG(WARNING) << "create paddle predictor sucess, path: " << data_path;
return 0;
}
};
class FluidGpuNativeDirCore : public FluidFamilyCore {
public:
int create(const std::string& data_path) {
if (access(data_path.c_str(), F_OK) == -1) {
LOG(ERROR) << "create paddle predictor failed, path not exits: "
<< data_path;
return -1;
}
paddle::NativeConfig native_config;
native_config.model_dir = data_path;
native_config.use_gpu = true;
native_config.fraction_of_gpu_memory = 0.01;
native_config.device = FLAGS_gpuid;
AutoLock lock(GlobalPaddleCreateMutex::instance());
_core = paddle::CreatePaddlePredictor<paddle::NativeConfig,
paddle::PaddleEngineKind::kNative>(
native_config);
if (NULL == _core.get()) {
LOG(ERROR) << "create paddle predictor failed, path: " << data_path;
return -1;
}
LOG(WARNING) << "create paddle predictor sucess, path: " << data_path;
return 0;
}
};
class Parameter {
public:
Parameter() : _row(0), _col(0), _params(NULL) {}
~Parameter() {
LOG(INFO) << "before destroy Parameter, file_name[" << _file_name << "]";
destroy();
}
int init(int row, int col, const char* file_name) {
destroy();
_file_name = file_name;
_row = row;
_col = col;
_params = reinterpret_cast<float*>(malloc(_row * _col * sizeof(float)));
if (_params == NULL) {
LOG(ERROR) << "Load " << _file_name << " malloc error.";
return -1;
}
LOG(WARNING) << "Load parameter file[" << _file_name << "] success.";
return 0;
}
void destroy() {
_row = 0;
_col = 0;
if (_params != NULL) {
free(_params);
_params = NULL;
}
}
int load() {
if (_params == NULL || _row <= 0 || _col <= 0) {
LOG(ERROR) << "load parameter error [not inited].";
return -1;
}
FILE* fs = fopen(_file_name.c_str(), "rb");
if (fs == NULL) {
LOG(ERROR) << "load " << _file_name << " fopen error.";
return -1;
}
static const uint32_t MODEL_FILE_HEAD_LEN = 16;
char head[MODEL_FILE_HEAD_LEN] = {0};
if (fread(head, 1, MODEL_FILE_HEAD_LEN, fs) != MODEL_FILE_HEAD_LEN) {
destroy();
LOG(ERROR) << "Load " << _file_name << " read head error.";
if (fs != NULL) {
fclose(fs);
fs = NULL;
}
return -1;
}
uint32_t matrix_size = _row * _col;
if (matrix_size == fread(_params, sizeof(float), matrix_size, fs)) {
if (fs != NULL) {
fclose(fs);
fs = NULL;
}
LOG(INFO) << "load " << _file_name << " read ok.";
return 0;
} else {
LOG(ERROR) << "load " << _file_name << " read error.";
destroy();
if (fs != NULL) {
fclose(fs);
fs = NULL;
}
return -1;
}
return 0;
}
public:
std::string _file_name;
int _row;
int _col;
float* _params;
};
class SigmoidModel {
public:
~SigmoidModel() {}
int load(const char* sigmoid_w_file,
const char* sigmoid_b_file,
float exp_max,
float exp_min) {
AutoLock lock(GlobalSigmoidCreateMutex::instance());
if (0 != _sigmoid_w.init(2, 1, sigmoid_w_file) || 0 != _sigmoid_w.load()) {
LOG(ERROR) << "load params sigmoid_w failed.";
return -1;
}
LOG(WARNING) << "load sigmoid_w [" << _sigmoid_w._params[0] << "] ["
<< _sigmoid_w._params[1] << "].";
if (0 != _sigmoid_b.init(2, 1, sigmoid_b_file) || 0 != _sigmoid_b.load()) {
LOG(ERROR) << "load params sigmoid_b failed.";
return -1;
}
LOG(WARNING) << "load sigmoid_b [" << _sigmoid_b._params[0] << "] ["
<< _sigmoid_b._params[1] << "].";
_exp_max_input = exp_max;
_exp_min_input = exp_min;
return 0;
}
int softmax(float x, double& o) { // NOLINT
float _y0 = x * _sigmoid_w._params[0] + _sigmoid_b._params[0];
float _y1 = x * _sigmoid_w._params[1] + _sigmoid_b._params[1];
_y0 = (_y0 > _exp_max_input)
? _exp_max_input
: ((_y0 < _exp_min_input) ? _exp_min_input : _y0);
_y1 = (_y1 > _exp_max_input)
? _exp_max_input
: ((_y1 < _exp_min_input) ? _exp_min_input : _y1);
o = 1.0f / (1.0f + exp(_y0 - _y1));
return 0;
}
public:
Parameter _sigmoid_w;
Parameter _sigmoid_b;
float _exp_max_input;
float _exp_min_input;
};
class SigmoidFluidModel {
public:
int softmax(float x, double& o) { // NOLINT
return _sigmoid_core->softmax(x, o);
} // NOLINT
std::unique_ptr<SigmoidFluidModel> Clone() {
std::unique_ptr<SigmoidFluidModel> clone_model;
clone_model.reset(new SigmoidFluidModel());
clone_model->_sigmoid_core = _sigmoid_core;
clone_model->_fluid_core = _fluid_core->Clone();
return std::move(clone_model);
}
public:
std::unique_ptr<paddle::PaddlePredictor> _fluid_core;
std::shared_ptr<SigmoidModel> _sigmoid_core;
};
class FluidGpuWithSigmoidCore : public FluidFamilyCore {
public:
virtual ~FluidGpuWithSigmoidCore() {}
public:
int create(const std::string& model_path) {
size_t pos = model_path.find_last_of("/\\");
std::string conf_path = model_path.substr(0, pos);
std::string conf_file = model_path.substr(pos);
configure::SigmoidConf conf;
if (configure::read_proto_conf(conf_path, conf_file, &conf) != 0) {
LOG(ERROR) << "failed load model path: " << model_path;
return -1;
}
_core.reset(new SigmoidFluidModel);
std::string fluid_model_data_path = conf.dnn_model_path();
int ret = load_fluid_model(fluid_model_data_path);
if (ret < 0) {
LOG(ERROR) << "fail to load fluid model.";
return -1;
}
const char* sigmoid_w_file = conf.sigmoid_w_file().c_str();
const char* sigmoid_b_file = conf.sigmoid_b_file().c_str();
float exp_max = conf.exp_max_input();
float exp_min = conf.exp_min_input();
_core->_sigmoid_core.reset(new SigmoidModel);
LOG(INFO) << "create sigmoid core[" << _core->_sigmoid_core.get()
<< "], use count[" << _core->_sigmoid_core.use_count() << "].";
ret = _core->_sigmoid_core->load(
sigmoid_w_file, sigmoid_b_file, exp_max, exp_min);
if (ret < 0) {
LOG(ERROR) << "fail to load sigmoid model.";
return -1;
}
return 0;
}
virtual bool Run(const void* in_data, void* out_data) {
if (!_core->_fluid_core->Run(
*(std::vector<paddle::PaddleTensor>*)in_data,
(std::vector<paddle::PaddleTensor>*)out_data)) {
LOG(ERROR) << "Failed call Run with paddle predictor";
return false;
}
return true;
}
virtual int clone(SigmoidFluidModel* origin_core) {
if (origin_core == NULL) {
LOG(ERROR) << "origin paddle Predictor is null.";
return -1;
}
_core = origin_core->Clone();
if (_core.get() == NULL) {
LOG(ERROR) << "fail to clone paddle predictor: " << origin_core;
return -1;
}
LOG(INFO) << "clone sigmoid core[" << _core->_sigmoid_core.get()
<< "] use count[" << _core->_sigmoid_core.use_count() << "].";
return 0;
}
virtual SigmoidFluidModel* get() { return _core.get(); }
virtual int load_fluid_model(const std::string& data_path) = 0;
int softmax(float x, double& o) { // NOLINT
return _core->_sigmoid_core->softmax(x, o);
}
protected:
std::unique_ptr<SigmoidFluidModel> _core;
};
class FluidGpuNativeDirWithSigmoidCore : public FluidGpuWithSigmoidCore {
public:
int load_fluid_model(const std::string& data_path) {
if (access(data_path.c_str(), F_OK) == -1) {
LOG(ERROR) << "create paddle predictor failed, path not exits: "
<< data_path;
return -1;
}
paddle::NativeConfig native_config;
native_config.model_dir = data_path;
native_config.use_gpu = true;
native_config.fraction_of_gpu_memory = 0.01;
native_config.device = FLAGS_gpuid;
AutoLock lock(GlobalPaddleCreateMutex::instance());
_core->_fluid_core =
paddle::CreatePaddlePredictor<paddle::NativeConfig,
paddle::PaddleEngineKind::kNative>(
native_config);
if (NULL == _core.get()) {
LOG(ERROR) << "create paddle predictor failed, path: " << data_path;
return -1;
}
LOG(WARNING) << "create paddle predictor sucess, path: " << data_path;
return 0;
}
};
class FluidGpuAnalysisDirWithSigmoidCore : public FluidGpuWithSigmoidCore {
public:
int load_fluid_model(const std::string& data_path) {
if (access(data_path.c_str(), F_OK) == -1) {
LOG(ERROR) << "create paddle predictor failed, path not exits: "
<< data_path;
return -1;
}
paddle::AnalysisConfig analysis_config;
analysis_config.SetModel(data_path);
analysis_config.EnableUseGpu(100, FLAGS_gpuid);
analysis_config.SwitchSpecifyInputNames(true);
analysis_config.SetCpuMathLibraryNumThreads(1);
analysis_config.EnableMemoryOptim();
AutoLock lock(GlobalPaddleCreateMutex::instance());
_core->_fluid_core =
paddle::CreatePaddlePredictor<paddle::AnalysisConfig>(analysis_config);
if (NULL == _core.get()) {
LOG(ERROR) << "create paddle predictor failed, path: " << data_path;
return -1;
}
LOG(WARNING) << "create paddle predictor sucess, path: " << data_path;
return 0;
}
};
} // namespace fluid_gpu
} // namespace paddle_serving
} // namespace baidu
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "inferencer-fluid-gpu/include/fluid_gpu_engine.h"
#include "predictor/framework/factory.h"
DEFINE_int32(gpuid, 0, "GPU device id to use");
namespace baidu {
namespace paddle_serving {
namespace fluid_gpu {
REGIST_FACTORY_OBJECT_IMPL_WITH_NAME(
::baidu::paddle_serving::predictor::FluidInferEngine<FluidGpuAnalysisCore>,
::baidu::paddle_serving::predictor::InferEngine,
"FLUID_GPU_ANALYSIS");
REGIST_FACTORY_OBJECT_IMPL_WITH_NAME(
::baidu::paddle_serving::predictor::FluidInferEngine<
FluidGpuAnalysisDirCore>,
::baidu::paddle_serving::predictor::InferEngine,
"FLUID_GPU_ANALYSIS_DIR");
REGIST_FACTORY_OBJECT_IMPL_WITH_NAME(
::baidu::paddle_serving::predictor::FluidInferEngine<
FluidGpuAnalysisDirWithSigmoidCore>,
::baidu::paddle_serving::predictor::InferEngine,
"FLUID_GPU_ANALYSIS_DIR_SIGMOID");
REGIST_FACTORY_OBJECT_IMPL_WITH_NAME(
::baidu::paddle_serving::predictor::FluidInferEngine<FluidGpuNativeCore>,
::baidu::paddle_serving::predictor::InferEngine,
"FLUID_GPU_NATIVE");
REGIST_FACTORY_OBJECT_IMPL_WITH_NAME(
::baidu::paddle_serving::predictor::FluidInferEngine<FluidGpuNativeDirCore>,
::baidu::paddle_serving::predictor::InferEngine,
"FLUID_GPU_NATIVE_DIR");
REGIST_FACTORY_OBJECT_IMPL_WITH_NAME(
::baidu::paddle_serving::predictor::FluidInferEngine<
FluidGpuNativeDirWithSigmoidCore>,
::baidu::paddle_serving::predictor::InferEngine,
"FLUID_GPU_NATIVE_DIR_SIGMOID");
} // namespace fluid_gpu
} // namespace paddle_serving
} // namespace baidu
......@@ -461,7 +461,7 @@ class CloneDBReloadableInferEngine
};
template <typename FluidFamilyCore>
class FluidInferEngine : public DBReloadableInferEngine<FluidFamilyCore> {
class FluidInferEngine : public CloneDBReloadableInferEngine<FluidFamilyCore> {
public:
FluidInferEngine() {}
~FluidInferEngine() {}
......
......@@ -69,7 +69,30 @@ DEFINE_bool(V, false, "print version, bool");
DEFINE_bool(g, false, "user defined gflag path");
DECLARE_string(flagfile);
void pthread_worker_start_fn() { Resource::instance().thread_initialize(); }
namespace bthread {
extern pthread_mutex_t g_task_control_mutex;
}
pthread_mutex_t g_worker_start_fn_mutex = PTHREAD_MUTEX_INITIALIZER;
void pthread_worker_start_fn() {
while (pthread_mutex_lock(&g_worker_start_fn_mutex) != 0) {
}
// Try to avoid deadlock in bthread
int lock_status = pthread_mutex_trylock(&bthread::g_task_control_mutex);
if (lock_status == EBUSY || lock_status == EAGAIN) {
pthread_mutex_unlock(&bthread::g_task_control_mutex);
}
Resource::instance().thread_initialize();
// Try to avoid deadlock in bthread
if (lock_status == EBUSY || lock_status == EAGAIN) {
while (pthread_mutex_lock(&bthread::g_task_control_mutex) != 0) {
}
}
pthread_mutex_unlock(&g_worker_start_fn_mutex);
}
static void g_change_server_port() {
InferServiceConf conf;
......@@ -111,7 +134,7 @@ int main(int argc, char** argv) {
g_change_server_port();
// initialize logger instance
// initialize logger instance
#ifdef BCLOUD
logging::LoggingSettings settings;
settings.logging_dest = logging::LOG_TO_FILE;
......@@ -183,6 +206,8 @@ int main(int argc, char** argv) {
}
LOG(INFO) << "Succ call pthread worker start function";
FLAGS_logtostderr = false;
if (ServerManager::instance().start_and_wait() != 0) {
LOG(ERROR) << "Failed start server and wait!";
return -1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册