未验证 提交 fca8595e 编写于 作者: Y Yuanle Liu 提交者: GitHub

[Paddle-TRT] add generic plugin for lookup_table_v2(embedding) op (#53539)

* add embedding generic plugin, not enabled
上级 2bf61284
......@@ -17,13 +17,13 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/generic_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.h"
#include "paddle/phi/api/ext/op_meta_info.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* Stack converter from fluid to tensorRT.
*/
class CustomPluginCreater : public OpConverter {
public:
void operator()(const framework::proto::OpDesc &op,
......@@ -164,6 +164,8 @@ class GenericPluginCreater : public OpConverter {
const framework::Scope &scope,
bool test_mode) override {
framework::OpDesc op_desc(op, nullptr);
VLOG(3) << "convert " << op_desc.Type() << " op to generic pluign layer";
CHECK(block_);
const framework::BlockDesc block_desc(
nullptr, const_cast<framework::proto::BlockDesc *>(block_));
......@@ -181,6 +183,14 @@ class GenericPluginCreater : public OpConverter {
phi_kernel_signature =
phi::DefaultKernelSignatureMap::Instance().Get(op_desc.Type());
}
VLOG(3) << phi_kernel_signature;
PADDLE_ENFORCE_EQ(
phi_kernel_signature.input_names.empty() ||
phi_kernel_signature.output_names.empty(),
false,
platform::errors::PreconditionNotMet(
"The %s op's kernel signature (inputs and output) should not null.",
op_desc.Type()));
bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
......
......@@ -417,6 +417,24 @@ nvinfer1::DimsExprs GridSamplerInferMeta(
return output;
}
nvinfer1::DimsExprs LookupTableV2InferMeta(
int output_index,
const nvinfer1::DimsExprs* inputs,
int nb_inputs,
nvinfer1::IExprBuilder& expr_builder, // NOLINT
const framework::OpDesc& op_desc) {
const auto x_dims = inputs[0];
const auto weight_dims = inputs[1];
nvinfer1::DimsExprs output;
output.nbDims = x_dims.nbDims + 1;
for (int i = 0; i < x_dims.nbDims; ++i) {
output.d[i] = x_dims.d[i];
}
output.d[x_dims.nbDims] = weight_dims.d[1];
return output;
}
PD_REGISTER_DYNAMIC_INFER_META_FN(gather_nd, GatherNdInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(yolo_box, YoloBoxInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(instance_norm, InstanceNormInferMeta);
......@@ -427,6 +445,7 @@ PD_REGISTER_DYNAMIC_INFER_META_FN(moe, MoeInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(pad3d, Pad3dInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(grid_sampler, GridSamplerInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(p_norm, PNormInferMeta);
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -2865,10 +2865,10 @@ struct SimpleOpTypeSetTeller : public Teller {
"logsigmoid",
"preln_layernorm_shift_partition",
"lookup_table",
"lookup_table_v2",
"trans_layernorm",
"merge_layernorm",
"skip_merge_layernorm",
"lookup_table_v2",
"expand_v2",
"expand_as_v2",
"fuse_eleadd_transpose",
......@@ -3143,6 +3143,7 @@ OpTeller::OpTeller() {
tellers_.emplace_back(new tensorrt::GenericPluginTeller);
tellers_.emplace_back(new tensorrt::CustomPluginTeller);
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -18,6 +18,7 @@
#include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/inference/tensorrt/dynamic_shape_infermeta_registry.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/kernel_factory.h"
......@@ -354,6 +355,19 @@ bool GenericPlugin::supportsFormatCombination(
if (pos == 3)
return in_out[0].type == in_out[pos].type &&
in_out[0].format == in_out[pos].format;
} else if (op_desc_.Type() == "lookup_table_v2") {
if (pos == 0)
return (in_out[pos].type == nvinfer1::DataType::kINT32 &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR));
if (pos == 1)
return (in_out[pos].type == nvinfer1::DataType::kFLOAT) ||
((isFp16Supported() &&
in_out[pos].type == nvinfer1::DataType::kHALF)) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
// output
if (pos == 2)
return in_out[1].type == in_out[pos].type &&
in_out[1].format == in_out[pos].format;
} else {
return (in_out[pos].type == nvinfer1::DataType::kFLOAT ||
(isFp16Supported() &&
......@@ -367,6 +381,9 @@ nvinfer1::DataType GenericPlugin::getOutputDataType(
int index,
const nvinfer1::DataType* input_types,
int nb_inputs) const TRT_NOEXCEPT {
if (op_desc_.Type() == "lookup_table_v2") {
return input_types[1];
}
return input_types[0];
}
......@@ -472,7 +489,7 @@ int GenericPlugin::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
cudaStream_t stream) TRT_NOEXCEPT {
platform::CUDAPlace place(platform::GetCurrentDeviceId());
// [TODO]now generic plugin do not support INT8 precision
// TODO(inference): generic plugin do not support INT8 precision now.
auto protoType2PhiType =
[&](int proto_type,
nvinfer1::DataType nv_dtype) -> std::pair<phi::DataType, int> {
......@@ -503,8 +520,13 @@ int GenericPlugin::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
}
};
nvinfer1::DataType data_type;
// input
auto data_type = input_desc[0].type;
if (op_desc_.Type() == "lookup_table_v2") {
data_type = input_desc[1].type;
} else {
data_type = input_desc[0].type;
}
CHECK((data_type == nvinfer1::DataType::kFLOAT) ||
(data_type == nvinfer1::DataType::kHALF));
......@@ -555,8 +577,6 @@ int GenericPlugin::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
output_numel * data_type_and_size.second,
place));
phi::DenseTensor output_densetonsor(output_alloc, output_meta);
(*dense_tensor_outputs_)[i] =
std::move(phi::DenseTensor(output_alloc, output_meta));
......
......@@ -15,8 +15,6 @@
#pragma once
#include <NvInfer.h>
#include <stdio.h>
#include <cassert>
#include <string>
#include <vector>
......@@ -84,45 +82,50 @@ class GenericPlugin : public DynamicPluginTensorRT {
// Shutdown the layer. This is called when the engine is destroyed
void terminate() TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT{};
void destroy() TRT_NOEXCEPT override{};
size_t getSerializationSize() const TRT_NOEXCEPT {
return op_meta_data_.size() + SerializedSize(inputs_data_type_) +
SerializedSize(outputs_data_type_) + SerializedSize(with_fp16_);
size_t getSerializationSize() const TRT_NOEXCEPT override {
size_t sum = 0;
sum += SerializedSize(inputs_data_type_);
sum += SerializedSize(outputs_data_type_);
sum += SerializedSize(with_fp16_);
sum += op_meta_data_.size();
return sum;
}
void serialize(void* buffer) const TRT_NOEXCEPT;
void serialize(void* buffer) const TRT_NOEXCEPT override;
// The Func in IPluginV2
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT;
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override;
nvinfer1::DimsExprs getOutputDimensions(
int output_index,
const nvinfer1::DimsExprs* inputs,
int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) // NOLINT
TRT_NOEXCEPT;
TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* in_out,
int nb_inputs,
int nb_outputs) TRT_NOEXCEPT;
int nb_outputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nb_inputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nb_outputs) TRT_NOEXCEPT;
int nb_outputs) TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) TRT_NOEXCEPT;
cudaStream_t stream) TRT_NOEXCEPT override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* input_types,
int nb_inputs) const TRT_NOEXCEPT;
int nb_inputs) const
TRT_NOEXCEPT override;
bool isFp16Supported() {
auto half_dtype = nvinfer1::DataType::kHALF;
......@@ -146,7 +149,6 @@ class GenericPlugin : public DynamicPluginTensorRT {
std::vector<phi::DenseTensor>* dense_tensor_outputs_{nullptr};
private:
InputOutPutVarInfo in_out_info_;
std::vector<int> inputs_data_type_;
std::vector<int> outputs_data_type_;
};
......@@ -166,6 +168,7 @@ class GenericPluginCreator : public TensorRTPluginCreator {
return new GenericPlugin(serial_data, serial_length);
}
};
REGISTER_TRT_PLUGIN_V2(GenericPluginCreator);
} // namespace plugin
......
......@@ -19,7 +19,7 @@ namespace inference {
namespace tensorrt {
bool PluginArgumentMappingContext::HasInput(const std::string& name) const {
auto inputs = op_desc_ptr_->Inputs();
auto inputs = op_desc_->Inputs();
for (auto& i : inputs) {
if (i.first == name && !i.second.empty()) return true;
}
......@@ -27,7 +27,7 @@ bool PluginArgumentMappingContext::HasInput(const std::string& name) const {
}
bool PluginArgumentMappingContext::HasOutput(const std::string& name) const {
auto outputs = op_desc_ptr_->Outputs();
auto outputs = op_desc_->Outputs();
for (auto& i : outputs) {
if (i.first == name && !i.second.empty()) return true;
}
......@@ -35,47 +35,44 @@ bool PluginArgumentMappingContext::HasOutput(const std::string& name) const {
}
bool PluginArgumentMappingContext::HasAttr(const std::string& name) const {
return op_desc_ptr_->HasAttr(name);
return op_desc_->HasAttr(name);
}
paddle::any PluginArgumentMappingContext::Attr(
const std::string& attr_name) const {
auto attr_type = op_desc_ptr_->GetAttrType(attr_name);
auto attr_type = op_desc_->GetAttrType(attr_name);
switch (attr_type) {
case framework::proto::AttrType::INT: {
return PADDLE_GET_CONST(int, op_desc_ptr_->GetAttr(attr_name));
return PADDLE_GET_CONST(int, op_desc_->GetAttr(attr_name));
break;
};
case framework::proto::AttrType::FLOAT: {
return PADDLE_GET_CONST(float, op_desc_ptr_->GetAttr(attr_name));
return PADDLE_GET_CONST(float, op_desc_->GetAttr(attr_name));
break;
};
case framework::proto::AttrType::STRING: {
return PADDLE_GET_CONST(std::string, op_desc_ptr_->GetAttr(attr_name));
return PADDLE_GET_CONST(std::string, op_desc_->GetAttr(attr_name));
break;
};
case framework::proto::AttrType::INTS: {
return PADDLE_GET_CONST(std::vector<int>,
op_desc_ptr_->GetAttr(attr_name));
return PADDLE_GET_CONST(std::vector<int>, op_desc_->GetAttr(attr_name));
break;
};
case framework::proto::AttrType::FLOATS: {
return PADDLE_GET_CONST(std::vector<float>,
op_desc_ptr_->GetAttr(attr_name));
return PADDLE_GET_CONST(std::vector<float>, op_desc_->GetAttr(attr_name));
break;
};
case framework::proto::AttrType::STRINGS: {
return PADDLE_GET_CONST(std::vector<std::string>,
op_desc_ptr_->GetAttr(attr_name));
op_desc_->GetAttr(attr_name));
break;
};
case framework::proto::AttrType::BOOLEAN: {
return PADDLE_GET_CONST(bool, op_desc_ptr_->GetAttr(attr_name));
return PADDLE_GET_CONST(bool, op_desc_->GetAttr(attr_name));
break;
};
case framework::proto::AttrType::BOOLEANS: {
return PADDLE_GET_CONST(std::vector<bool>,
op_desc_ptr_->GetAttr(attr_name));
return PADDLE_GET_CONST(std::vector<bool>, op_desc_->GetAttr(attr_name));
break;
};
default: {
......@@ -87,54 +84,82 @@ paddle::any PluginArgumentMappingContext::Attr(
}
size_t PluginArgumentMappingContext::InputSize(const std::string& name) const {
return op_desc_ptr_->Inputs().at(name).size();
return op_desc_->Inputs().at(name).size();
}
size_t PluginArgumentMappingContext::OutputSize(const std::string& name) const {
return op_desc_ptr_->Outputs().at(name).size();
return op_desc_->Outputs().at(name).size();
}
bool PluginArgumentMappingContext::IsDenseTensorInput(
const std::string& name) const {
return false;
return true;
}
bool PluginArgumentMappingContext::IsDenseTensorInputs(
const std::string& name) const {
return false;
return true;
}
bool PluginArgumentMappingContext::IsSelectedRowsInput(
bool PluginArgumentMappingContext::IsDenseTensorVectorInput(
const std::string& name) const {
PADDLE_THROW(phi::errors::Unimplemented(
"Not supported for input vector of DenseTensor."));
return false;
}
bool PluginArgumentMappingContext::IsSelectedRowsInputs(
bool PluginArgumentMappingContext::IsDenseTensorOutput(
const std::string& name) const {
return false;
return true;
}
bool PluginArgumentMappingContext::IsSparseCooTensorInput(
bool PluginArgumentMappingContext::IsSelectedRowsInput(
const std::string& name) const {
PADDLE_THROW(
phi::errors::Unimplemented("Not supported for input of SelectedRows."));
return false;
}
bool PluginArgumentMappingContext::IsSparseCooTensorOutput(
bool PluginArgumentMappingContext::IsSelectedRowsInputs(
const std::string& name) const {
PADDLE_THROW(
phi::errors::Unimplemented("Not supported for inputs of SelectedRows."));
return false;
}
bool PluginArgumentMappingContext::IsSparseCsrTensorInput(
bool PluginArgumentMappingContext::IsSelectedRowsOutput(
const std::string& name) const {
PADDLE_THROW(
phi::errors::Unimplemented("Not supported for output of SelectedRows."));
return false;
}
bool PluginArgumentMappingContext::IsDenseTensorVectorInput(
bool PluginArgumentMappingContext::IsSparseCooTensorInput(
const std::string& name) const {
PADDLE_THROW(phi::errors::Unimplemented(
"Not supported for input of SparseCooTensor."));
return false;
}
bool PluginArgumentMappingContext::IsDenseTensorOutput(
bool PluginArgumentMappingContext::IsSparseCooTensorOutput(
const std::string& name) const {
PADDLE_THROW(phi::errors::Unimplemented(
"Not supported for output of SparseCooTensor."));
return false;
}
bool PluginArgumentMappingContext::IsSelectedRowsOutput(
bool PluginArgumentMappingContext::IsSparseCsrTensorInput(
const std::string& name) const {
PADDLE_THROW(phi::errors::Unimplemented(
"Not supported for input of SparseCsrTensor."));
return false;
}
bool PluginArgumentMappingContext::IsForInferShape() const {
PADDLE_THROW(phi::errors::Unimplemented("Not supported for InferShape."));
return false;
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -25,8 +25,8 @@ namespace tensorrt {
class PluginArgumentMappingContext : public ::phi::ArgumentMappingContext {
public:
explicit PluginArgumentMappingContext(framework::OpDesc* op_desc_ptr)
: op_desc_ptr_(op_desc_ptr) {}
explicit PluginArgumentMappingContext(const framework::OpDesc* op_desc)
: op_desc_(op_desc) {}
bool HasInput(const std::string& name) const override;
......@@ -60,11 +60,12 @@ class PluginArgumentMappingContext : public ::phi::ArgumentMappingContext {
bool IsSelectedRowsOutput(const std::string& name) const override;
bool IsForInferShape() const override { return false; }
bool IsForInferShape() const override;
private:
framework::OpDesc* op_desc_ptr_;
const framework::OpDesc* op_desc_;
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -117,15 +117,10 @@ TEST(ArgMappingContexTest, BasicFunction) {
EXPECT_EQ(context.InputSize("X"), true);
EXPECT_EQ(context.OutputSize("Out"), true);
EXPECT_EQ(context.IsDenseTensorInput("X"), false);
EXPECT_EQ(context.IsDenseTensorInputs("X"), false);
EXPECT_EQ(context.IsSelectedRowsInput("X"), false);
EXPECT_EQ(context.IsDenseTensorVectorInput("X"), false);
EXPECT_EQ(context.IsDenseTensorOutput("Out"), false);
EXPECT_EQ(context.IsSelectedRowsOutput("Out"), false);
EXPECT_EQ(context.IsSparseCooTensorOutput("Out"), false);
EXPECT_EQ(context.IsForInferShape(), false);
EXPECT_EQ(context.IsDenseTensorInput("X"), true);
EXPECT_EQ(context.IsDenseTensorInputs("X"), true);
EXPECT_EQ(context.IsDenseTensorOutput("Out"), true);
}
} // namespace tensorrt
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册