未验证 提交 a7778930 编写于 作者: S shentanyue 提交者: GitHub

[TensorRT] Support yolov5s (#42688)

* support yolov5s static/int8

* fix eltwise_sub and div weight compute

* fix delete_fill_constant_pass
上级 9f4d342c
......@@ -89,6 +89,7 @@ pass_library(delete_quant_dequant_filter_op_pass inference)
pass_library(delete_weight_dequant_linear_op_pass inference)
pass_library(delete_quant_dequant_linear_op_pass inference)
pass_library(delete_dropout_op_pass inference)
pass_library(delete_fill_constant_op_pass inference)
pass_library(simplify_with_basic_ops_pass base)
pass_library(fc_elementwise_layernorm_fuse_pass base)
pass_library(skip_layernorm_fuse_pass base)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/delete_fill_constant_op_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
template <typename T>
void FillConstData(LoDTensor* out_t, T value) {
auto output_data = out_t->mutable_data<T>(platform::CPUPlace());
for (int i = 0; i < out_t->numel(); i++) {
output_data[i] = value;
}
}
void DeleteFillConstantOpPass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init("delete_fill_constant_op_pass", graph);
GraphPatternDetector detector;
auto fill_constant_op = detector.mutable_pattern()
->NewNode("fill_constant")
->assert_is_op("fill_constant")
->assert_is_not_op_input("ValueTensor")
->assert_is_not_op_input("str_value")
->assert_is_not_op_input("ShapeTensor")
->assert_is_not_op_input("ShapeTensorList");
auto fill_constant_out =
detector.mutable_pattern()
->NewNode("fill_constant_out")
->assert_is_op_output("fill_constant")
->assert_more([](Node* x) { return x->outputs.size() == 1UL; });
auto next_op = detector.mutable_pattern()
->NewNode("next_op")
->assert_is_not_op_type("conditional_block")
->assert_is_not_op_type("while");
// Create the topological connections for the above pattern nodes.
fill_constant_op->LinksTo({fill_constant_out});
next_op->LinksFrom({fill_constant_out});
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
Node* fill_constant_op_node = subgraph.at(fill_constant_op);
Node* fill_constant_out_node = subgraph.at(fill_constant_out);
// Get fill_constant's attr
auto fill_constant = fill_constant_op_node->Op();
auto value = BOOST_GET_CONST(float, fill_constant->GetAttr("value"));
auto shape =
BOOST_GET_CONST(std::vector<int64_t>, fill_constant->GetAttr("shape"));
auto* scope = param_scope();
auto fill_constant_out_desc = fill_constant_out_node->Var();
fill_constant_out_desc->SetShape(shape);
fill_constant_out_desc->SetPersistable(true);
auto* fill_constant_out_tensor =
scope->Var(fill_constant_out_desc->Name())->GetMutable<LoDTensor>();
auto dtype =
framework::TransToPhiDataType(fill_constant_out_desc->GetDataType());
fill_constant_out_tensor->Resize(phi::make_ddim(shape));
switch (dtype) {
case paddle::experimental::DataType::BOOL:
FillConstData<bool>(fill_constant_out_tensor, static_cast<bool>(value));
break;
case paddle::experimental::DataType::INT32:
FillConstData<int32_t>(fill_constant_out_tensor,
static_cast<int32_t>(value));
break;
case paddle::experimental::DataType::INT64:
FillConstData<int64_t>(fill_constant_out_tensor,
static_cast<int64_t>(value));
break;
case paddle::experimental::DataType::FLOAT32:
FillConstData<float>(fill_constant_out_tensor,
static_cast<float>(value));
break;
default:
LOG(WARNING) << "Unsupported dtype for fill_constant op: " << dtype;
return;
}
// Remove links in graph
GraphSafeRemoveNodes(graph, {fill_constant_op_node});
};
detector(graph, handler);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(delete_fill_constant_op_pass,
paddle::framework::ir::DeleteFillConstantOpPass);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
class DeleteFillConstantOpPass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
virtual ~DeleteFillConstantOpPass() = default;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -408,6 +408,13 @@ PDNode *PDNode::assert_is_op(const std::string &op_type) {
return this;
}
PDNode *PDNode::assert_is_not_op_type(const std::string &op_type) {
asserts_.emplace_back([op_type](Node *x) {
return x && x->IsOp() && x->Op()->Type() != op_type;
});
return this;
}
PDNode *PDNode::assert_is_var() {
asserts_.emplace_back([](Node *x) { return x && x->IsVar(); });
return this;
......
......@@ -110,6 +110,7 @@ struct PDNode {
// Assertions, helper functions to simplify the pattern definition.
PDNode* assert_is_op();
PDNode* assert_is_op(const std::string& op_type);
PDNode* assert_is_not_op_type(const std::string& op_type);
PDNode* assert_is_var();
PDNode* assert_var_dtype(proto::VarType::Type dtype);
PDNode* assert_is_not_ctrl_var();
......
......@@ -633,6 +633,11 @@ void AnalysisConfig::Update() {
(pass == "conv_bn_fuse_pass")) {
continue;
}
// delete_fill_constant_op_pass is not used under trt dynamic shape
if ((!min_input_shape_.empty() || trt_tuned_dynamic_shape_) &&
pass == "delete_fill_constant_op_pass") {
continue;
}
pass_builder()->AppendPass(pass);
}
}
......
......@@ -1731,6 +1731,10 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<AnalysisConfig>(
#if PADDLE_WITH_TENSORRT
USE_TRT_CONVERTER(elementwise_add_weight);
USE_TRT_CONVERTER(elementwise_sub_weight);
USE_TRT_CONVERTER(elementwise_mul_weight);
USE_TRT_CONVERTER(elementwise_div_weight);
USE_TRT_CONVERTER(elementwise_pow_weight);
USE_TRT_CONVERTER(elementwise_add_tensor);
USE_TRT_CONVERTER(elementwise_sub_tensor);
USE_TRT_CONVERTER(elementwise_div_tensor);
......
......@@ -85,6 +85,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
"adaptive_pool2d_convert_global_pass",
"shuffle_channel_detect_pass", //
"quant_conv2d_dequant_fuse_pass", //
"delete_fill_constant_op_pass", //
"delete_quant_dequant_op_pass", //
"delete_quant_dequant_filter_op_pass", //
"delete_weight_dequant_linear_op_pass", //
......
......@@ -53,20 +53,14 @@ class ElementwiseWeightOpConverter : public OpConverter {
auto output_name = op_desc.Output("Out")[0];
weight_data = engine_->GetWeightCPUData(op_desc.Input("Y").front(), Y_t);
nvinfer1::Dims dims_x = X->getDimensions();
std::vector<int> dims_y = phi::vectorize<int>(Y_t->dims());
auto regist_eltwise_weight = [&](nvinfer1::ScaleMode scale_mode) {
TensorRTEngine::Weight shift_weights{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data),
static_cast<size_t>(Y_t->numel())};
TensorRTEngine::Weight scale_weights{nvinfer1::DataType::kFLOAT, nullptr,
0};
TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr,
0};
nvinfer1::IShuffleLayer* expand_layer = nullptr;
nvinfer1::IShuffleLayer* squeeze_layer = nullptr;
int dynamic_shape_offset = engine_->with_dynamic_shape() ? 1 : 0;
auto input_dim = X->getDimensions();
// reshape
if (input_dim.nbDims < 3 + dynamic_shape_offset) {
nvinfer1::Dims expand_shape;
expand_shape.nbDims = 3 + dynamic_shape_offset;
......@@ -85,17 +79,45 @@ class ElementwiseWeightOpConverter : public OpConverter {
expand_layer->setName(
("Elewise: Shuffle: (Output: " + output_name + ")").c_str());
}
// eltwise_ops
TensorRTEngine::Weight shift_weights{nvinfer1::DataType::kFLOAT, nullptr,
0};
TensorRTEngine::Weight scale_weights{nvinfer1::DataType::kFLOAT, nullptr,
0};
TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr,
0};
if (op_type_ == "add") {
nvinfer1::IScaleLayer* scale_layer = TRT_ENGINE_ADD_LAYER(
engine_, ScaleNd, *X, scale_mode, shift_weights.get(),
scale_weights.get(), power_weights.get(), dynamic_shape_offset);
layer = scale_layer;
shift_weights = TensorRTEngine::Weight(
nvinfer1::DataType::kFLOAT, static_cast<void*>(weight_data),
static_cast<size_t>(Y_t->numel()));
} else if (op_type_ == "sub") {
for (int i = 0; i < Y_t->numel(); i++) {
weight_data[i] = -weight_data[i];
}
shift_weights = TensorRTEngine::Weight(
nvinfer1::DataType::kFLOAT, static_cast<void*>(weight_data),
static_cast<size_t>(Y_t->numel()));
} else if (op_type_ == "mul") {
nvinfer1::IScaleLayer* scale_layer = TRT_ENGINE_ADD_LAYER(
engine_, Scale, *X, scale_mode, scale_weights.get(),
shift_weights.get(), power_weights.get());
layer = scale_layer;
scale_weights = TensorRTEngine::Weight(
nvinfer1::DataType::kFLOAT, static_cast<void*>(weight_data),
static_cast<size_t>(Y_t->numel()));
} else if (op_type_ == "div") {
for (int i = 0; i < Y_t->numel(); i++) {
weight_data[i] = 1.f / weight_data[i];
}
scale_weights = TensorRTEngine::Weight(
nvinfer1::DataType::kFLOAT, static_cast<void*>(weight_data),
static_cast<size_t>(Y_t->numel()));
} else if (op_type_ == "pow") {
power_weights = TensorRTEngine::Weight(
nvinfer1::DataType::kFLOAT, static_cast<void*>(weight_data),
static_cast<size_t>(Y_t->numel()));
}
nvinfer1::IScaleLayer* scale_layer = TRT_ENGINE_ADD_LAYER(
engine_, ScaleNd, *X, scale_mode, shift_weights.get(),
scale_weights.get(), power_weights.get(), dynamic_shape_offset);
layer = scale_layer;
// reshape
if (input_dim.nbDims < 3 + dynamic_shape_offset) {
nvinfer1::Dims squeeze_shape;
squeeze_shape.nbDims = input_dim.nbDims;
......@@ -113,71 +135,43 @@ class ElementwiseWeightOpConverter : public OpConverter {
}
};
// dynamic shape
if (engine_->with_dynamic_shape()) {
if (Y_t->dims().size() == 1) {
auto scale_mode = nvinfer1::ScaleMode::kCHANNEL;
PADDLE_ENFORCE_EQ(Y_t->dims()[0], dims_x.d[1],
platform::errors::InvalidArgument(
"The Bias's size(%d) should be equal to the "
"first dim(%d) of the Input.",
Y_t->dims()[0], dims_x.d[1]));
regist_eltwise_weight(scale_mode);
if (dims_y.size() == 1 && dims_y[0] == dims_x.d[1]) {
regist_eltwise_weight(nvinfer1::ScaleMode::kCHANNEL);
} else if (dims_y.size() == 1 && dims_y[0] == 1) {
regist_eltwise_weight(nvinfer1::ScaleMode::kUNIFORM);
} else if (dims_y.size() == static_cast<size_t>(dims_x.nbDims)) {
regist_eltwise_weight(nvinfer1::ScaleMode::kELEMENTWISE);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The size of input bias's dims is %d, but TensorRT dynamic shape "
"only support size = 1 for Elementwise op!",
Y_t->dims().size()));
"The size of input_y's dims is %d, but TensorRT dynamic shape "
"only support size = 1 or size = input_x.size() for Elementwise "
"op!",
dims_y.size()));
}
return;
}
// static shape with dynamic batch
std::vector<int> no_batch_dims;
int start_index = 0;
for (; start_index < dims_x.nbDims; start_index++)
for (; start_index < dims_x.nbDims; start_index++) {
no_batch_dims.push_back(dims_x.d[start_index]);
auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
std::vector<int> dims_y = phi::vectorize<int>(Y_t->dims());
if (dims_y.size() == no_batch_dims.size() + 1) {
if (dims_y[0] == 1) dims_y.erase(dims_y.begin());
}
if (dims_y.size() == 1 && dims_y[0] == no_batch_dims[0]) {
scale_mode = nvinfer1::ScaleMode::kCHANNEL;
} else if (dims_y.size() == no_batch_dims.size() &&
dims_y[0] == no_batch_dims[0]) {
scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
for (size_t i = 1; i < no_batch_dims.size(); i++) {
if (dims_y[i] != no_batch_dims[i]) {
scale_mode = nvinfer1::ScaleMode::kCHANNEL;
break;
}
}
if (scale_mode == nvinfer1::ScaleMode::kCHANNEL) {
for (size_t i = 1; i < no_batch_dims.size(); i++) {
if (dims_y[i] != 1)
PADDLE_THROW(platform::errors::InvalidArgument(
"The bias's %d dim is %d, but TensorRT dynamic shape only "
"support it equals to 1 for Elementwise op!",
i, dims_y[i]));
}
}
regist_eltwise_weight(nvinfer1::ScaleMode::kCHANNEL);
} else if (dims_y.size() == 1 && dims_y[0] == 1) {
regist_eltwise_weight(nvinfer1::ScaleMode::kUNIFORM);
} else if (dims_y.size() == no_batch_dims.size() + 1) {
regist_eltwise_weight(nvinfer1::ScaleMode::kELEMENTWISE);
} else {
if (dims_y.size() >= 1) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The size of bias's dims is %d and bias's size is %d. TensorRT "
"doesn't support this shape for Elementwise op!",
dims_y.size(), dims_y[0]));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The size of bias's dims is %d. TensorRT doesn't support "
"this shape for Elementwise op!",
dims_y.size()));
}
PADDLE_THROW(platform::errors::InvalidArgument(
"The size of input_y's dims is %d, but TensorRT dynamic shape "
"only support size = 1 or size = input_x.size() for Elementwise "
"op!",
dims_y.size()));
}
regist_eltwise_weight(scale_mode);
}
protected:
......@@ -215,7 +209,6 @@ class ElementwiseTensorOpConverter : public OpConverter {
auto common_func = [&](nvinfer1::ILayer* layer) {
RreplenishLayerAndOutput(layer, "elementwise", {output_name}, test_mode);
};
if (dims_x.nbDims == dims_y.nbDims) {
// The two input tensor should have the same dims
VLOG(3) << "Convert a fluid elementwise op to TensorRT IElementWiseLayer";
......@@ -244,7 +237,6 @@ class ElementwiseTensorOpConverter : public OpConverter {
auto* plugin_layer = engine_->AddPlugin(
inputs.data(), inputs.size(),
reinterpret_cast<plugin::PluginTensorRT*>(plugin));
layer = plugin_layer;
}
}
......@@ -278,6 +270,21 @@ class ElementwiseWeightMulOpConverter : public ElementwiseWeightOpConverter {
ElementwiseWeightMulOpConverter() { op_type_ = "mul"; }
};
class ElementwiseWeightSubOpConverter : public ElementwiseWeightOpConverter {
public:
ElementwiseWeightSubOpConverter() { op_type_ = "sub"; }
};
class ElementwiseWeightDivOpConverter : public ElementwiseWeightOpConverter {
public:
ElementwiseWeightDivOpConverter() { op_type_ = "div"; }
};
class ElementwiseWeightPowOpConverter : public ElementwiseWeightOpConverter {
public:
ElementwiseWeightPowOpConverter() { op_type_ = "pow"; }
};
class ElementwiseTensorAddOpConverter : public ElementwiseTensorOpConverter {
public:
ElementwiseTensorAddOpConverter() { op_type_ = "add"; }
......@@ -321,6 +328,12 @@ REGISTER_TRT_OP_CONVERTER(elementwise_add_weight,
ElementwiseWeightAddOpConverter);
REGISTER_TRT_OP_CONVERTER(elementwise_mul_weight,
ElementwiseWeightMulOpConverter);
REGISTER_TRT_OP_CONVERTER(elementwise_sub_weight,
ElementwiseWeightSubOpConverter);
REGISTER_TRT_OP_CONVERTER(elementwise_div_weight,
ElementwiseWeightDivOpConverter);
REGISTER_TRT_OP_CONVERTER(elementwise_pow_weight,
ElementwiseWeightPowOpConverter);
REGISTER_TRT_OP_CONVERTER(elementwise_add_tensor,
ElementwiseTensorAddOpConverter);
......
......@@ -67,10 +67,8 @@ class OpConverter {
if (op_desc.Type().find("elementwise") != std::string::npos) {
static std::unordered_set<std::string> add_tensor_op_set{
"add", "mul", "sub", "div", "max", "min", "pow"};
// TODO(xingzhaolong): all mul, sub, div
// static std::unordered_set<std::string> add_weight_op_set {"add", "mul",
// "sub", "div"};
static std::unordered_set<std::string> add_weight_op_set{"add", "mul"};
static std::unordered_set<std::string> add_weight_op_set{
"add", "mul", "sub", "div", "pow"};
PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1UL,
platform::errors::InvalidArgument(
"The input op's Input(\"Y\")."
......
......@@ -39,7 +39,7 @@ class StridedSliceOpConverter : public OpConverter {
framework::OpDesc op_desc(op, nullptr);
auto* input = engine_->GetITensor(op_desc.Input("Input")[0]);
nvinfer1::Dims input_dims = input->getDimensions();
auto output_name = op_desc.Output("Out")[0];
std::vector<int> axes =
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("axes"));
std::vector<int> starts =
......@@ -48,79 +48,116 @@ class StridedSliceOpConverter : public OpConverter {
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("ends"));
std::vector<int> strides =
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("strides"));
nvinfer1::Dims start;
start.nbDims = input_dims.nbDims;
int axes_size = axes.size();
for (int i = 0; i < start.nbDims; i++) {
start.d[i] = 0;
}
for (int i = 0; i < axes_size; i++) {
start.d[axes[i]] = starts[i];
}
nvinfer1::Dims start;
nvinfer1::Dims stride;
stride.nbDims = input_dims.nbDims;
for (int i = 0; i < stride.nbDims; i++) {
stride.d[i] = 1;
}
for (int i = 0; i < axes_size; i++) {
stride.d[axes[i]] = strides[i];
}
nvinfer1::Dims size;
start.nbDims = input_dims.nbDims;
stride.nbDims = input_dims.nbDims;
size.nbDims = input_dims.nbDims;
for (int i = 0; i < size.nbDims; i++) {
size.d[i] = 1;
for (int i = 0; i < input_dims.nbDims; i++) {
start.d[i] = 0;
stride.d[i] = 1;
size.d[i] = input_dims.d[i];
}
auto output_name = op_desc.Output("Out")[0];
auto create_weights = [&](const std::vector<int>& data,
const std::string& type) -> int* {
std::unique_ptr<framework::Tensor> tmp_tensor(new framework::Tensor());
int data_size = data.size();
tmp_tensor->Resize({data_size});
auto* tmp_data = tmp_tensor->mutable_data<int>(platform::CPUPlace());
for (int i = 0; i < data_size; i++) {
tmp_data[i] = data[i];
if (!engine_->with_dynamic_shape()) {
for (int i = 0; i < axes_size; i++) {
start.d[axes[i] - 1] = starts[i];
}
for (int i = 0; i < axes_size; i++) {
stride.d[axes[i] - 1] = strides[i];
}
for (int i = 0; i < axes_size; ++i) {
int dim = size.d[axes[i] - 1];
if (dim > 0) {
int start = starts[i] < 0 ? (starts[i] + dim) : starts[i];
int end = ends[i] < 0 ? (ends[i] + dim) : ends[i];
int stride = std::abs(strides[i]);
start = std::max(start, 0);
end = std::max(end, 0);
end = std::min(end, dim);
size.d[axes[i] - 1] = (std::abs(end - start) + stride - 1) / stride;
}
}
auto* layer =
TRT_ENGINE_ADD_LAYER(engine_, Slice, *input, start, size, stride);
RreplenishLayerAndOutput(layer, "strided_slice", {output_name},
test_mode);
} else {
for (int i = 0; i < axes_size; i++) {
start.d[axes[i]] = starts[i];
}
for (int i = 0; i < axes_size; i++) {
stride.d[axes[i]] = strides[i];
}
for (int i = 0; i < axes_size; ++i) {
int dim = size.d[axes[i]];
if (dim > 0) {
int start = starts[i] < 0 ? (starts[i] + dim) : starts[i];
int end = ends[i] < 0 ? (ends[i] + dim) : ends[i];
int stride = std::abs(strides[i]);
start = std::max(start, 0);
end = std::max(end, 0);
end = std::min(end, dim);
size.d[axes[i]] = (std::abs(end - start) + stride - 1) / stride;
}
}
engine_->SetWeights(output_name + "_add_slice_op_" + type,
std::move(tmp_tensor));
return tmp_data;
};
auto create_weights = [&](const std::vector<int>& data,
const std::string& type) -> int* {
std::unique_ptr<framework::Tensor> tmp_tensor(new framework::Tensor());
int data_size = data.size();
tmp_tensor->Resize({data_size});
auto* tmp_data = tmp_tensor->mutable_data<int>(platform::CPUPlace());
for (int i = 0; i < data_size; i++) {
tmp_data[i] = data[i];
}
engine_->SetWeights(output_name + "_add_slice_op_" + type,
std::move(tmp_tensor));
return tmp_data;
};
std::vector<int> const_weight(input_dims.nbDims, 0);
for (int i = 0; i < axes_size; i++) {
int dim = input_dims.d[axes[i]];
int start = starts[i] < 0 ? (starts[i] + dim) : starts[i];
int end = ends[i] < 0 ? (ends[i] + dim) : ends[i];
int stride = std::abs(strides[i]);
start = std::max(start, 0);
end = std::max(end, 0);
end = std::min(end, dim);
const_weight[axes[i]] =
dim - ((std::abs(end - start) + stride - 1) / stride);
}
std::vector<int> const_weight(input_dims.nbDims, 1);
for (int i = 0; i < axes_size; i++) {
const_weight[axes[i]] = strides[i];
int* weight_data = create_weights(const_weight, "size");
TensorRTEngine::Weight weight{nvinfer1::DataType::kINT32,
static_cast<void*>(weight_data),
static_cast<size_t>(input_dims.nbDims)};
int input_dim_size = input_dims.nbDims;
nvinfer1::Dims input_shape;
input_shape.nbDims = 1;
input_shape.d[0] = input_dim_size;
auto const_layer =
TRT_ENGINE_ADD_LAYER(engine_, Constant, input_shape, weight.get());
auto shape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shape, *input);
// slice layer
auto* layer =
TRT_ENGINE_ADD_LAYER(engine_, Slice, *input, start, size, stride);
// elementwise layer for get size tensor
auto size_layer = TRT_ENGINE_ADD_LAYER(
engine_, ElementWise, *shape_layer->getOutput(0),
*const_layer->getOutput(0), nvinfer1::ElementWiseOperation::kSUB);
layer->setInput(2, *size_layer->getOutput(0));
RreplenishLayerAndOutput(layer, "strided_slice", {output_name},
test_mode);
}
int* weight_data = create_weights(const_weight, "size");
TensorRTEngine::Weight weight{nvinfer1::DataType::kINT32,
static_cast<void*>(weight_data),
static_cast<size_t>(input_dims.nbDims)};
int input_dim_size = input_dims.nbDims;
nvinfer1::Dims input_shape;
input_shape.nbDims = 1;
input_shape.d[0] = input_dim_size;
auto const_layer =
TRT_ENGINE_ADD_LAYER(engine_, Constant, input_shape, weight.get());
auto shape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shape, *input);
auto size_layer = TRT_ENGINE_ADD_LAYER(
engine_, ElementWise, *shape_layer->getOutput(0),
*const_layer->getOutput(0), nvinfer1::ElementWiseOperation::kDIV);
auto* layer =
TRT_ENGINE_ADD_LAYER(engine_, Slice, *input, start, size, stride);
layer->setInput(2, *size_layer->getOutput(0));
RreplenishLayerAndOutput(layer, "strided_slice", {output_name}, test_mode);
}
};
......
......@@ -79,6 +79,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"elementwise_sub",
"elementwise_mul",
"elementwise_div",
"elementwise_pow",
"dropout",
"prelu",
"conv2d_transpose",
......@@ -145,6 +146,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"elementwise_sub",
"elementwise_mul",
"elementwise_div",
"elementwise_pow",
"dropout",
"prelu",
"conv2d_transpose",
......@@ -958,9 +960,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
<< "strided_slice converter does not support trt versions below 7.0";
return false;
#endif
if (!with_dynamic_shape) {
return false;
}
if (!desc.HasAttr("axes") || !desc.HasAttr("starts") ||
!desc.HasAttr("ends") || !desc.HasAttr("strides")) {
VLOG(3)
......@@ -1026,7 +1025,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
}
if (op_type == "elementwise_add" || op_type == "elementwise_mul" ||
op_type == "elementwise_sub" || op_type == "elementwise_div") {
op_type == "elementwise_sub" || op_type == "elementwise_div" ||
op_type == "elementwise_pow") {
if (desc.Input("X").size() != 1) {
VLOG(3) << "The input op's Input(\"X\").size() "
"should equal to 1, but received Input(\"X\").size() = "
......@@ -1056,32 +1056,15 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
auto* y_var_desc = block->FindVar(desc.Input("Y")[0]);
const auto x_shape = x_var_desc->GetShape();
const auto y_shape = y_var_desc->GetShape();
if (op_type == "elementwise_add" && y_var_desc->Persistable()) {
if (y_shape.size() != 1) {
return false;
}
if (y_shape[0] != x_shape[1]) {
return false;
}
}
if (x_shape.size() == 1 && y_shape.size() == 1) {
VLOG(3) << "Now trt may not support two 1d tensor elementwise op.";
return false;
}
if (op_type == "elementwise_add" || op_type == "elementwise_mul") {
if (x_var_desc->Persistable()) {
VLOG(3) << "Input X is a parameter which is not supported for "
"elementwise_add/elementwise_mul in tensorrt, swap x and "
"y will work";
return false;
}
}
if (op_type == "elementwise_sub" || op_type == "elementwise_div") {
if (x_var_desc->Persistable() || y_var_desc->Persistable()) {
VLOG(3) << "Input X or Input Y is a parameter which is not supported "
"for elementwise_sub/elementwise_div in tensorrt";
return false;
}
if (x_var_desc->Persistable()) {
VLOG(3) << "Input X is a parameter which is not supported for "
"elementwise_add/elementwise_mul in tensorrt, swap x and "
"y will work";
return false;
}
}
......
......@@ -35,6 +35,19 @@ template <typename T>
struct Div {
__device__ T operator()(const T &a, const T &b) const { return a / b; }
};
template <typename T>
struct Sub {
__device__ T operator()(const T &a, const T &b) const { return a - b; }
};
template <typename T>
struct Pow {
__device__ T operator()(const T &a, const T &b) const {
return static_cast<T>(::powf(static_cast<float>(a), static_cast<float>(b)));
}
};
} // namespace details
template <typename T, typename Operator>
......@@ -139,6 +152,14 @@ int ElementWisePlugin::enqueue(int batch_size, const void *const *inputs,
elementwise_kernel<<<block, thread, 0, stream>>>(
num, x, y, out, prev_size_, batch_size * midd_size_, post_size_,
details::Div<float>());
} else if (type_ == "sub") {
elementwise_kernel<<<block, thread, 0, stream>>>(
num, x, y, out, prev_size_, batch_size * midd_size_, post_size_,
details::Sub<float>());
} else if (type_ == "pow") {
elementwise_kernel<<<block, thread, 0, stream>>>(
num, x, y, out, prev_size_, batch_size * midd_size_, post_size_,
details::Pow<float>());
} else {
PADDLE_THROW(platform::errors::Fatal(
"The %s type elementwise is not implemented in trt plugin.", type_));
......@@ -254,12 +275,18 @@ int ElementwisePluginDynamic::enqueue(
} else if (type_ == "div") {
elementwise_kernel<<<block, thread, 0, stream>>>(
num, x, y, out, prev_size, midd_size, post_size, details::Div<float>());
} else if (type_ == "sub") {
elementwise_kernel<<<block, thread, 0, stream>>>(
num, x, y, out, prev_size, midd_size, post_size, details::Sub<float>());
} else if (type_ == "pow") {
elementwise_kernel<<<block, thread, 0, stream>>>(
num, x, y, out, prev_size, midd_size, post_size, details::Pow<float>());
} else {
PADDLE_THROW(
platform::errors::Unimplemented("Paddle-TRT only support elementwise "
"operation: {add, mul, div} currently, "
"but got %s.",
type_));
PADDLE_THROW(platform::errors::Unimplemented(
"Paddle-TRT only support elementwise "
"operation: {add, mul, div, sub, pow} currently, "
"but got %s.",
type_));
}
return cudaGetLastError() != cudaSuccess;
......
......@@ -150,7 +150,7 @@ class TrtConvertElementwiseTest_two_input_without_broadcast(
for shape in [[4], [4, 32], [2, 64, 32], [1, 8, 16, 32]]:
for op_type in [
"elementwise_add", "elementwise_mul", "elementwise_sub",
"elementwise_div"
"elementwise_div", "elementwise_pow"
]:
for axis in [0, -1]:
self.dims = len(shape)
......@@ -309,7 +309,7 @@ class TrtConvertElementwiseTest_two_input_with_broadcast(TrtLayerAutoScanTest):
input2_shape = input2_shape_list[j][i]
for op_type in [
"elementwise_add", "elementwise_mul", "elementwise_sub",
"elementwise_div"
"elementwise_div", "elementwise_pow"
]:
for axis in axis_list[j][i]:
self.shape1 = input1_shape
......@@ -411,7 +411,7 @@ class TrtConvertElementwiseTest_one_input_corner_case(TrtLayerAutoScanTest):
[batch, 32, 16, 32]]:
for op_type in [
"elementwise_add", "elementwise_mul", "elementwise_sub",
"elementwise_div"
"elementwise_div", "elementwise_pow"
]:
for axis in [-1 if len(shape) == 1 else 1]:
self.dims = len(shape)
......@@ -511,18 +511,11 @@ class TrtConvertElementwiseTest_one_input_corner_case(TrtLayerAutoScanTest):
for weight_name in program_config.weights:
if weight_name in input_x_names:
return True
op_type = program_config.ops[0].type
if op_type in ["elementwise_sub", "elementwise_div"]:
input_y_names = program_config.ops[0].inputs["Y"]
for weight_name in program_config.weights:
if weight_name in input_y_names:
return True
return False
self.add_skip_case(
teller1, SkipReasons.TRT_NOT_SUPPORT,
"Input X should not be parameters in elementwise op and Input Y should not be parameters in elementwise_sub or elementwise_div op"
)
"Input X should not be parameters in elementwise op.")
def test(self):
self.add_skip_trt_case()
......
......@@ -113,6 +113,12 @@ class TrtConvertStridedSliceTest(TrtLayerAutoScanTest):
for i in range(len(program_config.ops))
]
# for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
......@@ -121,3 +127,7 @@ class TrtConvertStridedSliceTest(TrtLayerAutoScanTest):
def test(self):
self.run_test()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册