From ad6e3dd69cd915dd61287e96de7ec4ae132d24a5 Mon Sep 17 00:00:00 2001 From: zlsh80826 Date: Tue, 1 Sep 2020 20:41:56 +0800 Subject: [PATCH] [Paddle-TRT] Stack op plugin (#25605) * add stack_op to CMakeLists * add dim=3 support for scale op * add trt stack op, test=develop * remove debug message * add stack plugin serialize * remove slice, scale op, will add later * enhence error message * revise trt ernie test to conver the stack op CI testi, test=develop * add stack op serialization * fix test shape after adding stack op * remove slice op, will add after implementing serialization * roll back to min_graph=5 to avoid using slice op * fix scale op output layer * implement stack op createPlugin * use workspace and move the defination to .cu * move stack plugin creator definition to .cu, test=develop --- .../fluid/inference/api/analysis_predictor.cc | 1 + .../inference/tensorrt/convert/CMakeLists.txt | 4 +- .../inference/tensorrt/convert/scale_op.cc | 30 +++ .../inference/tensorrt/convert/stack_op.cc | 75 ++++++ paddle/fluid/inference/tensorrt/op_teller.cc | 1 + .../inference/tensorrt/plugin/CMakeLists.txt | 9 +- .../tensorrt/plugin/stack_op_plugin.cu | 247 ++++++++++++++++++ .../tensorrt/plugin/stack_op_plugin.h | 96 +++++++ ...rt_dynamic_shape_ernie_deserialize_test.cc | 7 +- .../tests/api/trt_dynamic_shape_ernie_test.cc | 7 +- 10 files changed, 463 insertions(+), 14 deletions(-) create mode 100644 paddle/fluid/inference/tensorrt/convert/stack_op.cc create mode 100644 paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.cu create mode 100644 paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 127a41aee89..500aa8341d6 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1058,6 +1058,7 @@ USE_TRT_CONVERTER(fused_embedding_eltwise_layernorm); USE_TRT_CONVERTER(skip_layernorm); USE_TRT_CONVERTER(slice); USE_TRT_CONVERTER(scale); +USE_TRT_CONVERTER(stack); #endif namespace paddle_infer { diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 8b7371490c0..39d02909abd 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -3,8 +3,8 @@ nv_library(tensorrt_converter SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc - shuffle_channel_op.cc swish_op.cc instance_norm_op.cc -emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc + shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc + emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) nv_test(test_op_converter SRCS test_op_converter.cc DEPS diff --git a/paddle/fluid/inference/tensorrt/convert/scale_op.cc b/paddle/fluid/inference/tensorrt/convert/scale_op.cc index 19e1895635a..f9a1fe41ddc 100644 --- a/paddle/fluid/inference/tensorrt/convert/scale_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/scale_op.cc @@ -58,6 +58,24 @@ class ScaleOpConverter : public OpConverter { TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr, 0}; nvinfer1::ILayer* layer = nullptr; + + auto input_dim = input->getDimensions(); + PADDLE_ENFORCE_GE(input_dim.nbDims, 3, + platform::errors::Fatal( + "Paddle-TRT scale mode only support dimension >= 3")); + + nvinfer1::IShuffleLayer* expand_layer = nullptr; + nvinfer1::IShuffleLayer* squeeze_layer = nullptr; + + if (input_dim.nbDims == 3) { + // TensorRT scale layer is not supporting input dims < 4 when using + // explicit batch + expand_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); + nvinfer1::Dims4 target_shape(0, 0, 0, 1); // expand 1 dims + expand_layer->setReshapeDimensions(target_shape); + input = expand_layer->getOutput(0); + } + if (bias_after_scale) { layer = TRT_ENGINE_ADD_LAYER( engine_, Scale, *input, nvinfer1::ScaleMode::kUNIFORM, @@ -73,6 +91,18 @@ class ScaleOpConverter : public OpConverter { power_weights.get(), scale_weights.get(), power_weights.get()); } + PADDLE_ENFORCE_EQ(layer != nullptr, true, + platform::errors::Fatal("Create scale layer failed.")); + + if (input_dim.nbDims == 3) { + // TensorRT scale layer is not supporting input dims < 4 when using + // explicit batch + squeeze_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *(layer->getOutput(0))); + nvinfer1::Dims3 target_shape(0, 0, 0); // expand 1 dims + squeeze_layer->setReshapeDimensions(target_shape); + layer = static_cast(squeeze_layer); + } RreplenishLayerAndOutput(layer, "scale", {out_name}, test_mode); } }; diff --git a/paddle/fluid/inference/tensorrt/convert/stack_op.cc b/paddle/fluid/inference/tensorrt/convert/stack_op.cc new file mode 100644 index 00000000000..f35024529c6 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/stack_op.cc @@ -0,0 +1,75 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * Stack converter from fluid to tensorRT. + */ +class StackOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(4) << "convert fluid stack op to tensorrt stack layer"; + + framework::OpDesc op_desc(op, nullptr); + auto input = op_desc.Input("X"); + int input_num = input.size(); + nvinfer1::ITensor** inputs = + (nvinfer1::ITensor**)malloc(input_num * sizeof(nvinfer1::ITensor*)); + + for (int i = 0; i < input_num; ++i) { + inputs[i] = engine_->GetITensor(input[i]); + } + + int axis = BOOST_GET_CONST(int, op_desc.GetAttr("axis")); + if (axis < 0) { + axis = axis + inputs[0]->getDimensions().nbDims + 1; + } + + nvinfer1::ILayer* layer = nullptr; + if (engine_->with_dynamic_shape()) { +#if IS_TRT_VERSION_GE(6000) + plugin::StackPluginDynamic* plugin = + new plugin::StackPluginDynamic(axis, input_num); + layer = engine_->AddPluginV2(inputs, input_num, plugin); + assert(layer != nullptr); +#else + PADDLE_THROW(platform::errors::Fatal( + "You are running the TRT Dynamic Shape mode, need to confirm that " + "your TRT version is no less than 6.0")); +#endif + } else { + PADDLE_THROW(platform::errors::Fatal( + "You are running the Ernie(Bert) model in static" + "shape mode, which is not supported for the time being.\n" + "You can use the config.SetTRTDynamicShapeInfo(...) interface" + " to set the shape information to run the dynamic shape mode.")); + } + auto output_name = op_desc.Output("Y").front(); + RreplenishLayerAndOutput(layer, "stack", {output_name}, test_mode); + free(inputs); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(stack, StackOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index e8cbb9431cb..a5b71356d0e 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -88,6 +88,7 @@ struct SimpleOpTypeSetTeller : public Teller { "gelu", "layer_norm", "scale", + "stack", }; }; diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index e417fcbb2ce..98afdbe254a 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -1,7 +1,8 @@ nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu - prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu + prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu -instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu -qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu hard_swish_op_plugin.cu - DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor) + instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu + qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu + hard_swish_op_plugin.cu stack_op_plugin.cu + DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor) diff --git a/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.cu new file mode 100644 index 00000000000..1ecbf4be154 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.cu @@ -0,0 +1,247 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +#if IS_TRT_VERSION_GE(6000) +StackPluginDynamic::StackPluginDynamic(int axis, int num_stack) + : axis_(axis), num_stack_(num_stack) {} + +StackPluginDynamic::StackPluginDynamic(void const* serial_data, + size_t serial_length) { + DeserializeValue(&serial_data, &serial_length, &axis_); + DeserializeValue(&serial_data, &serial_length, &num_stack_); +} + +StackPluginDynamic::~StackPluginDynamic() {} + +nvinfer1::IPluginV2DynamicExt* StackPluginDynamic::clone() const { + return new StackPluginDynamic(axis_, num_stack_); +} + +const char* StackPluginDynamic::getPluginType() const { return "stack_plugin"; } + +int StackPluginDynamic::getNbOutputs() const { return 1; } + +int StackPluginDynamic::initialize() { return 0; } + +size_t StackPluginDynamic::getSerializationSize() const { + size_t serialize_size = 0; + serialize_size += SerializedSize(axis_); + serialize_size += SerializedSize(num_stack_); + return serialize_size; +} + +void StackPluginDynamic::serialize(void* buffer) const { + SerializeValue(&buffer, axis_); + SerializeValue(&buffer, num_stack_); +} + +nvinfer1::DimsExprs StackPluginDynamic::getOutputDimensions( + int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs, + nvinfer1::IExprBuilder& expr_builder) { + nvinfer1::DimsExprs output(inputs[0]); + output.nbDims = inputs[0].nbDims + 1; + + for (int i = inputs[0].nbDims; i > axis_; --i) { + output.d[i] = inputs[0].d[i - 1]; + } + output.d[axis_] = expr_builder.constant(nb_inputs); + return output; +} + +void StackPluginDynamic::configurePlugin( + const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) {} + +size_t StackPluginDynamic::getWorkspaceSize( + const nvinfer1::PluginTensorDesc* inputs, int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const { + return num_stack_ * sizeof(uintptr_t); +} + +void StackPluginDynamic::destroy() { delete this; } + +void StackPluginDynamic::terminate() {} + +bool StackPluginDynamic::supportsFormatCombination( + int pos, const nvinfer1::PluginTensorDesc* in_out, int nb_inputs, + int nb_outputs) { + PADDLE_ENFORCE_NOT_NULL( + in_out, platform::errors::InvalidArgument( + "The input of stack plugin should not be nullptr.")); + + PADDLE_ENFORCE_LT( + pos, nb_inputs + nb_outputs, + platform::errors::InvalidArgument("The pos(%d) should be less than the " + "num(%d) of the input and the output.", + pos, nb_inputs + nb_outputs)); + + const nvinfer1::PluginTensorDesc& in = in_out[pos]; + if (pos == 0) { +#ifdef SUPPORTS_CUDA_FP16 + return (in.type == nvinfer1::DataType::kFLOAT || + in.type == nvinfer1::DataType::kHALF) && + (in.format == nvinfer1::TensorFormat::kLINEAR); +#else + return (in.type == nvinfer1::DataType::kFLOAT) && + (in.format == nvinfer1::TensorFormat::kLINEAR); +#endif + } + const nvinfer1::PluginTensorDesc& prev = in_out[pos - 1]; + // output + return in.type == prev.type && in.format == prev.format; +} + +nvinfer1::DataType StackPluginDynamic::getOutputDataType( + int index, const nvinfer1::DataType* input_types, int nb_inputs) const { + PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument( + "The index should be equal to 0")); + return input_types[0]; +} + +template +__global__ void StackKernel(const T* const* input, T* output, int num_stack, + int base_unit) { + int stack_id = blockIdx.x; + int lead_id = blockIdx.y; + + for (int i = threadIdx.x; i < base_unit; i += blockDim.x) { + output[lead_id * num_stack * base_unit + stack_id * base_unit + i] = + input[stack_id][lead_id * base_unit + i]; + } +} + +int StackPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* input_desc, + const nvinfer1::PluginTensorDesc* output_desc, + const void* const* inputs, void* const* outputs, + void* workspace, cudaStream_t stream) { + auto input_dims = input_desc[0].dims; // (batch, seq, seq) + auto out_dims = output_desc[0].dims; // (batch, num_head, seq, seq) + auto out_num_dims = out_dims.nbDims; + + int base_unit = 1; + for (int i = axis_ + 1; i < out_num_dims; ++i) { + PADDLE_ENFORCE_GT(out_dims.d[i], 0, + platform::errors::InvalidArgument( + "Input dimensions should be greater than 0")); + base_unit *= out_dims.d[i]; + } + + int lead_unit = 1; + for (int i = 0; i < axis_; ++i) { + PADDLE_ENFORCE_GT(out_dims.d[i], 0, + platform::errors::InvalidArgument( + "Input dimensions should be greater than 0")); + lead_unit *= out_dims.d[i]; + } + + PADDLE_ENFORCE_EQ( + out_dims.d[axis_], num_stack_, + platform::errors::InvalidArgument("number of stack axis should be same")); + + cudaMemcpyAsync(workspace, reinterpret_cast(inputs), + sizeof(void*) * out_dims.d[axis_], cudaMemcpyHostToDevice, + stream); + + const int num_stacks = out_dims.d[axis_]; + dim3 num_blocks(num_stacks, lead_unit); + const int num_threads = 256; + auto infer_type = input_desc[0].type; + + if (infer_type == nvinfer1::DataType::kFLOAT) { + float* output = static_cast(outputs[0]); + StackKernel<<>>( + reinterpret_cast(workspace), output, num_stacks, + base_unit); + } else if (infer_type == nvinfer1::DataType::kHALF) { +#ifdef SUPPORTS_CUDA_FP16 + __half* output = static_cast<__half*>(outputs[0]); + StackKernel<__half><<>>( + reinterpret_cast(workspace), output, num_stacks, + base_unit); +#else + PADDLE_THROW(platform::errors::Fatal( + "The cuda archs you specific should greater than 600.")); +#endif + } else { + PADDLE_THROW( + platform::errors::Fatal("The Stack TRT Plugin's input type only " + "support float or half currently.")); + } + return cudaGetLastError() != cudaSuccess; +} + +StackPluginDynamicCreator::StackPluginDynamicCreator() {} + +const char* StackPluginDynamicCreator::getPluginName() const { + return "stack_plugin"; +} + +const char* StackPluginDynamicCreator::getPluginVersion() const { return "1"; } + +const nvinfer1::PluginFieldCollection* +StackPluginDynamicCreator::getFieldNames() { + return &field_collection_; +} + +nvinfer1::IPluginV2* StackPluginDynamicCreator::createPlugin( + const char* name, const nvinfer1::PluginFieldCollection* fc) { + int axis = -1; + int num_stack = -1; + + for (int i = 0; i < fc->nbFields; ++i) { + const std::string name(fc->fields[i].name); + if (name == "axis") { + axis = static_cast(fc->fields[i].data)[0]; + } else if (name == "num_stack") { + num_stack = static_cast(fc->fields[i].data)[0]; + } else { + PADDLE_THROW(platform::errors::Fatal("Meet an unknown plugin field '" + + name + + "' when creating stack op plugin.")); + } + } + return new StackPluginDynamic(axis, num_stack); +} + +nvinfer1::IPluginV2* StackPluginDynamicCreator::deserializePlugin( + const char* name, const void* serial_data, size_t serial_length) { + auto plugin = new StackPluginDynamic(serial_data, serial_length); + return plugin; +} + +void StackPluginDynamicCreator::setPluginNamespace(const char* lib_namespace) { + plugin_namespace_ = lib_namespace; +} + +const char* StackPluginDynamicCreator::getPluginNamespace() const { + return plugin_namespace_.c_str(); +} + +#endif + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h new file mode 100644 index 00000000000..f4f6cde6f87 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h @@ -0,0 +1,96 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +#if IS_TRT_VERSION_GE(6000) +class StackPluginDynamic : public DynamicPluginTensorRT { + public: + explicit StackPluginDynamic(int axis, int num_stack); + StackPluginDynamic(void const* serial_data, size_t serial_length); + ~StackPluginDynamic(); + nvinfer1::IPluginV2DynamicExt* clone() const override; + nvinfer1::DimsExprs getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) override; + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, int nbOutputs) override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const override; + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) override; + + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const override; + + const char* getPluginType() const override; + int getNbOutputs() const override; + int initialize() override; + void terminate() override; + size_t getSerializationSize() const override; + void serialize(void* buffer) const override; + void destroy() override; + + private: + int axis_; + int num_stack_; +}; + +class StackPluginDynamicCreator : public nvinfer1::IPluginCreator { + public: + StackPluginDynamicCreator(); + const char* getPluginName() const override; + const char* getPluginVersion() const override; + const nvinfer1::PluginFieldCollection* getFieldNames() override; + nvinfer1::IPluginV2* createPlugin( + const char* name, const nvinfer1::PluginFieldCollection* fc) override; + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serial_data, + size_t serial_length) override; + void setPluginNamespace(const char* lib_namespace) override; + const char* getPluginNamespace() const override; + + private: + std::string plugin_namespace_; + nvinfer1::PluginFieldCollection field_collection_{0, nullptr}; + std::vector plugin_attributes_; +}; +REGISTER_TRT_PLUGIN_V2(StackPluginDynamicCreator); +#endif + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_deserialize_test.cc b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_deserialize_test.cc index 7e5dfa2424d..524e08891f4 100644 --- a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_deserialize_test.cc +++ b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_deserialize_test.cc @@ -90,7 +90,6 @@ void trt_ernie(bool with_fp16, std::vector result) { config.SwitchUseFeedFetchOps(false); - int head_number = 12; int batch = 1; int min_seq_len = 1; int max_seq_len = 128; @@ -104,17 +103,17 @@ void trt_ernie(bool with_fp16, std::vector result) { {"read_file_0.tmp_0", min_shape}, {"read_file_0.tmp_1", min_shape}, {"read_file_0.tmp_2", min_shape}, - {"stack_0.tmp_0", {batch, head_number, min_seq_len, min_seq_len}}}; + {"matmul_0.tmp_0", {batch, min_seq_len, min_seq_len}}}; std::map> max_input_shape = { {"read_file_0.tmp_0", max_shape}, {"read_file_0.tmp_1", max_shape}, {"read_file_0.tmp_2", max_shape}, - {"stack_0.tmp_0", {batch, head_number, max_seq_len, max_seq_len}}}; + {"matmul_0.tmp_0", {batch, max_seq_len, max_seq_len}}}; std::map> opt_input_shape = { {"read_file_0.tmp_0", opt_shape}, {"read_file_0.tmp_1", opt_shape}, {"read_file_0.tmp_2", opt_shape}, - {"stack_0.tmp_0", {batch, head_number, opt_seq_len, opt_seq_len}}}; + {"matmul_0.tmp_0", {batch, opt_seq_len, opt_seq_len}}}; auto precision = AnalysisConfig::Precision::kFloat32; if (with_fp16) { diff --git a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc index c99ebcdcb5f..17fedc3d3b8 100644 --- a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc +++ b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc @@ -90,7 +90,6 @@ void trt_ernie(bool with_fp16, std::vector result) { config.SwitchUseFeedFetchOps(false); - int head_number = 12; int batch = 1; int min_seq_len = 1; int max_seq_len = 128; @@ -104,17 +103,17 @@ void trt_ernie(bool with_fp16, std::vector result) { {"read_file_0.tmp_0", min_shape}, {"read_file_0.tmp_1", min_shape}, {"read_file_0.tmp_2", min_shape}, - {"stack_0.tmp_0", {batch, head_number, min_seq_len, min_seq_len}}}; + {"matmul_0.tmp_0", {batch, min_seq_len, min_seq_len}}}; std::map> max_input_shape = { {"read_file_0.tmp_0", max_shape}, {"read_file_0.tmp_1", max_shape}, {"read_file_0.tmp_2", max_shape}, - {"stack_0.tmp_0", {batch, head_number, max_seq_len, max_seq_len}}}; + {"matmul_0.tmp_0", {batch, max_seq_len, max_seq_len}}}; std::map> opt_input_shape = { {"read_file_0.tmp_0", opt_shape}, {"read_file_0.tmp_1", opt_shape}, {"read_file_0.tmp_2", opt_shape}, - {"stack_0.tmp_0", {batch, head_number, opt_seq_len, opt_seq_len}}}; + {"matmul_0.tmp_0", {batch, opt_seq_len, opt_seq_len}}}; auto precision = AnalysisConfig::Precision::kFloat32; if (with_fp16) { -- GitLab