未验证 提交 52a7b613 编写于 作者: J Jiawei Wang 提交者: GitHub

Merge pull request #1440 from bjjwwang/v0.7.0

V0.7.0 pick 8 PRs to v0.7.0
......@@ -25,7 +25,7 @@ set(BOOST_PROJECT "extern_boost")
set(BOOST_VER "1.74.0")
set(BOOST_TAR "boost_1_74_0" CACHE STRING "" FORCE)
set(BOOST_URL "http://paddlepaddledeps.cdn.bcebos.com/${BOOST_TAR}.tar.gz" CACHE STRING "" FORCE)
set(BOOST_URL "http://paddlepaddledeps.bj.bcebos.com/${BOOST_TAR}.tar.gz" CACHE STRING "" FORCE)
MESSAGE(STATUS "BOOST_TAR: ${BOOST_TAR}, BOOST_URL: ${BOOST_URL}")
......
......@@ -61,8 +61,11 @@ else()
endif()
if(CUDNN_FOUND)
file(READ ${CUDNN_INCLUDE_DIR}/cudnn.h CUDNN_VERSION_FILE_CONTENTS)
if(EXISTS "${CUDNN_INCLUDE_DIR}/cudnn_version.h")
file(READ ${CUDNN_INCLUDE_DIR}/cudnn_version.h CUDNN_VERSION_FILE_CONTENTS)
elseif(EXISTS "${CUDNN_INCLUDE_DIR}/cudnn.h")
file(READ ${CUDNN_INCLUDE_DIR}/cudnn.h CUDNN_VERSION_FILE_CONTENTS)
endif()
get_filename_component(CUDNN_LIB_PATH ${CUDNN_LIBRARY} DIRECTORY)
string(REGEX MATCH "define CUDNN_VERSION +([0-9]+)"
......
......@@ -27,12 +27,12 @@ set(prefix_path "${THIRD_PARTY_PATH}/install/gflags|${THIRD_PARTY_PATH}/install/
message( "WITH_GPU = ${WITH_GPU}")
# Paddle Version should be one of:
# latest: latest develop build
# version number like 1.5.2
SET(PADDLE_VERSION "2.2.0-rc0")
if (WITH_GPU)
message("CUDA: ${CUDA_VERSION}, CUDNN_MAJOR_VERSION: ${CUDNN_MAJOR_VERSION}")
# cuda 11.0 is not supported, 11.2 would be added.
if(CUDA_VERSION EQUAL 10.1)
set(CUDA_SUFFIX "x86-64_gcc8.2_avx_mkl_cuda10.1_cudnn7.6.5_trt6.0.1.5")
......@@ -52,14 +52,19 @@ if (WITH_GPU)
else()
set(WITH_TRT OFF)
endif()
if (WITH_GPU)
SET(PADDLE_LIB_VERSION "${PADDLE_VERSION}/cxx_c/Linux/GPU/${CUDA_SUFFIX}")
elseif (WITH_LITE)
message("cpu arch: ${CMAKE_SYSTEM_PROCESSOR}")
if (WITH_XPU)
SET(PADDLE_LIB_VERSION "arm64_gcc7.3_openblas")
if (CMAKE_SYSTEM_PROCESSOR STREQUAL "x86-64")
SET(PADDLE_LIB_VERSION "x86-64_gcc8.2_avx_mkl")
elseif(CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64")
SET(PADDLE_LIB_VERSION "arm64_gcc7.3_openblas")
endif()
else()
SET(PADDLE_LIB_VERSION "${PADDLE_VERSION}-${CMAKE_SYSTEM_PROCESSOR}")
MESSAGE("paddle lite lib is unknown.")
SET(PADDLE_LIB_VERSION "paddle-lite-unknown")
endif()
else()
if (WITH_AVX)
......
......@@ -23,8 +23,7 @@ using configure::GeneralModelConfig;
using baidu::paddle_serving::predictor::general_model::Request;
using baidu::paddle_serving::predictor::general_model::Response;
using baidu::paddle_serving::predictor::general_model::Tensor;
// paddle inference 2.1 support: FLOAT32, INT64, INT32, UINT8, INT8
// will support: FLOAT16
// support: FLOAT32, INT64, INT32, UINT8, INT8, FLOAT16
enum ProtoDataType {
P_INT64 = 0,
P_FLOAT32,
......@@ -431,7 +430,8 @@ int PredictorOutputs::ParseProto(const Response& res,
output.tensor(idx).int_data().begin(),
output.tensor(idx).int_data().begin() + size);
} else if (fetch_name_to_type[name] == P_UINT8
|| fetch_name_to_type[name] == P_INT8) {
|| fetch_name_to_type[name] == P_INT8
|| fetch_name_to_type[name] == P_FP16) {
VLOG(2) << "fetch var [" << name << "]type="
<< fetch_name_to_type[name];
string_data_map[name] = output.tensor(idx).tensor_content();
......
......@@ -25,8 +25,7 @@ using baidu::paddle_serving::Timer;
using baidu::paddle_serving::predictor::general_model::Request;
using baidu::paddle_serving::predictor::general_model::Response;
using baidu::paddle_serving::predictor::general_model::Tensor;
// paddle inference support: FLOAT32, INT64, INT32, UINT8, INT8
// will support: FLOAT16
// support: FLOAT32, INT64, INT32, UINT8, INT8, FLOAT16
enum ProtoDataType {
P_INT64 = 0,
P_FLOAT32,
......
......@@ -31,8 +31,7 @@ using baidu::paddle_serving::predictor::MempoolWrapper;
using baidu::paddle_serving::predictor::general_model::Tensor;
using baidu::paddle_serving::predictor::general_model::Request;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
// paddle inference 2.1 support: FLOAT32, INT64, INT32, UINT8, INT8
// will support: FLOAT16
// support: FLOAT32, INT64, INT32, UINT8, INT8, FLOAT16
enum ProtoDataType {
P_INT64 = 0,
P_FLOAT32,
......@@ -130,11 +129,11 @@ int GeneralReaderOp::inference() {
data_len = tensor.tensor_content().size();
src_ptr = tensor.tensor_content().data();
} else if (elem_type == P_FP16) {
// paddle inference will support FLOAT16
// elem_size = 1;
// paddleTensor.dtype = paddle::PaddleDType::FLOAT16;
// data_len = tensor.tensor_content().size();
// src_ptr = tensor.tensor_content().data();
// copy bytes from tensor content to TensorVector
elem_size = 1;
paddleTensor.dtype = paddle::PaddleDType::FLOAT16;
data_len = tensor.tensor_content().size();
src_ptr = tensor.tensor_content().data();
} else if (elem_type == P_STRING) {
// use paddle::PaddleDType::UINT8 as for String.
elem_size = sizeof(char);
......
......@@ -178,14 +178,12 @@ int GeneralResponseOp::inference() {
VLOG(2) << "(logid=" << log_id << ")Prepare int8 var ["
<< model_config->_fetch_name[idx] << "].";
tensor->set_tensor_content(in->at(idx).data.data(), in->at(idx).data.length());
}
// inference will support fp16
// else if (dtype == paddle::PaddleDType::FLOAT16) {
// tensor->set_elem_type(5);
// VLOG(2) << "(logid=" << log_id << ")Prepare float16 var ["
// << model_config->_fetch_name[idx] << "].";
// tensor->set_tensor_content(in->at(idx).data.data(), in->at(idx).data.length());
// }
} else if (dtype == paddle::PaddleDType::FLOAT16) {
tensor->set_elem_type(5);
VLOG(2) << "(logid=" << log_id << ")Prepare float16 var ["
<< model_config->_fetch_name[idx] << "].";
tensor->set_tensor_content(in->at(idx).data.data(), in->at(idx).data.length());
}
VLOG(2) << "(logid=" << log_id << ") fetch var ["
<< model_config->_fetch_name[idx] << "] ready";
......
......@@ -31,6 +31,7 @@
#include "core/predictor/framework/infer_data.h"
#include "core/predictor/framework/memory.h"
#include "paddle_inference_api.h" // NOLINT
#include "experimental/float16.h"
namespace baidu {
namespace paddle_serving {
namespace predictor {
......@@ -541,19 +542,17 @@ class FluidInferEngine : public CloneDBReloadableInferEngine<EngineCore> {
paddle::PaddleDType::INT8) {
int8_t* data = static_cast<int8_t*>(origin_data);
lod_tensor_in->CopyFromCpu(data);
} else if ((*tensorVector_in_pointer)[i].dtype ==
paddle::PaddleDType::FLOAT16) {
paddle::platform::float16* data =
static_cast<paddle::platform::float16*>(origin_data);
lod_tensor_in->CopyFromCpu(data);
} else {
LOG(ERROR) << "Inference not support type["
<< (*tensorVector_in_pointer)[i].dtype << "],name["
<< (*tensorVector_in_pointer)[i].name << "]"
<< " copy into core failed!";
}
// Paddle inference will support FP16 in next version.
// else if ((*tensorVector_in_pointer)[i].dtype ==
// paddle::PaddleDType::FLOAT16) {
// paddle::platform::float16* data =
// static_cast<paddle::platform::float16*>(origin_data);
// lod_tensor_in->CopyFromCpu(data);
// }
VLOG(2) << "Tensor:name=" << (*tensorVector_in_pointer)[i].name
<< ";in_dtype=" << (*tensorVector_in_pointer)[i].dtype
<< ";tensor_dtype=" << lod_tensor_in->type();
......@@ -641,20 +640,18 @@ class FluidInferEngine : public CloneDBReloadableInferEngine<EngineCore> {
int8_t* data_out = reinterpret_cast<int8_t*>(databuf_data);
lod_tensor_out->CopyToCpu(data_out);
databuf_char = reinterpret_cast<char*>(data_out);
} else if (dataType == paddle::PaddleDType::FLOAT16) {
databuf_size = out_num * sizeof(paddle::platform::float16);
databuf_data = MempoolWrapper::instance().malloc(databuf_size);
if (!databuf_data) {
LOG(ERROR) << "Malloc failed, size: " << databuf_size;
return -1;
}
paddle::platform::float16* data_out =
reinterpret_cast<paddle::platform::float16*>(databuf_data);
lod_tensor_out->CopyToCpu(data_out);
databuf_char = reinterpret_cast<char*>(data_out);
}
// Inference will support FP16 in next version
// else if (dataType == paddle::PaddleDType::FLOAT16) {
// using float16 = paddle::platform::float16;
// databuf_size = out_num * sizeof(float16);
// databuf_data = MempoolWrapper::instance().malloc(databuf_size);
// if (!databuf_data) {
// LOG(ERROR) << "Malloc failed, size: " << databuf_size;
// return -1;
// }
// float16* data_out = reinterpret_cast<float16*>(databuf_data);
// lod_tensor_out->CopyToCpu(data_out);
// databuf_char = reinterpret_cast<char*>(data_out);
// }
// Because task scheduling requires OPs to use 'Channel'
// (which is a data structure) to transfer data between OPs.
......
......@@ -266,6 +266,7 @@ class PaddleInferenceEngine : public EngineCore {
if (engine_conf.has_use_xpu() && engine_conf.use_xpu()) {
// 2 MB l3 cache
config.EnableXpu(2 * 1024 * 1024);
config.SetXpuDeviceId(gpu_id);
}
if (engine_conf.has_enable_memory_optimization() &&
......
......@@ -72,9 +72,13 @@ if (SERVER)
if(CUDA_VERSION EQUAL 10.1)
set(VERSION_SUFFIX 101)
elseif(CUDA_VERSION EQUAL 10.2)
set(VERSION_SUFFIX 102)
elseif(CUDA_VERSION EQUAL 11.0)
set(VERSION_SUFFIX 11)
if(CUDNN_MAJOR_VERSION EQUAL 7)
set(VERSION_SUFFIX 1027)
elseif(CUDNN_MAJOR_VERSION EQUAL 8)
set(VERSION_SUFFIX 1028)
endif()
elseif(CUDA_VERSION EQUAL 11.2)
set(VERSION_SUFFIX 112)
endif()
endif()
......
......@@ -219,6 +219,7 @@ class LocalPredictor(object):
if use_xpu:
# 2MB l3 cache
config.enable_xpu(8 * 1024 * 1024)
config.set_xpu_device_id(gpu_id)
# set cpu low precision
if not use_gpu and not use_lite:
if precision_type == paddle_infer.PrecisionType.Int8:
......
......@@ -551,6 +551,22 @@ class Client(object):
tmp_lod = result_batch_handle.get_lod(mi, name)
if np.size(tmp_lod) > 0:
result_map["{}.lod".format(name)] = tmp_lod
elif self.fetch_names_to_type_[name] == float16_type:
# result_map[name] will be py::array(numpy array)
tmp_str = result_batch_handle.get_string_by_name(
mi, name)
result_map[name] = np.fromstring(tmp_str, dtype = np.float16)
if result_map[name].size == 0:
raise ValueError(
"Failed to fetch, maybe the type of [{}]"
" is wrong, please check the model file".format(
name))
shape = result_batch_handle.get_shape(mi, name)
result_map[name].shape = shape
if name in self.lod_tensor_set:
tmp_lod = result_batch_handle.get_lod(mi, name)
if np.size(tmp_lod) > 0:
result_map["{}.lod".format(name)] = tmp_lod
multi_result_map.append(result_map)
ret = None
if len(model_engine_names) == 1:
......
......@@ -428,7 +428,7 @@ class Server(object):
if device_type == "0":
device_version = self.get_device_version()
elif device_type == "1":
if version_suffix == "101" or version_suffix == "102":
if version_suffix == "101" or version_suffix == "1027" or version_suffix == "1028" or version_suffix == "112":
device_version = "gpu-" + version_suffix
else:
device_version = "gpu-cuda" + version_suffix
......
......@@ -280,6 +280,10 @@ class LocalServiceHandler(object):
server.set_gpuid(gpuid)
# TODO: support arm or arm + xpu later
server.set_device(self._device_name)
if self._use_xpu:
server.set_xpu()
if self._use_lite:
server.set_lite()
server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(thread_num)
......
# A image for building paddle binaries
# Use cuda devel base image for both cpu and gpu environment
# When you modify it, please be aware of cudnn-runtime version
FROM nvidia/cuda:11.0.3-cudnn8-devel-ubuntu16.04
FROM nvidia/cuda:10.2-cudnn7-devel-ubuntu16.04
MAINTAINER PaddlePaddle Authors <paddle-dev@baidu.com>
# ENV variables
......@@ -104,7 +104,7 @@ ENV PATH=usr/local/go/bin:/root/go/bin:${PATH}
# Downgrade TensorRT
COPY tools/dockerfiles/build_scripts /build_scripts
RUN bash /build_scripts/install_trt.sh cuda11
RUN bash /build_scripts/install_trt.sh cuda10.2 cudnn7
RUN rm -rf /build_scripts
# git credential to skip password typing
......
......@@ -104,7 +104,7 @@ ENV PATH=usr/local/go/bin:/root/go/bin:${PATH}
# Downgrade TensorRT
COPY tools/dockerfiles/build_scripts /build_scripts
RUN bash /build_scripts/install_trt.sh cuda10.2
RUN bash /build_scripts/install_trt.sh cuda10.2 cudnn8
RUN rm -rf /build_scripts
# git credential to skip password typing
......
......@@ -15,20 +15,28 @@
# limitations under the License.
VERSION=$1
CUDNN=$2
if [[ "$VERSION" == "cuda10.1" ]];then
wget -q https://paddle-ci.gz.bcebos.com/TRT/TensorRT6-cuda10.1-cudnn7.tar.gz --no-check-certificate
tar -zxf TensorRT6-cuda10.1-cudnn7.tar.gz -C /usr/local
cp -rf /usr/local/TensorRT6-cuda10.1-cudnn7/include/* /usr/include/ && cp -rf /usr/local/TensorRT6-cuda10.1-cudnn7/lib/* /usr/lib/
echo "cuda10.1 trt install ==============>>>>>>>>>>>>"
rm TensorRT6-cuda10.1-cudnn7.tar.gz
elif [[ "$VERSION" == "cuda11" ]];then
wget -q https://paddle-ci.cdn.bcebos.com/TRT/TensorRT-7.1.3.4.Ubuntu-16.04.x86_64-gnu.cuda-11.0.cudnn8.0.tar.gz --no-check-certificate
tar -zxf TensorRT-7.1.3.4.Ubuntu-16.04.x86_64-gnu.cuda-11.0.cudnn8.0.tar.gz -C /usr/local
cp -rf /usr/local/TensorRT-7.1.3.4/include/* /usr/include/ && cp -rf /usr/local/TensorRT-7.1.3.4/lib/* /usr/lib/
rm TensorRT-7.1.3.4.Ubuntu-16.04.x86_64-gnu.cuda-11.0.cudnn8.0.tar.gz
elif [[ "$VERSION" == "cuda11.2" ]];then
wget https://paddle-ci.gz.bcebos.com/TRT/TensorRT-8.0.3.4.Linux.x86_64-gnu.cuda-11.3.cudnn8.2.tar.gz --no-check-certificate
tar -zxf TensorRT-8.0.3.4.Linux.x86_64-gnu.cuda-11.3.cudnn8.2.tar.gz
cp -rf /usr/local/TensorRT-8.0.3.4/include/* /usr/include/ && cp -rf /usr/local/TensorRT-8.0.3.4/lib/* /usr/lib/
rm -rf TensorRT-8.0.3.4.Linux.x86_64-gnu.cuda-11.3.cudnn8.2.tar.gz
elif [[ "$VERSION" == "cuda10.2" ]];then
wget https://paddle-ci.gz.bcebos.com/TRT/TensorRT7-cuda10.2-cudnn8.tar.gz --no-check-certificate
tar -zxf TensorRT7-cuda10.2-cudnn8.tar.gz -C /usr/local
cp -rf /usr/local/TensorRT-7.1.3.4/include/* /usr/include/ && cp -rf /usr/local/TensorRT-7.1.3.4/lib/* /usr/lib/
rm TensorRT7-cuda10.2-cudnn8.tar.gz
if [[ "$CUDNN" == "cudnn8" ]];then
wget https://paddle-ci.gz.bcebos.com/TRT/TensorRT7-cuda10.2-cudnn8.tar.gz --no-check-certificate
tar -zxf TensorRT7-cuda10.2-cudnn8.tar.gz -C /usr/local
cp -rf /usr/local/TensorRT-7.1.3.4/include/* /usr/include/ && cp -rf /usr/local/TensorRT-7.1.3.4/lib/* /usr/lib/
rm TensorRT7-cuda10.2-cudnn8.tar.gz
elif [[ "$CUDNN" == "cudnn7" ]];then
wget https://paddle-ci.gz.bcebos.com/TRT/TensorRT6-cuda10.2-cudnn7.tar.gz --no-check-certificate
tar -zxf TensorRT6-cuda10.2-cudnn7.tar.gz -C /usr/local
cp -rf /usr/local/TensorRT-6.0.1.8/include/* /usr/include/ && cp -rf /usr/local/TensorRT-6.0.1.8/lib/* /usr/lib/
rm -rf TensorRT6-cuda10.2-cudnn7.tar.gz
fi
fi
......@@ -53,7 +53,7 @@ if [[ $SERVING_VERSION == "0.5.0" ]]; then
fi
client_release="paddle-serving-client==$SERVING_VERSION"
app_release="paddle-serving-app==0.3.1"
elif [[ $SERVING_VERSION == "0.6.0" ]]; then
else
if [[ "$RUN_ENV" == "cpu" ]];then
server_release="https://paddle-serving.bj.bcebos.com/test-dev/whl/paddle_serving_server-$SERVING_VERSION-py3-none-any.whl"
serving_bin="https://paddle-serving.bj.bcebos.com/test-dev/bin/serving-cpu-avx-mkl-$SERVING_VERSION.tar.gz"
......@@ -80,10 +80,10 @@ if [[ "$RUN_ENV" == "cpu" ]];then
python$PYTHON_VERSION -m pip install $paddle_whl
cd /usr/local/
wget $serving_bin
tar xf serving-cpu-noavx-openblas-${SERVING_VERSION}.tar.gz
mv $PWD/serving-cpu-noavx-openblas-${SERVING_VERSION} $PWD/serving_bin
tar xf serving-cpu-avx-mkl-${SERVING_VERSION}.tar.gz
mv $PWD/serving-cpu-avx-mkl-${SERVING_VERSION} $PWD/serving_bin
echo "export SERVING_BIN=$PWD/serving_bin/serving">>/root/.bashrc
rm -rf serving-cpu-noavx-openblas-${SERVING_VERSION}.tar.gz
rm -rf serving-cpu-avx-mkl-${SERVING_VERSION}.tar.gz
cd -
elif [[ "$RUN_ENV" == "cuda10.1" ]];then
python$PYTHON_VERSION -m pip install $client_release $app_release $server_release
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册