未验证 提交 ea851796 编写于 作者: S Shang Zhizhou 提交者: GitHub

TensorRT中ernie模型推理性能优化,支持变长输入 (#28367)

* fp16 result ok

* change -DWITH_NVINFER_PLUGIN toconfig.EnableTensorRtOSS

* auto detect special slice op converter for ernie with trt oss

* ernie oss only support fp16

* fix special_slice_plugin serialize bug

* matmul in tensorrt ok

* ernie unittest ok

* add matmul tensorrt unittest

* remove demo code
上级 84cc61b2
...@@ -207,6 +207,7 @@ struct Argument { ...@@ -207,6 +207,7 @@ struct Argument {
DECL_ARGUMENT_FIELD(tensorrt_use_static_engine, TensorRtUseStaticEngine, DECL_ARGUMENT_FIELD(tensorrt_use_static_engine, TensorRtUseStaticEngine,
bool); bool);
DECL_ARGUMENT_FIELD(tensorrt_use_calib_mode, TensorRtUseCalibMode, bool); DECL_ARGUMENT_FIELD(tensorrt_use_calib_mode, TensorRtUseCalibMode, bool);
DECL_ARGUMENT_FIELD(tensorrt_use_oss, TensorRtUseOSS, bool);
DECL_ARGUMENT_FIELD(lite_passes_filter, LitePassesFilter, DECL_ARGUMENT_FIELD(lite_passes_filter, LitePassesFilter,
std::vector<std::string>); std::vector<std::string>);
......
...@@ -95,6 +95,7 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -95,6 +95,7 @@ void IRPassManager::CreatePasses(Argument *argument,
bool use_calib_mode = argument->tensorrt_use_calib_mode(); bool use_calib_mode = argument->tensorrt_use_calib_mode();
pass->Set("enable_int8", new bool(enable_int8)); pass->Set("enable_int8", new bool(enable_int8));
pass->Set("use_calib_mode", new bool(use_calib_mode)); pass->Set("use_calib_mode", new bool(use_calib_mode));
pass->Set("use_oss", new bool(argument->tensorrt_use_oss()));
pass->Set("precision_mode", pass->Set("precision_mode",
new AnalysisConfig::Precision(precision_mode)); new AnalysisConfig::Precision(precision_mode));
......
...@@ -117,11 +117,20 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -117,11 +117,20 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
block_desc.Proto()->set_idx(0); block_desc.Proto()->set_idx(0);
LOG(INFO) << "--- detect a sub-graph with " << subgraph.size() << " nodes"; LOG(INFO) << "--- detect a sub-graph with " << subgraph.size() << " nodes";
bool has_fused_embedding_eltwise_layernorm = false;
bool has_multihead_matmul = false;
for (auto *node : subgraph) { for (auto *node : subgraph) {
auto *new_block_op = new_block->AppendOp(); auto *new_block_op = new_block->AppendOp();
auto *op = block_desc.AppendOp(); auto *op = block_desc.AppendOp();
*new_block_op->Proto() = *node->Op()->Proto(); *new_block_op->Proto() = *node->Op()->Proto();
*op->Proto() = *node->Op()->Proto(); *op->Proto() = *node->Op()->Proto();
if (!has_fused_embedding_eltwise_layernorm
&& op->Type() == "fused_embedding_eltwise_layernorm") {
has_fused_embedding_eltwise_layernorm = true;
}
if (!has_multihead_matmul && op->Type() == "multihead_matmul") {
has_multihead_matmul = true;
}
} }
// Then, we will use the input_names_with_id and output_names_with_id to // Then, we will use the input_names_with_id and output_names_with_id to
...@@ -308,6 +317,9 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -308,6 +317,9 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
precision_mode, calibrator.get(), Get<int>("gpu_device_id"), precision_mode, calibrator.get(), Get<int>("gpu_device_id"),
min_input_shape, max_input_shape, opt_input_shape, min_input_shape, max_input_shape, opt_input_shape,
disable_trt_plugin_fp16); disable_trt_plugin_fp16);
trt_engine->SetUseOSS(Get<bool>("use_oss"));
trt_engine->SetWithErnie(
has_multihead_matmul && has_fused_embedding_eltwise_layernorm);
bool need_serialize = (use_static_engine && !load_from_memory); bool need_serialize = (use_static_engine && !load_from_memory);
if (need_serialize) { if (need_serialize) {
...@@ -386,4 +398,5 @@ REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass) ...@@ -386,4 +398,5 @@ REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass)
.EQ("instance_norm", 0) .EQ("instance_norm", 0)
.EQ("gelu", 0) .EQ("gelu", 0)
.EQ("layer_norm", 0) .EQ("layer_norm", 0)
.EQ("scale", 0)); .EQ("scale", 0)
.EQ("matmul", 0));
...@@ -122,6 +122,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { ...@@ -122,6 +122,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(tensorrt_precision_mode_); CP_MEMBER(tensorrt_precision_mode_);
CP_MEMBER(trt_use_static_engine_); CP_MEMBER(trt_use_static_engine_);
CP_MEMBER(trt_use_calib_mode_); CP_MEMBER(trt_use_calib_mode_);
CP_MEMBER(trt_use_oss_);
// MKLDNN related. // MKLDNN related.
CP_MEMBER(use_mkldnn_); CP_MEMBER(use_mkldnn_);
CP_MEMBER(mkldnn_enabled_op_types_); CP_MEMBER(mkldnn_enabled_op_types_);
...@@ -280,6 +281,10 @@ void AnalysisConfig::SetTRTDynamicShapeInfo( ...@@ -280,6 +281,10 @@ void AnalysisConfig::SetTRTDynamicShapeInfo(
disable_trt_plugin_fp16_ = disable_trt_plugin_fp16; disable_trt_plugin_fp16_ = disable_trt_plugin_fp16;
} }
void AnalysisConfig::EnableTensorRtOSS() {
trt_use_oss_ = true;
}
// TODO(Superjomn) refactor this, buggy. // TODO(Superjomn) refactor this, buggy.
void AnalysisConfig::Update() { void AnalysisConfig::Update() {
auto info = SerializeInfoCache(); auto info = SerializeInfoCache();
......
...@@ -470,6 +470,7 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -470,6 +470,7 @@ void AnalysisPredictor::PrepareArgument() {
argument_.SetTensorRtPrecisionMode(config_.tensorrt_precision_mode_); argument_.SetTensorRtPrecisionMode(config_.tensorrt_precision_mode_);
argument_.SetTensorRtUseStaticEngine(config_.trt_use_static_engine_); argument_.SetTensorRtUseStaticEngine(config_.trt_use_static_engine_);
argument_.SetTensorRtUseCalibMode(config_.trt_use_calib_mode_); argument_.SetTensorRtUseCalibMode(config_.trt_use_calib_mode_);
argument_.SetTensorRtUseOSS(config_.trt_use_oss_);
argument_.SetMinInputShape(config_.min_input_shape_); argument_.SetMinInputShape(config_.min_input_shape_);
argument_.SetMaxInputShape(config_.max_input_shape_); argument_.SetMaxInputShape(config_.max_input_shape_);
argument_.SetOptimInputShape(config_.optim_input_shape_); argument_.SetOptimInputShape(config_.optim_input_shape_);
...@@ -1055,7 +1056,7 @@ USE_TRT_CONVERTER(elementwise_mul_tensor); ...@@ -1055,7 +1056,7 @@ USE_TRT_CONVERTER(elementwise_mul_tensor);
USE_TRT_CONVERTER(elementwise_max_tensor); USE_TRT_CONVERTER(elementwise_max_tensor);
USE_TRT_CONVERTER(elementwise_min_tensor); USE_TRT_CONVERTER(elementwise_min_tensor);
USE_TRT_CONVERTER(elementwise_pow_tensor); USE_TRT_CONVERTER(elementwise_pow_tensor);
USE_TRT_CONVERTER(mul); USE_TRT_CONVERTER(matmul);
USE_TRT_CONVERTER(conv2d); USE_TRT_CONVERTER(conv2d);
USE_TRT_CONVERTER(relu); USE_TRT_CONVERTER(relu);
USE_TRT_CONVERTER(sigmoid); USE_TRT_CONVERTER(sigmoid);
......
...@@ -312,6 +312,20 @@ struct PD_INFER_DECL AnalysisConfig { ...@@ -312,6 +312,20 @@ struct PD_INFER_DECL AnalysisConfig {
std::map<std::string, std::vector<int>> max_input_shape, std::map<std::string, std::vector<int>> max_input_shape,
std::map<std::string, std::vector<int>> optim_input_shape, std::map<std::string, std::vector<int>> optim_input_shape,
bool disable_trt_plugin_fp16 = false); bool disable_trt_plugin_fp16 = false);
///
/// \brief Replace some TensorRT plugins to TensorRT OSS(
/// https://github.com/NVIDIA/TensorRT), with which some models's inference may
/// be more high-performance. Libnvinfer_plugin.so greater than V7.2.1 is needed.
///
void EnableTensorRtOSS();
///
/// \brief A boolean state telling whether to use the TensorRT OSS.
///
/// \return bool Whether to use the TensorRT OSS.
///
bool tensorrt_oss_enabled() { return trt_use_oss_; }
/// ///
/// \brief Turn on the usage of Lite sub-graph engine. /// \brief Turn on the usage of Lite sub-graph engine.
/// ///
...@@ -569,6 +583,7 @@ struct PD_INFER_DECL AnalysisConfig { ...@@ -569,6 +583,7 @@ struct PD_INFER_DECL AnalysisConfig {
Precision tensorrt_precision_mode_{Precision::kFloat32}; Precision tensorrt_precision_mode_{Precision::kFloat32};
bool trt_use_static_engine_{false}; bool trt_use_static_engine_{false};
bool trt_use_calib_mode_{true}; bool trt_use_calib_mode_{true};
bool trt_use_oss_{false};
std::map<std::string, std::vector<int>> min_input_shape_{}; std::map<std::string, std::vector<int>> min_input_shape_{};
std::map<std::string, std::vector<int>> max_input_shape_{}; std::map<std::string, std::vector<int>> max_input_shape_{};
std::map<std::string, std::vector<int>> optim_input_shape_{}; std::map<std::string, std::vector<int>> optim_input_shape_{};
......
# Add TRT tests # Add TRT tests
nv_library(tensorrt_converter nv_library(tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc SRCS matmul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc
......
...@@ -49,6 +49,9 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -49,6 +49,9 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
input_ids.push_back(engine_->GetITensor(id_names[i])); input_ids.push_back(engine_->GetITensor(id_names[i]));
} }
// input_embs[0]: word_embedding
// input_embs[1]: pos_embedding
// input_embs[2]: sent_embedding
std::vector<float*> input_embs; std::vector<float*> input_embs;
std::vector<int> emb_sizes; std::vector<int> emb_sizes;
...@@ -85,15 +88,88 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -85,15 +88,88 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
get_persistable_data(op_desc.Input("Scale").front(), &scale_dims); get_persistable_data(op_desc.Input("Scale").front(), &scale_dims);
int64_t bias_size = framework::product(bias_dims); int64_t bias_size = framework::product(bias_dims);
int64_t scale_size = framework::product(scale_dims); int64_t scale_size = framework::product(scale_dims);
float eps = BOOST_GET_CONST(float, op_desc.GetAttr("epsilon"));
nvinfer1::ILayer* layer = nullptr; nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) { if (engine_->with_dynamic_shape()) {
auto use_fp16 = engine_->WithFp16(); if (engine_->use_oss()) {
auto plugin = new plugin::EmbEltwiseLayernormPluginDynamic( int output_fp16 = static_cast<int>((engine_->WithFp16() == 1) ? 1 : 0);
input_embs, bias, scale, emb_sizes, bias_size, scale_size, hidden, PADDLE_ENFORCE_EQ(output_fp16, 1,
eps, use_fp16); platform::errors::InvalidArgument(
layer = engine_->AddPluginV2(input_ids.data(), input_num, plugin); "Only Precision::KHalf(fp16) is supported when infering "
"ernie(bert) model with config.EnableTensorRtOSS(). "
"But Precision::KFloat32 is setted."));
const std::vector<nvinfer1::PluginField> fields{
{"bert_embeddings_layernorm_beta", bias,
nvinfer1::PluginFieldType::kFLOAT32,
static_cast<int32_t>(bias_size)},
{"bert_embeddings_layernorm_gamma", scale,
nvinfer1::PluginFieldType::kFLOAT32,
static_cast<int32_t>(scale_size)},
{"bert_embeddings_word_embeddings", input_embs[0],
nvinfer1::PluginFieldType::kFLOAT32,
static_cast<int32_t>(emb_sizes[0])},
{"bert_embeddings_token_type_embeddings", input_embs[2],
nvinfer1::PluginFieldType::kFLOAT32,
static_cast<int32_t>(emb_sizes[2])},
{"bert_embeddings_position_embeddings", input_embs[1],
nvinfer1::PluginFieldType::kFLOAT32,
static_cast<int32_t>(emb_sizes[1])},
{"output_fp16", &output_fp16, nvinfer1::PluginFieldType::kINT32, 1},
};
// remember to free
nvinfer1::PluginFieldCollection* plugin_ptr =
static_cast<nvinfer1::PluginFieldCollection*>(
malloc(sizeof(*plugin_ptr) +
fields.size() * sizeof(nvinfer1::PluginField)));
plugin_ptr->nbFields = static_cast<int>(fields.size());
plugin_ptr->fields = fields.data();
std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.emplace_back(engine_->GetITensor(
engine_->network()->getInput(0)->getName())); // word_embedding,
// eval_placeholder_0
plugin_inputs.emplace_back(engine_->GetITensor(
engine_->network()->getInput(1)->getName())); // sent_embedding,
// eval_placeholder_1
plugin_inputs.emplace_back(engine_->GetITensor(
engine_->network()->getInput(2)->getName())); // cu_seqlens,
// eval_placeholder_2
auto max_seqlen_tensor = engine_->GetITensor(
engine_->network()->getInput(3)->getName());
auto* shuffle_layer = TRT_ENGINE_ADD_LAYER(
engine_, Shuffle, *const_cast<nvinfer1::ITensor*>(max_seqlen_tensor));
nvinfer1::Dims shape_dim;
shape_dim.nbDims = 1;
shape_dim.d[0] = -1;
shuffle_layer->setReshapeDimensions(shape_dim);
plugin_inputs.emplace_back(shuffle_layer->getOutput(0)); // max_seqlen, eval_placeholder_3
auto creator = GetPluginRegistry()->getPluginCreator(
"CustomEmbLayerNormPluginDynamic", "2");
auto plugin_obj =
creator->createPlugin("CustomEmbLayerNormPluginDynamic", plugin_ptr);
auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin_obj);
layer = plugin_layer;
free(plugin_ptr);
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "emb_eltwise_layernorm",
{output_name, std::string("qkv_plugin_mask")},
test_mode);
} else {
bool use_fp16 = engine_->WithFp16();
float eps = BOOST_GET_CONST(float, op_desc.GetAttr("epsilon"));
plugin::DynamicPluginTensorRT* plugin = nullptr;
plugin = new plugin::EmbEltwiseLayernormPluginDynamic(
input_embs, bias, scale, emb_sizes, bias_size, scale_size, hidden,
eps, use_fp16);
layer = engine_->AddPluginV2(input_ids.data(), input_num, plugin);
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "emb_eltwise_layernorm", {output_name},
test_mode);
}
} else { } else {
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"You are running the Ernie(Bert) model in static" "You are running the Ernie(Bert) model in static"
...@@ -102,9 +178,6 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -102,9 +178,6 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
" to set the shape information to run the dynamic shape mode.")); " to set the shape information to run the dynamic shape mode."));
} }
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "emb_eltwise_layernorm", {output_name},
test_mode);
#else #else
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"You are running the TRT Dynamic Shape mode, need to confirm that " "You are running the TRT Dynamic Shape mode, need to confirm that "
......
...@@ -28,25 +28,55 @@ namespace inference { ...@@ -28,25 +28,55 @@ namespace inference {
namespace tensorrt { namespace tensorrt {
/* /*
* MulOp, IMatrixMultiplyLayer in TRT. This Layer doesn't has weights. * MatMulOp, IMatrixMultiplyLayer in TRT. This Layer doesn't has weights.
*/ */
class MulOpConverter : public OpConverter { class MatMulOpConverter : public OpConverter {
public: public:
void operator()(const framework::proto::OpDesc& op, void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override { const framework::Scope& scope, bool test_mode) override {
VLOG(3) << "convert a fluid mul op to tensorrt mul layer without bias"; VLOG(3) << "convert a fluid matmul op to tensorrt mul layer without bias";
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
// Declare inputs // Declare inputs
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]); auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]); auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]);
// Both the input1 and input2 do not need transpose.
bool transpose_X = BOOST_GET_CONST(bool, op_desc.GetAttr("transpose_X"));
bool transpose_Y = BOOST_GET_CONST(bool, op_desc.GetAttr("transpose_Y"));
auto* layer = TRT_ENGINE_ADD_LAYER( auto* layer = TRT_ENGINE_ADD_LAYER(
engine_, MatrixMultiply, *const_cast<nvinfer1::ITensor*>(input1), false, engine_, MatrixMultiply, *const_cast<nvinfer1::ITensor*>(input1), transpose_X,
*const_cast<nvinfer1::ITensor*>(input2), false); *const_cast<nvinfer1::ITensor*>(input2), transpose_Y);
float alpha = BOOST_GET_CONST(float, op_desc.GetAttr("alpha"));
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
engine_->SetITensor(output_name, layer->getOutput(0)); if (fabs(alpha - 1.0) < std::numeric_limits<float>::epsilon()) {
engine_->SetITensor(output_name, layer->getOutput(0));
} else {
auto create_weights = [&](float data, const std::string &type) -> float* {
std::unique_ptr<framework::Tensor> tmp_tensor(new framework::Tensor());
tmp_tensor->Resize({1});
auto* tmp_data = tmp_tensor->mutable_data<float>(platform::CPUPlace());
tmp_data[0] = data;
engine_->SetWeights(output_name + "_add_scale_op_" + type,
std::move(tmp_tensor));
return tmp_data;
};
float* alpha_data = create_weights(alpha, "alpha");
float* shift_data = create_weights(0.0, "shift");
float* power_data = create_weights(1.0, "power");
TensorRTEngine::Weight nv_alpha{nvinfer1::DataType::kFLOAT,
static_cast<void*>(alpha_data), 1};
TensorRTEngine::Weight nv_shift{nvinfer1::DataType::kFLOAT,
static_cast<void*>(shift_data), 1};
TensorRTEngine::Weight nv_power{nvinfer1::DataType::kFLOAT,
static_cast<void*>(power_data), 1};
auto* scale_layer = TRT_ENGINE_ADD_LAYER(
engine_, Scale, *layer->getOutput(0),
nvinfer1::ScaleMode::kUNIFORM,
nv_shift.get(), nv_alpha.get(), nv_power.get());
engine_->SetITensor(output_name, scale_layer->getOutput(0));
}
if (test_mode) { // the test framework can not determine which is the if (test_mode) { // the test framework can not determine which is the
// output, so place the declaration inside. // output, so place the declaration inside.
engine_->DeclareOutput(output_name); engine_->DeclareOutput(output_name);
...@@ -58,4 +88,4 @@ class MulOpConverter : public OpConverter { ...@@ -58,4 +88,4 @@ class MulOpConverter : public OpConverter {
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
REGISTER_TRT_OP_CONVERTER(mul, MulOpConverter); REGISTER_TRT_OP_CONVERTER(matmul, MatMulOpConverter);
...@@ -30,7 +30,6 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -30,7 +30,6 @@ class MultiheadMatMulOpConverter : public OpConverter {
// Declare inputs // Declare inputs
// Shouble be a 5 dims tensor. // Shouble be a 5 dims tensor.
auto* input = engine_->GetITensor(op_desc.Input("Input").front()); auto* input = engine_->GetITensor(op_desc.Input("Input").front());
auto* input_bias_qk = engine_->GetITensor(op_desc.Input("BiasQK").front());
// fc weights and fc bias // fc weights and fc bias
auto weight_name = op_desc.Input("W").front(); auto weight_name = op_desc.Input("W").front();
...@@ -50,7 +49,7 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -50,7 +49,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
memcpy(weight_data_tmp.data(), weight_data, memcpy(weight_data_tmp.data(), weight_data,
weight_t->numel() * sizeof(float)); weight_t->numel() * sizeof(float));
// (hidden, 3, all_head_size) // (hidden, 3, all_head_size)
auto weight_dims = weight_t->dims(); auto weight_dims = weight_t->dims();
int hidden = weight_dims[0]; // channels_in int hidden = weight_dims[0]; // channels_in
...@@ -65,36 +64,136 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -65,36 +64,136 @@ class MultiheadMatMulOpConverter : public OpConverter {
} }
} }
}; };
// transpose weight_data from m * n to n * m
tranpose_weight(weight_data_tmp.data(), weight_data, m, n); tranpose_weight(weight_data_tmp.data(), weight_data, m, n);
TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data),
static_cast<size_t>(weight_t->numel())};
weight.dims.assign({n, m});
TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT,
static_cast<void*>(bias_data),
static_cast<size_t>(bias_t->numel())};
auto* fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *input, n,
weight.get(), bias.get());
auto* fc_out = fc_layer->getOutput(0);
// add qkv to context
int head_number = BOOST_GET_CONST(int, op_desc.GetAttr("head_number")); int head_number = BOOST_GET_CONST(int, op_desc.GetAttr("head_number"));
int head_size = all_head_size / head_number;
float scale = BOOST_GET_CONST(float, op_desc.GetAttr("alpha"));
std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.push_back(fc_out);
plugin_inputs.push_back(input_bias_qk);
nvinfer1::ILayer* layer = nullptr; nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) { if (engine_->with_dynamic_shape()) {
bool ban_fp16 = engine_->disable_trt_plugin_fp16(); if (engine_->use_oss()) {
plugin::DynamicPluginTensorRT* plugin = int head_size = hidden / head_number;
new plugin::QkvToContextPluginDynamic(hidden, head_number, head_size, // [3, Nout, Hout, Nin, Hin] -> [Nout, 3, Hout, Nin, Hin]
scale, ban_fp16); auto transpose_weight_v2 = [](const float* src, float* dst, int N,
layer = engine_->AddPluginV2(plugin_inputs.data(), 2, plugin); int H) {
const int HNH = H * N * H;
for (int i = 0; i < 3; ++i) {
for (int n = 0; n < N; ++n) {
for (int hnh = 0; hnh < HNH; ++hnh) {
dst[n * 3 * HNH + i * HNH + hnh] =
src[i * N * HNH + n * HNH + hnh];
}
}
}
};
// [3, N, H] -> [N, 3, H]
auto transpose_bias_v2 = [](const float* src, float* dst, int N, int H) {
for (int i = 0; i < 3; ++i) {
for (int n = 0; n < N; ++n) {
for (int h = 0; h < H; ++h) {
dst[n * 3 * H + i * H + h] = src[i * N * H + n * H + h];
}
}
}
};
memcpy(weight_data_tmp.data(), weight_data,
weight_t->numel() * sizeof(float));
transpose_weight_v2(weight_data_tmp.data(), weight_data, head_number,
head_size);
nvinfer1::Weights weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data),
static_cast<int32_t>(weight_t->numel())};
std::vector<float> bias_data_tmp;
bias_data_tmp.reserve(bias_t->numel());
memcpy(bias_data_tmp.data(), bias_data, bias_t->numel() * sizeof(float));
transpose_bias_v2(bias_data_tmp.data(), bias_data, head_number,
head_size);
nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT,
static_cast<void*>(bias_data),
static_cast<int32_t>(bias_t->numel())};
auto* fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *input, n,
weight, bias);
auto mask_tensor = engine_->GetITensor("qkv_plugin_mask");
auto creator = GetPluginRegistry()->getPluginCreator(
"CustomQKVToContextPluginDynamic", "2");
assert(creator != nullptr);
int type = static_cast<int>((engine_->WithFp16() == 1)
? nvinfer1::DataType::kHALF
: nvinfer1::DataType::kFLOAT);
bool has_mask = true;
int var_seqlen = 1;
const std::vector<nvinfer1::PluginField> fields{
{"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1},
{"hidden_size", &hidden, nvinfer1::PluginFieldType::kINT32, 1},
{"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, 1},
{"has_mask", &has_mask, nvinfer1::PluginFieldType::kINT32, 1},
{"var_seqlen", &var_seqlen, nvinfer1::PluginFieldType::kINT32, 1},
};
nvinfer1::PluginFieldCollection* plugin_collection =
static_cast<nvinfer1::PluginFieldCollection*>(
malloc(sizeof(*plugin_collection) +
fields.size() *
sizeof(nvinfer1::PluginField))); // remember to free
plugin_collection->nbFields = static_cast<int>(fields.size());
plugin_collection->fields = fields.data();
auto plugin = creator->createPlugin("CustomQKVToContextPluginDynamic",
plugin_collection);
free(plugin_collection);
std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.emplace_back(fc_layer->getOutput(0));
plugin_inputs.emplace_back(mask_tensor);
plugin_inputs.emplace_back(engine_->GetITensor(
engine_->network()->getInput(2)->getName())); // cu_seqlens,
// eval_placeholder_2
auto max_seqlen_tensor = engine_->GetITensor(
engine_->network()->getInput(3)->getName());
auto* shuffle_layer = TRT_ENGINE_ADD_LAYER(
engine_, Shuffle, *const_cast<nvinfer1::ITensor*>(max_seqlen_tensor));
nvinfer1::Dims shape_dim;
shape_dim.nbDims = 1;
shape_dim.d[0] = -1;
shuffle_layer->setReshapeDimensions(shape_dim);
plugin_inputs.emplace_back(shuffle_layer->getOutput(0)); // max_seqlen, eval_placeholder_3
auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin);
layer = plugin_layer;
} else {
// transpose weight_data from m * n to n * m
auto* input_bias_qk =
engine_->GetITensor(op_desc.Input("BiasQK").front());
TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data),
static_cast<size_t>(weight_t->numel())};
weight.dims.assign({n, m});
TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT,
static_cast<void*>(bias_data),
static_cast<size_t>(bias_t->numel())};
auto* fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *input, n,
weight.get(), bias.get());
auto* fc_out = fc_layer->getOutput(0);
// add qkv to context
int head_size = all_head_size / head_number;
float scale = BOOST_GET_CONST(float, op_desc.GetAttr("alpha"));
std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.push_back(fc_out);
plugin_inputs.push_back(input_bias_qk);
bool ban_fp16 = engine_->disable_trt_plugin_fp16();
plugin::DynamicPluginTensorRT* plugin =
new plugin::QkvToContextPluginDynamic(hidden, head_number, head_size,
scale, ban_fp16);
layer = engine_->AddPluginV2(plugin_inputs.data(), 2, plugin);
}
} else { } else {
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"You are running the Ernie(Bert) model in static shape mode, which " "You are running the Ernie(Bert) model in static shape mode, which "
......
...@@ -47,17 +47,50 @@ class SkipLayerNormOpConverter : public OpConverter { ...@@ -47,17 +47,50 @@ class SkipLayerNormOpConverter : public OpConverter {
framework::DDim bias_dims, scale_dims; framework::DDim bias_dims, scale_dims;
auto* bias = get_persistable_data("Bias", &bias_dims); auto* bias = get_persistable_data("Bias", &bias_dims);
auto* scale = get_persistable_data("Scale", &scale_dims); auto* scale = get_persistable_data("Scale", &scale_dims);
float eps = BOOST_GET_CONST(float, op_desc.GetAttr("epsilon"));
int bias_size = framework::product(bias_dims); int bias_size = framework::product(bias_dims);
int scale_size = framework::product(scale_dims); int scale_size = framework::product(scale_dims);
nvinfer1::ILayer* layer = nullptr; nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) { if (engine_->with_dynamic_shape()) {
bool ban_fp16 = engine_->disable_trt_plugin_fp16(); if (engine_->use_oss()) {
plugin::SkipLayerNormPluginDynamic* plugin = auto creator = GetPluginRegistry()->getPluginCreator(
new plugin::SkipLayerNormPluginDynamic(bias, scale, bias_size, "CustomSkipLayerNormPluginDynamic", "2");
scale_size, eps, ban_fp16); assert(creator != nullptr);
layer = engine_->AddPluginV2(inputs.data(), 2, plugin); int type = static_cast<int>((engine_->WithFp16() == 1)
? nvinfer1::DataType::kHALF
: nvinfer1::DataType::kFLOAT);
int ld = input1->getDimensions().d[2]; // hidden dimension
assert(ld > 0);
const std::vector<nvinfer1::PluginField> fields{
{"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1},
{"ld", &ld, nvinfer1::PluginFieldType::kINT32, 1},
{"beta", bias, nvinfer1::PluginFieldType::kFLOAT32, bias_size},
{"gamma", scale, nvinfer1::PluginFieldType::kFLOAT32, scale_size},
};
nvinfer1::PluginFieldCollection* pluginPtr =
static_cast<nvinfer1::PluginFieldCollection*>(
malloc(sizeof(*pluginPtr) +
fields.size() *
sizeof(nvinfer1::PluginField))); // remember to free
pluginPtr->nbFields = static_cast<int>(fields.size());
pluginPtr->fields = fields.data();
auto pluginObj =
creator->createPlugin("CustomSkipLayerNormPluginDynamic", pluginPtr);
auto plugin_layer = engine_->network()->addPluginV2(
inputs.data(), inputs.size(), *pluginObj);
assert(plugin_layer != nullptr);
layer = plugin_layer;
} else {
float eps = BOOST_GET_CONST(float, op_desc.GetAttr("epsilon"));
bool ban_fp16 = engine_->disable_trt_plugin_fp16();
plugin::SkipLayerNormPluginDynamic* plugin =
new plugin::SkipLayerNormPluginDynamic(bias, scale, bias_size,
scale_size, eps, ban_fp16);
layer = engine_->AddPluginV2(inputs.data(), 2, plugin);
}
} else { } else {
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"You are running the Ernie(Bert) model in static" "You are running the Ernie(Bert) model in static"
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -77,16 +78,31 @@ class SliceOpConverter : public OpConverter { ...@@ -77,16 +78,31 @@ class SliceOpConverter : public OpConverter {
nvinfer1::ILayer* layer = nullptr; nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) { if (engine_->with_dynamic_shape()) {
if (engine_->use_oss() && engine_->with_ernie()) {
std::vector<nvinfer1::ITensor*> plugin_inputs;
// plugin_inputs.emplace_back(trans_layer->getOutput(0));
plugin_inputs.emplace_back(input);
plugin_inputs.emplace_back(engine_->GetITensor(
engine_->network()->getInput(2)->getName())); // cu_seqlens,
// eval_placeholder_2
// bool ban_fp16 = engine_->disable_trt_plugin_fp16();
plugin::SpecialSlicePluginDynamic* plugin =
new plugin::SpecialSlicePluginDynamic();
layer = engine_->AddPluginV2(plugin_inputs.data(), plugin_inputs.size(),
plugin);
} else {
#if IS_TRT_VERSION_GE(6000) #if IS_TRT_VERSION_GE(6000)
bool ban_fp16 = engine_->disable_trt_plugin_fp16(); bool ban_fp16 = engine_->disable_trt_plugin_fp16();
plugin::SlicePluginDynamic* plugin = plugin::SlicePluginDynamic* plugin =
new plugin::SlicePluginDynamic(starts, ends, axes, ban_fp16); new plugin::SlicePluginDynamic(starts, ends, axes, ban_fp16);
layer = engine_->AddPluginV2(&input, 1, plugin); layer = engine_->AddPluginV2(&input, 1, plugin);
#else #else
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"You are running the TRT Dynamic Shape mode, need to confirm that " "You are running the TRT Dynamic Shape mode, need to confirm that "
"your TRT version is no less than 6.0")); "your TRT version is no less than 6.0"));
#endif #endif
}
} else { } else {
bool ban_fp16 = engine_->disable_trt_plugin_fp16(); bool ban_fp16 = engine_->disable_trt_plugin_fp16();
plugin::SlicePlugin* plugin = plugin::SlicePlugin* plugin =
......
...@@ -71,9 +71,9 @@ TRT_DT FluidDataType2TRT(FluidDT type) { ...@@ -71,9 +71,9 @@ TRT_DT FluidDataType2TRT(FluidDT type) {
template <typename T> template <typename T>
nvinfer1::Dims Vec2TRT_Dims(const std::vector<T>& shape, std::string input, nvinfer1::Dims Vec2TRT_Dims(const std::vector<T>& shape, std::string input,
bool with_dynamic_shape = false) { bool with_dynamic_shape = false) {
PADDLE_ENFORCE_GT(shape.size(), 1UL, PADDLE_ENFORCE_GT(shape.size(), 0UL,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"TensorRT's tensor input requires at least 2 " "TensorRT's tensor input requires at least 1 "
"dimensions, but input %s has %d dims.", "dimensions, but input %s has %d dims.",
input, shape.size())); input, shape.size()));
PADDLE_ENFORCE_LE(shape.size(), 4UL, PADDLE_ENFORCE_LE(shape.size(), 4UL,
...@@ -174,6 +174,7 @@ class TensorRTEngine { ...@@ -174,6 +174,7 @@ class TensorRTEngine {
"version should be at least 6."; "version should be at least 6.";
#endif #endif
} }
dy::initLibNvInferPlugins(&logger, "");
} }
~TensorRTEngine() {} ~TensorRTEngine() {}
...@@ -285,6 +286,9 @@ class TensorRTEngine { ...@@ -285,6 +286,9 @@ class TensorRTEngine {
suffix_counter += 1; suffix_counter += 1;
} }
void SetUseOSS(bool use_oss) { use_oss_ = use_oss; }
void SetWithErnie(bool with_ernie) { with_ernie_ = with_ernie; }
void ClearWeights() { void ClearWeights() {
for (auto& weight_pair : weight_map) { for (auto& weight_pair : weight_map) {
weight_pair.second.reset(nullptr); weight_pair.second.reset(nullptr);
...@@ -312,6 +316,8 @@ class TensorRTEngine { ...@@ -312,6 +316,8 @@ class TensorRTEngine {
ShapeMapType min_input_shape() { return min_input_shape_; } ShapeMapType min_input_shape() { return min_input_shape_; }
ShapeMapType max_input_shape() { return max_input_shape_; } ShapeMapType max_input_shape() { return max_input_shape_; }
ShapeMapType optim_input_shape() { return optim_input_shape_; } ShapeMapType optim_input_shape() { return optim_input_shape_; }
bool use_oss() { return use_oss_; };
bool with_ernie() { return with_ernie_; };
bool disable_trt_plugin_fp16() { return disable_trt_plugin_fp16_; } bool disable_trt_plugin_fp16() { return disable_trt_plugin_fp16_; }
bool with_dynamic_shape() { return with_dynamic_shape_; } bool with_dynamic_shape() { return with_dynamic_shape_; }
...@@ -347,6 +353,8 @@ class TensorRTEngine { ...@@ -347,6 +353,8 @@ class TensorRTEngine {
ShapeMapType max_input_shape_; ShapeMapType max_input_shape_;
ShapeMapType optim_input_shape_; ShapeMapType optim_input_shape_;
bool disable_trt_plugin_fp16_{false}; bool disable_trt_plugin_fp16_{false};
bool use_oss_{false};
bool with_ernie_{false};
nvinfer1::ILogger& logger_; nvinfer1::ILogger& logger_;
// max data size for the buffers. // max data size for the buffers.
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/inference/tensorrt/op_teller.h" #include "paddle/fluid/inference/tensorrt/op_teller.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/var_desc.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -70,6 +72,7 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -70,6 +72,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"hard_swish"}; "hard_swish"};
std::unordered_set<std::string> teller_set{ std::unordered_set<std::string> teller_set{
"mul", "mul",
"matmul",
"conv2d", "conv2d",
"pool2d", "pool2d",
"relu", "relu",
...@@ -122,6 +125,20 @@ bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc, ...@@ -122,6 +125,20 @@ bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc,
(padding_algorithm == "SAME" && op_type != "pool2d")) (padding_algorithm == "SAME" && op_type != "pool2d"))
return false; return false;
} }
if (op_type == "matmul") {
auto* block = desc.Block();
for (auto& param_name : desc.Inputs()) {
for (auto& var_name : param_name.second) {
auto* var_desc = block->FindVar(var_name);
const auto shape = var_desc->GetShape();
if (shape.size() < 3) {
VLOG(1) << "matmul op dims < 3 not supported in tensorrt, but got dims "
<< shape.size() << ", so jump it.";
return false;
}
}
}
}
if ((*teller)(op_type, desc, use_no_calib_int8)) return true; if ((*teller)(op_type, desc, use_no_calib_int8)) return true;
} }
return false; return false;
......
...@@ -4,5 +4,5 @@ nv_library(tensorrt_plugin ...@@ -4,5 +4,5 @@ nv_library(tensorrt_plugin
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu
hard_swish_op_plugin.cu stack_op_plugin.cu hard_swish_op_plugin.cu stack_op_plugin.cu special_slice_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor) DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cassert>
#include <cstring>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
#if IS_TRT_VERSION_GE(6000)
SpecialSlicePluginDynamic::SpecialSlicePluginDynamic() {}
SpecialSlicePluginDynamic::SpecialSlicePluginDynamic(void const* serial_data,
size_t serial_length) {}
SpecialSlicePluginDynamic::~SpecialSlicePluginDynamic() {}
nvinfer1::IPluginV2DynamicExt* SpecialSlicePluginDynamic::clone() const {
return new SpecialSlicePluginDynamic();
}
const char* SpecialSlicePluginDynamic::getPluginType() const {
return "special_slice_plugin";
}
int SpecialSlicePluginDynamic::getNbOutputs() const { return 1; }
int SpecialSlicePluginDynamic::initialize() { return 0; }
size_t SpecialSlicePluginDynamic::getSerializationSize() const {
size_t serialize_size = 0;
return serialize_size;
}
void SpecialSlicePluginDynamic::serialize(void* buffer) const {}
nvinfer1::DimsExprs SpecialSlicePluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) {
nvinfer1::DimsExprs output(inputs[0]);
auto one = expr_builder.constant(1);
output.d[0] = expr_builder.operation(nvinfer1::DimensionOperation::kSUB,
*inputs[1].d[0], *one);
return output;
}
void SpecialSlicePluginDynamic::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) {}
size_t SpecialSlicePluginDynamic::getWorkspaceSize(
const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const {
return 0;
}
void SpecialSlicePluginDynamic::destroy() { delete this; }
void SpecialSlicePluginDynamic::terminate() {}
bool SpecialSlicePluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc* desc, int nb_inputs,
int nb_outputs) {
if (pos == 0) // slice tensor
return (desc[pos].type == nvinfer1::DataType::kHALF &&
desc[pos].format ==
nvinfer1::TensorFormat::kLINEAR); // || desc[pos].type ==
// nvinfer1::DataType::kFLOAT);
if (pos == 1) // cu_seqlen
return (desc[pos].type == nvinfer1::DataType::kINT32 &&
desc[pos].format == nvinfer1::TensorFormat::kLINEAR);
return (desc[pos].type == nvinfer1::DataType::kHALF &&
desc[pos].format ==
nvinfer1::TensorFormat::kLINEAR); // || desc[pos].type ==
// nvinfer1::DataType::kFLOAT);
}
nvinfer1::DataType SpecialSlicePluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType* input_types, int nb_inputs) const {
PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument(
"The index should be equal to 0"));
return input_types[0];
}
template <typename T>
__global__ void SpecialSliceKernel(const T* slice_input,
const int32_t* cu_seqlens, T* output) {
const int hidden = blockDim.x;
const int batch = blockIdx.x;
output[batch * hidden + threadIdx.x] =
slice_input[cu_seqlens[batch] * hidden + threadIdx.x];
}
int SpecialSlicePluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc, const void* const* inputs,
void* const* outputs, void* workspace, cudaStream_t stream) {
auto input_dims = input_desc[0].dims; // (sum(S), 768, 1, 1)
auto out_dims = output_desc[0].dims; // (batch, 768, 1, 1)
assert(input_desc[0].type == nvinfer1::DataType::kHALF);
const int32_t hidden = input_dims.d[1];
const int num_blocks = out_dims.d[0]; // batch size
const int num_threads = hidden;
const half* slice_input = static_cast<const half*>(inputs[0]);
const int32_t* cu_seqlens = static_cast<const int32_t*>(inputs[1]);
half* output = static_cast<half*>(outputs[0]);
SpecialSliceKernel<<<num_blocks, num_threads, 0, stream>>>(
slice_input, cu_seqlens, output);
return cudaGetLastError() != cudaSuccess;
}
SpecialSlicePluginDynamicCreator::SpecialSlicePluginDynamicCreator() {}
const char* SpecialSlicePluginDynamicCreator::getPluginName() const {
return "special_slice_plugin";
}
const char* SpecialSlicePluginDynamicCreator::getPluginVersion() const {
return "1";
}
const nvinfer1::PluginFieldCollection*
SpecialSlicePluginDynamicCreator::getFieldNames() {
return &field_collection_;
}
nvinfer1::IPluginV2* SpecialSlicePluginDynamicCreator::createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) {
return new SpecialSlicePluginDynamic();
}
nvinfer1::IPluginV2* SpecialSlicePluginDynamicCreator::deserializePlugin(
const char* name, const void* serial_data, size_t serial_length) {
auto plugin = new SpecialSlicePluginDynamic(serial_data, serial_length);
return plugin;
}
void SpecialSlicePluginDynamicCreator::setPluginNamespace(
const char* lib_namespace) {
plugin_namespace_ = lib_namespace;
}
const char* SpecialSlicePluginDynamicCreator::getPluginNamespace() const {
return plugin_namespace_.c_str();
}
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdio.h>
#include <cassert>
#include <string>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
#if IS_TRT_VERSION_GE(6000)
class SpecialSlicePluginDynamic : public DynamicPluginTensorRT {
public:
SpecialSlicePluginDynamic();
SpecialSlicePluginDynamic(void const* serial_data, size_t serial_length);
~SpecialSlicePluginDynamic();
nvinfer1::IPluginV2DynamicExt* clone() const override;
nvinfer1::DimsExprs getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs, int nbOutputs) override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const override;
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* inputTypes,
int nbInputs) const override;
const char* getPluginType() const override;
int getNbOutputs() const override;
int initialize() override;
void terminate() override;
size_t getSerializationSize() const override;
void serialize(void* buffer) const override;
void destroy() override;
private:
int axis_;
int num_stack_;
};
class SpecialSlicePluginDynamicCreator : public nvinfer1::IPluginCreator {
public:
SpecialSlicePluginDynamicCreator();
const char* getPluginName() const override;
const char* getPluginVersion() const override;
const nvinfer1::PluginFieldCollection* getFieldNames() override;
nvinfer1::IPluginV2* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) override;
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override;
void setPluginNamespace(const char* lib_namespace) override;
const char* getPluginNamespace() const override;
private:
std::string plugin_namespace_;
nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
std::vector<nvinfer1::PluginField> plugin_attributes_;
};
REGISTER_TRT_PLUGIN_V2(SpecialSlicePluginDynamicCreator);
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
...@@ -126,17 +126,17 @@ void trt_ernie(bool with_fp16, std::vector<float> result) { ...@@ -126,17 +126,17 @@ void trt_ernie(bool with_fp16, std::vector<float> result) {
{"read_file_0.tmp_0", min_shape}, {"read_file_0.tmp_0", min_shape},
{"read_file_0.tmp_1", min_shape}, {"read_file_0.tmp_1", min_shape},
{"read_file_0.tmp_2", min_shape}, {"read_file_0.tmp_2", min_shape},
{"matmul_0.tmp_0", {batch, min_seq_len, min_seq_len}}}; {"read_file_0.tmp_4", min_shape}};
std::map<std::string, std::vector<int>> max_input_shape = { std::map<std::string, std::vector<int>> max_input_shape = {
{"read_file_0.tmp_0", max_shape}, {"read_file_0.tmp_0", max_shape},
{"read_file_0.tmp_1", max_shape}, {"read_file_0.tmp_1", max_shape},
{"read_file_0.tmp_2", max_shape}, {"read_file_0.tmp_2", max_shape},
{"matmul_0.tmp_0", {batch, max_seq_len, max_seq_len}}}; {"read_file_0.tmp_4", max_shape}};
std::map<std::string, std::vector<int>> opt_input_shape = { std::map<std::string, std::vector<int>> opt_input_shape = {
{"read_file_0.tmp_0", opt_shape}, {"read_file_0.tmp_0", opt_shape},
{"read_file_0.tmp_1", opt_shape}, {"read_file_0.tmp_1", opt_shape},
{"read_file_0.tmp_2", opt_shape}, {"read_file_0.tmp_2", opt_shape},
{"matmul_0.tmp_0", {batch, opt_seq_len, opt_seq_len}}}; {"read_file_0.tmp_4", opt_shape}};
auto precision = AnalysisConfig::Precision::kFloat32; auto precision = AnalysisConfig::Precision::kFloat32;
if (with_fp16) { if (with_fp16) {
......
...@@ -86,16 +86,16 @@ void run(const AnalysisConfig& config, std::vector<float>* out_data) { ...@@ -86,16 +86,16 @@ void run(const AnalysisConfig& config, std::vector<float>* out_data) {
void trt_ernie(bool with_fp16, std::vector<float> result) { void trt_ernie(bool with_fp16, std::vector<float> result) {
AnalysisConfig config; AnalysisConfig config;
std::string model_dir = FLAGS_infer_model; std::string model_dir = FLAGS_infer_model;
SetConfig(&config, model_dir, true /* use_gpu */); SetConfig(&config, model_dir, true);
config.SwitchUseFeedFetchOps(false); config.SwitchUseFeedFetchOps(false);
int batch = 1; int batch = 32;
int min_seq_len = 1; int min_seq_len = 1;
int max_seq_len = 128; int max_seq_len = 128;
int opt_seq_len = 128; int opt_seq_len = 128;
std::vector<int> min_shape = {batch, min_seq_len, 1}; std::vector<int> min_shape = {1, min_seq_len, 1};
std::vector<int> max_shape = {batch, max_seq_len, 1}; std::vector<int> max_shape = {batch, max_seq_len, 1};
std::vector<int> opt_shape = {batch, opt_seq_len, 1}; std::vector<int> opt_shape = {batch, opt_seq_len, 1};
// Set the input's min, max, opt shape // Set the input's min, max, opt shape
...@@ -103,17 +103,17 @@ void trt_ernie(bool with_fp16, std::vector<float> result) { ...@@ -103,17 +103,17 @@ void trt_ernie(bool with_fp16, std::vector<float> result) {
{"read_file_0.tmp_0", min_shape}, {"read_file_0.tmp_0", min_shape},
{"read_file_0.tmp_1", min_shape}, {"read_file_0.tmp_1", min_shape},
{"read_file_0.tmp_2", min_shape}, {"read_file_0.tmp_2", min_shape},
{"matmul_0.tmp_0", {batch, min_seq_len, min_seq_len}}}; {"read_file_0.tmp_4", min_shape}};
std::map<std::string, std::vector<int>> max_input_shape = { std::map<std::string, std::vector<int>> max_input_shape = {
{"read_file_0.tmp_0", max_shape}, {"read_file_0.tmp_0", max_shape},
{"read_file_0.tmp_1", max_shape}, {"read_file_0.tmp_1", max_shape},
{"read_file_0.tmp_2", max_shape}, {"read_file_0.tmp_2", max_shape},
{"matmul_0.tmp_0", {batch, max_seq_len, max_seq_len}}}; {"read_file_0.tmp_4", max_shape}};
std::map<std::string, std::vector<int>> opt_input_shape = { std::map<std::string, std::vector<int>> opt_input_shape = {
{"read_file_0.tmp_0", opt_shape}, {"read_file_0.tmp_0", opt_shape},
{"read_file_0.tmp_1", opt_shape}, {"read_file_0.tmp_1", opt_shape},
{"read_file_0.tmp_2", opt_shape}, {"read_file_0.tmp_2", opt_shape},
{"matmul_0.tmp_0", {batch, opt_seq_len, opt_seq_len}}}; {"read_file_0.tmp_4", opt_shape}};
auto precision = AnalysisConfig::Precision::kFloat32; auto precision = AnalysisConfig::Precision::kFloat32;
if (with_fp16) { if (with_fp16) {
...@@ -124,6 +124,7 @@ void trt_ernie(bool with_fp16, std::vector<float> result) { ...@@ -124,6 +124,7 @@ void trt_ernie(bool with_fp16, std::vector<float> result) {
opt_input_shape); opt_input_shape);
std::vector<float> out_data; std::vector<float> out_data;
run(config, &out_data); run(config, &out_data);
for (size_t i = 0; i < out_data.size(); i++) { for (size_t i = 0; i < out_data.size(); i++) {
EXPECT_NEAR(result[i], out_data[i], 1e-5); EXPECT_NEAR(result[i], out_data[i], 1e-5);
} }
......
...@@ -278,9 +278,11 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -278,9 +278,11 @@ class TensorRTEngineOp : public framework::OperatorBase {
buffers[bind_index] = static_cast<void *>(t.data<float>()); buffers[bind_index] = static_cast<void *>(t.data<float>());
} else if (type == framework::proto::VarType::INT64) { } else if (type == framework::proto::VarType::INT64) {
buffers[bind_index] = static_cast<void *>(t.data<int64_t>()); buffers[bind_index] = static_cast<void *>(t.data<int64_t>());
} else if (type == framework::proto::VarType::INT32) {
buffers[bind_index] = static_cast<void *>(t.data<int32_t>());
} else { } else {
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"The TRT Engine OP only support float and int64_t input.")); "The TRT Engine OP only support float/int32_t/int64_t input."));
} }
} }
......
...@@ -22,19 +22,15 @@ namespace dynload { ...@@ -22,19 +22,15 @@ namespace dynload {
std::once_flag tensorrt_dso_flag; std::once_flag tensorrt_dso_flag;
void* tensorrt_dso_handle; void* tensorrt_dso_handle;
std::once_flag tensorrt_plugin_dso_flag;
void* tensorrt_plugin_dso_handle;
#define DEFINE_WRAP(__name) DynLoad__##__name __name #define DEFINE_WRAP(__name) DynLoad__##__name __name
TENSORRT_RAND_ROUTINE_EACH(DEFINE_WRAP); TENSORRT_RAND_ROUTINE_EACH(DEFINE_WRAP);
TENSORRT_PLUGIN_RAND_ROUTINE_EACH(DEFINE_WRAP);
void* GetTensorRtHandle() { void* GetDsoHandle(const std::string& dso_name) {
#if defined(__APPLE__) || defined(__OSX__)
std::string dso_name = "libnvinfer.dylib";
#elif defined(_WIN32)
std::string dso_name = "nvinfer.dll";
#else
std::string dso_name = "libnvinfer.so";
#endif
#if !defined(_WIN32) #if !defined(_WIN32)
int dynload_flags = RTLD_LAZY | RTLD_LOCAL; int dynload_flags = RTLD_LAZY | RTLD_LOCAL;
#else #else
...@@ -49,10 +45,31 @@ void* GetTensorRtHandle() { ...@@ -49,10 +45,31 @@ void* GetTensorRtHandle() {
"library is not found. Ignore this if TensorRT is not needed."; "library is not found. Ignore this if TensorRT is not needed.";
std::cerr << error_msg; std::cerr << error_msg;
} }
return dso_handle; return dso_handle;
} }
void* GetTensorRtHandle() {
#if defined(__APPLE__) || defined(__OSX__)
std::string dso_name = "libnvinfer.dylib";
#elif defined(_WIN32)
std::string dso_name = "nvinfer.dll";
#else
std::string dso_name = "libnvinfer.so";
#endif
return GetDsoHandle(dso_name);
}
void* GetTensorRtPluginHandle() {
#if defined(__APPLE__) || defined(__OSX__)
std::string dso_name = "libnvinfer_plugin.dylib";
#elif defined(_WIN32)
std::string dso_name = "nvinfer_plugin.dll";
#else
std::string dso_name = "libnvinfer_plugin.so";
#endif
return GetDsoHandle(dso_name);
}
} // namespace dynload } // namespace dynload
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <NvInfer.h> #include <NvInfer.h>
#include <NvInferPlugin.h>
#if !defined(_WIN32) #if !defined(_WIN32)
#include <dlfcn.h> #include <dlfcn.h>
#endif #endif
...@@ -32,6 +33,10 @@ void* GetTensorRtHandle(); ...@@ -32,6 +33,10 @@ void* GetTensorRtHandle();
extern std::once_flag tensorrt_dso_flag; extern std::once_flag tensorrt_dso_flag;
extern void* tensorrt_dso_handle; extern void* tensorrt_dso_handle;
void* GetTensorRtPluginHandle();
extern std::once_flag tensorrt_plugin_dso_flag;
extern void* tensorrt_plugin_dso_handle;
#define DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP(__name) \ #define DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP(__name) \
struct DynLoad__##__name { \ struct DynLoad__##__name { \
template <typename... Args> \ template <typename... Args> \
...@@ -50,7 +55,26 @@ extern void* tensorrt_dso_handle; ...@@ -50,7 +55,26 @@ extern void* tensorrt_dso_handle;
}; \ }; \
extern DynLoad__##__name __name extern DynLoad__##__name __name
#define DECLARE_DYNAMIC_LOAD_TENSORRT_PLUGIN_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
std::call_once(tensorrt_plugin_dso_flag, []() { \
tensorrt_plugin_dso_handle = \
paddle::platform::dynload::GetTensorRtPluginHandle(); \
}); \
static void* p_##__name = dlsym(tensorrt_plugin_dso_handle, #__name); \
PADDLE_ENFORCE_NOT_NULL(p_##__name, \
platform::errors::Unavailable( \
"Load tensorrt plugin %s failed", #__name)); \
using tensorrt_plugin_func = decltype(&::__name); \
return reinterpret_cast<tensorrt_plugin_func>(p_##__name)(args...); \
} \
}; \
extern DynLoad__##__name __name
#ifdef NV_TENSORRT_MAJOR #ifdef NV_TENSORRT_MAJOR
#if (NV_TENSORRT_MAJOR >= 6) #if (NV_TENSORRT_MAJOR >= 6)
#define TENSORRT_RAND_ROUTINE_EACH(__macro) \ #define TENSORRT_RAND_ROUTINE_EACH(__macro) \
__macro(createInferBuilder_INTERNAL); \ __macro(createInferBuilder_INTERNAL); \
...@@ -62,8 +86,13 @@ extern void* tensorrt_dso_handle; ...@@ -62,8 +86,13 @@ extern void* tensorrt_dso_handle;
__macro(createInferRuntime_INTERNAL); __macro(createInferRuntime_INTERNAL);
#endif #endif
#define TENSORRT_PLUGIN_RAND_ROUTINE_EACH(__macro) \
__macro(initLibNvInferPlugins);
TENSORRT_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP) TENSORRT_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP)
#endif TENSORRT_PLUGIN_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_TENSORRT_PLUGIN_WRAP)
#endif // end of NV_TENSORRT_MAJOR
} // namespace dynload } // namespace dynload
} // namespace platform } // namespace platform
......
...@@ -487,6 +487,8 @@ void BindAnalysisConfig(py::module *m) { ...@@ -487,6 +487,8 @@ void BindAnalysisConfig(py::module *m) {
py::arg("optim_input_shape") = py::arg("optim_input_shape") =
std::map<std::string, std::vector<int>>({}), std::map<std::string, std::vector<int>>({}),
py::arg("disable_trt_plugin_fp16") = false) py::arg("disable_trt_plugin_fp16") = false)
.def("enable_tensorrt_oss", &AnalysisConfig::EnableTensorRtOSS)
.def("tensorrt_oss_enabled", &AnalysisConfig::tensorrt_oss_enabled)
.def("tensorrt_engine_enabled", &AnalysisConfig::tensorrt_engine_enabled) .def("tensorrt_engine_enabled", &AnalysisConfig::tensorrt_engine_enabled)
.def("enable_lite_engine", &AnalysisConfig::EnableLiteEngine, .def("enable_lite_engine", &AnalysisConfig::EnableLiteEngine,
py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32, py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32,
......
...@@ -20,6 +20,7 @@ import random ...@@ -20,6 +20,7 @@ import random
import unittest import unittest
import numpy as np import numpy as np
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.core import PaddleTensor from paddle.fluid.core import PaddleTensor
...@@ -34,6 +35,7 @@ from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass ...@@ -34,6 +35,7 @@ from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
class InferencePassTest(unittest.TestCase): class InferencePassTest(unittest.TestCase):
def __init__(self, methodName='runTest'): def __init__(self, methodName='runTest'):
paddle.enable_static()
super(InferencePassTest, self).__init__(methodName) super(InferencePassTest, self).__init__(methodName)
self.main_program = fluid.Program() self.main_program = fluid.Program()
self.startup_program = fluid.Program() self.startup_program = fluid.Program()
...@@ -211,6 +213,7 @@ class InferencePassTest(unittest.TestCase): ...@@ -211,6 +213,7 @@ class InferencePassTest(unittest.TestCase):
if flatten: if flatten:
out = out.flatten() out = out.flatten()
analysis_output = analysis_output.flatten() analysis_output = analysis_output.flatten()
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(
out, analysis_output, atol=atol), out, analysis_output, atol=atol),
...@@ -232,6 +235,7 @@ class InferencePassTest(unittest.TestCase): ...@@ -232,6 +235,7 @@ class InferencePassTest(unittest.TestCase):
if flatten: if flatten:
out = out.flatten() out = out.flatten()
tensorrt_output = tensorrt_output.flatten() tensorrt_output = tensorrt_output.flatten()
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(
out, tensorrt_output, atol=atol), out, tensorrt_output, atol=atol),
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from inference_pass_test import InferencePassTest
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.core import PassVersionChecker
from paddle.fluid.core import AnalysisConfig
class TensorRTMatMulDims2Test(InferencePassTest):
def setUp(self):
self.set_params()
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[24, 24], dtype="float32")
matmul_out = fluid.layers.matmul(
x=data,
y=data,
transpose_x = self.transpose_x,
transpose_y = self.transpose_y,
alpha = self.alpha)
out = fluid.layers.batch_norm(matmul_out, is_test=True)
self.feeds = {
"data": np.ones([24, 24]).astype("float32"),
}
self.enable_trt = True
self.trt_parameters = TensorRTMatMulDims2Test.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False)
self.fetch_list = [out]
def set_params(self):
self.transpose_x = True
self.transpose_y = True
self.alpha = 2.0
def test_check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
class TensorRTMatMulTest(InferencePassTest):
def setUp(self):
self.set_params()
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 6, 24, 24], dtype="float32")
matmul_out = fluid.layers.matmul(
x=data,
y=data,
transpose_x = self.transpose_x,
transpose_y = self.transpose_y,
alpha = self.alpha)
out = fluid.layers.batch_norm(matmul_out, is_test=True)
self.feeds = {
"data": np.ones([1, 6, 24, 24]).astype("float32"),
}
self.enable_trt = True
self.trt_parameters = TensorRTMatMulTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False)
self.fetch_list = [out]
def set_params(self):
self.transpose_x = False
self.transpose_y = False
self.alpha = 1.0
def test_check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
class TensorRTMatMulTransposeXTest(TensorRTMatMulTest):
def set_params(self):
self.transpose_x = True
self.transpose_y = False
self.alpha = 1.0
class TensorRTMatMulTransposeYTest(TensorRTMatMulTest):
def set_params(self):
self.transpose_x = False
self.transpose_y = True
self.alpha = 1.0
class TensorRTMatMulScaleTest(TensorRTMatMulTest):
def set_params(self):
self.transpose_x = False
self.transpose_y = False
self.alpha = 2.0
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册