未验证 提交 2bdad6cd 编写于 作者: Z Zhang Jun 提交者: GitHub

[inference][trt] Fp16 support for Generic plugin (#48253)

* Support FP16 in generic TensorRT plugin.
* Support FP16 for Pad3D.
上级 9ffc760f
......@@ -30,9 +30,7 @@
#include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/op_teller.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/utils/io_utils.h"
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
......@@ -42,15 +40,6 @@ namespace inference {
namespace analysis {
namespace {
bool IsFloat(framework::proto::VarType::Type t) {
if (t == framework::proto::VarType::FP16 ||
t == framework::proto::VarType::FP32 ||
t == framework::proto::VarType::FP64 ||
t == framework::proto::VarType::BF16)
return true;
return false;
}
// if in mixed model precision, we should make all tensorrt_engine's output
// floats dtype to float32 dtype.
void OutputProcess(framework::ir::Graph *graph,
......@@ -85,7 +74,7 @@ void OutputProcess(framework::ir::Graph *graph,
for (auto *var_node : op_node->outputs) {
if (!trt_outputs.count(var_node)) continue;
if (!var_node->Var()->Persistable() &&
IsFloat(var_node->Var()->GetDataType()) &&
tensorrt::IsFloatVar(var_node->Var()->GetDataType()) &&
var_node->Var()->GetDataType() != framework::proto::VarType::FP32) {
for (auto *next_op : var_node->outputs) {
// if next_op support mixed_precision, we need to add cast op.
......
......@@ -182,6 +182,8 @@ class GenericPluginCreater : public OpConverter {
phi::DefaultKernelSignatureMap::Instance().Get(op_desc.Type());
}
bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
plugin::GenericPlugin::InputOutPutVarInfo in_out_info;
for (auto &param_name : phi_kernel_signature.input_names) {
......@@ -218,7 +220,8 @@ class GenericPluginCreater : public OpConverter {
in_out_info.outputs_data_type.push_back(var->GetDataType());
}
}
plugin::GenericPlugin *plugin = new plugin::GenericPlugin(op, in_out_info);
plugin::GenericPlugin *plugin =
new plugin::GenericPlugin(op, in_out_info, with_fp16);
layer = engine_->AddDynamicPlugin(inputs.data(), inputs.size(), plugin);
RreplenishLayerAndOutput(layer, op_desc.Type(), output_names, test_mode);
......
......@@ -60,7 +60,7 @@ class MultiheadMatMulRoformerOpConverter : public OpConverter {
weight_data_tmp.data(), weight_data, weight_t->numel() * sizeof(float));
// (hidden_in, 3, hidden_out)
auto weight_dims = weight_t->dims();
auto& weight_dims = weight_t->dims();
int hidden_in = weight_dims[0]; // channels_in
int three = weight_dims[1]; // channels_out
......
......@@ -22,6 +22,7 @@
#include <utility>
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/platform/dynload/tensorrt.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/common/data_type.h"
......@@ -213,6 +214,15 @@ static inline nvinfer1::DataType PhiType2NvType(phi::DataType type) {
}
return nv_type;
}
static bool IsFloatVar(framework::proto::VarType::Type t) {
if (t == framework::proto::VarType::FP16 ||
t == framework::proto::VarType::FP32 ||
t == framework::proto::VarType::FP64 ||
t == framework::proto::VarType::BF16)
return true;
return false;
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -30,8 +30,11 @@ namespace plugin {
void BuildPhiKernelContextAttr(const framework::OpDesc& op_desc,
phi::KernelContext* kernel_context,
const phi::KernelSignature& signature,
const phi::Kernel& phi_kernel) {
const phi::KernelArgsDef& args_def = phi_kernel.args_def();
const phi::Kernel* phi_kernel) {
if (!phi_kernel->IsValid()) {
return;
}
const phi::KernelArgsDef& args_def = phi_kernel->args_def();
const auto& attr_names = signature.attr_names;
const auto& attr_defs = args_def.attribute_defs();
......@@ -221,28 +224,34 @@ void BuildPhiKernelContextAttr(const framework::OpDesc& op_desc,
GenericPlugin::GenericPlugin(
const paddle::framework::proto::OpDesc& proto_op_desc,
const InputOutPutVarInfo& in_out_info) {
const InputOutPutVarInfo& in_out_info,
bool with_fp16) {
proto_op_desc_ = proto_op_desc;
op_desc_ = std::move(framework::OpDesc(proto_op_desc_, nullptr));
proto_op_desc_.SerializeToString(&op_meta_data_);
inputs_data_type_ = in_out_info.inputs_data_type;
outputs_data_type_ = in_out_info.outputs_data_type;
with_fp16_ = with_fp16;
}
GenericPlugin::GenericPlugin(
const paddle::framework::proto::OpDesc& proto_op_desc,
const std::vector<int>& inputs_data_type,
const std::vector<int>& outputs_data_type) {
const std::vector<int>& outputs_data_type,
bool with_fp16) {
proto_op_desc_ = proto_op_desc;
op_desc_ = std::move(framework::OpDesc(proto_op_desc_, nullptr));
proto_op_desc_.SerializeToString(&op_meta_data_);
inputs_data_type_ = inputs_data_type;
outputs_data_type_ = outputs_data_type;
with_fp16_ = with_fp16;
}
GenericPlugin::GenericPlugin(void const* serial_data, size_t serial_length) {
DeserializeValue(&serial_data, &serial_length, &inputs_data_type_);
DeserializeValue(&serial_data, &serial_length, &outputs_data_type_);
DeserializeValue(&serial_data, &serial_length, &with_fp16_);
std::string op_meta_data((char*)(serial_data), serial_length); // NOLINT
op_meta_data_ = std::move(op_meta_data);
proto_op_desc_.ParseFromString(op_meta_data_);
......@@ -266,8 +275,8 @@ int GenericPlugin::getNbInputs() const TRT_NOEXCEPT {
}
nvinfer1::IPluginV2DynamicExt* GenericPlugin::clone() const TRT_NOEXCEPT {
nvinfer1::IPluginV2DynamicExt* plugin =
new GenericPlugin(proto_op_desc_, inputs_data_type_, outputs_data_type_);
nvinfer1::IPluginV2DynamicExt* plugin = new GenericPlugin(
proto_op_desc_, inputs_data_type_, outputs_data_type_, with_fp16_);
plugin->initialize();
return plugin;
}
......@@ -277,6 +286,8 @@ void GenericPlugin::serialize(void* buffer) const TRT_NOEXCEPT {
SerializeValue(&buffer, inputs_data_type_);
// outputs_data_type_
SerializeValue(&buffer, outputs_data_type_);
// use fp16
SerializeValue(&buffer, with_fp16_);
// serialize op_meta_data_
std::memcpy(buffer, op_meta_data_.c_str(), op_meta_data_.size());
reinterpret_cast<char*&>(buffer) += op_meta_data_.size();
......@@ -310,6 +321,12 @@ bool GenericPlugin::supportsFormatCombination(
if (pos == 3)
return (in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
} else if (op_desc_.Type() == "pad3d") {
return (in_out[pos].type == nvinfer1::DataType::kFLOAT ||
(isFp16Supported() &&
in_out[pos].type == nvinfer1::DataType::kHALF)) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR) &&
(in_out[0].type == in_out[pos].type);
} else {
return (in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
......@@ -337,34 +354,43 @@ int GenericPlugin::initialize() TRT_NOEXCEPT {
phi::DefaultKernelSignatureMap::Instance().Get(op_type);
}
phi::KernelKey phi_kernel_key(
phi::Backend::GPU, phi::DataLayout::ANY, phi::DataType::FLOAT32);
PADDLE_ENFORCE_EQ(
phi::KernelFactory::Instance().HasCompatiblePhiKernel(op_type),
true,
platform::errors::Fatal("%s has no compatible phi kernel!",
op_type.c_str()));
const phi::Kernel& phi_kernel = phi::KernelFactory::Instance().SelectKernel(
phi_kernel_signature.name, phi_kernel_key);
phi_kernel_ = &phi_kernel;
PADDLE_ENFORCE_EQ(phi_kernel_->IsValid(),
true,
platform::errors::Fatal("%s phi kernel is invalid!.",
phi_kernel_signature.name));
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
platform::CUDAPlace place(platform::GetCurrentDeviceId());
auto* dev_ctx = static_cast<phi::GPUContext*>(pool.Get(place));
if (!phi_kernel_context_) {
phi_kernel_context_ = new phi::KernelContext(dev_ctx);
BuildPhiKernelContextAttr(
op_desc_, phi_kernel_context_, phi_kernel_signature, phi_kernel);
std::vector<phi::DataType> precision_types{phi::DataType::FLOAT32,
phi::DataType::FLOAT16};
for (auto& precision_type : precision_types) {
phi::KernelKey phi_kernel_key(
phi::Backend::GPU, phi::DataLayout::ANY, precision_type);
auto nv_dtype = PhiType2NvType(precision_type);
phi_kernels_[nv_dtype].reset(
new phi::Kernel(phi::KernelFactory::Instance().SelectKernel(
phi_kernel_signature.name, phi_kernel_key)));
if (phi_kernel_contexts_.find(nv_dtype) == phi_kernel_contexts_.end() ||
!phi_kernel_contexts_[nv_dtype]) {
phi_kernel_contexts_[nv_dtype].reset(new phi::KernelContext(dev_ctx));
BuildPhiKernelContextAttr(op_desc_,
phi_kernel_contexts_[nv_dtype].get(),
phi_kernel_signature,
phi_kernels_[nv_dtype].get());
}
}
PADDLE_ENFORCE_EQ(phi_kernels_[nvinfer1::DataType::kFLOAT]->IsValid() ||
phi_kernels_[nvinfer1::DataType::kHALF]->IsValid(),
true,
platform::errors::Fatal("%s phi kernel is invalid!.",
phi_kernel_signature.name));
if (!dense_tensor_inputs_)
dense_tensor_inputs_ = new std::vector<phi::DenseTensor>(getNbInputs());
if (!dense_tensor_outputs_)
......@@ -396,15 +422,14 @@ void GenericPlugin::configurePlugin(
int nb_inputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nb_outputs) TRT_NOEXCEPT {
CHECK(phi_kernel_context_);
CHECK(phi_kernel_);
CHECK(phi_kernels_[nvinfer1::DataType::kFLOAT]->IsValid() ||
phi_kernels_[nvinfer1::DataType::kHALF]->IsValid());
CHECK(nb_inputs == getNbInputs());
CHECK(nb_outputs == getNbOutputs());
}
// Shutdown the layer. This is called when the engine is destroyed
void GenericPlugin::terminate() TRT_NOEXCEPT {
delete phi_kernel_context_;
delete dense_tensor_inputs_;
delete dense_tensor_outputs_;
}
......@@ -418,27 +443,42 @@ int GenericPlugin::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
platform::CUDAPlace place(platform::GetCurrentDeviceId());
// [TODO]now generic plugin do not support FP16 and INT8 precision
auto protoType2PhiType = [](int proto_type) -> std::pair<phi::DataType, int> {
auto protoType2PhiType =
[&](int proto_type,
nvinfer1::DataType nv_dtype) -> std::pair<phi::DataType, int> {
if (proto_type ==
static_cast<int>(framework::proto::VarType_Type::VarType_Type_FP32))
return {phi::DataType::FLOAT32, sizeof(float)};
else if (proto_type ==
static_cast<int>(
framework::proto::VarType_Type::VarType_Type_INT64) ||
proto_type ==
static_cast<int>(
framework::proto::VarType_Type::VarType_Type_INT32))
static_cast<int>(framework::proto::VarType_Type::VarType_Type_FP16)) {
return {phi::DataType::FLOAT16, sizeof(half)};
} else if (proto_type ==
static_cast<int>(
framework::proto::VarType_Type::VarType_Type_FP32)) {
if (isFp16Supported() && nv_dtype == nvinfer1::DataType::kHALF) {
return {phi::DataType::FLOAT16, sizeof(half)};
} else {
return {phi::DataType::FLOAT32, sizeof(float)};
}
} else if (proto_type ==
static_cast<int>(
framework::proto::VarType_Type::VarType_Type_INT64) ||
proto_type ==
static_cast<int>(
framework::proto::VarType_Type::VarType_Type_INT32)) {
return {phi::DataType::INT32, sizeof(int32_t)};
else if (proto_type ==
static_cast<int>(
framework::proto::VarType_Type::VarType_Type_BOOL))
} else if (proto_type ==
static_cast<int>(
framework::proto::VarType_Type::VarType_Type_BOOL)) {
return {phi::DataType::BOOL, sizeof(bool)};
else
} else {
CHECK(false) << "precision is not supported";
}
};
// input
phi_kernel_context_->ClearInputOutput();
auto data_type = input_desc[0].type;
CHECK((data_type == nvinfer1::DataType::kFLOAT) ||
(data_type == nvinfer1::DataType::kHALF));
phi_kernel_contexts_[data_type]->ClearInputOutput();
for (int i = 0; i < getNbInputs(); i++) {
auto const& input_dims = input_desc[i].dims;
......@@ -450,7 +490,9 @@ int GenericPlugin::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
int input_numel = 1;
for (int k = 0; k < input_shape.size(); k++) input_numel *= input_shape[k];
auto data_type_and_size = protoType2PhiType(inputs_data_type_[i]);
auto data_type_and_size =
protoType2PhiType(inputs_data_type_[i], data_type);
phi::DenseTensorMeta input_meta(data_type_and_size.first,
phi::make_ddim(input_shape));
std::shared_ptr<phi::Allocation> input_alloc(
......@@ -459,9 +501,9 @@ int GenericPlugin::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
place));
(*dense_tensor_inputs_)[i] =
std::move(phi::DenseTensor(input_alloc, input_meta));
phi_kernel_context_->EmplaceBackInput(&((*dense_tensor_inputs_)[i]));
phi_kernel_contexts_[data_type]->EmplaceBackInput(
&((*dense_tensor_inputs_)[i]));
}
// output
for (int i = 0; i < getNbOutputs(); i++) {
auto const& output_dims = output_desc[i].dims;
......@@ -474,23 +516,28 @@ int GenericPlugin::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
for (int k = 0; k < output_shape.size(); k++)
output_numel *= output_shape[k];
auto data_type_and_size = protoType2PhiType(inputs_data_type_[i]);
auto data_type_and_size =
protoType2PhiType(inputs_data_type_[i], data_type);
phi::DenseTensorMeta output_meta(data_type_and_size.first,
phi::make_ddim(output_shape));
std::shared_ptr<phi::Allocation> output_alloc(
new phi::Allocation(reinterpret_cast<void*>(outputs[i]),
output_numel * data_type_and_size.second,
place));
phi::DenseTensor output_densetonsor(output_alloc, output_meta);
(*dense_tensor_outputs_)[i] =
std::move(phi::DenseTensor(output_alloc, output_meta));
phi_kernel_context_->EmplaceBackOutput(&((*dense_tensor_outputs_)[i]));
phi_kernel_contexts_[data_type]->EmplaceBackOutput(
&((*dense_tensor_outputs_)[i]));
}
CHECK_EQ(phi_kernel_context_->InputsSize(), getNbInputs());
CHECK_EQ(phi_kernel_context_->OutputsSize(), getNbOutputs());
CHECK_EQ(phi_kernel_contexts_[data_type]->InputsSize(), getNbInputs());
CHECK_EQ(phi_kernel_contexts_[data_type]->OutputsSize(), getNbOutputs());
(*phi_kernel_)(phi_kernel_context_);
(*phi_kernels_[data_type])(phi_kernel_contexts_[data_type].get());
return cudaGetLastError() != cudaSuccess;
}
......
......@@ -44,7 +44,7 @@ namespace plugin {
void BuildPhiKernelContextAttr(const framework::OpDesc& op_desc,
phi::KernelContext* kernel_context,
const phi::KernelSignature& signature,
const phi::Kernel& phi_kernel);
const phi::Kernel* phi_kernel);
class GenericPlugin : public DynamicPluginTensorRT {
public:
......@@ -57,11 +57,13 @@ class GenericPlugin : public DynamicPluginTensorRT {
GenericPlugin() {}
GenericPlugin(const paddle::framework::proto::OpDesc& proto_op_desc,
const InputOutPutVarInfo& in_out_info);
const InputOutPutVarInfo& in_out_info,
bool with_fp16_ = false);
GenericPlugin(const paddle::framework::proto::OpDesc& proto_op_desc,
const std::vector<int>& inputs_data_type,
const std::vector<int>& outputs_data_type);
const std::vector<int>& outputs_data_type,
bool with_fp16_ = false);
// It was used for tensorrt deserialization.
// It should not be called by users.
......@@ -86,7 +88,7 @@ class GenericPlugin : public DynamicPluginTensorRT {
size_t getSerializationSize() const TRT_NOEXCEPT {
return op_meta_data_.size() + SerializedSize(inputs_data_type_) +
SerializedSize(outputs_data_type_);
SerializedSize(outputs_data_type_) + SerializedSize(with_fp16_);
}
void serialize(void* buffer) const TRT_NOEXCEPT;
......@@ -122,15 +124,24 @@ class GenericPlugin : public DynamicPluginTensorRT {
const nvinfer1::DataType* input_types,
int nb_inputs) const TRT_NOEXCEPT;
bool isFp16Supported() {
auto half_dtype = nvinfer1::DataType::kHALF;
return with_fp16_ &&
!(phi_kernels_.find(half_dtype) == phi_kernels_.end()) &&
phi_kernels_[half_dtype]->IsValid();
}
private:
std::string op_meta_data_;
framework::proto::OpDesc proto_op_desc_;
framework::OpDesc op_desc_;
private:
const phi::Kernel* phi_kernel_{nullptr};
std::unordered_map<nvinfer1::DataType, std::unique_ptr<phi::Kernel>>
phi_kernels_;
std::unordered_map<nvinfer1::DataType, std::unique_ptr<phi::KernelContext>>
phi_kernel_contexts_;
phi::KernelContext* phi_kernel_context_{nullptr};
std::vector<phi::DenseTensor>* dense_tensor_inputs_{nullptr};
std::vector<phi::DenseTensor>* dense_tensor_outputs_{nullptr};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册