diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 85deab25dee44b163bc8b412bec5a72db4e4e521..18ffa2661da48e5c10a8e462925cc37114232c28 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -135,6 +135,7 @@ if(WITH_TENSORRT) pass_library(remove_padding_recover_padding_pass inference) pass_library(delete_remove_padding_recover_padding_pass inference) pass_library(layernorm_shift_partition_fuse_pass inference) + pass_library(reverse_roll_fuse_pass inference) pass_library(preln_layernorm_x_fuse_pass inference) endif() diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index dd5edaaa9c821f35f54877ad948aaaa28bbfb996..9912cee3838db41e9c455439ef39c6a2f11dce96 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -3799,6 +3799,70 @@ PDNode *patterns::LayernormShiftPartitionPattern::operator()() { return reshape4_out; } +PDNode *patterns::ReverseRollPattern::operator()(PDNode *in) { + in->AsInput(); + auto reshape2_00_op = + pattern->NewNode(reshape2_00_op_repr())->assert_is_op("reshape2"); + + auto reshape2_00_out = pattern->NewNode(reshape2_00_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("reshape2", "Out") + ->assert_is_op_input("reshape2", "X"); + + auto reshape2_10_op = + pattern->NewNode(reshape2_10_op_repr())->assert_is_op("reshape2"); + auto reshape2_10_out = pattern->NewNode(reshape2_10_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("reshape2", "Out") + ->assert_is_op_input("transpose2", "X"); + + auto transpose2_20_op = + pattern->NewNode(transpose2_20_op_repr())->assert_is_op("transpose2"); + auto transpose2_20_out = pattern->NewNode(transpose2_20_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("transpose2", "Out") + ->assert_is_op_input("reshape2", "X"); + + auto reshape2_30_op = + pattern->NewNode(reshape2_30_op_repr())->assert_is_op("reshape2"); + auto reshape2_30_out = pattern->NewNode(reshape2_30_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("reshape2", "Out"); + PDNode *roll_40_op = nullptr; + PDNode *roll_40_out = nullptr; + if (with_roll_) { + reshape2_30_out->assert_is_op_input("roll", "X"); + roll_40_op = pattern->NewNode(roll_40_op_repr())->assert_is_op("roll"); + roll_40_out = pattern->NewNode(roll_40_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("roll", "Out") + ->assert_is_op_input("reshape2", "X"); + } else { + reshape2_30_out->assert_is_op_input("reshape2", "X"); + } + auto reshape2_50_op = + pattern->NewNode(reshape2_50_op_repr())->assert_is_op("reshape2"); + auto reshape2_50_out = pattern->NewNode(reshaep2_50_out_repr()) + ->assert_is_op_output("reshape2", "Out") + ->AsOutput(); + reshape2_00_op->LinksFrom({in}); + reshape2_00_out->LinksFrom({reshape2_00_op}); + reshape2_10_op->LinksFrom({reshape2_00_out}); + reshape2_10_out->LinksFrom({reshape2_10_op}); + transpose2_20_op->LinksFrom({reshape2_10_out}); + transpose2_20_out->LinksFrom({transpose2_20_op}); + reshape2_30_op->LinksFrom({transpose2_20_out}); + reshape2_30_out->LinksFrom({reshape2_30_op}); + if (with_roll_) { + roll_40_op->LinksFrom({reshape2_30_out}); + roll_40_out->LinksFrom({roll_40_op}); + reshape2_50_op->LinksFrom({roll_40_out}); + } else { + reshape2_50_op->LinksFrom({reshape2_30_out}); + } + reshape2_50_out->LinksFrom({reshape2_50_op}); + return reshape2_50_out; +} PDNode *patterns::MergeLayernormPattern::operator()(PDNode *in) { in->AsInput(); auto reshape2_00_op = diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index f8f985fa5994ec015dee8cc2a01a8331b1461135..8263a19756b1d3854fcd780f88914b01e8af45fb 100755 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -2053,6 +2053,34 @@ struct LayernormShiftPartitionPattern : public PatternBase { PATTERN_DECL_NODE(reshape4_out); }; +// +// \bref pattern looking for reverse circlic shift in window attention. +// The reverse circlic shift based on roll op, +// therefore, reverse_roll were adopted as pattern and fused op name. +// +struct ReverseRollPattern : public PatternBase { + ReverseRollPattern(PDPattern* pattern, + const std::string& name_scope, + bool with_roll) + : PatternBase(pattern, name_scope, "reverse_roll"), + with_roll_(with_roll) {} + + PDNode* operator()(PDNode* in); + bool with_roll_; + PATTERN_DECL_NODE(reshape2_00_op); + PATTERN_DECL_NODE(reshape2_00_out); + PATTERN_DECL_NODE(reshape2_10_op); + PATTERN_DECL_NODE(reshape2_10_out); + PATTERN_DECL_NODE(transpose2_20_op); + PATTERN_DECL_NODE(transpose2_20_out); + PATTERN_DECL_NODE(reshape2_30_op); + PATTERN_DECL_NODE(reshape2_30_out); + PATTERN_DECL_NODE(roll_40_op); + PATTERN_DECL_NODE(roll_40_out); + PATTERN_DECL_NODE(reshape2_50_op); + PATTERN_DECL_NODE(reshaep2_50_out); +}; + // pattern for merge_layernorm struct MergeLayernormPattern : public PatternBase { MergeLayernormPattern(PDPattern* pattern, const std::string& name_scope) diff --git a/paddle/fluid/framework/ir/reverse_roll_fuse_pass.cc b/paddle/fluid/framework/ir/reverse_roll_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..a48720f2bbb948e5213c1101db0c04e8f7399364 --- /dev/null +++ b/paddle/fluid/framework/ir/reverse_roll_fuse_pass.cc @@ -0,0 +1,202 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/reverse_roll_fuse_pass.h" + +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/op_proto_maker.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +#define GET_IR_NODE(node__) \ + GET_IR_NODE_FROM_SUBGRAPH(node__, node__, reverse_roll_pattern); +#define GET_NODES \ + GET_IR_NODE(reshape2_00_op); \ + GET_IR_NODE(reshape2_00_out); \ + GET_IR_NODE(reshape2_10_op); \ + GET_IR_NODE(reshape2_10_out); \ + GET_IR_NODE(transpose2_20_op); \ + GET_IR_NODE(transpose2_20_out); \ + GET_IR_NODE(reshape2_30_op); \ + GET_IR_NODE(reshape2_30_out); \ + GET_IR_NODE(reshape2_50_op); \ + GET_IR_NODE(reshaep2_50_out); + +namespace paddle { +namespace framework { +namespace ir { +class Node; +ReverseRollFusePass::ReverseRollFusePass() { + AddOpCompat(OpCompat("reshape2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("shape") + .IsType>() + .End(); + AddOpCompat(OpCompat("transpose2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("axis") + .IsType>() + .End(); + AddOpCompat(OpCompat("roll")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsType>() + .End() + .AddAttr("shifts") + .IsType>() + .End(); +} +int ReverseRollFusePass::ApplyPattern(ir::Graph* graph, bool with_roll) const { + PADDLE_ENFORCE_NOT_NULL( + graph, + platform::errors::InvalidArgument( + "The input graph of ReverseRollFusePass should not be " + "nullptr.")); + GraphPatternDetector gpd; + FusePassBase::Init(scope_name_, graph); + PDNode* x = gpd.mutable_pattern() + ->NewNode("x") + ->assert_is_op_input("reshape2", "X") + ->AsInput(); + patterns::ReverseRollPattern reverse_roll_pattern( + gpd.mutable_pattern(), scope_name_, with_roll); + reverse_roll_pattern(x); + int fuse_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "reverse roll in op compat failed."; + return; + } + if (with_roll) { + VLOG(4) << "reverse_roll_fuse pass, shift_size>0, with roll op"; + } else { + VLOG(4) << "reverse_roll_fuse pass, shift_size=0, without roll op"; + } + GET_NODES; + Node* roll_40_op = nullptr; + Node* roll_40_out = nullptr; + if (with_roll) { + GET_IR_NODE_FROM_SUBGRAPH( + tmp_roll_40_op, roll_40_op, reverse_roll_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + tmp_roll_40_out, roll_40_out, reverse_roll_pattern); + roll_40_op = tmp_roll_40_op; + roll_40_out = tmp_roll_40_out; + } + + std::unordered_set del_node_set = {reshape2_00_op, + reshape2_00_out, + reshape2_10_op, + reshape2_10_out, + transpose2_20_op, + transpose2_20_out, + reshape2_30_op, + reshape2_30_out, + reshape2_50_op}; + if (with_roll) { + del_node_set.insert(roll_40_op); + del_node_set.insert(roll_40_out); + } + + std::vector reshape2_10_attr_shape = PADDLE_GET_CONST( + std::vector, reshape2_10_op->Op()->GetAttr("shape")); + if (reshape2_10_attr_shape[1] <= 0) { + return; + } + if (reshape2_10_attr_shape[1] != reshape2_10_attr_shape[2]) { + return; + } + int window_number = reshape2_10_attr_shape[1] * reshape2_10_attr_shape[2]; + std::vector reshape_2_00_attr_shape = PADDLE_GET_CONST( + std::vector, reshape2_00_op->Op()->GetAttr("shape")); + int window_size_h = reshape_2_00_attr_shape[1]; + if (window_size_h <= 0) { + return; + } + int window_size_w = reshape_2_00_attr_shape[2]; + if (window_size_h != window_size_w) { + return; + } + int window_size = window_size_h; + int window_len = window_size_h * window_size_w; + int input_resolution = reshape2_10_attr_shape[1] * window_size_h; + + auto shift_size = 0; + if (with_roll) { + std::vector roll_40_op_attr_shifts = PADDLE_GET_CONST( + std::vector, roll_40_op->Op()->GetAttr("shifts")); + if (roll_40_op_attr_shifts[0] != roll_40_op_attr_shifts[1]) { + return; + } + shift_size = roll_40_op_attr_shifts[0]; + } + OpDesc reverse_roll_desc(reshape2_00_op->Op()->Block()); + reverse_roll_desc.SetType("reverse_roll"); + reverse_roll_desc.SetInput("X", {subgraph.at(x)->Name()}); + reverse_roll_desc.SetOutput("Out", {reshaep2_50_out->Name()}); + reverse_roll_desc.SetAttr("window_number", window_number); + reverse_roll_desc.SetAttr("window_size", window_size); + reverse_roll_desc.SetAttr("window_len", window_len); + reverse_roll_desc.SetAttr("shift_size", static_cast(shift_size)); + reverse_roll_desc.SetAttr("input_resolution", input_resolution); + auto reverse_roll_node = graph->CreateOpNode(&reverse_roll_desc); + IR_NODE_LINK_TO(subgraph.at(x), reverse_roll_node); + IR_NODE_LINK_TO(reverse_roll_node, reshaep2_50_out); + GraphSafeRemoveNodes(graph, del_node_set); + ++fuse_count; + }; + gpd(graph, handler); + return fuse_count; +} +void ReverseRollFusePass::ApplyImpl(ir::Graph* graph) const { + int fuse_count = 0; + fuse_count += ApplyPattern(graph, true); + fuse_count += ApplyPattern(graph, false); + AddStatis(fuse_count); +} +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(reverse_roll_fuse_pass, + paddle::framework::ir::ReverseRollFusePass); +REGISTER_PASS_CAPABILITY(reverse_roll_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("transpose2", 0) + .EQ("reshape2", 0)); diff --git a/paddle/fluid/framework/ir/reverse_roll_fuse_pass.h b/paddle/fluid/framework/ir/reverse_roll_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..f7c8229311ed97f6f5c347f3ddd4852dcfbed58b --- /dev/null +++ b/paddle/fluid/framework/ir/reverse_roll_fuse_pass.h @@ -0,0 +1,68 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +// | +// reshape2 +// | +// reshape2 +// | +// transpose2 -> reverse_roll (shift_size=0) +// | fuse +// reshape2 +// | +// reshape2 +// | +// +// or +// +// | +// reshape2 +// | +// reshape2 +// | -> reverse_roll (shift_size>0) +// transpose2 fuse +// | +// reshape2 +// | +// roll +// | +// reshape2 +// | + +class ReverseRollFusePass : public FusePassBase { + public: + ReverseRollFusePass(); + virtual ~ReverseRollFusePass() {} + + protected: + void ApplyImpl(ir::Graph *graph) const override; + int ApplyPattern(ir::Graph *graph, bool with_roll) const; + + private: + const std::string scope_name_{"reverse_roll_fuse"}; +}; +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 67e6478bffa7061fd4745e1e9341226f6aed7533..da09ecff079bd33e6b47bdd36360c6e3596f4e90 100755 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2338,6 +2338,7 @@ USE_TRT_CONVERTER(fill_constant) USE_TRT_CONVERTER(fused_token_prune) USE_TRT_CONVERTER(celu) USE_TRT_CONVERTER(layernorm_shift_partition) +USE_TRT_CONVERTER(reverse_roll) USE_TRT_CONVERTER(preln_layernorm_shift_partition) USE_TRT_CONVERTER(merge_layernorm) USE_TRT_CONVERTER(skip_merge_layernorm) diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 16db8bee9ecdae5efc359ca19a7361124ac28b55..ce55ba81200c62cf743e0b44a45f92adab40c030 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -115,6 +115,7 @@ const std::vector kTRTSubgraphPasses({ "merge_layernorm_fuse_pass", // "preln_residual_bias_fuse_pass", // "preln_layernorm_x_fuse_pass", // + "reverse_roll_fuse_pass", // // "set_transformer_input_convert_pass", // "conv_bn_fuse_pass", // "unsqueeze2_eltwise_fuse_pass", // diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index cec617c2f56a55751b7e51f9cfe95479dca1f0f8..2598b4c2ae0f0f6d97fd966146a1065f9ede5acd 100755 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -82,6 +82,7 @@ list( fused_token_prune_op.cc celu_op.cc layernorm_shift_partition_op.cc + reverse_roll_op.cc tanhshrink_op.cc take_along_axis_op.cc logsigmoid_op.cc diff --git a/paddle/fluid/inference/tensorrt/convert/reverse_roll_op.cc b/paddle/fluid/inference/tensorrt/convert/reverse_roll_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..9be2ad266e36475722b1156cb4ee73043d0c3daf --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/reverse_roll_op.cc @@ -0,0 +1,79 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/reverse_roll_op_plugin.h" +namespace paddle { +namespace inference { +namespace tensorrt { +class ReverseRollOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { + VLOG(4) << "convert a reverse_roll op to tensorrt " + "reverse_roll plugin"; + framework::OpDesc op_desc(op, nullptr); + + auto* X = engine_->GetITensor(op_desc.Input("X").front()); + const int window_number = + PADDLE_GET_CONST(int, op_desc.GetAttr("window_number")); + const int window_size = + PADDLE_GET_CONST(int, op_desc.GetAttr("window_size")); + const int window_len = PADDLE_GET_CONST(int, op_desc.GetAttr("window_len")); + const int shift_size = PADDLE_GET_CONST(int, op_desc.GetAttr("shift_size")); + const int input_resolution = + PADDLE_GET_CONST(int, op_desc.GetAttr("input_resolution")); + PADDLE_ENFORCE_EQ(window_size * window_size, + window_len, + platform::errors::InvalidArgument( + "The window_len should equal to window_size * " + "window_size, but got window_size:%d, window_len:%d", + window_size, + window_len)); + PADDLE_ENFORCE_EQ( + window_number * window_len, + input_resolution * input_resolution, + platform::errors::InvalidArgument( + "The input_resolution*input_resolution should equal to " + "window_number * window_len, but got window_len:%d, " + "window_number:%d, input_resolution:%d", + window_len, + window_number, + input_resolution)); + bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); + nvinfer1::ILayer* reverse_roll_layer = nullptr; + if (engine_->with_dynamic_shape()) { + plugin::ReverseRollPluginDynamic* plugin = + new plugin::ReverseRollPluginDynamic(window_number, + window_len, + window_size, + input_resolution, + shift_size, + with_fp16); + reverse_roll_layer = engine_->AddDynamicPlugin(&X, 1, plugin); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "ReverseROll TRT Plugin should run in dynamic shape.")); + } + auto output_name = op_desc.Output("Out").front(); + RreplenishLayerAndOutput( + reverse_roll_layer, "reverse_roll", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(reverse_roll, ReverseRollOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 7344755790fb1b063c4d66eab2b78d0178178f9d..c68ab4d6acacb8ebc001a62380dfe1fe99d9f176 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -2324,6 +2324,13 @@ struct SimpleOpTypeSetTeller : public Teller { } } + if (op_type == "reverse_roll") { + if (!with_dynamic_shape) { + VLOG(3) << "The reverse roll fused op does not support static shape " + "mode yet."; + return false; + } + } if (op_type == "skip_merge_layernorm") { if (!with_dynamic_shape) { VLOG(3) << "The merge_layernorm op does not support " @@ -2499,6 +2506,7 @@ struct SimpleOpTypeSetTeller : public Teller { "squeeze2", "unsqueeze2", "layernorm_shift_partition", + "reverse_roll", "take_along_axis", "tanh_shrink", "logsigmoid", @@ -2639,6 +2647,7 @@ struct SimpleOpTypeSetTeller : public Teller { "unsqueeze2", "fused_token_prune", "layernorm_shift_partition", + "reverse_roll", "tanh_shrink", "take_along_axis", "logsigmoid", diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index 2ecb8c8c71bcf9b6d58b810c65ea5e69c0bc3e87..875b2c3fc523abf381ebebd775b4e812d4131c27 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -33,6 +33,7 @@ list( preln_residual_bias_plugin.cu fused_token_prune_op_plugin.cu layernorm_shift_partition_op.cu + reverse_roll_op_plugin.cu prelnlayernorm_shift_partition_op.cu merge_layernorm_op_plugin.cu skip_merge_layernorm_op_plugin.cu diff --git a/paddle/fluid/inference/tensorrt/plugin/reverse_roll_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/reverse_roll_op_plugin.cu new file mode 100644 index 0000000000000000000000000000000000000000..8e27fbd6ef8f8b4980adef5f96d6dd26aec8cdc8 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/reverse_roll_op_plugin.cu @@ -0,0 +1,260 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/fluid/inference/tensorrt/plugin/reverse_roll_op_plugin.h" +#include + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +/******************* invokeReverseRoll ***********************/ +// src is [batch*window_num, window_len, dim] +// dst is [batch, H, W, dim] + rolled +// grid(W, H, batch) +// block(min(1024, dim)) + +template +__global__ void reverse_roll(T *dst, + const T *src, + const int batch, + const int window_num, + const int window_len, + const int window_size, + const int H, + const int W, + const int shift_size, + const int dim) { + const int batch_idx = blockIdx.z; + const int H_idx_shifted = (blockIdx.y + shift_size) % H; + const int W_idx_shifted = (blockIdx.x + shift_size) % W; + const int H_idx = blockIdx.y; + const int W_idx = blockIdx.x; + const int window_idx = + H_idx / window_size * (W / window_size) + W_idx / window_size; + const int idx_in_window = + (H_idx % window_size) * window_size + (W_idx % window_size); + const int input_offset = + (batch_idx * window_num + window_idx) * window_len + idx_in_window; + const int output_offset = (batch_idx * H + H_idx_shifted) * W + W_idx_shifted; + for (int tid = threadIdx.x; tid < dim; tid += blockDim.x) { + dst[output_offset * dim + tid] = src[input_offset * dim + tid]; + } +} + +// src is [batch*window_num, window_len, dim] +// dst is [batch, H, W, dim] + rolled +// grid(W, H, batch) +// block(min(1024, dim)) +template +void invokeReverseRoll(T *dst, + const T *src, + int batch, + int window_num, + int window_len, + int window_size, + int H, + int W, + int dim, + int shift_size, + cudaStream_t stream) { + dim3 grid(W, H, batch); + int blockSize = dim; + if (std::is_same::value && (dim % 2 == 0)) { + blockSize = dim / 2; + if (blockSize > 1024) { + blockSize = 1024; + } + using T2 = half2; + reverse_roll<<>>( + reinterpret_cast(dst), + reinterpret_cast(src), + batch, + window_num, + window_len, + window_size, + H, + W, + shift_size, + dim / 2); + } else { + if (blockSize > 1024) { + blockSize = 1024; + } + reverse_roll<<>>(dst, + src, + batch, + window_num, + window_len, + window_size, + H, + W, + shift_size, + dim); + } +} + +template void invokeReverseRoll(float *dst, + const float *src, + int batch, + int window_num, + int window_len, + int window_size, + int H, + int W, + int dim, + int shift_size, + cudaStream_t stream); + +template void invokeReverseRoll(half *dst, + const half *src, + int batch, + int window_num, + int window_len, + int window_size, + int H, + int W, + int dim, + int shift_size, + cudaStream_t stream); + +void ReverseRollPluginDynamic::configurePlugin( + const nvinfer1::DynamicPluginTensorDesc *in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc *out, + int nbOutputs) TRT_NOEXCEPT {} +bool ReverseRollPluginDynamic::supportsFormatCombination( + int pos, + const nvinfer1::PluginTensorDesc *in_out, + int nb_inputs, + int nb_outputs) TRT_NOEXCEPT { + PADDLE_ENFORCE_NOT_NULL( + in_out, + platform::errors::InvalidArgument("The input of ReverseRoll " + "plugin shoule not be nullptr.")); + PADDLE_ENFORCE_LT( + pos, + nb_inputs + nb_outputs, + platform::errors::InvalidArgument("The pos(%d) should be less than the " + "num(%d) of the input and the output.", + pos, + nb_inputs + nb_outputs)); + const nvinfer1::PluginTensorDesc &in = in_out[pos]; + if (pos == 0) { + if (with_fp16_) { + return in.type == nvinfer1::DataType::kHALF && + in.format == nvinfer1::TensorFormat::kLINEAR; + } else { + return in.type == nvinfer1::DataType::kFLOAT && + in.format == nvinfer1::TensorFormat::kLINEAR; + } + } + const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1]; + // output + return in.type == prev.type && in.format == prev.format; +} + +nvinfer1::DataType ReverseRollPluginDynamic::getOutputDataType( + int index, + const nvinfer1::DataType *input_types, + int nb_inputs) const TRT_NOEXCEPT { + PADDLE_ENFORCE_EQ(index, + 0, + platform::errors::InvalidArgument( + "The ReverseRoll only has one input, so the " + "index value should be 0, but get %d.", + index)); + return input_types[0]; +} + +nvinfer1::DimsExprs ReverseRollPluginDynamic::getOutputDimensions( + int output_index, + const nvinfer1::DimsExprs *inputs, + int nb_inputs, + nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT { + PADDLE_ENFORCE_EQ(output_index, + 0, + platform::errors::InvalidArgument( + "There is only one output of the ReverseRoll, " + "so the index should be zero," + "but it's (%d)", + output_index)); + PADDLE_ENFORCE_EQ( + nb_inputs, + 1, + platform::errors::InvalidArgument( + "The Input of the ReverseRoll should be 1, but we found " + "it has (%d) inputs", + nb_inputs)); + + nvinfer1::DimsExprs ret; + ret.nbDims = 3; + ret.d[0] = expr_builder.operation(nvinfer1::DimensionOperation::kFLOOR_DIV, + *inputs[0].d[0], + *expr_builder.constant(window_num_)); + ret.d[1] = expr_builder.operation(nvinfer1::DimensionOperation::kPROD, + *inputs[0].d[1], + *expr_builder.constant(window_num_)); + ret.d[2] = inputs[0].d[2]; + return ret; +} +int ReverseRollPluginDynamic::enqueue( + const nvinfer1::PluginTensorDesc *input_desc, + const nvinfer1::PluginTensorDesc *output_desc, + const void *const *inputs, + void *const *outputs, + void *workspace, + cudaStream_t stream) TRT_NOEXCEPT { + const auto &input_dims = input_desc[0].dims; + auto input_type = input_desc[0].type; + int batch = input_dims.d[0] / window_num_; + int dim = input_dims.d[2]; + if (input_type == nvinfer1::DataType::kFLOAT) { + VLOG(3) << "TRT Plugin DataType selected. ReverseRoll-->fp32"; + invokeReverseRoll(reinterpret_cast(outputs[0]), + reinterpret_cast(inputs[0]), + batch, + window_num_, + window_len_, + window_size_, + input_resolution_, + input_resolution_, + dim, + shift_size_, + stream); + } else if (input_type == nvinfer1::DataType::kHALF) { + VLOG(3) << "TRT Plugin DataType selected. ReverseRoll-->fp16"; + invokeReverseRoll(reinterpret_cast(outputs[0]), + reinterpret_cast(inputs[0]), + batch, + window_num_, + window_len_, + window_size_, + input_resolution_, + input_resolution_, + dim, + shift_size_, + stream); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "The ReverseRoll TRT Plugin's input type should be float or half.")); + } + return cudaGetLastError() != cudaSuccess; +} + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/reverse_roll_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/reverse_roll_op_plugin.h new file mode 100644 index 0000000000000000000000000000000000000000..328b596594006b015e0c468d8beed2df58fa9de0 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/reverse_roll_op_plugin.h @@ -0,0 +1,138 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +class ReverseRollPluginDynamic : public DynamicPluginTensorRT { + public: + ReverseRollPluginDynamic(int window_num, + int window_len, + int window_size, + int input_resolution, + int shift_size, + bool with_fp16) + : window_num_(window_num), + window_len_(window_len), + window_size_(window_size), + input_resolution_(input_resolution), + shift_size_(shift_size), + with_fp16_(with_fp16) {} + ReverseRollPluginDynamic(void const* serialData, size_t serialLength) { + DeserializeValue(&serialData, &serialLength, &window_num_); + DeserializeValue(&serialData, &serialLength, &window_len_); + DeserializeValue(&serialData, &serialLength, &window_size_); + DeserializeValue(&serialData, &serialLength, &input_resolution_); + DeserializeValue(&serialData, &serialLength, &shift_size_); + DeserializeValue(&serialData, &serialLength, &with_fp16_); + } + + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { + return new ReverseRollPluginDynamic(window_num_, + window_len_, + window_size_, + input_resolution_, + shift_size_, + with_fp16_); + } + const char* getPluginType() const TRT_NOEXCEPT override { + return "reverse_roll_dynamic"; + } + int getNbOutputs() const TRT_NOEXCEPT override { return 1; } + int initialize() TRT_NOEXCEPT override { return 0; } + size_t getSerializationSize() const TRT_NOEXCEPT override { + return SerializedSize(window_num_) + SerializedSize(window_len_) + + SerializedSize(window_size_) + SerializedSize(input_resolution_) + + SerializedSize(shift_size_) + SerializedSize(with_fp16_); + } + void serialize(void* buffer) const TRT_NOEXCEPT override { + SerializeValue(&buffer, window_num_); + SerializeValue(&buffer, window_len_); + SerializeValue(&buffer, window_size_); + SerializeValue(&buffer, input_resolution_); + SerializeValue(&buffer, shift_size_); + SerializeValue(&buffer, with_fp16_); + } + + nvinfer1::DimsExprs getOutputDimensions( + int output_index, + const nvinfer1::DimsExprs* inputs, + int nb_inputs, + nvinfer1::IExprBuilder& expr_builder) // NOLINT + TRT_NOEXCEPT override; + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) TRT_NOEXCEPT override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT override { + return 0; + } + + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const + TRT_NOEXCEPT override; + + void destroy() TRT_NOEXCEPT override { delete this; } + + private: + int window_num_; + int window_len_; + int window_size_; + int input_resolution_; + int shift_size_; + bool with_fp16_; +}; + +class ReverseRollPluginDynamicCreater : public TensorRTPluginCreator { + public: + const char* getPluginName() const TRT_NOEXCEPT override { + return "reverse_roll_dynamic"; + } + const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; } + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serial_data, + size_t serial_length) + TRT_NOEXCEPT override { + return new ReverseRollPluginDynamic(serial_data, serial_length); + } +}; +REGISTER_TRT_PLUGIN_V2(ReverseRollPluginDynamicCreater); + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt index 006814a56fc4f3cc3caff1d6867369d38e4bc4e7..ea4a5e0e5dd0119a55873dd223634fa9211809a3 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -161,7 +161,10 @@ if(WITH_GPU AND TENSORRT_FOUND) AND WITH_GPU) set_tests_properties(test_emb_eltwise_layernorm_fuse_pass PROPERTIES TIMEOUT 120) + set_tests_properties(test_fc_fuse_pass PROPERTIES TIMEOUT 240) + set_tests_properties(test_reverse_roll_fuse_pass PROPERTIES TIMEOUT 120) + set_tests_properties(test_simplify_with_basic_ops_pass_autoscan PROPERTIES TIMEOUT 60) set_tests_properties(test_adaptive_pool2d_convert_global_pass_autoscan diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_reverse_roll_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_reverse_roll_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..50cbb46da940f8c5c34fbbe8e13aae60cccec5ce --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_reverse_roll_fuse_pass.py @@ -0,0 +1,398 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import unittest +from functools import partial + +import hypothesis.strategies as st +import numpy as np +from auto_scan_test import PassAutoScanTest +from program_config import OpConfig, ProgramConfig, TensorConfig + +import paddle.inference as paddle_infer + + +class ReverseRollPass(PassAutoScanTest): + """ + | + reshape2 + | + reshape2 + | + transpose2 + | + reshape2 + | + roll + | + reshape2 + | + """ + + def sample_predictor_configs(self, program_config): + # trt with dynamic_shape + config = self.create_trt_inference_config() + config.enable_tensorrt_engine( + max_batch_size=4, + workspace_size=102400, + min_subgraph_size=0, + precision_mode=paddle_infer.PrecisionType.Float32, + use_static=False, + use_calib_mode=False, + ) + config.set_trt_dynamic_shape_info( + { + "input0": [64, 9, 96], + }, + { + "input0": [512, 144, 768], + }, + { + "input0": [64, 49, 96], + }, + ) + + yield config, ['reverse_roll'], (1e-5, 1e-5) + + # trt dynamic_shape + config = self.create_trt_inference_config() + config.enable_tensorrt_engine( + max_batch_size=1, + workspace_size=102400, + min_subgraph_size=0, + precision_mode=paddle_infer.PrecisionType.Half, + use_static=False, + use_calib_mode=False, + ) + config.set_trt_dynamic_shape_info( + { + "input0": [64, 9, 96], + }, + { + "input0": [512, 144, 768], + }, + { + "input0": [64, 49, 96], + }, + ) + + yield config, ['reverse_roll'], (1e-3, 1e-3) + + def sample_program_config(self, draw): + batch_size = draw(st.integers(min_value=1, max_value=4)) + window_size = draw(st.sampled_from([3, 5, 7, 12])) + dim = draw(st.sampled_from([96, 192, 384, 768])) + window_number = 64 + + def generate_input(attrs): + return np.random.random( + [ + attrs[0]["batch_size"] * attrs[1]["window_number"], + attrs[1]["window_size"] * attrs[1]["window_size"], + attrs[1]["dim"], + ] + ).astype(np.float32) + + attrs = [ + {"batch_size": batch_size}, + { + "window_number": window_number, + "window_size": window_size, + "dim": dim, + }, + ] + reshape2_00 = OpConfig( + type="reshape2", + inputs={"X": ["input0"]}, + outputs={ + "Out": ["reshape2_00_out"], + "XShape": ["reshape2_00_outXshape"], + }, + attrs={"shape": [-1, window_size, window_size, dim]}, + ) + reshape2_10 = OpConfig( + type="reshape2", + inputs={"X": ["reshape2_00_out"]}, + outputs={ + "Out": ["reshape2_10_out"], + "XShape": ["reshape2_10_outXshape"], + }, + attrs={ + "shape": [ + -1, + int(math.sqrt(window_number)), + int(math.sqrt(window_number)), + window_size, + window_size, + dim, + ] + }, + ) + transpose2_20 = OpConfig( + type="transpose2", + inputs={"X": ["reshape2_10_out"]}, + outputs={ + "Out": ["transpose2_20_out"], + "XShape": ["transpose2_20_outXshape"], + }, + attrs={"axis": [0, 1, 3, 2, 4, 5]}, + ) + reshape2_30 = OpConfig( + type="reshape2", + inputs={"X": ["transpose2_20_out"]}, + outputs={ + "Out": ["reshape2_30_out"], + "XShape": ["reshape2_30_outXshape"], + }, + attrs={ + "shape": [ + -1, + int(math.sqrt(window_number)) * window_size, + int(math.sqrt(window_number)) * window_size, + dim, + ] + }, + ) + roll_30_1 = OpConfig( + type="roll", + inputs={"X": ["reshape2_30_out"]}, + outputs={"Out": ["roll_30_1_out"]}, + attrs={ + "axis": [1, 2], + "shifts": [ + math.floor(window_size // 2), + math.floor(window_size // 2), + ], + }, + ) + reshape2_40 = OpConfig( + type="reshape2", + inputs={"X": ["roll_30_1_out"]}, + outputs={ + "Out": ["reshape2_40_out"], + "XShape": ["reshape2_40_outXshape"], + }, + attrs={ + "shape": [-1, window_number * window_size * window_size, dim] + }, + ) + + program_config = ProgramConfig( + ops=[ + reshape2_00, + reshape2_10, + transpose2_20, + reshape2_30, + roll_30_1, + reshape2_40, + ], + weights={}, + inputs={ + "input0": TensorConfig(data_gen=partial(generate_input, attrs)), + }, + outputs=["reshape2_40_out"], + ) + + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=50, + passes=["reverse_roll_fuse_pass"], + max_duration=250, + min_success_num=50, + ) + + +class ReverseRoll2Pass(PassAutoScanTest): + """ + | + reshape2 + | + reshape2 + | + transpose2 + | + reshape2 + | + reshape2 + | + """ + + def sample_predictor_configs(self, program_config): + config = self.create_trt_inference_config() + config.enable_tensorrt_engine( + max_batch_size=4, + workspace_size=102400, + min_subgraph_size=0, + precision_mode=paddle_infer.PrecisionType.Float32, + use_static=False, + use_calib_mode=False, + ) + config.set_trt_dynamic_shape_info( + { + "input0": [64, 9, 96], + }, + { + "input0": [512, 144, 768], + }, + { + "input0": [64, 49, 96], + }, + ) + + yield config, ['reverse_roll'], (1e-5, 1e-5) + + # trt dynamic_shape + config = self.create_trt_inference_config() + config.enable_tensorrt_engine( + max_batch_size=1, + workspace_size=102400, + min_subgraph_size=0, + precision_mode=paddle_infer.PrecisionType.Half, + use_static=False, + use_calib_mode=False, + ) + config.set_trt_dynamic_shape_info( + { + "input0": [64, 9, 96], + }, + { + "input0": [512, 144, 768], + }, + { + "input0": [64, 49, 96], + }, + ) + + yield config, ['reverse_roll'], (1e-3, 1e-3) + + def sample_program_config(self, draw): + batch_size = draw(st.integers(min_value=1, max_value=4)) + window_size = draw(st.sampled_from([3, 5, 7, 12])) + dim = draw(st.sampled_from([96, 192, 384, 768])) + window_number = 64 + + def generate_input(attrs): + return np.random.random( + [ + attrs[0]["batch_size"] * attrs[1]["window_number"], + attrs[1]["window_size"] * attrs[1]["window_size"], + attrs[1]["dim"], + ] + ).astype(np.float32) + + attrs = [ + {"batch_size": batch_size}, + { + "window_number": window_number, + "window_size": window_size, + "dim": dim, + }, + ] + reshape2_00 = OpConfig( + type="reshape2", + inputs={"X": ["input0"]}, + outputs={ + "Out": ["reshape2_00_out"], + "XShape": ["reshape2_00_outXshape"], + }, + attrs={"shape": [-1, window_size, window_size, dim]}, + ) + reshape2_10 = OpConfig( + type="reshape2", + inputs={"X": ["reshape2_00_out"]}, + outputs={ + "Out": ["reshape2_10_out"], + "XShape": ["reshape2_10_outXshape"], + }, + attrs={ + "shape": [ + -1, + int(math.sqrt(window_number)), + int(math.sqrt(window_number)), + window_size, + window_size, + dim, + ] + }, + ) + transpose2_20 = OpConfig( + type="transpose2", + inputs={"X": ["reshape2_10_out"]}, + outputs={ + "Out": ["transpose2_20_out"], + "XShape": ["transpose2_20_outXshape"], + }, + attrs={"axis": [0, 1, 3, 2, 4, 5]}, + ) + reshape2_30 = OpConfig( + type="reshape2", + inputs={"X": ["transpose2_20_out"]}, + outputs={ + "Out": ["reshape2_30_out"], + "XShape": ["reshape2_30_outXshape"], + }, + attrs={ + "shape": [ + -1, + int(math.sqrt(window_number)) * window_size, + int(math.sqrt(window_number)) * window_size, + dim, + ] + }, + ) + reshape2_40 = OpConfig( + type="reshape2", + inputs={"X": ["reshape2_30_out"]}, + outputs={ + "Out": ["reshape2_40_out"], + "XShape": ["reshape2_40_outXshape"], + }, + attrs={ + "shape": [-1, window_number * window_size * window_size, dim] + }, + ) + + program_config = ProgramConfig( + ops=[ + reshape2_00, + reshape2_10, + transpose2_20, + reshape2_30, + reshape2_40, + ], + weights={}, + inputs={ + "input0": TensorConfig(data_gen=partial(generate_input, attrs)), + }, + outputs=["reshape2_40_out"], + ) + + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=50, + passes=["reverse_roll_fuse_pass"], + max_duration=250, + min_success_num=50, + ) + + +if __name__ == "__main__": + unittest.main()