未验证 提交 feb68dd1 编写于 作者: W Wang Bojun 提交者: GitHub

Reverse roll fuse (#46914)

* pass

* pass

* draft version

* share mem opt

* remove sharemem

* add pattern for the case with circle_shift=0

* add UT

* pass opt

* test_fix

* code-commit

* code-style

* code style

* code-style

* ut-fix

* op teller refine

* resolve conflict

* adjust position op_teller list and pass order for swin

* ut code style update

* adjust paddle pass order

* refine pass order

* refine pass order

* refine pass order
上级 65ffc3f5
...@@ -135,6 +135,7 @@ if(WITH_TENSORRT) ...@@ -135,6 +135,7 @@ if(WITH_TENSORRT)
pass_library(remove_padding_recover_padding_pass inference) pass_library(remove_padding_recover_padding_pass inference)
pass_library(delete_remove_padding_recover_padding_pass inference) pass_library(delete_remove_padding_recover_padding_pass inference)
pass_library(layernorm_shift_partition_fuse_pass inference) pass_library(layernorm_shift_partition_fuse_pass inference)
pass_library(reverse_roll_fuse_pass inference)
pass_library(preln_layernorm_x_fuse_pass inference) pass_library(preln_layernorm_x_fuse_pass inference)
endif() endif()
......
...@@ -3799,6 +3799,70 @@ PDNode *patterns::LayernormShiftPartitionPattern::operator()() { ...@@ -3799,6 +3799,70 @@ PDNode *patterns::LayernormShiftPartitionPattern::operator()() {
return reshape4_out; return reshape4_out;
} }
PDNode *patterns::ReverseRollPattern::operator()(PDNode *in) {
in->AsInput();
auto reshape2_00_op =
pattern->NewNode(reshape2_00_op_repr())->assert_is_op("reshape2");
auto reshape2_00_out = pattern->NewNode(reshape2_00_out_repr())
->AsIntermediate()
->assert_is_op_output("reshape2", "Out")
->assert_is_op_input("reshape2", "X");
auto reshape2_10_op =
pattern->NewNode(reshape2_10_op_repr())->assert_is_op("reshape2");
auto reshape2_10_out = pattern->NewNode(reshape2_10_out_repr())
->AsIntermediate()
->assert_is_op_output("reshape2", "Out")
->assert_is_op_input("transpose2", "X");
auto transpose2_20_op =
pattern->NewNode(transpose2_20_op_repr())->assert_is_op("transpose2");
auto transpose2_20_out = pattern->NewNode(transpose2_20_out_repr())
->AsIntermediate()
->assert_is_op_output("transpose2", "Out")
->assert_is_op_input("reshape2", "X");
auto reshape2_30_op =
pattern->NewNode(reshape2_30_op_repr())->assert_is_op("reshape2");
auto reshape2_30_out = pattern->NewNode(reshape2_30_out_repr())
->AsIntermediate()
->assert_is_op_output("reshape2", "Out");
PDNode *roll_40_op = nullptr;
PDNode *roll_40_out = nullptr;
if (with_roll_) {
reshape2_30_out->assert_is_op_input("roll", "X");
roll_40_op = pattern->NewNode(roll_40_op_repr())->assert_is_op("roll");
roll_40_out = pattern->NewNode(roll_40_out_repr())
->AsIntermediate()
->assert_is_op_output("roll", "Out")
->assert_is_op_input("reshape2", "X");
} else {
reshape2_30_out->assert_is_op_input("reshape2", "X");
}
auto reshape2_50_op =
pattern->NewNode(reshape2_50_op_repr())->assert_is_op("reshape2");
auto reshape2_50_out = pattern->NewNode(reshaep2_50_out_repr())
->assert_is_op_output("reshape2", "Out")
->AsOutput();
reshape2_00_op->LinksFrom({in});
reshape2_00_out->LinksFrom({reshape2_00_op});
reshape2_10_op->LinksFrom({reshape2_00_out});
reshape2_10_out->LinksFrom({reshape2_10_op});
transpose2_20_op->LinksFrom({reshape2_10_out});
transpose2_20_out->LinksFrom({transpose2_20_op});
reshape2_30_op->LinksFrom({transpose2_20_out});
reshape2_30_out->LinksFrom({reshape2_30_op});
if (with_roll_) {
roll_40_op->LinksFrom({reshape2_30_out});
roll_40_out->LinksFrom({roll_40_op});
reshape2_50_op->LinksFrom({roll_40_out});
} else {
reshape2_50_op->LinksFrom({reshape2_30_out});
}
reshape2_50_out->LinksFrom({reshape2_50_op});
return reshape2_50_out;
}
PDNode *patterns::MergeLayernormPattern::operator()(PDNode *in) { PDNode *patterns::MergeLayernormPattern::operator()(PDNode *in) {
in->AsInput(); in->AsInput();
auto reshape2_00_op = auto reshape2_00_op =
......
...@@ -2053,6 +2053,34 @@ struct LayernormShiftPartitionPattern : public PatternBase { ...@@ -2053,6 +2053,34 @@ struct LayernormShiftPartitionPattern : public PatternBase {
PATTERN_DECL_NODE(reshape4_out); PATTERN_DECL_NODE(reshape4_out);
}; };
//
// \bref pattern looking for reverse circlic shift in window attention.
// The reverse circlic shift based on roll op,
// therefore, reverse_roll were adopted as pattern and fused op name.
//
struct ReverseRollPattern : public PatternBase {
ReverseRollPattern(PDPattern* pattern,
const std::string& name_scope,
bool with_roll)
: PatternBase(pattern, name_scope, "reverse_roll"),
with_roll_(with_roll) {}
PDNode* operator()(PDNode* in);
bool with_roll_;
PATTERN_DECL_NODE(reshape2_00_op);
PATTERN_DECL_NODE(reshape2_00_out);
PATTERN_DECL_NODE(reshape2_10_op);
PATTERN_DECL_NODE(reshape2_10_out);
PATTERN_DECL_NODE(transpose2_20_op);
PATTERN_DECL_NODE(transpose2_20_out);
PATTERN_DECL_NODE(reshape2_30_op);
PATTERN_DECL_NODE(reshape2_30_out);
PATTERN_DECL_NODE(roll_40_op);
PATTERN_DECL_NODE(roll_40_out);
PATTERN_DECL_NODE(reshape2_50_op);
PATTERN_DECL_NODE(reshaep2_50_out);
};
// pattern for merge_layernorm // pattern for merge_layernorm
struct MergeLayernormPattern : public PatternBase { struct MergeLayernormPattern : public PatternBase {
MergeLayernormPattern(PDPattern* pattern, const std::string& name_scope) MergeLayernormPattern(PDPattern* pattern, const std::string& name_scope)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/reverse_roll_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#define GET_IR_NODE(node__) \
GET_IR_NODE_FROM_SUBGRAPH(node__, node__, reverse_roll_pattern);
#define GET_NODES \
GET_IR_NODE(reshape2_00_op); \
GET_IR_NODE(reshape2_00_out); \
GET_IR_NODE(reshape2_10_op); \
GET_IR_NODE(reshape2_10_out); \
GET_IR_NODE(transpose2_20_op); \
GET_IR_NODE(transpose2_20_out); \
GET_IR_NODE(reshape2_30_op); \
GET_IR_NODE(reshape2_30_out); \
GET_IR_NODE(reshape2_50_op); \
GET_IR_NODE(reshaep2_50_out);
namespace paddle {
namespace framework {
namespace ir {
class Node;
ReverseRollFusePass::ReverseRollFusePass() {
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("roll"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsType<std::vector<int64_t>>()
.End()
.AddAttr("shifts")
.IsType<std::vector<int64_t>>()
.End();
}
int ReverseRollFusePass::ApplyPattern(ir::Graph* graph, bool with_roll) const {
PADDLE_ENFORCE_NOT_NULL(
graph,
platform::errors::InvalidArgument(
"The input graph of ReverseRollFusePass should not be "
"nullptr."));
GraphPatternDetector gpd;
FusePassBase::Init(scope_name_, graph);
PDNode* x = gpd.mutable_pattern()
->NewNode("x")
->assert_is_op_input("reshape2", "X")
->AsInput();
patterns::ReverseRollPattern reverse_roll_pattern(
gpd.mutable_pattern(), scope_name_, with_roll);
reverse_roll_pattern(x);
int fuse_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "reverse roll in op compat failed.";
return;
}
if (with_roll) {
VLOG(4) << "reverse_roll_fuse pass, shift_size>0, with roll op";
} else {
VLOG(4) << "reverse_roll_fuse pass, shift_size=0, without roll op";
}
GET_NODES;
Node* roll_40_op = nullptr;
Node* roll_40_out = nullptr;
if (with_roll) {
GET_IR_NODE_FROM_SUBGRAPH(
tmp_roll_40_op, roll_40_op, reverse_roll_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
tmp_roll_40_out, roll_40_out, reverse_roll_pattern);
roll_40_op = tmp_roll_40_op;
roll_40_out = tmp_roll_40_out;
}
std::unordered_set<const Node*> del_node_set = {reshape2_00_op,
reshape2_00_out,
reshape2_10_op,
reshape2_10_out,
transpose2_20_op,
transpose2_20_out,
reshape2_30_op,
reshape2_30_out,
reshape2_50_op};
if (with_roll) {
del_node_set.insert(roll_40_op);
del_node_set.insert(roll_40_out);
}
std::vector<int32_t> reshape2_10_attr_shape = PADDLE_GET_CONST(
std::vector<int32_t>, reshape2_10_op->Op()->GetAttr("shape"));
if (reshape2_10_attr_shape[1] <= 0) {
return;
}
if (reshape2_10_attr_shape[1] != reshape2_10_attr_shape[2]) {
return;
}
int window_number = reshape2_10_attr_shape[1] * reshape2_10_attr_shape[2];
std::vector<int> reshape_2_00_attr_shape = PADDLE_GET_CONST(
std::vector<int>, reshape2_00_op->Op()->GetAttr("shape"));
int window_size_h = reshape_2_00_attr_shape[1];
if (window_size_h <= 0) {
return;
}
int window_size_w = reshape_2_00_attr_shape[2];
if (window_size_h != window_size_w) {
return;
}
int window_size = window_size_h;
int window_len = window_size_h * window_size_w;
int input_resolution = reshape2_10_attr_shape[1] * window_size_h;
auto shift_size = 0;
if (with_roll) {
std::vector<int64_t> roll_40_op_attr_shifts = PADDLE_GET_CONST(
std::vector<int64_t>, roll_40_op->Op()->GetAttr("shifts"));
if (roll_40_op_attr_shifts[0] != roll_40_op_attr_shifts[1]) {
return;
}
shift_size = roll_40_op_attr_shifts[0];
}
OpDesc reverse_roll_desc(reshape2_00_op->Op()->Block());
reverse_roll_desc.SetType("reverse_roll");
reverse_roll_desc.SetInput("X", {subgraph.at(x)->Name()});
reverse_roll_desc.SetOutput("Out", {reshaep2_50_out->Name()});
reverse_roll_desc.SetAttr("window_number", window_number);
reverse_roll_desc.SetAttr("window_size", window_size);
reverse_roll_desc.SetAttr("window_len", window_len);
reverse_roll_desc.SetAttr("shift_size", static_cast<int>(shift_size));
reverse_roll_desc.SetAttr("input_resolution", input_resolution);
auto reverse_roll_node = graph->CreateOpNode(&reverse_roll_desc);
IR_NODE_LINK_TO(subgraph.at(x), reverse_roll_node);
IR_NODE_LINK_TO(reverse_roll_node, reshaep2_50_out);
GraphSafeRemoveNodes(graph, del_node_set);
++fuse_count;
};
gpd(graph, handler);
return fuse_count;
}
void ReverseRollFusePass::ApplyImpl(ir::Graph* graph) const {
int fuse_count = 0;
fuse_count += ApplyPattern(graph, true);
fuse_count += ApplyPattern(graph, false);
AddStatis(fuse_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(reverse_roll_fuse_pass,
paddle::framework::ir::ReverseRollFusePass);
REGISTER_PASS_CAPABILITY(reverse_roll_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("transpose2", 0)
.EQ("reshape2", 0));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
// |
// reshape2
// |
// reshape2
// |
// transpose2 -> reverse_roll (shift_size=0)
// | fuse
// reshape2
// |
// reshape2
// |
//
// or
//
// |
// reshape2
// |
// reshape2
// | -> reverse_roll (shift_size>0)
// transpose2 fuse
// |
// reshape2
// |
// roll
// |
// reshape2
// |
class ReverseRollFusePass : public FusePassBase {
public:
ReverseRollFusePass();
virtual ~ReverseRollFusePass() {}
protected:
void ApplyImpl(ir::Graph *graph) const override;
int ApplyPattern(ir::Graph *graph, bool with_roll) const;
private:
const std::string scope_name_{"reverse_roll_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -2338,6 +2338,7 @@ USE_TRT_CONVERTER(fill_constant) ...@@ -2338,6 +2338,7 @@ USE_TRT_CONVERTER(fill_constant)
USE_TRT_CONVERTER(fused_token_prune) USE_TRT_CONVERTER(fused_token_prune)
USE_TRT_CONVERTER(celu) USE_TRT_CONVERTER(celu)
USE_TRT_CONVERTER(layernorm_shift_partition) USE_TRT_CONVERTER(layernorm_shift_partition)
USE_TRT_CONVERTER(reverse_roll)
USE_TRT_CONVERTER(preln_layernorm_shift_partition) USE_TRT_CONVERTER(preln_layernorm_shift_partition)
USE_TRT_CONVERTER(merge_layernorm) USE_TRT_CONVERTER(merge_layernorm)
USE_TRT_CONVERTER(skip_merge_layernorm) USE_TRT_CONVERTER(skip_merge_layernorm)
......
...@@ -115,6 +115,7 @@ const std::vector<std::string> kTRTSubgraphPasses({ ...@@ -115,6 +115,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
"merge_layernorm_fuse_pass", // "merge_layernorm_fuse_pass", //
"preln_residual_bias_fuse_pass", // "preln_residual_bias_fuse_pass", //
"preln_layernorm_x_fuse_pass", // "preln_layernorm_x_fuse_pass", //
"reverse_roll_fuse_pass", //
// "set_transformer_input_convert_pass", // // "set_transformer_input_convert_pass", //
"conv_bn_fuse_pass", // "conv_bn_fuse_pass", //
"unsqueeze2_eltwise_fuse_pass", // "unsqueeze2_eltwise_fuse_pass", //
......
...@@ -82,6 +82,7 @@ list( ...@@ -82,6 +82,7 @@ list(
fused_token_prune_op.cc fused_token_prune_op.cc
celu_op.cc celu_op.cc
layernorm_shift_partition_op.cc layernorm_shift_partition_op.cc
reverse_roll_op.cc
tanhshrink_op.cc tanhshrink_op.cc
take_along_axis_op.cc take_along_axis_op.cc
logsigmoid_op.cc logsigmoid_op.cc
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/reverse_roll_op_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class ReverseRollOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
VLOG(4) << "convert a reverse_roll op to tensorrt "
"reverse_roll plugin";
framework::OpDesc op_desc(op, nullptr);
auto* X = engine_->GetITensor(op_desc.Input("X").front());
const int window_number =
PADDLE_GET_CONST(int, op_desc.GetAttr("window_number"));
const int window_size =
PADDLE_GET_CONST(int, op_desc.GetAttr("window_size"));
const int window_len = PADDLE_GET_CONST(int, op_desc.GetAttr("window_len"));
const int shift_size = PADDLE_GET_CONST(int, op_desc.GetAttr("shift_size"));
const int input_resolution =
PADDLE_GET_CONST(int, op_desc.GetAttr("input_resolution"));
PADDLE_ENFORCE_EQ(window_size * window_size,
window_len,
platform::errors::InvalidArgument(
"The window_len should equal to window_size * "
"window_size, but got window_size:%d, window_len:%d",
window_size,
window_len));
PADDLE_ENFORCE_EQ(
window_number * window_len,
input_resolution * input_resolution,
platform::errors::InvalidArgument(
"The input_resolution*input_resolution should equal to "
"window_number * window_len, but got window_len:%d, "
"window_number:%d, input_resolution:%d",
window_len,
window_number,
input_resolution));
bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
nvinfer1::ILayer* reverse_roll_layer = nullptr;
if (engine_->with_dynamic_shape()) {
plugin::ReverseRollPluginDynamic* plugin =
new plugin::ReverseRollPluginDynamic(window_number,
window_len,
window_size,
input_resolution,
shift_size,
with_fp16);
reverse_roll_layer = engine_->AddDynamicPlugin(&X, 1, plugin);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"ReverseROll TRT Plugin should run in dynamic shape."));
}
auto output_name = op_desc.Output("Out").front();
RreplenishLayerAndOutput(
reverse_roll_layer, "reverse_roll", {output_name}, test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(reverse_roll, ReverseRollOpConverter);
...@@ -2324,6 +2324,13 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2324,6 +2324,13 @@ struct SimpleOpTypeSetTeller : public Teller {
} }
} }
if (op_type == "reverse_roll") {
if (!with_dynamic_shape) {
VLOG(3) << "The reverse roll fused op does not support static shape "
"mode yet.";
return false;
}
}
if (op_type == "skip_merge_layernorm") { if (op_type == "skip_merge_layernorm") {
if (!with_dynamic_shape) { if (!with_dynamic_shape) {
VLOG(3) << "The merge_layernorm op does not support " VLOG(3) << "The merge_layernorm op does not support "
...@@ -2499,6 +2506,7 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2499,6 +2506,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"squeeze2", "squeeze2",
"unsqueeze2", "unsqueeze2",
"layernorm_shift_partition", "layernorm_shift_partition",
"reverse_roll",
"take_along_axis", "take_along_axis",
"tanh_shrink", "tanh_shrink",
"logsigmoid", "logsigmoid",
...@@ -2639,6 +2647,7 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2639,6 +2647,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"unsqueeze2", "unsqueeze2",
"fused_token_prune", "fused_token_prune",
"layernorm_shift_partition", "layernorm_shift_partition",
"reverse_roll",
"tanh_shrink", "tanh_shrink",
"take_along_axis", "take_along_axis",
"logsigmoid", "logsigmoid",
......
...@@ -33,6 +33,7 @@ list( ...@@ -33,6 +33,7 @@ list(
preln_residual_bias_plugin.cu preln_residual_bias_plugin.cu
fused_token_prune_op_plugin.cu fused_token_prune_op_plugin.cu
layernorm_shift_partition_op.cu layernorm_shift_partition_op.cu
reverse_roll_op_plugin.cu
prelnlayernorm_shift_partition_op.cu prelnlayernorm_shift_partition_op.cu
merge_layernorm_op_plugin.cu merge_layernorm_op_plugin.cu
skip_merge_layernorm_op_plugin.cu skip_merge_layernorm_op_plugin.cu
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/plugin/reverse_roll_op_plugin.h"
#include <vector>
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
/******************* invokeReverseRoll ***********************/
// src is [batch*window_num, window_len, dim]
// dst is [batch, H, W, dim] + rolled
// grid(W, H, batch)
// block(min(1024, dim))
template <typename T>
__global__ void reverse_roll(T *dst,
const T *src,
const int batch,
const int window_num,
const int window_len,
const int window_size,
const int H,
const int W,
const int shift_size,
const int dim) {
const int batch_idx = blockIdx.z;
const int H_idx_shifted = (blockIdx.y + shift_size) % H;
const int W_idx_shifted = (blockIdx.x + shift_size) % W;
const int H_idx = blockIdx.y;
const int W_idx = blockIdx.x;
const int window_idx =
H_idx / window_size * (W / window_size) + W_idx / window_size;
const int idx_in_window =
(H_idx % window_size) * window_size + (W_idx % window_size);
const int input_offset =
(batch_idx * window_num + window_idx) * window_len + idx_in_window;
const int output_offset = (batch_idx * H + H_idx_shifted) * W + W_idx_shifted;
for (int tid = threadIdx.x; tid < dim; tid += blockDim.x) {
dst[output_offset * dim + tid] = src[input_offset * dim + tid];
}
}
// src is [batch*window_num, window_len, dim]
// dst is [batch, H, W, dim] + rolled
// grid(W, H, batch)
// block(min(1024, dim))
template <typename T>
void invokeReverseRoll(T *dst,
const T *src,
int batch,
int window_num,
int window_len,
int window_size,
int H,
int W,
int dim,
int shift_size,
cudaStream_t stream) {
dim3 grid(W, H, batch);
int blockSize = dim;
if (std::is_same<T, half>::value && (dim % 2 == 0)) {
blockSize = dim / 2;
if (blockSize > 1024) {
blockSize = 1024;
}
using T2 = half2;
reverse_roll<<<grid, blockSize, 0, stream>>>(
reinterpret_cast<T2 *>(dst),
reinterpret_cast<const T2 *>(src),
batch,
window_num,
window_len,
window_size,
H,
W,
shift_size,
dim / 2);
} else {
if (blockSize > 1024) {
blockSize = 1024;
}
reverse_roll<<<grid, blockSize, 0, stream>>>(dst,
src,
batch,
window_num,
window_len,
window_size,
H,
W,
shift_size,
dim);
}
}
template void invokeReverseRoll(float *dst,
const float *src,
int batch,
int window_num,
int window_len,
int window_size,
int H,
int W,
int dim,
int shift_size,
cudaStream_t stream);
template void invokeReverseRoll(half *dst,
const half *src,
int batch,
int window_num,
int window_len,
int window_size,
int H,
int W,
int dim,
int shift_size,
cudaStream_t stream);
void ReverseRollPluginDynamic::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc *in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out,
int nbOutputs) TRT_NOEXCEPT {}
bool ReverseRollPluginDynamic::supportsFormatCombination(
int pos,
const nvinfer1::PluginTensorDesc *in_out,
int nb_inputs,
int nb_outputs) TRT_NOEXCEPT {
PADDLE_ENFORCE_NOT_NULL(
in_out,
platform::errors::InvalidArgument("The input of ReverseRoll "
"plugin shoule not be nullptr."));
PADDLE_ENFORCE_LT(
pos,
nb_inputs + nb_outputs,
platform::errors::InvalidArgument("The pos(%d) should be less than the "
"num(%d) of the input and the output.",
pos,
nb_inputs + nb_outputs));
const nvinfer1::PluginTensorDesc &in = in_out[pos];
if (pos == 0) {
if (with_fp16_) {
return in.type == nvinfer1::DataType::kHALF &&
in.format == nvinfer1::TensorFormat::kLINEAR;
} else {
return in.type == nvinfer1::DataType::kFLOAT &&
in.format == nvinfer1::TensorFormat::kLINEAR;
}
}
const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
// output
return in.type == prev.type && in.format == prev.format;
}
nvinfer1::DataType ReverseRollPluginDynamic::getOutputDataType(
int index,
const nvinfer1::DataType *input_types,
int nb_inputs) const TRT_NOEXCEPT {
PADDLE_ENFORCE_EQ(index,
0,
platform::errors::InvalidArgument(
"The ReverseRoll only has one input, so the "
"index value should be 0, but get %d.",
index));
return input_types[0];
}
nvinfer1::DimsExprs ReverseRollPluginDynamic::getOutputDimensions(
int output_index,
const nvinfer1::DimsExprs *inputs,
int nb_inputs,
nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT {
PADDLE_ENFORCE_EQ(output_index,
0,
platform::errors::InvalidArgument(
"There is only one output of the ReverseRoll, "
"so the index should be zero,"
"but it's (%d)",
output_index));
PADDLE_ENFORCE_EQ(
nb_inputs,
1,
platform::errors::InvalidArgument(
"The Input of the ReverseRoll should be 1, but we found "
"it has (%d) inputs",
nb_inputs));
nvinfer1::DimsExprs ret;
ret.nbDims = 3;
ret.d[0] = expr_builder.operation(nvinfer1::DimensionOperation::kFLOOR_DIV,
*inputs[0].d[0],
*expr_builder.constant(window_num_));
ret.d[1] = expr_builder.operation(nvinfer1::DimensionOperation::kPROD,
*inputs[0].d[1],
*expr_builder.constant(window_num_));
ret.d[2] = inputs[0].d[2];
return ret;
}
int ReverseRollPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc,
const void *const *inputs,
void *const *outputs,
void *workspace,
cudaStream_t stream) TRT_NOEXCEPT {
const auto &input_dims = input_desc[0].dims;
auto input_type = input_desc[0].type;
int batch = input_dims.d[0] / window_num_;
int dim = input_dims.d[2];
if (input_type == nvinfer1::DataType::kFLOAT) {
VLOG(3) << "TRT Plugin DataType selected. ReverseRoll-->fp32";
invokeReverseRoll(reinterpret_cast<float *>(outputs[0]),
reinterpret_cast<const float *>(inputs[0]),
batch,
window_num_,
window_len_,
window_size_,
input_resolution_,
input_resolution_,
dim,
shift_size_,
stream);
} else if (input_type == nvinfer1::DataType::kHALF) {
VLOG(3) << "TRT Plugin DataType selected. ReverseRoll-->fp16";
invokeReverseRoll(reinterpret_cast<half *>(outputs[0]),
reinterpret_cast<const half *>(inputs[0]),
batch,
window_num_,
window_len_,
window_size_,
input_resolution_,
input_resolution_,
dim,
shift_size_,
stream);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The ReverseRoll TRT Plugin's input type should be float or half."));
}
return cudaGetLastError() != cudaSuccess;
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
class ReverseRollPluginDynamic : public DynamicPluginTensorRT {
public:
ReverseRollPluginDynamic(int window_num,
int window_len,
int window_size,
int input_resolution,
int shift_size,
bool with_fp16)
: window_num_(window_num),
window_len_(window_len),
window_size_(window_size),
input_resolution_(input_resolution),
shift_size_(shift_size),
with_fp16_(with_fp16) {}
ReverseRollPluginDynamic(void const* serialData, size_t serialLength) {
DeserializeValue(&serialData, &serialLength, &window_num_);
DeserializeValue(&serialData, &serialLength, &window_len_);
DeserializeValue(&serialData, &serialLength, &window_size_);
DeserializeValue(&serialData, &serialLength, &input_resolution_);
DeserializeValue(&serialData, &serialLength, &shift_size_);
DeserializeValue(&serialData, &serialLength, &with_fp16_);
}
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
return new ReverseRollPluginDynamic(window_num_,
window_len_,
window_size_,
input_resolution_,
shift_size_,
with_fp16_);
}
const char* getPluginType() const TRT_NOEXCEPT override {
return "reverse_roll_dynamic";
}
int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
int initialize() TRT_NOEXCEPT override { return 0; }
size_t getSerializationSize() const TRT_NOEXCEPT override {
return SerializedSize(window_num_) + SerializedSize(window_len_) +
SerializedSize(window_size_) + SerializedSize(input_resolution_) +
SerializedSize(shift_size_) + SerializedSize(with_fp16_);
}
void serialize(void* buffer) const TRT_NOEXCEPT override {
SerializeValue(&buffer, window_num_);
SerializeValue(&buffer, window_len_);
SerializeValue(&buffer, window_size_);
SerializeValue(&buffer, input_resolution_);
SerializeValue(&buffer, shift_size_);
SerializeValue(&buffer, with_fp16_);
}
nvinfer1::DimsExprs getOutputDimensions(
int output_index,
const nvinfer1::DimsExprs* inputs,
int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) // NOLINT
TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const TRT_NOEXCEPT override {
return 0;
}
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) TRT_NOEXCEPT override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* inputTypes,
int nbInputs) const
TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override { delete this; }
private:
int window_num_;
int window_len_;
int window_size_;
int input_resolution_;
int shift_size_;
bool with_fp16_;
};
class ReverseRollPluginDynamicCreater : public TensorRTPluginCreator {
public:
const char* getPluginName() const TRT_NOEXCEPT override {
return "reverse_roll_dynamic";
}
const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length)
TRT_NOEXCEPT override {
return new ReverseRollPluginDynamic(serial_data, serial_length);
}
};
REGISTER_TRT_PLUGIN_V2(ReverseRollPluginDynamicCreater);
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
...@@ -161,7 +161,10 @@ if(WITH_GPU AND TENSORRT_FOUND) ...@@ -161,7 +161,10 @@ if(WITH_GPU AND TENSORRT_FOUND)
AND WITH_GPU) AND WITH_GPU)
set_tests_properties(test_emb_eltwise_layernorm_fuse_pass PROPERTIES TIMEOUT set_tests_properties(test_emb_eltwise_layernorm_fuse_pass PROPERTIES TIMEOUT
120) 120)
set_tests_properties(test_fc_fuse_pass PROPERTIES TIMEOUT 240) set_tests_properties(test_fc_fuse_pass PROPERTIES TIMEOUT 240)
set_tests_properties(test_reverse_roll_fuse_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_simplify_with_basic_ops_pass_autoscan set_tests_properties(test_simplify_with_basic_ops_pass_autoscan
PROPERTIES TIMEOUT 60) PROPERTIES TIMEOUT 60)
set_tests_properties(test_adaptive_pool2d_convert_global_pass_autoscan set_tests_properties(test_adaptive_pool2d_convert_global_pass_autoscan
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import unittest
from functools import partial
import hypothesis.strategies as st
import numpy as np
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
import paddle.inference as paddle_infer
class ReverseRollPass(PassAutoScanTest):
"""
|
reshape2
|
reshape2
|
transpose2
|
reshape2
|
roll
|
reshape2
|
"""
def sample_predictor_configs(self, program_config):
# trt with dynamic_shape
config = self.create_trt_inference_config()
config.enable_tensorrt_engine(
max_batch_size=4,
workspace_size=102400,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Float32,
use_static=False,
use_calib_mode=False,
)
config.set_trt_dynamic_shape_info(
{
"input0": [64, 9, 96],
},
{
"input0": [512, 144, 768],
},
{
"input0": [64, 49, 96],
},
)
yield config, ['reverse_roll'], (1e-5, 1e-5)
# trt dynamic_shape
config = self.create_trt_inference_config()
config.enable_tensorrt_engine(
max_batch_size=1,
workspace_size=102400,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Half,
use_static=False,
use_calib_mode=False,
)
config.set_trt_dynamic_shape_info(
{
"input0": [64, 9, 96],
},
{
"input0": [512, 144, 768],
},
{
"input0": [64, 49, 96],
},
)
yield config, ['reverse_roll'], (1e-3, 1e-3)
def sample_program_config(self, draw):
batch_size = draw(st.integers(min_value=1, max_value=4))
window_size = draw(st.sampled_from([3, 5, 7, 12]))
dim = draw(st.sampled_from([96, 192, 384, 768]))
window_number = 64
def generate_input(attrs):
return np.random.random(
[
attrs[0]["batch_size"] * attrs[1]["window_number"],
attrs[1]["window_size"] * attrs[1]["window_size"],
attrs[1]["dim"],
]
).astype(np.float32)
attrs = [
{"batch_size": batch_size},
{
"window_number": window_number,
"window_size": window_size,
"dim": dim,
},
]
reshape2_00 = OpConfig(
type="reshape2",
inputs={"X": ["input0"]},
outputs={
"Out": ["reshape2_00_out"],
"XShape": ["reshape2_00_outXshape"],
},
attrs={"shape": [-1, window_size, window_size, dim]},
)
reshape2_10 = OpConfig(
type="reshape2",
inputs={"X": ["reshape2_00_out"]},
outputs={
"Out": ["reshape2_10_out"],
"XShape": ["reshape2_10_outXshape"],
},
attrs={
"shape": [
-1,
int(math.sqrt(window_number)),
int(math.sqrt(window_number)),
window_size,
window_size,
dim,
]
},
)
transpose2_20 = OpConfig(
type="transpose2",
inputs={"X": ["reshape2_10_out"]},
outputs={
"Out": ["transpose2_20_out"],
"XShape": ["transpose2_20_outXshape"],
},
attrs={"axis": [0, 1, 3, 2, 4, 5]},
)
reshape2_30 = OpConfig(
type="reshape2",
inputs={"X": ["transpose2_20_out"]},
outputs={
"Out": ["reshape2_30_out"],
"XShape": ["reshape2_30_outXshape"],
},
attrs={
"shape": [
-1,
int(math.sqrt(window_number)) * window_size,
int(math.sqrt(window_number)) * window_size,
dim,
]
},
)
roll_30_1 = OpConfig(
type="roll",
inputs={"X": ["reshape2_30_out"]},
outputs={"Out": ["roll_30_1_out"]},
attrs={
"axis": [1, 2],
"shifts": [
math.floor(window_size // 2),
math.floor(window_size // 2),
],
},
)
reshape2_40 = OpConfig(
type="reshape2",
inputs={"X": ["roll_30_1_out"]},
outputs={
"Out": ["reshape2_40_out"],
"XShape": ["reshape2_40_outXshape"],
},
attrs={
"shape": [-1, window_number * window_size * window_size, dim]
},
)
program_config = ProgramConfig(
ops=[
reshape2_00,
reshape2_10,
transpose2_20,
reshape2_30,
roll_30_1,
reshape2_40,
],
weights={},
inputs={
"input0": TensorConfig(data_gen=partial(generate_input, attrs)),
},
outputs=["reshape2_40_out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=50,
passes=["reverse_roll_fuse_pass"],
max_duration=250,
min_success_num=50,
)
class ReverseRoll2Pass(PassAutoScanTest):
"""
|
reshape2
|
reshape2
|
transpose2
|
reshape2
|
reshape2
|
"""
def sample_predictor_configs(self, program_config):
config = self.create_trt_inference_config()
config.enable_tensorrt_engine(
max_batch_size=4,
workspace_size=102400,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Float32,
use_static=False,
use_calib_mode=False,
)
config.set_trt_dynamic_shape_info(
{
"input0": [64, 9, 96],
},
{
"input0": [512, 144, 768],
},
{
"input0": [64, 49, 96],
},
)
yield config, ['reverse_roll'], (1e-5, 1e-5)
# trt dynamic_shape
config = self.create_trt_inference_config()
config.enable_tensorrt_engine(
max_batch_size=1,
workspace_size=102400,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Half,
use_static=False,
use_calib_mode=False,
)
config.set_trt_dynamic_shape_info(
{
"input0": [64, 9, 96],
},
{
"input0": [512, 144, 768],
},
{
"input0": [64, 49, 96],
},
)
yield config, ['reverse_roll'], (1e-3, 1e-3)
def sample_program_config(self, draw):
batch_size = draw(st.integers(min_value=1, max_value=4))
window_size = draw(st.sampled_from([3, 5, 7, 12]))
dim = draw(st.sampled_from([96, 192, 384, 768]))
window_number = 64
def generate_input(attrs):
return np.random.random(
[
attrs[0]["batch_size"] * attrs[1]["window_number"],
attrs[1]["window_size"] * attrs[1]["window_size"],
attrs[1]["dim"],
]
).astype(np.float32)
attrs = [
{"batch_size": batch_size},
{
"window_number": window_number,
"window_size": window_size,
"dim": dim,
},
]
reshape2_00 = OpConfig(
type="reshape2",
inputs={"X": ["input0"]},
outputs={
"Out": ["reshape2_00_out"],
"XShape": ["reshape2_00_outXshape"],
},
attrs={"shape": [-1, window_size, window_size, dim]},
)
reshape2_10 = OpConfig(
type="reshape2",
inputs={"X": ["reshape2_00_out"]},
outputs={
"Out": ["reshape2_10_out"],
"XShape": ["reshape2_10_outXshape"],
},
attrs={
"shape": [
-1,
int(math.sqrt(window_number)),
int(math.sqrt(window_number)),
window_size,
window_size,
dim,
]
},
)
transpose2_20 = OpConfig(
type="transpose2",
inputs={"X": ["reshape2_10_out"]},
outputs={
"Out": ["transpose2_20_out"],
"XShape": ["transpose2_20_outXshape"],
},
attrs={"axis": [0, 1, 3, 2, 4, 5]},
)
reshape2_30 = OpConfig(
type="reshape2",
inputs={"X": ["transpose2_20_out"]},
outputs={
"Out": ["reshape2_30_out"],
"XShape": ["reshape2_30_outXshape"],
},
attrs={
"shape": [
-1,
int(math.sqrt(window_number)) * window_size,
int(math.sqrt(window_number)) * window_size,
dim,
]
},
)
reshape2_40 = OpConfig(
type="reshape2",
inputs={"X": ["reshape2_30_out"]},
outputs={
"Out": ["reshape2_40_out"],
"XShape": ["reshape2_40_outXshape"],
},
attrs={
"shape": [-1, window_number * window_size * window_size, dim]
},
)
program_config = ProgramConfig(
ops=[
reshape2_00,
reshape2_10,
transpose2_20,
reshape2_30,
reshape2_40,
],
weights={},
inputs={
"input0": TensorConfig(data_gen=partial(generate_input, attrs)),
},
outputs=["reshape2_40_out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=50,
passes=["reverse_roll_fuse_pass"],
max_duration=250,
min_success_num=50,
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册