未验证 提交 4283be52 编写于 作者: P Pei Yang 提交者: GitHub

[cherry-pick][Paddle-TRT] Stack op plugin (#25605) (#27365)

* [Paddle-TRT] Stack op plugin (#25605)

* add stack_op to CMakeLists

* add dim=3 support for scale op

* add trt stack op, test=develop

* remove debug message

* add stack plugin serialize

* remove slice, scale op, will add later

* enhence error message

* revise trt ernie test to conver the stack op CI testi, test=develop

* add stack op serialization

* fix test shape after adding stack op

* remove slice op, will add after implementing serialization

* roll back to min_graph=5 to avoid using slice op

* fix scale op output layer

* implement stack op createPlugin

* use workspace and move the defination to .cu

* move stack plugin creator definition to .cu, test=develop

* sync ut with develop
Co-authored-by: Nzlsh80826 <zlsh80826@gmail.com>
上级 7dbc5126
......@@ -981,4 +981,5 @@ USE_TRT_CONVERTER(fused_embedding_eltwise_layernorm);
USE_TRT_CONVERTER(skip_layernorm);
USE_TRT_CONVERTER(slice);
USE_TRT_CONVERTER(scale);
USE_TRT_CONVERTER(stack);
#endif
......@@ -3,8 +3,8 @@ nv_library(tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc
emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc
emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS
......
......@@ -58,6 +58,24 @@ class ScaleOpConverter : public OpConverter {
TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr,
0};
nvinfer1::ILayer* layer = nullptr;
auto input_dim = input->getDimensions();
PADDLE_ENFORCE_GE(input_dim.nbDims, 3,
platform::errors::Fatal(
"Paddle-TRT scale mode only support dimension >= 3"));
nvinfer1::IShuffleLayer* expand_layer = nullptr;
nvinfer1::IShuffleLayer* squeeze_layer = nullptr;
if (input_dim.nbDims == 3) {
// TensorRT scale layer is not supporting input dims < 4 when using
// explicit batch
expand_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
nvinfer1::Dims4 target_shape(0, 0, 0, 1); // expand 1 dims
expand_layer->setReshapeDimensions(target_shape);
input = expand_layer->getOutput(0);
}
if (bias_after_scale) {
layer = TRT_ENGINE_ADD_LAYER(
engine_, Scale, *input, nvinfer1::ScaleMode::kUNIFORM,
......@@ -73,6 +91,18 @@ class ScaleOpConverter : public OpConverter {
power_weights.get(), scale_weights.get(), power_weights.get());
}
PADDLE_ENFORCE_EQ(layer != nullptr, true,
platform::errors::Fatal("Create scale layer failed."));
if (input_dim.nbDims == 3) {
// TensorRT scale layer is not supporting input dims < 4 when using
// explicit batch
squeeze_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *(layer->getOutput(0)));
nvinfer1::Dims3 target_shape(0, 0, 0); // expand 1 dims
squeeze_layer->setReshapeDimensions(target_shape);
layer = static_cast<nvinfer1::ILayer*>(squeeze_layer);
}
RreplenishLayerAndOutput(layer, "scale", {out_name}, test_mode);
}
};
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* Stack converter from fluid to tensorRT.
*/
class StackOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(4) << "convert fluid stack op to tensorrt stack layer";
framework::OpDesc op_desc(op, nullptr);
auto input = op_desc.Input("X");
int input_num = input.size();
nvinfer1::ITensor** inputs =
(nvinfer1::ITensor**)malloc(input_num * sizeof(nvinfer1::ITensor*));
for (int i = 0; i < input_num; ++i) {
inputs[i] = engine_->GetITensor(input[i]);
}
int axis = boost::get<int>(op_desc.GetAttr("axis"));
if (axis < 0) {
axis = axis + inputs[0]->getDimensions().nbDims + 1;
}
nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) {
#if IS_TRT_VERSION_GE(6000)
plugin::StackPluginDynamic* plugin =
new plugin::StackPluginDynamic(axis, input_num);
layer = engine_->AddPluginV2(inputs, input_num, plugin);
assert(layer != nullptr);
#else
PADDLE_THROW(platform::errors::Fatal(
"You are running the TRT Dynamic Shape mode, need to confirm that "
"your TRT version is no less than 6.0"));
#endif
} else {
PADDLE_THROW(platform::errors::Fatal(
"You are running the Ernie(Bert) model in static"
"shape mode, which is not supported for the time being.\n"
"You can use the config.SetTRTDynamicShapeInfo(...) interface"
" to set the shape information to run the dynamic shape mode."));
}
auto output_name = op_desc.Output("Y").front();
RreplenishLayerAndOutput(layer, "stack", {output_name}, test_mode);
free(inputs);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(stack, StackOpConverter);
......@@ -186,6 +186,7 @@ void TensorRTEngine::FreezeNetwork() {
Vec2TRT_Dims(optim_input_shape_[input.first], input.first, true));
}
infer_builder_config_->addOptimizationProfile(optim_profile_);
infer_builder_config_->setMaxWorkspaceSize(max_workspace_);
if (WithFp16()) {
infer_builder_config_->setFlag(nvinfer1::BuilderFlag::kFP16);
if (disable_trt_plugin_fp16()) {
......
......@@ -85,6 +85,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"gelu",
"layer_norm",
"scale",
"stack",
};
};
......
nv_library(tensorrt_plugin
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu hard_swish_op_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor)
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu
hard_swish_op_plugin.cu stack_op_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cassert>
#include <cstring>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
#if IS_TRT_VERSION_GE(6000)
StackPluginDynamic::StackPluginDynamic(int axis, int num_stack)
: axis_(axis), num_stack_(num_stack) {}
StackPluginDynamic::StackPluginDynamic(void const* serial_data,
size_t serial_length) {
DeserializeValue(&serial_data, &serial_length, &axis_);
DeserializeValue(&serial_data, &serial_length, &num_stack_);
}
StackPluginDynamic::~StackPluginDynamic() {}
nvinfer1::IPluginV2DynamicExt* StackPluginDynamic::clone() const {
return new StackPluginDynamic(axis_, num_stack_);
}
const char* StackPluginDynamic::getPluginType() const { return "stack_plugin"; }
int StackPluginDynamic::getNbOutputs() const { return 1; }
int StackPluginDynamic::initialize() { return 0; }
size_t StackPluginDynamic::getSerializationSize() const {
size_t serialize_size = 0;
serialize_size += SerializedSize(axis_);
serialize_size += SerializedSize(num_stack_);
return serialize_size;
}
void StackPluginDynamic::serialize(void* buffer) const {
SerializeValue(&buffer, axis_);
SerializeValue(&buffer, num_stack_);
}
nvinfer1::DimsExprs StackPluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) {
nvinfer1::DimsExprs output(inputs[0]);
output.nbDims = inputs[0].nbDims + 1;
for (int i = inputs[0].nbDims; i > axis_; --i) {
output.d[i] = inputs[0].d[i - 1];
}
output.d[axis_] = expr_builder.constant(nb_inputs);
return output;
}
void StackPluginDynamic::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) {}
size_t StackPluginDynamic::getWorkspaceSize(
const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const {
return num_stack_ * sizeof(uintptr_t);
}
void StackPluginDynamic::destroy() { delete this; }
void StackPluginDynamic::terminate() {}
bool StackPluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc* in_out, int nb_inputs,
int nb_outputs) {
PADDLE_ENFORCE_NOT_NULL(
in_out, platform::errors::InvalidArgument(
"The input of stack plugin should not be nullptr."));
PADDLE_ENFORCE_LT(
pos, nb_inputs + nb_outputs,
platform::errors::InvalidArgument("The pos(%d) should be less than the "
"num(%d) of the input and the output.",
pos, nb_inputs + nb_outputs));
const nvinfer1::PluginTensorDesc& in = in_out[pos];
if (pos == 0) {
#ifdef SUPPORTS_CUDA_FP16
return (in.type == nvinfer1::DataType::kFLOAT ||
in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
#else
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
#endif
}
const nvinfer1::PluginTensorDesc& prev = in_out[pos - 1];
// output
return in.type == prev.type && in.format == prev.format;
}
nvinfer1::DataType StackPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType* input_types, int nb_inputs) const {
PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument(
"The index should be equal to 0"));
return input_types[0];
}
template <typename T>
__global__ void StackKernel(const T* const* input, T* output, int num_stack,
int base_unit) {
int stack_id = blockIdx.x;
int lead_id = blockIdx.y;
for (int i = threadIdx.x; i < base_unit; i += blockDim.x) {
output[lead_id * num_stack * base_unit + stack_id * base_unit + i] =
input[stack_id][lead_id * base_unit + i];
}
}
int StackPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc,
const void* const* inputs, void* const* outputs,
void* workspace, cudaStream_t stream) {
auto input_dims = input_desc[0].dims; // (batch, seq, seq)
auto out_dims = output_desc[0].dims; // (batch, num_head, seq, seq)
auto out_num_dims = out_dims.nbDims;
int base_unit = 1;
for (int i = axis_ + 1; i < out_num_dims; ++i) {
PADDLE_ENFORCE_GT(out_dims.d[i], 0,
platform::errors::InvalidArgument(
"Input dimensions should be greater than 0"));
base_unit *= out_dims.d[i];
}
int lead_unit = 1;
for (int i = 0; i < axis_; ++i) {
PADDLE_ENFORCE_GT(out_dims.d[i], 0,
platform::errors::InvalidArgument(
"Input dimensions should be greater than 0"));
lead_unit *= out_dims.d[i];
}
PADDLE_ENFORCE_EQ(
out_dims.d[axis_], num_stack_,
platform::errors::InvalidArgument("number of stack axis should be same"));
cudaMemcpyAsync(workspace, reinterpret_cast<const void* const>(inputs),
sizeof(void*) * out_dims.d[axis_], cudaMemcpyHostToDevice,
stream);
const int num_stacks = out_dims.d[axis_];
dim3 num_blocks(num_stacks, lead_unit);
const int num_threads = 256;
auto infer_type = input_desc[0].type;
if (infer_type == nvinfer1::DataType::kFLOAT) {
float* output = static_cast<float*>(outputs[0]);
StackKernel<float><<<num_blocks, num_threads, 0, stream>>>(
reinterpret_cast<const float* const*>(workspace), output, num_stacks,
base_unit);
} else if (infer_type == nvinfer1::DataType::kHALF) {
#ifdef SUPPORTS_CUDA_FP16
__half* output = static_cast<__half*>(outputs[0]);
StackKernel<__half><<<num_blocks, num_threads, 0, stream>>>(
reinterpret_cast<const __half* const*>(workspace), output, num_stacks,
base_unit);
#else
PADDLE_THROW(platform::errors::Fatal(
"The cuda archs you specific should greater than 600."));
#endif
} else {
PADDLE_THROW(
platform::errors::Fatal("The Stack TRT Plugin's input type only "
"support float or half currently."));
}
return cudaGetLastError() != cudaSuccess;
}
StackPluginDynamicCreator::StackPluginDynamicCreator() {}
const char* StackPluginDynamicCreator::getPluginName() const {
return "stack_plugin";
}
const char* StackPluginDynamicCreator::getPluginVersion() const { return "1"; }
const nvinfer1::PluginFieldCollection*
StackPluginDynamicCreator::getFieldNames() {
return &field_collection_;
}
nvinfer1::IPluginV2* StackPluginDynamicCreator::createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) {
int axis = -1;
int num_stack = -1;
for (int i = 0; i < fc->nbFields; ++i) {
const std::string name(fc->fields[i].name);
if (name == "axis") {
axis = static_cast<const int*>(fc->fields[i].data)[0];
} else if (name == "num_stack") {
num_stack = static_cast<const int*>(fc->fields[i].data)[0];
} else {
PADDLE_THROW(platform::errors::Fatal("Meet an unknown plugin field '" +
name +
"' when creating stack op plugin."));
}
}
return new StackPluginDynamic(axis, num_stack);
}
nvinfer1::IPluginV2* StackPluginDynamicCreator::deserializePlugin(
const char* name, const void* serial_data, size_t serial_length) {
auto plugin = new StackPluginDynamic(serial_data, serial_length);
return plugin;
}
void StackPluginDynamicCreator::setPluginNamespace(const char* lib_namespace) {
plugin_namespace_ = lib_namespace;
}
const char* StackPluginDynamicCreator::getPluginNamespace() const {
return plugin_namespace_.c_str();
}
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdio.h>
#include <cassert>
#include <string>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
#if IS_TRT_VERSION_GE(6000)
class StackPluginDynamic : public DynamicPluginTensorRT {
public:
explicit StackPluginDynamic(int axis, int num_stack);
StackPluginDynamic(void const* serial_data, size_t serial_length);
~StackPluginDynamic();
nvinfer1::IPluginV2DynamicExt* clone() const override;
nvinfer1::DimsExprs getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs, int nbOutputs) override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const override;
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* inputTypes,
int nbInputs) const override;
const char* getPluginType() const override;
int getNbOutputs() const override;
int initialize() override;
void terminate() override;
size_t getSerializationSize() const override;
void serialize(void* buffer) const override;
void destroy() override;
private:
int axis_;
int num_stack_;
};
class StackPluginDynamicCreator : public nvinfer1::IPluginCreator {
public:
StackPluginDynamicCreator();
const char* getPluginName() const override;
const char* getPluginVersion() const override;
const nvinfer1::PluginFieldCollection* getFieldNames() override;
nvinfer1::IPluginV2* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) override;
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override;
void setPluginNamespace(const char* lib_namespace) override;
const char* getPluginNamespace() const override;
private:
std::string plugin_namespace_;
nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
std::vector<nvinfer1::PluginField> plugin_attributes_;
};
REGISTER_TRT_PLUGIN_V2(StackPluginDynamicCreator);
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -426,22 +426,14 @@ if(WITH_GPU AND TENSORRT_FOUND)
ARGS --infer_model=${TEST_TRT_ERNIE_MODEL}/ernie_model_4)
set(TEST_TRT_ERNIE_UNSER_MODEL "${TRT_MODEL_INSTALL_DIR}/ernie_test/ernie_model_4_unserialized/")
if (NOT EXISTS ${TEST_TRT_ERNIE_UNSER_MODEL})
if (NOT EXISTS ${TEST_TRT_ERNIE_UNSER_MODEL}/ernie_model_4_unserialized.tgz)
inference_download_and_uncompress(${TEST_TRT_ERNIE_MODEL} ${INFERENCE_URL}/tensorrt_test "ernie_model_4_unserialized.tgz")
endif()
inference_analysis_test(test_trt_dynamic_shape_ernie_serialize SRCS trt_dynamic_shape_ernie_deserialize_test.cc
inference_analysis_test(test_trt_dynamic_shape_ernie_ser_deser SRCS trt_dynamic_shape_ernie_deserialize_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TEST_TRT_ERNIE_MODEL}/ernie_model_4_unserialized)
set(TEST_TRT_ERNIE_SER_MODEL "${TRT_MODEL_INSTALL_DIR}/ernie_model_4_serialized/")
if (NOT EXISTS ${TEST_TRT_ERNIE_SER_MODEL})
inference_download_and_uncompress(${TRT_MODEL_INSTALL_DIR} ${INFERENCE_URL}/tensorrt_test "ernie_model_4_serialized.tgz")
endif()
inference_analysis_test(test_trt_dynamic_shape_ernie_deserialize SRCS trt_dynamic_shape_ernie_deserialize_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/ernie_model_4_serialized)
endif()
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <dirent.h>
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <unistd.h>
#include "paddle/fluid/inference/tests/api/trt_test_helper.h"
namespace paddle {
namespace inference {
int DeleteCache(std::string path) {
DIR* dir = opendir(path.c_str());
if (dir == NULL) return 0;
struct dirent* ptr;
while ((ptr = readdir(dir)) != NULL) {
if (std::strcmp(ptr->d_name, ".") == 0 ||
std::strcmp(ptr->d_name, "..") == 0) {
continue;
} else if (ptr->d_type == 8) {
std::string file_rm = path + "/" + ptr->d_name;
return remove(file_rm.c_str());
}
}
return 0;
}
void run(const AnalysisConfig& config, std::vector<float>* out_data) {
auto predictor = CreatePaddlePredictor(config);
auto input_names = predictor->GetInputNames();
......@@ -86,11 +101,15 @@ void run(const AnalysisConfig& config, std::vector<float>* out_data) {
void trt_ernie(bool with_fp16, std::vector<float> result) {
AnalysisConfig config;
std::string model_dir = FLAGS_infer_model;
// Delete serialization cache to perform serialization first rather than
// deserialization.
std::string opt_cache_dir = FLAGS_infer_model + "/_opt_cache";
DeleteCache(opt_cache_dir);
SetConfig(&config, model_dir, true /* use_gpu */);
config.SwitchUseFeedFetchOps(false);
int head_number = 12;
int batch = 1;
int min_seq_len = 1;
int max_seq_len = 128;
......@@ -104,17 +123,17 @@ void trt_ernie(bool with_fp16, std::vector<float> result) {
{"read_file_0.tmp_0", min_shape},
{"read_file_0.tmp_1", min_shape},
{"read_file_0.tmp_2", min_shape},
{"stack_0.tmp_0", {batch, head_number, min_seq_len, min_seq_len}}};
{"matmul_0.tmp_0", {batch, min_seq_len, min_seq_len}}};
std::map<std::string, std::vector<int>> max_input_shape = {
{"read_file_0.tmp_0", max_shape},
{"read_file_0.tmp_1", max_shape},
{"read_file_0.tmp_2", max_shape},
{"stack_0.tmp_0", {batch, head_number, max_seq_len, max_seq_len}}};
{"matmul_0.tmp_0", {batch, max_seq_len, max_seq_len}}};
std::map<std::string, std::vector<int>> opt_input_shape = {
{"read_file_0.tmp_0", opt_shape},
{"read_file_0.tmp_1", opt_shape},
{"read_file_0.tmp_2", opt_shape},
{"stack_0.tmp_0", {batch, head_number, opt_seq_len, opt_seq_len}}};
{"matmul_0.tmp_0", {batch, opt_seq_len, opt_seq_len}}};
auto precision = AnalysisConfig::Precision::kFloat32;
if (with_fp16) {
......@@ -123,8 +142,11 @@ void trt_ernie(bool with_fp16, std::vector<float> result) {
config.EnableTensorRtEngine(1 << 30, 1, 5, precision, true, false);
config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape,
opt_input_shape);
AnalysisConfig* config_deser = new AnalysisConfig(config);
std::vector<float> out_data;
run(config, &out_data);
run(config, &out_data); // serialize
run(*config_deser, &out_data); // deserialize
for (size_t i = 0; i < out_data.size(); i++) {
EXPECT_NEAR(result[i], out_data[i], 1e-6);
}
......
......@@ -90,7 +90,6 @@ void trt_ernie(bool with_fp16, std::vector<float> result) {
config.SwitchUseFeedFetchOps(false);
int head_number = 12;
int batch = 1;
int min_seq_len = 1;
int max_seq_len = 128;
......@@ -104,17 +103,17 @@ void trt_ernie(bool with_fp16, std::vector<float> result) {
{"read_file_0.tmp_0", min_shape},
{"read_file_0.tmp_1", min_shape},
{"read_file_0.tmp_2", min_shape},
{"stack_0.tmp_0", {batch, head_number, min_seq_len, min_seq_len}}};
{"matmul_0.tmp_0", {batch, min_seq_len, min_seq_len}}};
std::map<std::string, std::vector<int>> max_input_shape = {
{"read_file_0.tmp_0", max_shape},
{"read_file_0.tmp_1", max_shape},
{"read_file_0.tmp_2", max_shape},
{"stack_0.tmp_0", {batch, head_number, max_seq_len, max_seq_len}}};
{"matmul_0.tmp_0", {batch, max_seq_len, max_seq_len}}};
std::map<std::string, std::vector<int>> opt_input_shape = {
{"read_file_0.tmp_0", opt_shape},
{"read_file_0.tmp_1", opt_shape},
{"read_file_0.tmp_2", opt_shape},
{"stack_0.tmp_0", {batch, head_number, opt_seq_len, opt_seq_len}}};
{"matmul_0.tmp_0", {batch, opt_seq_len, opt_seq_len}}};
auto precision = AnalysisConfig::Precision::kFloat32;
if (with_fp16) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册