未验证 提交 133f1fc1 编写于 作者: Z Zhaolong Xing 提交者: GitHub

[Eernie TRT]: add slice op and add emb eltwise layernorm fp16 support (#23723)

* refine ernie trt dynamic shape support
1. add slice op converter
2. add emb eltwise layernorm fp16 support
test=develop

* fix dynamic shape test ut
test=develop

* fix comments.
test=develop

* fix comments
test=develop
上级 2b896c1f
......@@ -976,5 +976,6 @@ USE_TRT_CONVERTER(gelu);
USE_TRT_CONVERTER(multihead_matmul);
USE_TRT_CONVERTER(fused_embedding_eltwise_layernorm);
USE_TRT_CONVERTER(skip_layernorm);
USE_TRT_CONVERTER(slice);
USE_TRT_CONVERTER(scale);
#endif
......@@ -4,7 +4,7 @@ nv_library(tensorrt_converter
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc
emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc hard_sigmoid_op.cc hard_swish_op.cc
emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS
......
......@@ -83,10 +83,23 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) {
plugin::EmbEltwiseLayernormPluginDynamic* plugin =
new plugin::EmbEltwiseLayernormPluginDynamic(input_embs, bias, scale,
emb_sizes, bias_size,
scale_size, hidden, eps);
auto use_fp16 = engine_->WithFp16();
plugin::DynamicPluginTensorRT* plugin = nullptr;
if (use_fp16) {
#ifdef SUPPORTS_CUDA_FP16
plugin = new plugin::EmbEltwiseLayernormPluginDynamic<half>(
input_embs, bias, scale, emb_sizes, bias_size, scale_size, hidden,
eps);
#else
PADDLE_THROW(
platform::errors::Fatal("use EmbEltwiseLayernormPluginDynamic "
"FP16, but GPU doesn't have FP16."));
#endif
} else {
plugin = new plugin::EmbEltwiseLayernormPluginDynamic<float>(
input_embs, bias, scale, emb_sizes, bias_size, scale_size, hidden,
eps);
}
layer = engine_->AddPluginV2(input_ids.data(), input_num, plugin);
} else {
PADDLE_THROW(platform::errors::Fatal(
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class SliceOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
// This OP is implemented by trt dynamic shpae plugin.
// Dynamic shape plugin requires TRT version greater than 6.0.
#if IS_TRT_VERSION_GE(6000)
VLOG(4) << "convert slice op to tensorrt layer";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
auto* input = engine_->GetITensor(op_desc.Input("Input")[0]);
std::vector<int> axes =
boost::get<std::vector<int>>(op_desc.GetAttr("axes"));
std::vector<int> starts =
boost::get<std::vector<int>>(op_desc.GetAttr("starts"));
std::vector<int> ends =
boost::get<std::vector<int>>(op_desc.GetAttr("ends"));
nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) {
bool ban_fp16 = engine_->disable_trt_plugin_fp16();
plugin::SlicePluginDynamic* plugin =
new plugin::SlicePluginDynamic(starts, ends, ends, ban_fp16);
layer = engine_->AddPluginV2(&input, 1, plugin);
} else {
PADDLE_THROW(platform::errors::Fatal(
"You are running the Ernie(Bert) model in static"
"shape mode, which is not supported for the time being.\n"
"You can use the config.SetTRTDynamicShapeInfo(...) interface"
" to set the shape information to run the dynamic shape mode."));
}
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "skip_layernorm", {output_name}, test_mode);
#else
PADDLE_THROW(platform::errors::Fatal(
"You are running the TRT Dynamic Shape mode, need to confirm that "
"your TRT version is no less than 6.0"));
#endif
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(slice, SliceOpConverter);
......@@ -29,6 +29,7 @@ struct SimpleOpTypeSetTeller : public Teller {
teller_set.insert("fused_embedding_eltwise_layernorm");
teller_set.insert("multihead_matmul");
teller_set.insert("skip_layernorm");
teller_set.insert("slice");
#endif
}
......
......@@ -3,5 +3,5 @@ nv_library(tensorrt_plugin
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu hard_swish_op_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu hard_swish_op_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor)
......@@ -28,14 +28,37 @@ namespace inference {
namespace tensorrt {
namespace plugin {
// Dynamic Plugin below.
// Dynamic shape plugin requires TRT version greater than 6.0.
#if IS_TRT_VERSION_GE(6000)
int EmbEltwiseLayernormPluginDynamic::initialize() {
embs_gpu_.reserve(embs_.size());
template <typename T>
int EmbEltwiseLayernormPluginDynamic<T>::initialize() {
int nb_emb = embs_.size();
std::vector<void *> ptr_vector(nb_emb);
std::vector<std::vector<half>> emb_fp16(nb_emb);
if (sizeof(T) == sizeof(float)) {
// FP32
for (int i = 0; i < nb_emb; ++i) {
ptr_vector[i] = embs_[i];
}
} else {
// FP16
for (int i = 0; i < nb_emb; ++i) {
auto emb_size = emb_sizes_[i];
auto &tmp = emb_fp16[i];
tmp.resize(emb_size);
for (int j = 0; j < emb_size; ++j) {
tmp[j] = static_cast<half>(embs_[i][j]);
}
ptr_vector[i] = tmp.data();
}
}
embs_gpu_.resize(embs_.size());
for (int i = 0; i < embs_.size(); i++) {
cudaMalloc(&embs_gpu_[i], sizeof(float) * emb_sizes_[i]);
cudaMemcpy(embs_gpu_[i], embs_[i], emb_sizes_[i] * sizeof(float),
cudaMalloc(&embs_gpu_[i], sizeof(T) * emb_sizes_[i]);
cudaMemcpy(embs_gpu_[i], ptr_vector[i], emb_sizes_[i] * sizeof(T),
cudaMemcpyHostToDevice);
}
......@@ -49,15 +72,18 @@ int EmbEltwiseLayernormPluginDynamic::initialize() {
return 0;
}
size_t EmbEltwiseLayernormPluginDynamic::getSerializationSize() const {
template <typename T>
size_t EmbEltwiseLayernormPluginDynamic<T>::getSerializationSize() const {
return 0;
}
void EmbEltwiseLayernormPluginDynamic::serialize(void *buffer) const {}
template <typename T>
void EmbEltwiseLayernormPluginDynamic<T>::serialize(void *buffer) const {}
nvinfer1::DimsExprs EmbEltwiseLayernormPluginDynamic::getOutputDimensions(
template <typename T>
nvinfer1::DimsExprs EmbEltwiseLayernormPluginDynamic<T>::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs,
nvinfer1::IExprBuilder &expr_builder) {
nvinfer1::IExprBuilder &expr_builder) { // NOLINT
PADDLE_ENFORCE_EQ(output_index, 0,
platform::errors::InvalidArgument(
"There is only one output of the EmbEltwiseLayernorm, "
......@@ -80,7 +106,8 @@ nvinfer1::DimsExprs EmbEltwiseLayernormPluginDynamic::getOutputDimensions(
return ret;
}
bool EmbEltwiseLayernormPluginDynamic::supportsFormatCombination(
template <typename T>
bool EmbEltwiseLayernormPluginDynamic<T>::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs,
int nb_outputs) {
PADDLE_ENFORCE_NOT_NULL(
......@@ -110,11 +137,16 @@ bool EmbEltwiseLayernormPluginDynamic::supportsFormatCombination(
}
if (pos == 3) {
return desc.type == nvinfer1::DataType::kFLOAT;
if (sizeof(T) == sizeof(float)) {
return desc.type == nvinfer1::DataType::kFLOAT;
} else {
return desc.type == nvinfer1::DataType::kHALF;
}
}
}
nvinfer1::DataType EmbEltwiseLayernormPluginDynamic::getOutputDataType(
template <typename T>
nvinfer1::DataType EmbEltwiseLayernormPluginDynamic<T>::getOutputDataType(
int index, const nvinfer1::DataType *input_types, int nb_inputs) const {
PADDLE_ENFORCE_EQ(
index, 0, platform::errors::InvalidArgument(
......@@ -124,7 +156,8 @@ nvinfer1::DataType EmbEltwiseLayernormPluginDynamic::getOutputDataType(
return nvinfer1::DataType::kFLOAT;
}
int EmbEltwiseLayernormPluginDynamic::enqueue(
template <typename T>
int EmbEltwiseLayernormPluginDynamic<T>::enqueue(
const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs,
void *const *outputs, void *workspace, cudaStream_t stream) {
......@@ -160,18 +193,36 @@ int EmbEltwiseLayernormPluginDynamic::enqueue(
const unsigned tpb = 256;
const dim3 grid(seq_len, batch, 1);
const dim3 block(tpb, 1, 1);
PADDLE_ENFORCE_EQ(
out_type == nvinfer1::DataType::kFLOAT, true,
platform::errors::InvalidArgument(
"The EmbEltwiseLayernorm Plugin only only support fp32 input."));
if (sizeof(T) == sizeof(float)) {
PADDLE_ENFORCE_EQ(
out_type == nvinfer1::DataType::kFLOAT, true,
platform::errors::InvalidArgument(
"The EmbEltwiseLayernorm Plugin only support fp32 input."));
} else if (sizeof(T) == sizeof(int16_t)) {
PADDLE_ENFORCE_EQ(
out_type == nvinfer1::DataType::kHALF, true,
platform::errors::InvalidArgument(
"The EmbEltwiseLayernorm Plugin only support fp16 input."));
} else {
PADDLE_THROW(platform::errors::Fatal(
"Unsupport data type, the out type of EmbEltwiseLayernorm should be "
"float or half."));
}
float *output_d = static_cast<float *>(outputs[0]);
operators::math::EmbEltwiseLayerNormFunctor<float> emb_eltwise_layernorm_func;
T *output_d = static_cast<T *>(outputs[0]);
operators::math::EmbEltwiseLayerNormFunctor<T> emb_eltwise_layernorm_func;
emb_eltwise_layernorm_func(batch, seq_len, hidden_size_, in_ptr_gpu_d,
scale_gpu_, bias_gpu_, emb_ptr_gpu_d, output_d,
eps_, input_num, stream);
return cudaGetLastError() != cudaSuccess;
}
template class EmbEltwiseLayernormPluginDynamic<float>;
#ifdef SUPPORTS_CUDA_FP16
template class EmbEltwiseLayernormPluginDynamic<half>;
#endif // SUPPORTS_CUDA_FP16
#endif
} // namespace plugin
......
......@@ -27,6 +27,7 @@ namespace tensorrt {
namespace plugin {
#if IS_TRT_VERSION_GE(6000)
template <typename T>
class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
public:
explicit EmbEltwiseLayernormPluginDynamic(std::vector<float*> input_embs,
......@@ -98,7 +99,7 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
// data on devices
float* bias_gpu_;
float* scale_gpu_;
std::vector<float*> embs_gpu_;
std::vector<T*> embs_gpu_;
std::vector<int> emb_sizes_;
int bias_size_;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda_runtime.h>
#include <stdio.h>
#include <cassert>
#include <cub/cub.cuh> // NOLINT
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
// Dynamic Plugin below.
#if IS_TRT_VERSION_GE(6000)
template <typename T>
__global__ void SliceKernel(int num, int dims, const T *input,
const int *offsets_info, T *output) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
extern __shared__ int shared_data[];
if (threadIdx.x == 0) {
for (int i = 0; i < dims * 3; i++) {
shared_data[i] = offsets_info[i];
}
}
__syncthreads();
if (idx < num) {
int t_idx = idx;
int in_idx = 0;
for (int i = dims - 1; i >= 0; i--) {
// output_shape
auto t = t_idx % shared_data[i * 3 + 1];
// out offset
auto s = t + shared_data[i * 3];
// input_seg_offset
in_idx = in_idx + shared_data[i * 3 + 2] * s;
t_idx = t_idx / shared_data[i * 3 + 1];
}
output[idx] = input[in_idx];
}
}
int SlicePluginDynamic::initialize() { return 0; }
size_t SlicePluginDynamic::getSerializationSize() const { return 0; }
void SlicePluginDynamic::serialize(void *buffer) const {}
nvinfer1::DimsExprs SlicePluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs,
nvinfer1::IExprBuilder &expr_builder) {
auto in_dims = inputs[0];
nvinfer1::DimsExprs ret;
for (int i = 0; i < ret.nbDims; i++) {
ret.d[i] = in_dims.d[i];
}
// start, ends should greater 0
for (size_t i = 0; i < axes_.size(); i++) {
int start = starts_[i];
int end = ends_[i];
ret.d[axes_[i]] = expr_builder.constant(end - start);
}
return ret;
}
bool SlicePluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs,
int nb_outputs) {
PADDLE_ENFORCE_NOT_NULL(
in_out, platform::errors::InvalidArgument(
"The input of swish plugin shoule not be nullptr."));
PADDLE_ENFORCE_LT(
pos, nb_inputs + nb_outputs,
platform::errors::InvalidArgument("The pos(%d) should be less than the "
"num(%d) of the input and the output.",
pos, nb_inputs + nb_outputs));
const nvinfer1::PluginTensorDesc &in = in_out[pos];
if (pos == 0) {
#ifdef SUPPORTS_CUDA_FP16
if (ban_fp16_) {
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
} else {
return (in.type == nvinfer1::DataType::kFLOAT ||
in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
}
#else
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
#endif
}
const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
// output
return in.type == prev.type && in.format == prev.format;
}
nvinfer1::DataType SlicePluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType *input_types, int nb_inputs) const {
PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument(
"The Slice Plugin only has one input, so the "
"index value should be 0, but get %d.",
index));
PADDLE_ENFORCE_EQ((input_types[0] == nvinfer1::DataType::kFLOAT ||
input_types[0] == nvinfer1::DataType::kHALF),
true, platform::errors::InvalidArgument(
"The input type should be half or float"));
return input_types[0];
}
int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc,
const void *const *inputs, void *const *outputs,
void *workspace, cudaStream_t stream) {
auto input_dims = input_desc[0].dims;
auto out_dims = output_desc[0].dims;
auto num_dims = input_dims.nbDims;
size_t out_num = ProductDim(out_dims);
std::vector<int> seg_offsets;
std::vector<int> offsets;
std::vector<int> extends;
offsets.reserve(num_dims);
extends.reserve(num_dims);
seg_offsets.reserve(num_dims);
seg_offsets[num_dims - 1] = 1;
for (int i = num_dims - 2; i >= 0; i--) {
seg_offsets[i] = input_dims.d[i + 1] * seg_offsets[i + 1];
}
for (size_t i = 0; i < num_dims; ++i) {
offsets[i] = 0;
extends[i] = out_dims.d[i];
}
for (size_t i = 0; i < axes_.size(); ++i) {
offsets[axes_[i]] = starts_[i];
}
std::vector<int> offset_info;
for (size_t i = 0; i < num_dims; ++i) {
offset_info.push_back(offsets[i]);
offset_info.push_back(extends[i]);
offset_info.push_back(seg_offsets[i]);
}
framework::Tensor offset_temp_tensor;
int device_id;
cudaGetDevice(&device_id);
offset_temp_tensor.Resize({3 * num_dims});
auto *offset_temp_data =
offset_temp_tensor.mutable_data<int>(platform::CUDAPlace(device_id));
cudaMemcpyAsync(offset_temp_data, offset_info.data(),
sizeof(int) * 3 * num_dims, cudaMemcpyHostToDevice, stream);
int threads = 256;
int blocks = (out_num + threads - 1) / threads;
auto input_type = input_desc[0].type;
if (input_type == nvinfer1::DataType::kFLOAT) {
const float *input1 = static_cast<const float *>(inputs[0]);
float *output = static_cast<float *>(outputs[0]);
SliceKernel<float><<<blocks, threads, 3 * num_dims * sizeof(int), stream>>>(
out_num, num_dims, input1, offset_temp_data, output);
} else if (input_type == nvinfer1::DataType::kHALF) {
#ifdef SUPPORTS_CUDA_FP16
const half *input1 = static_cast<const half *>(inputs[0]);
half *output = static_cast<half *>(outputs[0]);
SliceKernel<half><<<blocks, threads, 3 * num_dims * sizeof(int), stream>>>(
out_num, num_dims, input1, offset_temp_data, output);
#else
PADDLE_THROW(platform::errors::Fatal(
"The cuda archs you specific should greater than 600."));
#endif
} else {
PADDLE_THROW(platform::errors::Fatal(
"The Slice TRT Plugin's input type should be float or half."));
}
return cudaGetLastError() != cudaSuccess;
}
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
#if IS_TRT_VERSION_GE(6000)
class SlicePluginDynamic : public DynamicPluginTensorRT {
public:
explicit SlicePluginDynamic(std::vector<int> starts, std::vector<int> ends,
std::vector<int> axes, bool ban_fp16)
: starts_(starts), ends_(ends), axes_(axes), ban_fp16_(ban_fp16) {}
SlicePluginDynamic(void const* serialData, size_t serialLength) {}
nvinfer1::IPluginV2DynamicExt* clone() const override {
return new SlicePluginDynamic(starts_, ends_, axes_, ban_fp16_);
}
const char* getPluginType() const override { return "slice_plugin"; }
int getNbOutputs() const override { return 1; }
int initialize() override;
size_t getSerializationSize() const override;
void serialize(void* buffer) const override;
nvinfer1::DimsExprs getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs, int nbOutputs) override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) override {}
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const override {
return 0;
}
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* inputTypes,
int nbInputs) const override;
void destroy() override { delete this; }
private:
std::vector<int> starts_;
std::vector<int> ends_;
std::vector<int> axes_;
bool ban_fp16_{false};
};
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -120,7 +120,7 @@ void trt_ernie(bool with_fp16, std::vector<float> result) {
if (with_fp16) {
precision = AnalysisConfig::Precision::kHalf;
}
config.EnableTensorRtEngine(1 << 30, 1, 5, precision, false, true);
config.EnableTensorRtEngine(1 << 30, 1, 1, precision, false, true);
config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape,
opt_input_shape);
std::vector<float> out_data;
......
......@@ -124,7 +124,7 @@ void TestDynamic2() {
output_t->copy_to_cpu(out_data.data());
std::vector<float> result = {0.617728, 1.63504, 2.15771, 0.535556};
for (size_t i = 0; i < out_data.size(); i++) {
EXPECT_NEAR(result[i], out_data[i], 1e-6);
EXPECT_NEAR(result[i], out_data[i], 1e-5);
}
}
......
......@@ -287,7 +287,11 @@ class TensorRTEngineOp : public framework::OperatorBase {
#if IS_TRT_VERSION_GE(6000)
auto *trt_context = engine->context();
auto dims = trt_context->getBindingDimensions(bind_index);
for (int i = 0; i < dims.nbDims; i++) ddim.push_back(dims.d[i]);
int nb_dims = dims.nbDims;
for (; nb_dims > 0; nb_dims--) {
if (dims.d[nb_dims - 1] != 1) break;
}
for (int i = 0; i < nb_dims; i++) ddim.push_back(dims.d[i]);
#endif
}
auto *fluid_v = scope.FindVar(y);
......@@ -303,24 +307,29 @@ class TensorRTEngineOp : public framework::OperatorBase {
output_index += 1;
}
PADDLE_ENFORCE_LE(
runtime_batch, max_batch_size_,
platform::errors::InvalidArgument(
"The runtime batch size (%d) is greater than the max batch "
"size(%d).\n"
"There are two possible causes for this problem: \n"
"1. Check whether the runtime batch is larger than the max_batch "
"set by EnableTensorrtEngine()\n"
"2. Check whether the model you are running has multiple trt "
"subgraphs: \n "
"\tIf there are multiple trt subgraphs, you need to ensure that "
"the first dimension of the input tensor of these subgraphs is "
"consistent.\n"
"\tIf there are inconsistent subgraphs, you need to filter them by "
"setting min_subgraph_size using EnableTensorrtEngine interface.\n"
"\tThe min_subgraph_size shouble to be greater than the number of "
"nodes in the inconsistent subgraph.\n",
runtime_batch, max_batch_size_));
if (!engine->with_dynamic_shape()) {
PADDLE_ENFORCE_LE(
runtime_batch, max_batch_size_,
platform::errors::InvalidArgument(
"The runtime batch size (%d) is greater than the max batch "
"size(%d).\n"
"There are two possible causes for this problem: \n"
"1. Check whether the runtime batch is larger than the max_batch "
"set by EnableTensorrtEngine()\n"
"2. Check whether the model you are running has multiple trt "
"subgraphs: \n "
"\tIf there are multiple trt subgraphs, you need to ensure that "
"the first dimension of the input tensor of these subgraphs is "
"consistent.\n"
"\tIf there are inconsistent subgraphs, you need to filter them "
"by "
"setting min_subgraph_size using EnableTensorrtEngine "
"interface.\n"
"\tThe min_subgraph_size shouble to be greater than the number "
"of "
"nodes in the inconsistent subgraph.\n",
runtime_batch, max_batch_size_));
}
// Execute the engine.
engine->Execute(runtime_batch, &buffers, stream);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册