未验证 提交 4cd8a78a 编写于 作者: W wangxinxin08 提交者: GitHub

[cherry-pick]mish trt plugin (#38866)

* add mish trt plugin, compile & install success, run error. test=develop

* modify code of mish plugin

* upgrade mish trt plugin

* modify code according to review

* add TRT_NOEXCEPT for mish trt plugin

* add unittest for mish trt plugin

* remove unnecessary check of mish in op_teller.cc

* fix some problem of trt8

* add check and modify unittest while converting mish to trt plugin
Co-authored-by: Ndengkaipeng <dengkaipeng@baidu.com>
上级 f4867e57
......@@ -35,7 +35,7 @@ void IsTestPass::ApplyImpl(ir::Graph* graph) const {
"hard_shrink", "hard_sigmoid", "relu6",
"soft_relu", "swish", "thresholded_relu",
"log", "square", "softplus",
"softsign", "silu"};
"softsign", "silu", "mish"};
for (const Node* n : graph->Nodes()) {
if (n->IsOp()) {
auto* op = n->Op();
......
......@@ -1414,6 +1414,7 @@ USE_TRT_CONVERTER(tile);
USE_TRT_CONVERTER(conv3d);
USE_TRT_CONVERTER(conv3d_transpose);
USE_TRT_CONVERTER(pool3d);
USE_TRT_CONVERTER(mish);
#endif
namespace paddle_infer {
......
......@@ -18,6 +18,7 @@ nv_library(tensorrt_converter
tile_op.cc
conv3d_op.cc
pool3d_op.cc
mish_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/mish_op_plugin.h"
namespace paddle {
namespace framework {
class Scope;
namespace proto {
class OpDesc;
} // namespace proto
} // namespace framework
} // namespace paddle
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* Mish converter from fluid to tensorRT.
*/
class MishOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(4) << "convert fluid Mish op to tensorrt Mish plugin";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
int input_num = op_desc.Input("X").size();
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
const float threshold =
op_desc.HasAttr("threshold")
? BOOST_GET_CONST(float, op_desc.GetAttr("threshold"))
: 20.0f;
nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) {
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
plugin::MishPluginDynamic* plugin =
new plugin::MishPluginDynamic(threshold, with_fp16);
layer = engine_->AddDynamicPlugin(&input, input_num, plugin);
} else {
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
plugin::MishPlugin* plugin = new plugin::MishPlugin(threshold, with_fp16);
layer = engine_->AddPlugin(&input, input_num, plugin);
}
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "mish", {output_name}, test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(mish, MishOpConverter);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
namespace paddle {
namespace inference {
namespace tensorrt {
TEST(mish_op, test_mish) {
std::unordered_set<std::string> parameters;
framework::Scope scope;
TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("mish-X", nvinfer1::Dims3(3, 2, 2));
validator.DeclOutputVar("mish-Out", nvinfer1::Dims3(3, 2, 2));
// Prepare Op description
framework::OpDesc desc;
desc.SetType("mish");
desc.SetInput("X", {"mish-X"});
desc.SetOutput("Out", {"mish-Out"});
desc.SetAttr("threshold", 20.0f);
validator.SetOp(*desc.Proto());
validator.Execute(1);
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
USE_OP(mish);
......@@ -169,7 +169,8 @@ struct SimpleOpTypeSetTeller : public Teller {
"reduce_mean",
"conv3d",
"conv3d_transpose",
"pool3d"};
"pool3d",
"mish"};
};
bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
......@@ -1160,6 +1161,44 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
#endif
}
if (op_type == "mish") {
if (desc.Input("X").size() != 1) {
VLOG(3) << "Invalid input X's size of mish TRT converter. "
"Expected 1, received "
<< desc.Input("X").size() << ".";
return false;
}
if (desc.Output("Out").size() != 1) {
VLOG(3) << "Invalid output Out's size of mish TRT converter. "
"Expected 1, received "
<< desc.Output("Out").size() << ".";
return false;
}
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
if (x_shape.size() == 1) {
VLOG(3) << "mish op does not support input's dim is 1 in tensorrt.";
return false;
}
if (!with_dynamic_shape) {
if (x_shape.size() == 2) {
VLOG(3) << "mish op does not support input's dim is 2 in tensorrt.";
return false;
}
}
}
if (op_type == "roi_align") {
if (!with_dynamic_shape) {
VLOG(3) << "TRT roi align plugin only accept the dynamic shape, "
......
......@@ -10,6 +10,7 @@ nv_library(tensorrt_plugin
roi_align_op_plugin.cu
gather_nd_op_plugin.cu
pool3d_op_plugin.cu
mish_op_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor)
nv_test(test_split_plugin SRCS test_split_plugin.cc DEPS
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cstring>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/mish_op_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
int MishPlugin::initialize() TRT_NOEXCEPT { return 0; }
bool MishPlugin::supportsFormat(
nvinfer1::DataType type, nvinfer1::PluginFormat format) const TRT_NOEXCEPT {
if (with_fp16_) {
return ((type == nvinfer1::DataType::kFLOAT ||
type == nvinfer1::DataType::kHALF) &&
(format == nvinfer1::PluginFormat::kLINEAR));
} else {
return ((type == nvinfer1::DataType::kFLOAT) &&
(format == nvinfer1::PluginFormat::kLINEAR));
}
}
nvinfer1::Dims MishPlugin::getOutputDimensions(int index,
const nvinfer1::Dims* in_dims,
int nb_inputs) TRT_NOEXCEPT {
PADDLE_ENFORCE_EQ(nb_inputs, 1, platform::errors::InvalidArgument(
"We expect [number of inputs] == 1"
"in TRT Mish op plugin, but got "
"[number of inputs] = %d.",
nb_inputs));
PADDLE_ENFORCE_LT(index, this->getNbOutputs(),
platform::errors::InvalidArgument(
"We expect [index] < [number of outputs]"
"in TRT Mish op plugin, but got "
"[index] = %d, [number of outputs] = %d.",
index, this->getNbOutputs()));
nvinfer1::Dims const& input_dims = in_dims[0];
nvinfer1::Dims output_dims = input_dims;
return output_dims;
}
template <typename T>
__device__ T kTanh(T x) {
return tanh(x);
}
template <>
__device__ half kTanh<half>(half x) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
const float tmp = tanhf(__half2float(x));
return __float2half(tmp);
#endif
}
template <typename T>
__device__ T kSoftplus(T x, T threshold) {
return x > threshold ? x : log(exp(x) + static_cast<T>(1.0f));
}
template <>
__device__ half kSoftplus<half>(half x, half threshold) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
return x > threshold ? x : hlog(hexp(x) + static_cast<half>(1.0f));
#endif
}
template <typename T>
__global__ void mish_kernel(float threshold, int n, const T* input, T* output) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
const T in = input[idx];
output[idx] = in * kTanh<T>(kSoftplus<T>(in, static_cast<T>(threshold)));
}
}
template <>
__global__ void mish_kernel<half>(float threshold, int n, const half* input,
half* output) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
const half in = input[idx];
output[idx] =
in * kTanh<half>(kSoftplus<half>(in, static_cast<half>(threshold)));
}
#endif
}
#if IS_TRT_VERSION_LT(8000)
int MishPlugin::enqueue(int batchSize, const void* const* inputs,
void** outputs,
#else
int MishPlugin::enqueue(int batchSize, const void* const* inputs,
void* const* outputs,
#endif
void* workspace, cudaStream_t stream) TRT_NOEXCEPT {
const auto& input_dims = this->getInputDims(0);
int num = batchSize;
for (int i = 0; i < input_dims.nbDims; i++) {
num *= input_dims.d[i];
}
const int block_size = 256;
const int grid_size = (num + block_size - 1) / block_size;
auto type = getDataType();
if (type == nvinfer1::DataType::kFLOAT) {
VLOG(1) << "TRT Plugin DataType selected. Mish-->fp32";
const float* input = static_cast<const float*>(inputs[0]);
float* output = static_cast<float*>(outputs[0]);
mish_kernel<float><<<grid_size, block_size, 0, stream>>>(threshold_, num,
input, output);
} else if (type == nvinfer1::DataType::kHALF) {
VLOG(1) << "TRT Plugin DataType selected. Mish-->fp16";
const half* input = static_cast<const half*>(inputs[0]);
half* output = static_cast<half*>(outputs[0]);
mish_kernel<half><<<grid_size, block_size, 0, stream>>>(threshold_, num,
input, output);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The Mish TRT Plugin's input type should be float or half."));
}
return cudaGetLastError() != cudaSuccess;
}
// Dynamic Plugin below.
int MishPluginDynamic::initialize() TRT_NOEXCEPT {
getPluginNamespace();
return 0;
}
size_t MishPluginDynamic::getSerializationSize() const TRT_NOEXCEPT {
return SerializedSize(threshold_) + SerializedSize(with_fp16_);
}
void MishPluginDynamic::serialize(void* buffer) const TRT_NOEXCEPT {
SerializeValue(&buffer, threshold_);
SerializeValue(&buffer, with_fp16_);
}
nvinfer1::DimsExprs MishPluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT {
return inputs[0];
}
bool MishPluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc* in_out, int nb_inputs,
int nb_outputs) TRT_NOEXCEPT {
PADDLE_ENFORCE_NOT_NULL(
in_out, platform::errors::InvalidArgument(
"The input of mish plugin shoule not be nullptr."));
PADDLE_ENFORCE_LT(
pos, nb_inputs + nb_outputs,
platform::errors::InvalidArgument("The pos(%d) should be less than the "
"num(%d) of the input and the output.",
pos, nb_inputs + nb_outputs));
const nvinfer1::PluginTensorDesc& in = in_out[pos];
if (pos == 0) {
if (with_fp16_) {
return (in.type == nvinfer1::DataType::kFLOAT ||
in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
} else {
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
}
}
const nvinfer1::PluginTensorDesc& prev = in_out[pos - 1];
// output
return in.type == prev.type && in.format == prev.format;
}
nvinfer1::DataType MishPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType* input_types,
int nb_inputs) const TRT_NOEXCEPT {
PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument(
"The Mish Plugin only has one input, so the "
"index value should be 0, but get %d.",
index));
return input_types[0];
}
int MishPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc,
const void* const* inputs, void* const* outputs,
void* workspace,
cudaStream_t stream) TRT_NOEXCEPT {
auto input_dims = input_desc[0].dims;
size_t num = ProductDim(input_dims);
const int block_size = 256;
const int grid_size = (num + block_size - 1) / block_size;
auto input_type = input_desc[0].type;
if (input_type == nvinfer1::DataType::kFLOAT) {
VLOG(1) << "TRT Plugin DataType selected. Mish-->fp32";
const float* input = static_cast<const float*>(inputs[0]);
float* output = static_cast<float*>(outputs[0]);
mish_kernel<float><<<grid_size, block_size, 0, stream>>>(threshold_, num,
input, output);
} else if (input_type == nvinfer1::DataType::kHALF) {
VLOG(1) << "TRT Plugin DataType selected. Mish-->fp16";
const half* input = static_cast<const half*>(inputs[0]);
half* output = static_cast<half*>(outputs[0]);
mish_kernel<half><<<grid_size, block_size, 0, stream>>>(threshold_, num,
input, output);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The Mish TRT Plugin's input type should be float or half."));
}
return cudaGetLastError() != cudaSuccess;
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdio.h>
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
class MishPlugin : public PluginTensorRT {
private:
float threshold_;
protected:
size_t getSerializationSize() const TRT_NOEXCEPT override {
return getBaseSerializationSize() + SerializedSize(threshold_);
}
// TRT will call this func to serialize the configuration of TRT
// It should not be called by users.
void serialize(void* buffer) const TRT_NOEXCEPT override {
serializeBase(buffer);
SerializeValue(&buffer, threshold_);
}
public:
explicit MishPlugin(const float threshold, const bool with_fp16)
: threshold_(threshold) {
with_fp16_ = with_fp16;
}
// It was used for tensorrt deserialization.
// It should not be called by users.
MishPlugin(void const* serialData, size_t serialLength) {
deserializeBase(serialData, serialLength);
DeserializeValue(&serialData, &serialLength, &threshold_);
}
~MishPlugin() {}
MishPlugin* clone() const TRT_NOEXCEPT override {
return new MishPlugin(threshold_, with_fp16_);
}
const char* getPluginType() const TRT_NOEXCEPT override {
return "mish_plugin";
}
int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
int initialize() TRT_NOEXCEPT override;
bool supportsFormat(nvinfer1::DataType type, nvinfer1::PluginFormat format)
const TRT_NOEXCEPT override;
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs,
int nbInputDims) TRT_NOEXCEPT override;
#if IS_TRT_VERSION_LT(8000)
int enqueue(int batchSize, const void* const* inputs, void** outputs,
#else
int enqueue(int batchSize, const void* const* inputs, void* const* outputs,
#endif
void* workspace, cudaStream_t stream) TRT_NOEXCEPT override;
};
class MishPluginCreator : public TensorRTPluginCreator {
public:
const char* getPluginName() const TRT_NOEXCEPT override {
return "mish_plugin";
}
const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
nvinfer1::IPluginV2* deserializePlugin(
const char* name, const void* serial_data,
size_t serial_length) TRT_NOEXCEPT override {
return new MishPlugin(serial_data, serial_length);
}
};
REGISTER_TRT_PLUGIN_V2(MishPluginCreator);
class MishPluginDynamic : public DynamicPluginTensorRT {
public:
explicit MishPluginDynamic(const float threshold, const bool with_fp16)
: threshold_(threshold) {
with_fp16_ = with_fp16;
}
MishPluginDynamic(void const* serialData, size_t serialLength) {
DeserializeValue(&serialData, &serialLength, &threshold_);
DeserializeValue(&serialData, &serialLength, &with_fp16_);
}
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
return new MishPluginDynamic(threshold_, with_fp16_);
}
const char* getPluginType() const TRT_NOEXCEPT override {
return "mish_plugin_dynamic";
}
int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
int initialize() TRT_NOEXCEPT override;
size_t getSerializationSize() const TRT_NOEXCEPT override;
void serialize(void* buffer) const TRT_NOEXCEPT override;
nvinfer1::DimsExprs getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) TRT_NOEXCEPT override {}
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const TRT_NOEXCEPT override {
return 0;
}
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) TRT_NOEXCEPT override;
nvinfer1::DataType getOutputDataType(
int index, const nvinfer1::DataType* inputTypes,
int nbInputs) const TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override { delete this; }
private:
float threshold_;
};
class MishPluginDynamicCreator : public TensorRTPluginCreator {
public:
const char* getPluginName() const TRT_NOEXCEPT override {
return "mish_plugin_dynamic";
}
const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
nvinfer1::IPluginV2* deserializePlugin(
const char* name, const void* serial_data,
size_t serial_length) TRT_NOEXCEPT override {
auto plugin = new MishPluginDynamic(serial_data, serial_length);
return plugin;
}
};
REGISTER_TRT_PLUGIN_V2(MishPluginDynamicCreator);
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -139,6 +139,42 @@ class TensorRTSubgraphPassDynamicSwishFp16SerializeTest(
return fluid.layers.swish(x)
class TensorRTSubgraphPassMishTest(TensorRTSubgraphPassActivationTest):
def setUpTensorRTParam(self):
self.enable_trt = True
self.trt_parameters = TensorRTSubgraphPassActivationTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Float32, True, False)
def append_act(self, x):
return fluid.layers.mish(x)
class TensorRTSubgraphPassMishFp16SerializeTest(
TensorRTSubgraphPassActivationTest):
def setUpTensorRTParam(self):
self.enable_trt = True
self.trt_parameters = TensorRTSubgraphPassActivationTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Half, True, False)
def append_act(self, x):
return fluid.layers.mish(x)
class TensorRTSubgraphPassDynamicMishFp16SerializeTest(
TensorRTSubgraphPassActivationTest):
def setUpTensorRTParam(self):
self.enable_trt = True
self.trt_parameters = TensorRTSubgraphPassActivationTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Half, False, False)
self.dynamic_shape_params = TensorRTSubgraphPassActivationTest.DynamicShapeParam(
{
'data': [1, 6, 8, 8]
}, {'data': [1, 6, 512, 512]}, {'data': [1, 6, 256, 256]}, False)
def append_act(self, x):
return fluid.layers.mish(x)
class TensorRTSubgraphPassPreluAllTest(TensorRTSubgraphPassActivationTest):
def append_act(self, x):
return fluid.layers.prelu(x, mode='all')
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons
from program_config import TensorConfig, ProgramConfig
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
import unittest
class TrtConvertMishTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def sample_program_configs(self):
def generate_input(batch, dim1, dim2, dim3):
shape = [batch]
if dim1 != 0:
shape.append(dim1)
if dim2 != 0:
shape.append(dim2)
if dim3 != 0:
shape.append(dim3)
return np.random.random(shape).astype(np.float32)
for batch in [1, 4]:
for dim1 in [0, 3]:
for dim2 in [0, 16]:
for dim3 in [0, 32]:
for thre in [5.0, 20.0]:
self.dim1 = dim1
self.dim2 = dim2
self.dim3 = dim3
if dim1 == 0 and dim2 != 0:
continue
if dim1 == 0 and dim2 == 0 and dim3 != 0:
continue
ops_config = [{
"op_type": "mish",
"op_inputs": {
"X": ["input_data"]
},
"op_outputs": {
"Out": ["mish_output_data"]
},
"op_attrs": {
"threshold": thre
}
}]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input_data": TensorConfig(
data_gen=partial(generate_input, batch,
dim1, dim2, dim3))
},
outputs=["mish_output_data"])
yield program_config
def sample_predictor_configs(self, program_config):
def generate_dynamic_shape(attrs):
if self.dim1 == 0:
self.dynamic_shape.min_input_shape = {"input_data": [1], }
self.dynamic_shape.max_input_shape = {"input_data": [4], }
self.dynamic_shape.opt_input_shape = {"input_data": [2], }
else:
if self.dim2 == 0 and self.dim3 == 0:
self.dynamic_shape.min_input_shape = {
"input_data": [1, 1],
}
self.dynamic_shape.max_input_shape = {
"input_data": [4, 64],
}
self.dynamic_shape.opt_input_shape = {
"input_data": [2, 3],
}
elif self.dim2 != 0 and self.dim3 != 0:
self.dynamic_shape.min_input_shape = {
"input_data": [1, 1, 1, 1],
}
self.dynamic_shape.max_input_shape = {
"input_data": [4, 64, 128, 128],
}
self.dynamic_shape.opt_input_shape = {
"input_data": [2, 3, 16, 32],
}
elif self.dim3 == 0:
self.dynamic_shape.min_input_shape = {
"input_data": [1, 1, 1],
}
self.dynamic_shape.max_input_shape = {
"input_data": [4, 64, 256],
}
self.dynamic_shape.opt_input_shape = {
"input_data": [2, 3, 128],
}
def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
return 1, 2
attrs = [
program_config.ops[i].attrs
for i in range(len(program_config.ops))
]
# for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(attrs,
True), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(attrs,
True), 1e-5
def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if self.dim1 == 0 and self.dim2 == 0 and self.dim3 == 0:
return True
return False
self.add_skip_case(teller1, SkipReasons.TRT_NOT_SUPPORT,
"Trt does not support 1-dimensional input.")
def teller2(program_config, predictor_config):
if (len(self.dynamic_shape.min_input_shape) == 0):
if self.dim1 != 0 and self.dim2 == 0 and self.dim3 == 0:
return True
return False
self.add_skip_case(
teller2, SkipReasons.TRT_NOT_SUPPORT,
"Need to repair the case: the output of GPU and tensorrt has diff when the input dimension is 2 in static shape mode."
)
def test(self):
self.add_skip_trt_case()
self.run_test()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册