diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index cdd703e679d95cbea55dfda96810ad080a309789..8166c43e65db1fa7fb6e78884e67e695a88dfdd1 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -89,6 +89,7 @@ pass_library(delete_quant_dequant_filter_op_pass inference) pass_library(delete_weight_dequant_linear_op_pass inference) pass_library(delete_quant_dequant_linear_op_pass inference) pass_library(delete_dropout_op_pass inference) +pass_library(delete_fill_constant_op_pass inference) pass_library(simplify_with_basic_ops_pass base) pass_library(fc_elementwise_layernorm_fuse_pass base) pass_library(skip_layernorm_fuse_pass base) diff --git a/paddle/fluid/framework/ir/delete_fill_constant_op_pass.cc b/paddle/fluid/framework/ir/delete_fill_constant_op_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..e86bb2926b640b33eed8378166ab417048aa20db --- /dev/null +++ b/paddle/fluid/framework/ir/delete_fill_constant_op_pass.cc @@ -0,0 +1,103 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/delete_fill_constant_op_pass.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +template +void FillConstData(LoDTensor* out_t, T value) { + auto output_data = out_t->mutable_data(platform::CPUPlace()); + for (int i = 0; i < out_t->numel(); i++) { + output_data[i] = value; + } +} + +void DeleteFillConstantOpPass::ApplyImpl(ir::Graph* graph) const { + FusePassBase::Init("delete_fill_constant_op_pass", graph); + GraphPatternDetector detector; + auto fill_constant_op = detector.mutable_pattern() + ->NewNode("fill_constant") + ->assert_is_op("fill_constant") + ->assert_is_not_op_input("ValueTensor") + ->assert_is_not_op_input("str_value") + ->assert_is_not_op_input("ShapeTensor") + ->assert_is_not_op_input("ShapeTensorList"); + auto fill_constant_out = + detector.mutable_pattern() + ->NewNode("fill_constant_out") + ->assert_is_op_output("fill_constant") + ->assert_more([](Node* x) { return x->outputs.size() == 1UL; }); + auto next_op = detector.mutable_pattern() + ->NewNode("next_op") + ->assert_is_not_op_type("conditional_block") + ->assert_is_not_op_type("while"); + // Create the topological connections for the above pattern nodes. + fill_constant_op->LinksTo({fill_constant_out}); + next_op->LinksFrom({fill_constant_out}); + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + Node* fill_constant_op_node = subgraph.at(fill_constant_op); + Node* fill_constant_out_node = subgraph.at(fill_constant_out); + // Get fill_constant's attr + auto fill_constant = fill_constant_op_node->Op(); + auto value = BOOST_GET_CONST(float, fill_constant->GetAttr("value")); + auto shape = + BOOST_GET_CONST(std::vector, fill_constant->GetAttr("shape")); + auto* scope = param_scope(); + auto fill_constant_out_desc = fill_constant_out_node->Var(); + fill_constant_out_desc->SetShape(shape); + fill_constant_out_desc->SetPersistable(true); + auto* fill_constant_out_tensor = + scope->Var(fill_constant_out_desc->Name())->GetMutable(); + auto dtype = + framework::TransToPhiDataType(fill_constant_out_desc->GetDataType()); + fill_constant_out_tensor->Resize(phi::make_ddim(shape)); + switch (dtype) { + case paddle::experimental::DataType::BOOL: + FillConstData(fill_constant_out_tensor, static_cast(value)); + break; + case paddle::experimental::DataType::INT32: + FillConstData(fill_constant_out_tensor, + static_cast(value)); + break; + case paddle::experimental::DataType::INT64: + FillConstData(fill_constant_out_tensor, + static_cast(value)); + break; + case paddle::experimental::DataType::FLOAT32: + FillConstData(fill_constant_out_tensor, + static_cast(value)); + break; + default: + LOG(WARNING) << "Unsupported dtype for fill_constant op: " << dtype; + return; + } + // Remove links in graph + GraphSafeRemoveNodes(graph, {fill_constant_op_node}); + }; + + detector(graph, handler); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(delete_fill_constant_op_pass, + paddle::framework::ir::DeleteFillConstantOpPass); diff --git a/paddle/fluid/framework/ir/delete_fill_constant_op_pass.h b/paddle/fluid/framework/ir/delete_fill_constant_op_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..33d10f4502f2ab7e3c9d4d363361c7ee920070b2 --- /dev/null +++ b/paddle/fluid/framework/ir/delete_fill_constant_op_pass.h @@ -0,0 +1,39 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +class Graph; + +class DeleteFillConstantOpPass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + virtual ~DeleteFillConstantOpPass() = default; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 8c8d9fdddec851c9854ebb0c784d2b56d6dd8526..f7c1a68c826f0935fb6c551a744776679fc0bb69 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -408,6 +408,13 @@ PDNode *PDNode::assert_is_op(const std::string &op_type) { return this; } +PDNode *PDNode::assert_is_not_op_type(const std::string &op_type) { + asserts_.emplace_back([op_type](Node *x) { + return x && x->IsOp() && x->Op()->Type() != op_type; + }); + return this; +} + PDNode *PDNode::assert_is_var() { asserts_.emplace_back([](Node *x) { return x && x->IsVar(); }); return this; diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 9e5a82fc4458603da8b2b51587cad39047bc75e9..cab8f82660d901d0a8318ce4c5079adf6231ab54 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -110,6 +110,7 @@ struct PDNode { // Assertions, helper functions to simplify the pattern definition. PDNode* assert_is_op(); PDNode* assert_is_op(const std::string& op_type); + PDNode* assert_is_not_op_type(const std::string& op_type); PDNode* assert_is_var(); PDNode* assert_var_dtype(proto::VarType::Type dtype); PDNode* assert_is_not_ctrl_var(); diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 735e1b7be4c1fadacb9fc6fe90fb578863a5c32a..adc3fc46f72ac8898d1ba0565eedc3ded4f65989 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -633,6 +633,11 @@ void AnalysisConfig::Update() { (pass == "conv_bn_fuse_pass")) { continue; } + // delete_fill_constant_op_pass is not used under trt dynamic shape + if ((!min_input_shape_.empty() || trt_tuned_dynamic_shape_) && + pass == "delete_fill_constant_op_pass") { + continue; + } pass_builder()->AppendPass(pass); } } diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 6c81997d1356262332464d717044e242b2048811..13f81059df5e3320cb8166708e2f3c795548c504 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1731,6 +1731,10 @@ std::unique_ptr CreatePaddlePredictor( #if PADDLE_WITH_TENSORRT USE_TRT_CONVERTER(elementwise_add_weight); +USE_TRT_CONVERTER(elementwise_sub_weight); +USE_TRT_CONVERTER(elementwise_mul_weight); +USE_TRT_CONVERTER(elementwise_div_weight); +USE_TRT_CONVERTER(elementwise_pow_weight); USE_TRT_CONVERTER(elementwise_add_tensor); USE_TRT_CONVERTER(elementwise_sub_tensor); USE_TRT_CONVERTER(elementwise_div_tensor); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 77203b069e602a828073aa20f4bc0b1a70e64b21..fdb979283f76ecdf38d0082cc5b72470d3032ddf 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -85,6 +85,7 @@ const std::vector kTRTSubgraphPasses({ "adaptive_pool2d_convert_global_pass", "shuffle_channel_detect_pass", // "quant_conv2d_dequant_fuse_pass", // + "delete_fill_constant_op_pass", // "delete_quant_dequant_op_pass", // "delete_quant_dequant_filter_op_pass", // "delete_weight_dequant_linear_op_pass", // diff --git a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc index 8fd0e1bbd068db709130624fd5c68f008608644f..35d3ead0097203b5b45a96b789fbb45579126d6e 100644 --- a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc @@ -53,20 +53,14 @@ class ElementwiseWeightOpConverter : public OpConverter { auto output_name = op_desc.Output("Out")[0]; weight_data = engine_->GetWeightCPUData(op_desc.Input("Y").front(), Y_t); nvinfer1::Dims dims_x = X->getDimensions(); + std::vector dims_y = phi::vectorize(Y_t->dims()); auto regist_eltwise_weight = [&](nvinfer1::ScaleMode scale_mode) { - TensorRTEngine::Weight shift_weights{nvinfer1::DataType::kFLOAT, - static_cast(weight_data), - static_cast(Y_t->numel())}; - TensorRTEngine::Weight scale_weights{nvinfer1::DataType::kFLOAT, nullptr, - 0}; - TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr, - 0}; - nvinfer1::IShuffleLayer* expand_layer = nullptr; nvinfer1::IShuffleLayer* squeeze_layer = nullptr; int dynamic_shape_offset = engine_->with_dynamic_shape() ? 1 : 0; auto input_dim = X->getDimensions(); + // reshape if (input_dim.nbDims < 3 + dynamic_shape_offset) { nvinfer1::Dims expand_shape; expand_shape.nbDims = 3 + dynamic_shape_offset; @@ -85,17 +79,45 @@ class ElementwiseWeightOpConverter : public OpConverter { expand_layer->setName( ("Elewise: Shuffle: (Output: " + output_name + ")").c_str()); } + // eltwise_ops + TensorRTEngine::Weight shift_weights{nvinfer1::DataType::kFLOAT, nullptr, + 0}; + TensorRTEngine::Weight scale_weights{nvinfer1::DataType::kFLOAT, nullptr, + 0}; + TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr, + 0}; if (op_type_ == "add") { - nvinfer1::IScaleLayer* scale_layer = TRT_ENGINE_ADD_LAYER( - engine_, ScaleNd, *X, scale_mode, shift_weights.get(), - scale_weights.get(), power_weights.get(), dynamic_shape_offset); - layer = scale_layer; + shift_weights = TensorRTEngine::Weight( + nvinfer1::DataType::kFLOAT, static_cast(weight_data), + static_cast(Y_t->numel())); + } else if (op_type_ == "sub") { + for (int i = 0; i < Y_t->numel(); i++) { + weight_data[i] = -weight_data[i]; + } + shift_weights = TensorRTEngine::Weight( + nvinfer1::DataType::kFLOAT, static_cast(weight_data), + static_cast(Y_t->numel())); } else if (op_type_ == "mul") { - nvinfer1::IScaleLayer* scale_layer = TRT_ENGINE_ADD_LAYER( - engine_, Scale, *X, scale_mode, scale_weights.get(), - shift_weights.get(), power_weights.get()); - layer = scale_layer; + scale_weights = TensorRTEngine::Weight( + nvinfer1::DataType::kFLOAT, static_cast(weight_data), + static_cast(Y_t->numel())); + } else if (op_type_ == "div") { + for (int i = 0; i < Y_t->numel(); i++) { + weight_data[i] = 1.f / weight_data[i]; + } + scale_weights = TensorRTEngine::Weight( + nvinfer1::DataType::kFLOAT, static_cast(weight_data), + static_cast(Y_t->numel())); + } else if (op_type_ == "pow") { + power_weights = TensorRTEngine::Weight( + nvinfer1::DataType::kFLOAT, static_cast(weight_data), + static_cast(Y_t->numel())); } + nvinfer1::IScaleLayer* scale_layer = TRT_ENGINE_ADD_LAYER( + engine_, ScaleNd, *X, scale_mode, shift_weights.get(), + scale_weights.get(), power_weights.get(), dynamic_shape_offset); + layer = scale_layer; + // reshape if (input_dim.nbDims < 3 + dynamic_shape_offset) { nvinfer1::Dims squeeze_shape; squeeze_shape.nbDims = input_dim.nbDims; @@ -113,71 +135,43 @@ class ElementwiseWeightOpConverter : public OpConverter { } }; + // dynamic shape if (engine_->with_dynamic_shape()) { - if (Y_t->dims().size() == 1) { - auto scale_mode = nvinfer1::ScaleMode::kCHANNEL; - PADDLE_ENFORCE_EQ(Y_t->dims()[0], dims_x.d[1], - platform::errors::InvalidArgument( - "The Bias's size(%d) should be equal to the " - "first dim(%d) of the Input.", - Y_t->dims()[0], dims_x.d[1])); - regist_eltwise_weight(scale_mode); + if (dims_y.size() == 1 && dims_y[0] == dims_x.d[1]) { + regist_eltwise_weight(nvinfer1::ScaleMode::kCHANNEL); + } else if (dims_y.size() == 1 && dims_y[0] == 1) { + regist_eltwise_weight(nvinfer1::ScaleMode::kUNIFORM); + } else if (dims_y.size() == static_cast(dims_x.nbDims)) { + regist_eltwise_weight(nvinfer1::ScaleMode::kELEMENTWISE); } else { PADDLE_THROW(platform::errors::InvalidArgument( - "The size of input bias's dims is %d, but TensorRT dynamic shape " - "only support size = 1 for Elementwise op!", - Y_t->dims().size())); + "The size of input_y's dims is %d, but TensorRT dynamic shape " + "only support size = 1 or size = input_x.size() for Elementwise " + "op!", + dims_y.size())); } return; } + // static shape with dynamic batch std::vector no_batch_dims; int start_index = 0; - - for (; start_index < dims_x.nbDims; start_index++) + for (; start_index < dims_x.nbDims; start_index++) { no_batch_dims.push_back(dims_x.d[start_index]); - - auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE; - - std::vector dims_y = phi::vectorize(Y_t->dims()); - if (dims_y.size() == no_batch_dims.size() + 1) { - if (dims_y[0] == 1) dims_y.erase(dims_y.begin()); } - if (dims_y.size() == 1 && dims_y[0] == no_batch_dims[0]) { - scale_mode = nvinfer1::ScaleMode::kCHANNEL; - } else if (dims_y.size() == no_batch_dims.size() && - dims_y[0] == no_batch_dims[0]) { - scale_mode = nvinfer1::ScaleMode::kELEMENTWISE; - for (size_t i = 1; i < no_batch_dims.size(); i++) { - if (dims_y[i] != no_batch_dims[i]) { - scale_mode = nvinfer1::ScaleMode::kCHANNEL; - break; - } - } - if (scale_mode == nvinfer1::ScaleMode::kCHANNEL) { - for (size_t i = 1; i < no_batch_dims.size(); i++) { - if (dims_y[i] != 1) - PADDLE_THROW(platform::errors::InvalidArgument( - "The bias's %d dim is %d, but TensorRT dynamic shape only " - "support it equals to 1 for Elementwise op!", - i, dims_y[i])); - } - } + regist_eltwise_weight(nvinfer1::ScaleMode::kCHANNEL); + } else if (dims_y.size() == 1 && dims_y[0] == 1) { + regist_eltwise_weight(nvinfer1::ScaleMode::kUNIFORM); + } else if (dims_y.size() == no_batch_dims.size() + 1) { + regist_eltwise_weight(nvinfer1::ScaleMode::kELEMENTWISE); } else { - if (dims_y.size() >= 1) { - PADDLE_THROW(platform::errors::InvalidArgument( - "The size of bias's dims is %d and bias's size is %d. TensorRT " - "doesn't support this shape for Elementwise op!", - dims_y.size(), dims_y[0])); - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "The size of bias's dims is %d. TensorRT doesn't support " - "this shape for Elementwise op!", - dims_y.size())); - } + PADDLE_THROW(platform::errors::InvalidArgument( + "The size of input_y's dims is %d, but TensorRT dynamic shape " + "only support size = 1 or size = input_x.size() for Elementwise " + "op!", + dims_y.size())); } - regist_eltwise_weight(scale_mode); } protected: @@ -215,7 +209,6 @@ class ElementwiseTensorOpConverter : public OpConverter { auto common_func = [&](nvinfer1::ILayer* layer) { RreplenishLayerAndOutput(layer, "elementwise", {output_name}, test_mode); }; - if (dims_x.nbDims == dims_y.nbDims) { // The two input tensor should have the same dims VLOG(3) << "Convert a fluid elementwise op to TensorRT IElementWiseLayer"; @@ -244,7 +237,6 @@ class ElementwiseTensorOpConverter : public OpConverter { auto* plugin_layer = engine_->AddPlugin( inputs.data(), inputs.size(), reinterpret_cast(plugin)); - layer = plugin_layer; } } @@ -278,6 +270,21 @@ class ElementwiseWeightMulOpConverter : public ElementwiseWeightOpConverter { ElementwiseWeightMulOpConverter() { op_type_ = "mul"; } }; +class ElementwiseWeightSubOpConverter : public ElementwiseWeightOpConverter { + public: + ElementwiseWeightSubOpConverter() { op_type_ = "sub"; } +}; + +class ElementwiseWeightDivOpConverter : public ElementwiseWeightOpConverter { + public: + ElementwiseWeightDivOpConverter() { op_type_ = "div"; } +}; + +class ElementwiseWeightPowOpConverter : public ElementwiseWeightOpConverter { + public: + ElementwiseWeightPowOpConverter() { op_type_ = "pow"; } +}; + class ElementwiseTensorAddOpConverter : public ElementwiseTensorOpConverter { public: ElementwiseTensorAddOpConverter() { op_type_ = "add"; } @@ -321,6 +328,12 @@ REGISTER_TRT_OP_CONVERTER(elementwise_add_weight, ElementwiseWeightAddOpConverter); REGISTER_TRT_OP_CONVERTER(elementwise_mul_weight, ElementwiseWeightMulOpConverter); +REGISTER_TRT_OP_CONVERTER(elementwise_sub_weight, + ElementwiseWeightSubOpConverter); +REGISTER_TRT_OP_CONVERTER(elementwise_div_weight, + ElementwiseWeightDivOpConverter); +REGISTER_TRT_OP_CONVERTER(elementwise_pow_weight, + ElementwiseWeightPowOpConverter); REGISTER_TRT_OP_CONVERTER(elementwise_add_tensor, ElementwiseTensorAddOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h index f7eb7f859afaa3700ff3703992291e02188f1a2a..0a99b12edc25c0b27fbccdc2972f3f653bd2111f 100644 --- a/paddle/fluid/inference/tensorrt/convert/op_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -67,10 +67,8 @@ class OpConverter { if (op_desc.Type().find("elementwise") != std::string::npos) { static std::unordered_set add_tensor_op_set{ "add", "mul", "sub", "div", "max", "min", "pow"}; - // TODO(xingzhaolong): all mul, sub, div - // static std::unordered_set add_weight_op_set {"add", "mul", - // "sub", "div"}; - static std::unordered_set add_weight_op_set{"add", "mul"}; + static std::unordered_set add_weight_op_set{ + "add", "mul", "sub", "div", "pow"}; PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1UL, platform::errors::InvalidArgument( "The input op's Input(\"Y\")." diff --git a/paddle/fluid/inference/tensorrt/convert/strided_slice_op.cc b/paddle/fluid/inference/tensorrt/convert/strided_slice_op.cc index 26046d38bcbd9f47dbedc9fdef29280cb69d4055..9680e90b2e29d624e457ba829efdb3c9884f34e3 100644 --- a/paddle/fluid/inference/tensorrt/convert/strided_slice_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/strided_slice_op.cc @@ -39,7 +39,7 @@ class StridedSliceOpConverter : public OpConverter { framework::OpDesc op_desc(op, nullptr); auto* input = engine_->GetITensor(op_desc.Input("Input")[0]); nvinfer1::Dims input_dims = input->getDimensions(); - + auto output_name = op_desc.Output("Out")[0]; std::vector axes = BOOST_GET_CONST(std::vector, op_desc.GetAttr("axes")); std::vector starts = @@ -48,79 +48,116 @@ class StridedSliceOpConverter : public OpConverter { BOOST_GET_CONST(std::vector, op_desc.GetAttr("ends")); std::vector strides = BOOST_GET_CONST(std::vector, op_desc.GetAttr("strides")); - - nvinfer1::Dims start; - start.nbDims = input_dims.nbDims; int axes_size = axes.size(); - for (int i = 0; i < start.nbDims; i++) { - start.d[i] = 0; - } - for (int i = 0; i < axes_size; i++) { - start.d[axes[i]] = starts[i]; - } - + nvinfer1::Dims start; nvinfer1::Dims stride; - stride.nbDims = input_dims.nbDims; - for (int i = 0; i < stride.nbDims; i++) { - stride.d[i] = 1; - } - for (int i = 0; i < axes_size; i++) { - stride.d[axes[i]] = strides[i]; - } - nvinfer1::Dims size; + start.nbDims = input_dims.nbDims; + stride.nbDims = input_dims.nbDims; size.nbDims = input_dims.nbDims; - for (int i = 0; i < size.nbDims; i++) { - size.d[i] = 1; + for (int i = 0; i < input_dims.nbDims; i++) { + start.d[i] = 0; + stride.d[i] = 1; + size.d[i] = input_dims.d[i]; } - auto output_name = op_desc.Output("Out")[0]; - - auto create_weights = [&](const std::vector& data, - const std::string& type) -> int* { - std::unique_ptr tmp_tensor(new framework::Tensor()); - int data_size = data.size(); - tmp_tensor->Resize({data_size}); - auto* tmp_data = tmp_tensor->mutable_data(platform::CPUPlace()); - for (int i = 0; i < data_size; i++) { - tmp_data[i] = data[i]; + if (!engine_->with_dynamic_shape()) { + for (int i = 0; i < axes_size; i++) { + start.d[axes[i] - 1] = starts[i]; + } + for (int i = 0; i < axes_size; i++) { + stride.d[axes[i] - 1] = strides[i]; + } + for (int i = 0; i < axes_size; ++i) { + int dim = size.d[axes[i] - 1]; + if (dim > 0) { + int start = starts[i] < 0 ? (starts[i] + dim) : starts[i]; + int end = ends[i] < 0 ? (ends[i] + dim) : ends[i]; + int stride = std::abs(strides[i]); + start = std::max(start, 0); + end = std::max(end, 0); + end = std::min(end, dim); + size.d[axes[i] - 1] = (std::abs(end - start) + stride - 1) / stride; + } + } + auto* layer = + TRT_ENGINE_ADD_LAYER(engine_, Slice, *input, start, size, stride); + RreplenishLayerAndOutput(layer, "strided_slice", {output_name}, + test_mode); + } else { + for (int i = 0; i < axes_size; i++) { + start.d[axes[i]] = starts[i]; + } + for (int i = 0; i < axes_size; i++) { + stride.d[axes[i]] = strides[i]; + } + for (int i = 0; i < axes_size; ++i) { + int dim = size.d[axes[i]]; + if (dim > 0) { + int start = starts[i] < 0 ? (starts[i] + dim) : starts[i]; + int end = ends[i] < 0 ? (ends[i] + dim) : ends[i]; + int stride = std::abs(strides[i]); + start = std::max(start, 0); + end = std::max(end, 0); + end = std::min(end, dim); + size.d[axes[i]] = (std::abs(end - start) + stride - 1) / stride; + } } - engine_->SetWeights(output_name + "_add_slice_op_" + type, - std::move(tmp_tensor)); - return tmp_data; - }; + auto create_weights = [&](const std::vector& data, + const std::string& type) -> int* { + std::unique_ptr tmp_tensor(new framework::Tensor()); + int data_size = data.size(); + tmp_tensor->Resize({data_size}); + auto* tmp_data = tmp_tensor->mutable_data(platform::CPUPlace()); + for (int i = 0; i < data_size; i++) { + tmp_data[i] = data[i]; + } + + engine_->SetWeights(output_name + "_add_slice_op_" + type, + std::move(tmp_tensor)); + return tmp_data; + }; + + std::vector const_weight(input_dims.nbDims, 0); + for (int i = 0; i < axes_size; i++) { + int dim = input_dims.d[axes[i]]; + int start = starts[i] < 0 ? (starts[i] + dim) : starts[i]; + int end = ends[i] < 0 ? (ends[i] + dim) : ends[i]; + int stride = std::abs(strides[i]); + start = std::max(start, 0); + end = std::max(end, 0); + end = std::min(end, dim); + const_weight[axes[i]] = + dim - ((std::abs(end - start) + stride - 1) / stride); + } - std::vector const_weight(input_dims.nbDims, 1); - for (int i = 0; i < axes_size; i++) { - const_weight[axes[i]] = strides[i]; + int* weight_data = create_weights(const_weight, "size"); + + TensorRTEngine::Weight weight{nvinfer1::DataType::kINT32, + static_cast(weight_data), + static_cast(input_dims.nbDims)}; + + int input_dim_size = input_dims.nbDims; + nvinfer1::Dims input_shape; + input_shape.nbDims = 1; + input_shape.d[0] = input_dim_size; + + auto const_layer = + TRT_ENGINE_ADD_LAYER(engine_, Constant, input_shape, weight.get()); + + auto shape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shape, *input); + // slice layer + auto* layer = + TRT_ENGINE_ADD_LAYER(engine_, Slice, *input, start, size, stride); + // elementwise layer for get size tensor + auto size_layer = TRT_ENGINE_ADD_LAYER( + engine_, ElementWise, *shape_layer->getOutput(0), + *const_layer->getOutput(0), nvinfer1::ElementWiseOperation::kSUB); + layer->setInput(2, *size_layer->getOutput(0)); + RreplenishLayerAndOutput(layer, "strided_slice", {output_name}, + test_mode); } - - int* weight_data = create_weights(const_weight, "size"); - - TensorRTEngine::Weight weight{nvinfer1::DataType::kINT32, - static_cast(weight_data), - static_cast(input_dims.nbDims)}; - - int input_dim_size = input_dims.nbDims; - nvinfer1::Dims input_shape; - input_shape.nbDims = 1; - input_shape.d[0] = input_dim_size; - - auto const_layer = - TRT_ENGINE_ADD_LAYER(engine_, Constant, input_shape, weight.get()); - - auto shape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shape, *input); - - auto size_layer = TRT_ENGINE_ADD_LAYER( - engine_, ElementWise, *shape_layer->getOutput(0), - *const_layer->getOutput(0), nvinfer1::ElementWiseOperation::kDIV); - - auto* layer = - TRT_ENGINE_ADD_LAYER(engine_, Slice, *input, start, size, stride); - layer->setInput(2, *size_layer->getOutput(0)); - - RreplenishLayerAndOutput(layer, "strided_slice", {output_name}, test_mode); } }; diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index ba5b28a4dfed9fa114fe38276c6dc18d3931610d..cbe151294db099040c70be79342fd94ef9106658 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -79,6 +79,7 @@ struct SimpleOpTypeSetTeller : public Teller { "elementwise_sub", "elementwise_mul", "elementwise_div", + "elementwise_pow", "dropout", "prelu", "conv2d_transpose", @@ -145,6 +146,7 @@ struct SimpleOpTypeSetTeller : public Teller { "elementwise_sub", "elementwise_mul", "elementwise_div", + "elementwise_pow", "dropout", "prelu", "conv2d_transpose", @@ -958,9 +960,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, << "strided_slice converter does not support trt versions below 7.0"; return false; #endif - if (!with_dynamic_shape) { - return false; - } if (!desc.HasAttr("axes") || !desc.HasAttr("starts") || !desc.HasAttr("ends") || !desc.HasAttr("strides")) { VLOG(3) @@ -1026,7 +1025,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, } if (op_type == "elementwise_add" || op_type == "elementwise_mul" || - op_type == "elementwise_sub" || op_type == "elementwise_div") { + op_type == "elementwise_sub" || op_type == "elementwise_div" || + op_type == "elementwise_pow") { if (desc.Input("X").size() != 1) { VLOG(3) << "The input op's Input(\"X\").size() " "should equal to 1, but received Input(\"X\").size() = " @@ -1056,32 +1056,15 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, auto* y_var_desc = block->FindVar(desc.Input("Y")[0]); const auto x_shape = x_var_desc->GetShape(); const auto y_shape = y_var_desc->GetShape(); - if (op_type == "elementwise_add" && y_var_desc->Persistable()) { - if (y_shape.size() != 1) { - return false; - } - if (y_shape[0] != x_shape[1]) { - return false; - } - } if (x_shape.size() == 1 && y_shape.size() == 1) { VLOG(3) << "Now trt may not support two 1d tensor elementwise op."; return false; } - if (op_type == "elementwise_add" || op_type == "elementwise_mul") { - if (x_var_desc->Persistable()) { - VLOG(3) << "Input X is a parameter which is not supported for " - "elementwise_add/elementwise_mul in tensorrt, swap x and " - "y will work"; - return false; - } - } - if (op_type == "elementwise_sub" || op_type == "elementwise_div") { - if (x_var_desc->Persistable() || y_var_desc->Persistable()) { - VLOG(3) << "Input X or Input Y is a parameter which is not supported " - "for elementwise_sub/elementwise_div in tensorrt"; - return false; - } + if (x_var_desc->Persistable()) { + VLOG(3) << "Input X is a parameter which is not supported for " + "elementwise_add/elementwise_mul in tensorrt, swap x and " + "y will work"; + return false; } } diff --git a/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.cu index c9163e62a2e19ea9c4449a5eaffd637844710d6d..1070a88cee7372cdbe6bcbef83681c624b7470a2 100644 --- a/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.cu @@ -35,6 +35,19 @@ template struct Div { __device__ T operator()(const T &a, const T &b) const { return a / b; } }; + +template +struct Sub { + __device__ T operator()(const T &a, const T &b) const { return a - b; } +}; + +template +struct Pow { + __device__ T operator()(const T &a, const T &b) const { + return static_cast(::powf(static_cast(a), static_cast(b))); + } +}; + } // namespace details template @@ -139,6 +152,14 @@ int ElementWisePlugin::enqueue(int batch_size, const void *const *inputs, elementwise_kernel<<>>( num, x, y, out, prev_size_, batch_size * midd_size_, post_size_, details::Div()); + } else if (type_ == "sub") { + elementwise_kernel<<>>( + num, x, y, out, prev_size_, batch_size * midd_size_, post_size_, + details::Sub()); + } else if (type_ == "pow") { + elementwise_kernel<<>>( + num, x, y, out, prev_size_, batch_size * midd_size_, post_size_, + details::Pow()); } else { PADDLE_THROW(platform::errors::Fatal( "The %s type elementwise is not implemented in trt plugin.", type_)); @@ -254,12 +275,18 @@ int ElementwisePluginDynamic::enqueue( } else if (type_ == "div") { elementwise_kernel<<>>( num, x, y, out, prev_size, midd_size, post_size, details::Div()); + } else if (type_ == "sub") { + elementwise_kernel<<>>( + num, x, y, out, prev_size, midd_size, post_size, details::Sub()); + } else if (type_ == "pow") { + elementwise_kernel<<>>( + num, x, y, out, prev_size, midd_size, post_size, details::Pow()); } else { - PADDLE_THROW( - platform::errors::Unimplemented("Paddle-TRT only support elementwise " - "operation: {add, mul, div} currently, " - "but got %s.", - type_)); + PADDLE_THROW(platform::errors::Unimplemented( + "Paddle-TRT only support elementwise " + "operation: {add, mul, div, sub, pow} currently, " + "but got %s.", + type_)); } return cudaGetLastError() != cudaSuccess; diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py index ec02a357a48b6a79150bd82705122e354fdc3364..27d8247aded5a26a7f535b6ce99727c995eebc1a 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py @@ -150,7 +150,7 @@ class TrtConvertElementwiseTest_two_input_without_broadcast( for shape in [[4], [4, 32], [2, 64, 32], [1, 8, 16, 32]]: for op_type in [ "elementwise_add", "elementwise_mul", "elementwise_sub", - "elementwise_div" + "elementwise_div", "elementwise_pow" ]: for axis in [0, -1]: self.dims = len(shape) @@ -309,7 +309,7 @@ class TrtConvertElementwiseTest_two_input_with_broadcast(TrtLayerAutoScanTest): input2_shape = input2_shape_list[j][i] for op_type in [ "elementwise_add", "elementwise_mul", "elementwise_sub", - "elementwise_div" + "elementwise_div", "elementwise_pow" ]: for axis in axis_list[j][i]: self.shape1 = input1_shape @@ -411,7 +411,7 @@ class TrtConvertElementwiseTest_one_input_corner_case(TrtLayerAutoScanTest): [batch, 32, 16, 32]]: for op_type in [ "elementwise_add", "elementwise_mul", "elementwise_sub", - "elementwise_div" + "elementwise_div", "elementwise_pow" ]: for axis in [-1 if len(shape) == 1 else 1]: self.dims = len(shape) @@ -511,18 +511,11 @@ class TrtConvertElementwiseTest_one_input_corner_case(TrtLayerAutoScanTest): for weight_name in program_config.weights: if weight_name in input_x_names: return True - op_type = program_config.ops[0].type - if op_type in ["elementwise_sub", "elementwise_div"]: - input_y_names = program_config.ops[0].inputs["Y"] - for weight_name in program_config.weights: - if weight_name in input_y_names: - return True return False self.add_skip_case( teller1, SkipReasons.TRT_NOT_SUPPORT, - "Input X should not be parameters in elementwise op and Input Y should not be parameters in elementwise_sub or elementwise_div op" - ) + "Input X should not be parameters in elementwise op.") def test(self): self.add_skip_trt_case() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_strided_slice.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_strided_slice.py index 6a204ebbad27d7a5738cc28e62c89756502a329f..8bc48047c1397409d843efcbbeca342041ef8b10 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_strided_slice.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_strided_slice.py @@ -113,6 +113,12 @@ class TrtConvertStridedSliceTest(TrtLayerAutoScanTest): for i in range(len(program_config.ops)) ] + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + # for dynamic_shape generate_dynamic_shape(attrs) self.trt_param.precision = paddle_infer.PrecisionType.Float32 @@ -121,3 +127,7 @@ class TrtConvertStridedSliceTest(TrtLayerAutoScanTest): def test(self): self.run_test() + + +if __name__ == "__main__": + unittest.main()