未验证 提交 591be3bd 编写于 作者: W wenbin 提交者: GitHub

Preln groupnorm (#49463)

* skip_groupnorm

* init

* preln

* add ut

* more assert

* set timeout

* fix windows ci issue
上级 aaa25222
......@@ -139,6 +139,8 @@ if(WITH_TENSORRT)
pass_library(layernorm_shift_partition_fuse_pass inference)
pass_library(reverse_roll_fuse_pass inference)
pass_library(preln_layernorm_x_fuse_pass inference)
pass_library(elementwise_groupnorm_act_pass inference)
pass_library(preln_elementwise_groupnorm_act_pass inference)
endif()
if(WITH_TENSORRT)
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/elementwise_groupnorm_act_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
class Node;
} // namespace ir
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct SkipGroupNormAct : public PatternBase {
SkipGroupNormAct(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "skip_groupnorm_act") {}
void operator()(PDNode *x, PDNode *y);
// declare operator node's name
PATTERN_DECL_NODE(elementwise);
PATTERN_DECL_NODE(group_norm);
// declare variable node's name
PATTERN_DECL_NODE(elementwise_out);
PATTERN_DECL_NODE(group_norm_bias);
PATTERN_DECL_NODE(group_norm_scale);
PATTERN_DECL_NODE(group_norm_out);
PATTERN_DECL_NODE(act);
PATTERN_DECL_NODE(act_out);
};
void SkipGroupNormAct::operator()(PDNode *x, PDNode *y) {
auto *elementwise = pattern->NewNode(elementwise_repr())
->assert_is_op("elementwise_add")
->assert_has_n_outputs(1);
auto *elementwise_out_var =
pattern->NewNode(elementwise_out_repr())
->assert_is_op_output("elementwise_add", "Out")
->assert_has_n_outputs(1)
->assert_is_op_input("group_norm", "X");
elementwise->LinksFrom({x, y}).LinksTo({elementwise_out_var});
// Create nodes for group_norm op.
auto *group_norm =
pattern->NewNode(group_norm_repr())->assert_is_op("group_norm");
auto *group_norm_bias_var = pattern->NewNode(group_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("group_norm", "Bias");
auto *group_norm_scale_var = pattern->NewNode(group_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("group_norm", "Scale");
auto *group_norm_out_var = pattern->NewNode(group_norm_out_repr())
->AsOutput()
->assert_is_op_output("group_norm", "Y")
->assert_is_op_input("silu", "X");
// Add links for group_norm op.
group_norm
->LinksFrom(
{elementwise_out_var, group_norm_bias_var, group_norm_scale_var})
.LinksTo({group_norm_out_var});
auto *act = pattern->NewNode(act_repr())->assert_is_op("silu");
auto *act_out = pattern->NewNode(act_out_repr())
->AsOutput()
->assert_is_op_output("silu", "Out");
act->LinksFrom({group_norm_out_var}).LinksTo({act_out});
}
} // namespace patterns
int SkipGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
FusePassBase::Init("skip_groupnorm_silu_fuse", graph);
int found_subgraph_count = 0;
GraphPatternDetector gpd;
PDNode *x = nullptr;
PDNode *y = nullptr;
x = gpd.mutable_pattern()
->NewNode("skip_groupnorm_act_fuse/x")
->AsInput()
->assert_var_not_persistable()
->assert_is_op_input("elementwise_add", "X");
y = gpd.mutable_pattern()
->NewNode("skip_groupnorm_act_fuse/y")
->AsInput()
->assert_var_not_persistable()
->assert_is_op_input("elementwise_add", "Y")
->assert_more([&](Node *x) {
auto shape = x->Var()->GetShape();
if (shape.size() == 2 ||
(shape.size() == 4 && shape[3] == 1 && shape[2] == 1))
return true;
else
return false;
});
patterns::SkipGroupNormAct fused_pattern(gpd.mutable_pattern(),
"skip_groupnorm_act_fuse");
fused_pattern(x, y);
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *graph) {
if (subgraph.count(x) <= 0 || subgraph.count(y) <= 0) {
LOG(WARNING) << "The subgraph is empty.";
return;
}
VLOG(4) << "handle skip groupnorm act fuse";
GET_IR_NODE_FROM_SUBGRAPH(elementwise, elementwise, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(group_norm, group_norm, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(group_norm_bias, group_norm_bias, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
group_norm_scale, group_norm_scale, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(group_norm_out, group_norm_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act, act, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, fused_pattern);
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "skip groupnorm act pass in op compat failed.";
return;
}
std::unordered_set<const Node *> del_node_set;
// Create an skip_groupnorm_act op node
OpDesc new_desc(*group_norm->Op());
new_desc.SetType("skip_groupnorm_act");
new_desc.SetInput("X", {subgraph.at(x)->Name()});
new_desc.SetInput("Y", {subgraph.at(y)->Name()});
new_desc.SetOutput("Out", {act_out->Name()});
new_desc.RemoveOutput("Y");
new_desc.Flush();
auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied.
del_node_set.insert(elementwise);
del_node_set.insert(group_norm);
del_node_set.insert(elementwise_out);
del_node_set.insert(group_norm_out);
del_node_set.insert(act);
GraphSafeRemoveNodes(graph, del_node_set);
IR_NODE_LINK_TO(subgraph.at(x), fused_node);
IR_NODE_LINK_TO(subgraph.at(y), fused_node);
IR_NODE_LINK_TO(group_norm_scale, fused_node);
IR_NODE_LINK_TO(group_norm_bias, fused_node);
IR_NODE_LINK_TO(fused_node, act_out);
found_subgraph_count++;
};
gpd(graph, handler);
return found_subgraph_count;
}
void SkipGroupNormActFusePass::ApplyImpl(ir::Graph *graph) const {
FusePassBase::Init("skip_groupnorm_act_fuse_pass", graph);
int found_subgraph_count = ApplyGNSiluPattern(graph);
AddStatis(found_subgraph_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(elementwise_groupnorm_act_pass,
paddle::framework::ir::SkipGroupNormActFusePass);
REGISTER_PASS_CAPABILITY(elementwise_groupnorm_act_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("elementwise_add", 1)
.EQ("silu", 0)
.EQ("group_norm", 0));
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
//
//
//
// | |
// elementwise_add fuse | |
// | -> skip_gn_act
// group_norm |
// |
// silu
// |
class Graph;
class SkipGroupNormActFusePass : public FusePassBase {
public:
SkipGroupNormActFusePass() {
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({0, -1})
.End();
AddOpCompat(OpCompat("group_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsTensor()
.End()
.AddOutput("Variance")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(1.0f)
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("data_layout")
.IsStringIn({"NCHW"})
.End();
AddOpCompat(OpCompat("silu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
}
virtual ~SkipGroupNormActFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
int ApplyGNSiluPattern(ir::Graph* graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/preln_elementwise_groupnorm_act_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
class Node;
} // namespace ir
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct PrelnGroupNormAct : public PatternBase {
PrelnGroupNormAct(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "preln_groupnorm_act") {}
void operator()(PDNode *x, PDNode *y);
// declare operator node's name
PATTERN_DECL_NODE(elementwise);
PATTERN_DECL_NODE(group_norm);
// declare variable node's name
PATTERN_DECL_NODE(elementwise_out);
PATTERN_DECL_NODE(group_norm_bias);
PATTERN_DECL_NODE(group_norm_scale);
PATTERN_DECL_NODE(group_norm_out);
PATTERN_DECL_NODE(act);
PATTERN_DECL_NODE(act_out);
};
void PrelnGroupNormAct::operator()(PDNode *x, PDNode *y) {
auto *elementwise =
pattern->NewNode(elementwise_repr())->assert_is_op("elementwise_add");
auto *elementwise_out_var =
pattern->NewNode(elementwise_out_repr())
->assert_is_op_output("elementwise_add", "Out")
->assert_is_op_input("group_norm", "X");
elementwise->LinksFrom({x, y}).LinksTo({elementwise_out_var});
// Create nodes for group_norm op.
auto *group_norm =
pattern->NewNode(group_norm_repr())->assert_is_op("group_norm");
auto *group_norm_bias_var = pattern->NewNode(group_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("group_norm", "Bias");
auto *group_norm_scale_var = pattern->NewNode(group_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("group_norm", "Scale");
auto *group_norm_out_var = pattern->NewNode(group_norm_out_repr())
->AsOutput()
->assert_is_op_output("group_norm", "Y")
->assert_is_op_input("silu", "X");
// Add links for group_norm op.
group_norm
->LinksFrom(
{elementwise_out_var, group_norm_bias_var, group_norm_scale_var})
.LinksTo({group_norm_out_var});
auto *act = pattern->NewNode(act_repr())->assert_is_op("silu");
auto *act_out = pattern->NewNode(act_out_repr())
->AsOutput()
->assert_is_op_output("silu", "Out");
act->LinksFrom({group_norm_out_var}).LinksTo({act_out});
}
} // namespace patterns
int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
FusePassBase::Init("preln_groupnorm_silu_fuse", graph);
int found_subgraph_count = 0;
GraphPatternDetector gpd;
PDNode *x = nullptr;
PDNode *y = nullptr;
x = gpd.mutable_pattern()
->NewNode("preln_groupnorm_act_fuse/x")
->AsInput()
->assert_var_not_persistable()
->assert_is_op_input("elementwise_add", "X");
y = gpd.mutable_pattern()
->NewNode("preln_groupnorm_act_fuse/y")
->AsInput()
->assert_var_not_persistable()
->assert_is_op_input("elementwise_add", "Y");
patterns::PrelnGroupNormAct fused_pattern(gpd.mutable_pattern(),
"preln_groupnorm_act_fuse");
fused_pattern(x, y);
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *graph) {
if (subgraph.count(x) <= 0 || subgraph.count(y) <= 0) {
LOG(WARNING) << "The subgraph is empty.";
return;
}
VLOG(4) << "handle preln groupnorm act fuse";
GET_IR_NODE_FROM_SUBGRAPH(elementwise, elementwise, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(group_norm, group_norm, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(group_norm_bias, group_norm_bias, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
group_norm_scale, group_norm_scale, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(group_norm_out, group_norm_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act, act, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, fused_pattern);
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "preln groupnorm act pass in op compat failed.";
return;
}
std::unordered_set<const Node *> del_node_set;
// Create an preln_groupnorm_act op node
OpDesc new_desc(*group_norm->Op());
new_desc.SetType("preln_groupnorm_act");
new_desc.SetInput("X", {subgraph.at(x)->Name()});
new_desc.SetInput("Y", {subgraph.at(y)->Name()});
new_desc.SetOutput("Out_0", {elementwise_out->Name()});
new_desc.SetOutput("Out_1", {act_out->Name()});
new_desc.RemoveOutput("Y");
new_desc.Flush();
auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied.
del_node_set.insert(elementwise);
del_node_set.insert(group_norm);
del_node_set.insert(group_norm_out);
del_node_set.insert(act);
GraphSafeRemoveNodes(graph, del_node_set);
IR_NODE_LINK_TO(subgraph.at(x), fused_node);
IR_NODE_LINK_TO(subgraph.at(y), fused_node);
IR_NODE_LINK_TO(group_norm_scale, fused_node);
IR_NODE_LINK_TO(group_norm_bias, fused_node);
IR_NODE_LINK_TO(fused_node, act_out);
IR_NODE_LINK_TO(fused_node, elementwise_out);
found_subgraph_count++;
};
gpd(graph, handler);
return found_subgraph_count;
}
void PrelnGroupNormActFusePass::ApplyImpl(ir::Graph *graph) const {
FusePassBase::Init("preln_groupnorm_act_fuse_pass", graph);
int found_subgraph_count = ApplyGNSiluPattern(graph);
AddStatis(found_subgraph_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(preln_elementwise_groupnorm_act_pass,
paddle::framework::ir::PrelnGroupNormActFusePass);
REGISTER_PASS_CAPABILITY(preln_elementwise_groupnorm_act_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("elementwise_add", 1)
.EQ("silu", 0)
.EQ("group_norm", 0));
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
//
// | |
// elementwise_add fuse | |
// | | -> preln_gn_act
// other op group_norm | |
// | other op
// silu
// |
class Graph;
class PrelnGroupNormActFusePass : public FusePassBase {
public:
PrelnGroupNormActFusePass() {
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({0, -1})
.End();
AddOpCompat(OpCompat("group_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsTensor()
.End()
.AddOutput("Variance")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(1.0f)
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("data_layout")
.IsStringIn({"NCHW"})
.End();
AddOpCompat(OpCompat("silu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
}
virtual ~PrelnGroupNormActFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
int ApplyGNSiluPattern(ir::Graph* graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -2420,6 +2420,8 @@ USE_TRT_CONVERTER(logsigmoid)
USE_TRT_CONVERTER(lookup_table)
USE_TRT_CONVERTER(expand_v2)
USE_TRT_CONVERTER(take_along_axis)
USE_TRT_CONVERTER(skip_groupnorm_act)
USE_TRT_CONVERTER(preln_groupnorm_act)
#if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000)
USE_TRT_CONVERTER(sparse_fc)
USE_TRT_CONVERTER(sparse_multihead_matmul)
......
......@@ -106,8 +106,8 @@ const std::vector<std::string> kTRTSubgraphPasses({
"vit_attention_fuse_pass", //
#if defined _WIN32 // Windows CI is TensorRT7.0. Remove this after upgrading.
#else
"trt_skip_layernorm_fuse_pass", //
"preln_skip_layernorm_fuse_pass", //
"trt_skip_layernorm_fuse_pass", //
"preln_skip_layernorm_fuse_pass", //
#endif
"layernorm_shift_partition_fuse_pass", //
"merge_layernorm_fuse_pass", //
......@@ -128,8 +128,13 @@ const std::vector<std::string> kTRTSubgraphPasses({
// "yolo_box_fuse_pass", //
"dense_fc_to_sparse_pass", //
"dense_multihead_matmul_to_sparse_pass", //
"tensorrt_subgraph_pass", //
"conv_bn_fuse_pass", //
#if defined _WIN32 // Windows CI is TensorRT7.0. Remove this after upgrading.
#else
"elementwise_groupnorm_act_pass", //
"preln_elementwise_groupnorm_act_pass", //
#endif
"tensorrt_subgraph_pass", //
"conv_bn_fuse_pass", //
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
// guaranteed at least v7
// cudnn8.0 has memory leak problem in conv + eltwise + act, so we
......
......@@ -94,6 +94,8 @@ list(
skip_merge_layernorm_op.cc
generic_and_custom_plugin_creater.cc
fused_lookup_tables_op.cc
skip_groupnorm_act_op.cc
preln_groupnorm_act_op.cc
expand_v2_op.cc)
if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7)
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.h"
#include <vector>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
namespace paddle {
namespace framework {
class Scope;
namespace proto {
class OpDesc;
} // namespace proto
} // namespace framework
} // namespace paddle
namespace paddle {
namespace inference {
namespace tensorrt {
class PrelnGroupnormActOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
VLOG(4) << "convert a fluid preln_groupnorm_act op to tensorrt "
"preln_groupnorm_act plugin";
framework::OpDesc op_desc(op, nullptr);
auto* input_x = engine_->GetITensor(op_desc.Input("X").front());
auto* input_y = engine_->GetITensor(op_desc.Input("Y").front());
std::vector<nvinfer1::ITensor*> inputs{input_x, input_y};
int groups = PADDLE_GET_CONST(int, op_desc.GetAttr("groups"));
float epsilon = PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon"));
std::string scale_name = op_desc.Input("Scale").front();
std::string bias_name = op_desc.Input("Bias").front();
// get the presistable var's data
auto GetWeight = [&](const std::string& var_name,
framework::DDim* dims) -> TensorRTEngine::Weight {
auto* temp_var = scope.FindVar(var_name);
auto* temp_tensor = temp_var->GetMutable<phi::DenseTensor>();
(*dims) = temp_tensor->dims();
auto weight = engine_->GetTrtWeight(var_name, *temp_tensor);
return weight;
};
framework::DDim scale_dims;
framework::DDim bias_dims;
auto scale_weights = GetWeight(scale_name, &scale_dims);
auto bias_weights = GetWeight(bias_name, &bias_dims);
bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
if (engine_->with_dynamic_shape()) {
plugin::PrelnGroupnormActPluginDynamic* plugin =
new plugin::PrelnGroupnormActPluginDynamic(
static_cast<const float*>(scale_weights.get().values),
scale_weights.get().count,
static_cast<const float*>(bias_weights.get().values),
bias_weights.get().count,
epsilon,
groups,
with_fp16);
nvinfer1::ILayer* groupnorm_layer =
engine_->AddDynamicPlugin(inputs.data(), 2, plugin);
std::vector<std::string> output_names;
output_names.emplace_back(op_desc.Output("Out_0").front());
output_names.emplace_back(op_desc.Output("Out_1").front());
RreplenishLayerAndOutput(
groupnorm_layer, "preln_groupnorm_act", output_names, test_mode);
}
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(preln_groupnorm_act, PrelnGroupnormActOpConverter);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.h"
#include <vector>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
namespace paddle {
namespace framework {
class Scope;
namespace proto {
class OpDesc;
} // namespace proto
} // namespace framework
} // namespace paddle
namespace paddle {
namespace inference {
namespace tensorrt {
class SkipGroupnormActOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
VLOG(4) << "convert a fluid skip_groupnorm_act op to tensorrt "
"skip_groupnorm_act plugin";
framework::OpDesc op_desc(op, nullptr);
auto* inputx = engine_->GetITensor(op_desc.Input("X").front());
auto* inputy = engine_->GetITensor(op_desc.Input("Y").front());
std::vector<nvinfer1::ITensor*> inputs{inputx, inputy};
int groups = PADDLE_GET_CONST(int, op_desc.GetAttr("groups"));
float epsilon = PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon"));
std::string scale_name = op_desc.Input("Scale").front();
std::string bias_name = op_desc.Input("Bias").front();
// get the presistable var's data
auto GetWeight = [&](const std::string& var_name,
framework::DDim* dims) -> TensorRTEngine::Weight {
auto* temp_var = scope.FindVar(var_name);
auto* temp_tensor = temp_var->GetMutable<phi::DenseTensor>();
(*dims) = temp_tensor->dims();
auto weight = engine_->GetTrtWeight(var_name, *temp_tensor);
return weight;
};
framework::DDim scale_dims;
framework::DDim bias_dims;
auto scale_weights = GetWeight(scale_name, &scale_dims);
auto bias_weights = GetWeight(bias_name, &bias_dims);
bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
if (engine_->with_dynamic_shape()) {
plugin::SkipGroupnormActPluginDynamic* plugin =
new plugin::SkipGroupnormActPluginDynamic(
static_cast<const float*>(scale_weights.get().values),
scale_weights.get().count,
static_cast<const float*>(bias_weights.get().values),
bias_weights.get().count,
epsilon,
groups,
with_fp16);
nvinfer1::ILayer* groupnorm_layer =
engine_->AddDynamicPlugin(inputs.data(), 2, plugin);
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(
groupnorm_layer, "skip_groupnorm_act", {output_name}, test_mode);
}
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(skip_groupnorm_act, SkipGroupnormActOpConverter);
......@@ -2390,6 +2390,22 @@ struct SimpleOpTypeSetTeller : public Teller {
}
}
if (op_type == "skip_groupnorm_act") {
if (!with_dynamic_shape) {
VLOG(3) << "The skip_groupnorm_act op does not support "
"static shape yet";
return false;
}
}
if (op_type == "preln_groupnorm_act") {
if (!with_dynamic_shape) {
VLOG(3) << "The preln_groupnorm_act op does not support "
"static shape yet";
return false;
}
}
if (op_type == "lookup_table") {
if (!with_dynamic_shape) {
VLOG(3) << "the lookup_table does not support "
......@@ -2561,7 +2577,9 @@ struct SimpleOpTypeSetTeller : public Teller {
"merge_layernorm",
"skip_merge_layernorm",
"lookup_table_v2",
"expand_v2"};
"expand_v2",
"skip_groupnorm_act",
"preln_groupnorm_act"};
std::unordered_set<std::string> teller_set{
"mul",
......@@ -2709,7 +2727,9 @@ struct SimpleOpTypeSetTeller : public Teller {
"skip_merge_layernorm",
"lookup_table",
"lookup_table_v2",
"expand_v2"};
"expand_v2",
"skip_groupnorm_act",
"preln_groupnorm_act"};
};
struct GenericPluginTeller : public Teller {
......
......@@ -35,6 +35,8 @@ list(
prelnlayernorm_shift_partition_op.cu
merge_layernorm_op_plugin.cu
skip_merge_layernorm_op_plugin.cu
skip_groupnorm_act_op_plugin.cu
preln_groupnorm_act_op_plugin.cu
generic_plugin.cu
lookup_table.cu
many_emb_layernorm_plugin.cu
......
......@@ -27,6 +27,8 @@ namespace plugin {
struct GroupNormNHWCParams {
// The output buffer. Layout NHWC.
__half* dst;
// The output buffer. Layout NHWC.
__half* eleOut;
// The input buffer. Layout NHWC.
__half const* srcX;
// The input buffer. Layout NHWC.
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES.
All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.h"
#include <cub/cub.cuh>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
nvinfer1::DimsExprs PrelnGroupnormActPluginDynamic::getOutputDimensions(
int output_index,
const nvinfer1::DimsExprs *inputDims,
int nb_inputs,
nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT {
return inputDims[0];
}
bool PrelnGroupnormActPluginDynamic::supportsFormatCombination(
int pos,
const nvinfer1::PluginTensorDesc *in_out,
int nb_inputs,
int nb_outputs) TRT_NOEXCEPT {
PADDLE_ENFORCE_NOT_NULL(
in_out,
platform::errors::InvalidArgument(
"The input of prelnGroupnormAct plugin shoule not be nullptr."));
PADDLE_ENFORCE_LT(
pos,
nb_inputs + nb_outputs,
platform::errors::InvalidArgument("The pos(%d) should be less than the "
"num(%d) of the input and the output.",
pos,
nb_inputs + nb_outputs));
const nvinfer1::PluginTensorDesc &in = in_out[pos];
if (pos == 0) {
if (with_fp16_) {
return ((in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::PluginFormat::kHWC8));
} else {
PADDLE_THROW(platform::errors::Fatal(
"PrelnGroupnormAct TRT Plugin is fp16 only so far"));
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
}
}
const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
// output
return in.type == prev.type && in.format == prev.format;
}
nvinfer1::DataType PrelnGroupnormActPluginDynamic::getOutputDataType(
int index,
const nvinfer1::DataType *input_types,
int nb_inputs) const TRT_NOEXCEPT {
return input_types[0];
}
int PrelnGroupnormActPluginDynamic::initialize() TRT_NOEXCEPT { return 0; }
static inline int32_t divUp(int32_t m, int32_t n) { return (m + n - 1) / n; }
static int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) {
int32_t maxDivisor = -1;
for (int32_t i = 1; i <= std::sqrt(n); i++) {
if (n % i == 0) {
int32_t divisor1 = n / i;
int32_t divisor2 = i;
if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) {
maxDivisor = divisor1;
}
if (divisor2 > maxDivisor && divisor2 < maxAllowedDivisor) {
maxDivisor = divisor2;
}
}
}
return maxDivisor;
}
static inline __device__ __host__ float sigmoid(float x) {
return 1.F / (1.F + expf(-x));
}
struct GroupSums {
// Is it the 1st element of the group?
int32_t flag;
// The sum.
float sum;
// The sum of squares.
float sumSq;
};
struct GroupSumsOp {
inline __device__ GroupSums operator()(GroupSums const &a,
GroupSums const &b) {
GroupSums dst;
dst.sum = b.flag ? b.sum : (a.sum + b.sum);
dst.sumSq = b.flag ? b.sumSq : (a.sumSq + b.sumSq);
dst.flag = a.flag + b.flag;
return dst;
}
};
template <int32_t tTHREADS_PER_BLOCK>
__global__ void prelnGroupNormNHWCSumKernel(GroupNormNHWCParams params) {
// The object in charge of doing the sums for the different blocks.
typedef cub::BlockScan<GroupSums, tTHREADS_PER_BLOCK> BlockScan;
// Allocate shared memory for BlockScan.
__shared__ typename BlockScan::TempStorage tempStorage;
// Allocate shared memory for the groups. We could reduce the amount of shared
// memory reserved.
__shared__ float2 smem[tTHREADS_PER_BLOCK];
// The instance in the batch.
int32_t ni = blockIdx.z;
// The channel loaded by that thread (2 channels per thread for F16x2).
int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2;
// The first activation loaded by that block.
int32_t hwBegin = blockIdx.y * params.hwPerBlock;
// The last activation loaded by that block.
int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw);
// The sums.
float sum = 0.F;
float sumSq = 0.F;
// Iterate over the activations to compute the sums.
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) {
// The offset.
int64_t offset = static_cast<int64_t>(ni) * params.hwc +
static_cast<int64_t>(hwi) * params.c + ci;
// Fetch two channels per thread.
__half2 h2(0, 0);
if (ci < params.c) {
// int64_t offsetY = static_cast<int64_t>(ni) * params.c + ci;
__half2 y = *reinterpret_cast<__half2 const *>(&params.srcY[offset]);
h2 = *reinterpret_cast<__half2 const *>(&params.srcX[offset]);
h2 = __hadd2(h2, y);
// elementwise_add
*reinterpret_cast<__half2 *>(&params.eleOut[offset]) = h2;
}
// Extract the two half values.
float2 f2 = __half22float2(h2);
// Update the sum.
sum += f2.x + f2.y;
// Update the sum of squares.
sumSq += f2.x * f2.x + f2.y * f2.y;
}
// The group that thread works on and the channel in the group (modulus).
int32_t gi = threadIdx.x * 2 / params.cPerGroup;
int32_t cj = threadIdx.x * 2 - params.cPerGroup * gi;
// The data for the summations.
GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq};
// Do the segmented scan.
GroupSums out;
BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp());
// Store the results for the groups in shared memory (to produce coalesced
// stores later).
if (cj == params.cPerGroup - 2 /* 2 channels per thread */) {
smem[gi] = make_float2(out.sum, out.sumSq);
}
// Make sure the data is in shared memory.
__syncthreads();
// The global group index.
int32_t gj = blockIdx.x * params.groupsPerBlock + threadIdx.x;
// Threads that have nothing left to do, exit.
if (threadIdx.x >= params.groupsPerBlock || gj >= params.groups) {
return;
}
// The first threads (those storing to global memory, load the values).
float2 sums = smem[threadIdx.x];
// Store to global memory.
atomicAdd(&params.redBuffer[(2 * ni + 0) * params.groups + gj], sums.x);
atomicAdd(&params.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y);
}
void prelnGroupNormNHWCSum(GroupNormNHWCParams const &params,
cudaStream_t stream) {
// Make sure the values are as we expect.
PADDLE_ENFORCE_EQ(params.c % params.cPerBlock,
0,
platform::errors::InvalidArgument(
"The groupNormNHWCSum of prelnGroupnormAct Plugin got "
"wrong parameters"
"params.c %% params.cPerBlock should be 0, but get %d.",
params.c % params.cPerBlock));
PADDLE_ENFORCE_EQ(
params.hw % params.hwPerBlock,
0,
platform::errors::InvalidArgument(
"The groupNormNHWCSum of prelnGroupnormAct Plugin got wrong "
"parameters"
"params.hw %% params.hwPerBlock should be 0, but get %d.",
params.hw % params.hwPerBlock));
// Make sure a group does not span multiple blocks.
PADDLE_ENFORCE_EQ(
params.cPerBlock % params.cPerGroup,
0,
platform::errors::InvalidArgument(
"The groupNormNHWCSum of prelnGroupnormAct Plugin got wrong "
"parameters"
"params.cPerBlock %% params.cPerGroup should be 0, but get %d.",
params.cPerBlock % params.cPerGroup));
dim3 grid;
// The number of blocks to compute all the channels.
grid.x = params.c / params.cPerBlock;
// The number of blocks to compute all the activations in a given instance.
grid.y = divUp(params.hw, params.hwPerBlock);
// The number of instances.
grid.z = params.n;
switch (params.cPerBlock) {
case 320:
prelnGroupNormNHWCSumKernel<160><<<grid, 160, 0, stream>>>(params);
break;
case 480:
prelnGroupNormNHWCSumKernel<256><<<grid, 256, 0, stream>>>(params);
break;
case 256:
prelnGroupNormNHWCSumKernel<128><<<grid, 128, 0, stream>>>(params);
break;
case 128:
prelnGroupNormNHWCSumKernel<64><<<grid, 64, 0, stream>>>(params);
break;
default:
PADDLE_THROW(platform::errors::Fatal(
"The function groupNormNHWCSum of prelnGroupnormAct TRT Plugin "
"encounter error"));
}
}
template <int32_t tTHREADS_PER_BLOCK>
__global__ void prelnGroupNormNHWCScaleKernel(GroupNormNHWCParams params) {
// The instance in the batch.
int32_t ni = blockIdx.z;
// The channel loaded by that thread (2 channels per thread for F16x2).
int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2;
// The group that thread works on and the channel in the group (modulus).
int32_t gi = ci / params.cPerGroup;
// Load the sum and sum of squares for the group.
float sum = 0.F, sumSq = 0.F;
if (gi < params.groups) {
sum = params.redBuffer[(2 * ni + 0) * params.groups + gi];
sumSq = params.redBuffer[(2 * ni + 1) * params.groups + gi];
}
// Load gamma/beta.
float2 gammaF2, betaF2;
if (ci < params.c) {
gammaF2 = *reinterpret_cast<float2 const *>(
reinterpret_cast<float const *>(params.gamma) + ci);
betaF2 = *reinterpret_cast<float2 const *>(
reinterpret_cast<float const *>(params.beta) + ci);
}
// Compute the mean.
float mean = sum * params.invHWC;
// Compute the variance.
float var = sumSq * params.invHWC - (mean * mean);
// Compute the inverse of the stddev.
float invStdDev = rsqrtf(var + params.eps);
// The first activation loaded by that block.
int32_t hwBegin = blockIdx.y * params.hwPerBlock;
// The last activation loaded by that block.
int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw);
// Iterate over the activations to compute the sums.
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) {
// The src/dst offset.
int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci;
// Fetch two channels per thread.
__half2 h2(0, 0);
if (ci < params.c) {
h2 = *reinterpret_cast<__half2 const *>(&params.eleOut[offset]);
}
// Extract the two half values.
float2 f2 = __half22float2(h2);
// Normalize the channels.
f2.x = (f2.x - mean) * invStdDev;
f2.y = (f2.y - mean) * invStdDev;
// Scale by gamma and add beta.
f2.x = gammaF2.x * f2.x + betaF2.x;
f2.y = gammaF2.y * f2.y + betaF2.y;
// Apply Swish if needed.
if (params.withSwish) {
f2.x = f2.x * sigmoid(f2.x);
f2.y = f2.y * sigmoid(f2.y);
}
// Store the scaled values.
if (ci < params.c) {
*reinterpret_cast<__half2 *>(&params.dst[offset]) = __float22half2_rn(f2);
}
}
}
void prelnGroupNormNHWCScale(GroupNormNHWCParams const &params,
cudaStream_t stream) {
// Make sure the dimensions are aligned with what we expect.
PADDLE_ENFORCE_EQ(
params.c % params.cPerBlock,
0,
platform::errors::InvalidArgument(
"The groupNormNHWCScale of prelnGroupnormAct Plugin got "
"wrong parameters"
"params.c %% params.cPerBlock should be 0, but get %d.",
params.c % params.cPerBlock));
// Make sure a group does not span multiple blocks.
PADDLE_ENFORCE_EQ(
params.cPerBlock % params.cPerGroup,
0,
platform::errors::InvalidArgument(
"The groupNormNHWCScale of prelnGroupnormAct Plugin got wrong "
"parameters"
"params.cPerBlock %% params.cPerGroup should be 0, but get %d.",
params.cPerBlock % params.cPerGroup));
dim3 grid;
// The number of blocks to compute all the channels.
grid.x = params.c / params.cPerBlock;
// The number of blocks to compute all the activations in a given instance.
grid.y = divUp(params.hw, params.hwPerBlock);
// The number of instances.
grid.z = params.n;
switch (params.cPerBlock) {
case 320:
prelnGroupNormNHWCScaleKernel<160><<<grid, 160, 0, stream>>>(params);
break;
case 480:
prelnGroupNormNHWCScaleKernel<256><<<grid, 256, 0, stream>>>(params);
break;
case 256:
prelnGroupNormNHWCScaleKernel<128><<<grid, 128, 0, stream>>>(params);
break;
case 128:
prelnGroupNormNHWCScaleKernel<64><<<grid, 64, 0, stream>>>(params);
break;
default:
PADDLE_THROW(platform::errors::Fatal(
"The function groupNormNHWCSum of prelnGroupnormAct TRT Plugin "
"encounter error"));
}
}
int PrelnGroupnormActPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc,
const void *const *inputs,
void *const *outputs,
void *workspace,
cudaStream_t stream) TRT_NOEXCEPT {
auto input_type = input_desc[0].type;
if (input_type == nvinfer1::DataType::kFLOAT) {
VLOG(1) << "TRT Plugin DataType selected. prelnGroupnormAct-->fp32";
PADDLE_THROW(platform::errors::Fatal(
"The prelnGroupnormAct TRT Plugin's only support fp16 input"));
} else if (input_type == nvinfer1::DataType::kHALF) {
VLOG(1) << "TRT Plugin DataType selected. prelnGroupnormAct-->fp16";
int32_t cPerBlock = 320;
int32_t maxBlocksPerHW = 1024;
switch (input_desc[0].dims.d[1]) {
case 960:
case 1920:
cPerBlock = 480;
break;
case 512:
case 256:
cPerBlock = 256;
break;
case 128:
cPerBlock = 128;
break;
default:
cPerBlock = 320;
}
params_.withSwish = true;
params_.dst = static_cast<half *>(outputs[1]);
params_.eleOut = static_cast<half *>(outputs[0]);
params_.srcX = static_cast<half const *>(inputs[0]);
params_.srcY = static_cast<half const *>(inputs[1]);
params_.gamma = scale_gpu_.get();
params_.beta = bias_gpu_.get();
params_.redBuffer = static_cast<float *>(workspace);
params_.n = input_desc[0].dims.d[0];
params_.h = input_desc[0].dims.d[2];
params_.w = input_desc[0].dims.d[3];
params_.c = input_desc[0].dims.d[1];
params_.groups = groups_;
params_.hw = params_.h * params_.w;
const int32_t blocksPerHW = findMaxDivisor(params_.hw, maxBlocksPerHW);
params_.hwPerBlock = divUp(params_.hw, blocksPerHW);
params_.cPerBlock = cPerBlock;
params_.cPerGroup = params_.c / params_.groups;
params_.hwc = params_.hw * params_.c;
params_.invHWC = 1.F / static_cast<float>(params_.hw * params_.cPerGroup);
params_.groupsPerBlock = cPerBlock / params_.cPerGroup;
params_.eps = eps_;
cudaMemsetAsync(params_.redBuffer, 0, ws_, stream);
prelnGroupNormNHWCSum(params_, stream);
prelnGroupNormNHWCScale(params_, stream);
} else {
// input not fp16
PADDLE_THROW(platform::errors::Fatal(
"The PrelnGroupnormAct TRT Plugin's only support fp16 input"));
}
return cudaGetLastError() != cudaSuccess;
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT {
public:
PrelnGroupnormActPluginDynamic(const float* scale,
const int scale_num,
const float* bias,
const int bias_num,
float eps,
int groups,
bool with_fp16,
std::shared_ptr<void> scale_gpu = nullptr,
std::shared_ptr<void> bias_gpu = nullptr)
: scale_gpu_(scale_gpu),
bias_gpu_(bias_gpu),
groups_(groups),
eps_(eps),
with_fp16_(with_fp16) {
scale_.resize(scale_num);
bias_.resize(bias_num);
std::copy(scale, scale + scale_num, scale_.data());
std::copy(bias, bias + bias_num, bias_.data());
if (scale_gpu_ == nullptr) {
void* p;
cudaMalloc(reinterpret_cast<void**>(&p), scale_num * sizeof(float));
scale_gpu_.reset(p, [](void* ptr) { cudaFree(ptr); });
cudaMemcpy(
p, scale_.data(), scale_num * sizeof(float), cudaMemcpyHostToDevice);
}
if (bias_gpu_ == nullptr) {
void* p;
cudaMalloc(reinterpret_cast<void**>(&p), bias_num * sizeof(float));
bias_gpu_.reset(p, [](void* ptr) { cudaFree(ptr); });
cudaMemcpy(
p, bias_.data(), bias_num * sizeof(float), cudaMemcpyHostToDevice);
}
}
PrelnGroupnormActPluginDynamic(void const* serialData, size_t serialLength) {
DeserializeValue(&serialData, &serialLength, &scale_);
DeserializeValue(&serialData, &serialLength, &bias_);
DeserializeValue(&serialData, &serialLength, &eps_);
DeserializeValue(&serialData, &serialLength, &groups_);
DeserializeValue(&serialData, &serialLength, &with_fp16_);
{
void* p;
cudaMalloc(reinterpret_cast<void**>(&p), scale_.size() * sizeof(float));
scale_gpu_.reset(p, [](void* ptr) { cudaFree(ptr); });
cudaMemcpy(p,
scale_.data(),
scale_.size() * sizeof(float),
cudaMemcpyHostToDevice);
}
{
void* p;
cudaMalloc(reinterpret_cast<void**>(&p), bias_.size() * sizeof(float));
bias_gpu_.reset(p, [](void* ptr) { cudaFree(ptr); });
cudaMemcpy(p,
bias_.data(),
bias_.size() * sizeof(float),
cudaMemcpyHostToDevice);
}
}
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
auto* ptr = new PrelnGroupnormActPluginDynamic(scale_.data(),
scale_.size(),
bias_.data(),
bias_.size(),
eps_,
groups_,
with_fp16_,
scale_gpu_,
bias_gpu_);
return ptr;
}
const char* getPluginType() const TRT_NOEXCEPT override {
return "preln_groupnorm_act_plugin_dynamic";
}
int getNbOutputs() const TRT_NOEXCEPT override { return 2; }
int initialize() TRT_NOEXCEPT override;
size_t getSerializationSize() const TRT_NOEXCEPT override {
return SerializedSize(scale_) + SerializedSize(bias_) +
SerializedSize(eps_) + SerializedSize(groups_) +
SerializedSize(with_fp16_);
}
void serialize(void* buffer) const TRT_NOEXCEPT override {
SerializeValue(&buffer, scale_);
SerializeValue(&buffer, bias_);
SerializeValue(&buffer, eps_);
SerializeValue(&buffer, groups_);
SerializeValue(&buffer, with_fp16_);
}
nvinfer1::DimsExprs getOutputDimensions(
int output_index,
const nvinfer1::DimsExprs* inputs,
int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) // NOLINT
TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) TRT_NOEXCEPT override {
// sizeof(float2) * maxBatchSize * maxNumberOfGroup. float2
// contians two buffers for sum and squared sum;
ws_ = sizeof(float) * 2 * in[0].max.d[0] * groups_;
}
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const TRT_NOEXCEPT override {
return ws_;
}
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) TRT_NOEXCEPT override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* inputTypes,
int nbInputs) const
TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override { delete this; }
void terminate() TRT_NOEXCEPT override{};
private:
size_t ws_;
std::vector<float> scale_;
std::vector<float> bias_;
std::shared_ptr<void> scale_gpu_;
std::shared_ptr<void> bias_gpu_;
GroupNormNHWCParams params_;
int groups_;
float eps_;
bool with_fp16_;
};
class PrelnGroupnormActPluginDynamicCreator : public TensorRTPluginCreator {
public:
const char* getPluginName() const TRT_NOEXCEPT override {
return "preln_groupnorm_act_plugin_dynamic";
}
const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length)
TRT_NOEXCEPT override {
return new PrelnGroupnormActPluginDynamic(serial_data, serial_length);
}
};
REGISTER_TRT_PLUGIN_V2(PrelnGroupnormActPluginDynamicCreator);
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES.
All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.h"
#include <cub/cub.cuh>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
nvinfer1::DimsExprs SkipGroupnormActPluginDynamic::getOutputDimensions(
int output_index,
const nvinfer1::DimsExprs *inputDims,
int nb_inputs,
nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT {
return inputDims[0];
}
bool SkipGroupnormActPluginDynamic::supportsFormatCombination(
int pos,
const nvinfer1::PluginTensorDesc *in_out,
int nb_inputs,
int nb_outputs) TRT_NOEXCEPT {
PADDLE_ENFORCE_NOT_NULL(
in_out,
platform::errors::InvalidArgument(
"The input of SkipGroupnormAct plugin shoule not be nullptr."));
PADDLE_ENFORCE_LT(
pos,
nb_inputs + nb_outputs,
platform::errors::InvalidArgument("The pos(%d) should be less than the "
"num(%d) of the input and the output.",
pos,
nb_inputs + nb_outputs));
const nvinfer1::PluginTensorDesc &in = in_out[pos];
if (pos == 0) {
if (with_fp16_) {
return ((in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::PluginFormat::kHWC8));
} else {
PADDLE_THROW(platform::errors::Fatal(
"SkipGroupnormAct TRT Plugin is fp16 only so far"));
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
}
}
const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
// output
return in.type == prev.type && in.format == prev.format;
}
nvinfer1::DataType SkipGroupnormActPluginDynamic::getOutputDataType(
int index,
const nvinfer1::DataType *input_types,
int nb_inputs) const TRT_NOEXCEPT {
PADDLE_ENFORCE_EQ(
index,
0,
platform::errors::InvalidArgument(
"The SkipGroupnormAct Plugin only has one input, so the "
"index value should be 0, but get %d.",
index));
PADDLE_ENFORCE_EQ((input_types[0] == nvinfer1::DataType::kFLOAT ||
input_types[0] == nvinfer1::DataType::kHALF),
true,
platform::errors::InvalidArgument(
"The input type should be half or float"));
return input_types[0];
}
int SkipGroupnormActPluginDynamic::initialize() TRT_NOEXCEPT { return 0; }
static inline int32_t divUp(int32_t m, int32_t n) { return (m + n - 1) / n; }
static int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) {
int32_t maxDivisor = -1;
for (int32_t i = 1; i <= std::sqrt(n); i++) {
if (n % i == 0) {
int32_t divisor1 = n / i;
int32_t divisor2 = i;
if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) {
maxDivisor = divisor1;
}
if (divisor2 > maxDivisor && divisor2 < maxAllowedDivisor) {
maxDivisor = divisor2;
}
}
}
return maxDivisor;
}
static inline __device__ __host__ float sigmoid(float x) {
return 1.F / (1.F + expf(-x));
}
struct GroupSums {
// Is it the 1st element of the group?
int32_t flag;
// The sum.
float sum;
// The sum of squares.
float sumSq;
};
struct GroupSumsOp {
inline __device__ GroupSums operator()(GroupSums const &a,
GroupSums const &b) {
GroupSums dst;
dst.sum = b.flag ? b.sum : (a.sum + b.sum);
dst.sumSq = b.flag ? b.sumSq : (a.sumSq + b.sumSq);
dst.flag = a.flag + b.flag;
return dst;
}
};
template <int32_t tTHREADS_PER_BLOCK>
__global__ void skipGroupNormNHWCSumKernel(GroupNormNHWCParams params) {
// The object in charge of doing the sums for the different blocks.
typedef cub::BlockScan<GroupSums, tTHREADS_PER_BLOCK> BlockScan;
// Allocate shared memory for BlockScan.
__shared__ typename BlockScan::TempStorage tempStorage;
// Allocate shared memory for the groups. We could reduce the amount of shared
// memory reserved.
__shared__ float2 smem[tTHREADS_PER_BLOCK];
// The instance in the batch.
int32_t ni = blockIdx.z;
// The channel loaded by that thread (2 channels per thread for F16x2).
int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2;
// The first activation loaded by that block.
int32_t hwBegin = blockIdx.y * params.hwPerBlock;
// The last activation loaded by that block.
int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw);
// The sums.
float sum = 0.F;
float sumSq = 0.F;
// Iterate over the activations to compute the sums.
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) {
// The offset.
int64_t offset = static_cast<int64_t>(ni) * params.hwc +
static_cast<int64_t>(hwi) * params.c + ci;
// Fetch two channels per thread.
__half2 h2(0, 0);
if (ci < params.c) {
// W = 1, H = 1
int64_t offsetY = static_cast<int64_t>(ni) * params.c + ci;
__half2 y = *reinterpret_cast<__half2 const *>(&params.srcY[offsetY]);
h2 = *reinterpret_cast<__half2 const *>(&params.srcX[offset]);
h2 = __hadd2(h2, y);
// elementwise_add
*reinterpret_cast<__half2 *>(&params.dst[offset]) = h2;
}
// Extract the two half values.
float2 f2 = __half22float2(h2);
// Update the sum.
sum += f2.x + f2.y;
// Update the sum of squares.
sumSq += f2.x * f2.x + f2.y * f2.y;
}
// The group that thread works on and the channel in the group (modulus).
int32_t gi = threadIdx.x * 2 / params.cPerGroup;
int32_t cj = threadIdx.x * 2 - params.cPerGroup * gi;
// The data for the summations.
GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq};
// Do the segmented scan.
GroupSums out;
BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp());
// Store the results for the groups in shared memory (to produce coalesced
// stores later).
if (cj == params.cPerGroup - 2 /* 2 channels per thread */) {
smem[gi] = make_float2(out.sum, out.sumSq);
}
// Make sure the data is in shared memory.
__syncthreads();
// The global group index.
int32_t gj = blockIdx.x * params.groupsPerBlock + threadIdx.x;
// Threads that have nothing left to do, exit.
if (threadIdx.x >= params.groupsPerBlock || gj >= params.groups) {
return;
}
// The first threads (those storing to global memory, load the values).
float2 sums = smem[threadIdx.x];
// Store to global memory.
atomicAdd(&params.redBuffer[(2 * ni + 0) * params.groups + gj], sums.x);
atomicAdd(&params.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y);
}
void skipGroupNormNHWCSum(GroupNormNHWCParams const &params,
cudaStream_t stream) {
// Make sure the values are as we expect.
PADDLE_ENFORCE_EQ(
params.c % params.cPerBlock,
0,
platform::errors::InvalidArgument(
"The groupNormNHWCSum of SkipGroupnormAct Plugin got wrong parameters"
"params.c %% params.cPerBlock should be 0, but get %d.",
params.c % params.cPerBlock));
PADDLE_ENFORCE_EQ(
params.hw % params.hwPerBlock,
0,
platform::errors::InvalidArgument(
"The groupNormNHWCSum of SkipGroupnormAct Plugin got wrong parameters"
"params.hw %% params.hwPerBlock should be 0, but get %d.",
params.hw % params.hwPerBlock));
// Make sure a group does not span multiple blocks.
PADDLE_ENFORCE_EQ(
params.cPerBlock % params.cPerGroup,
0,
platform::errors::InvalidArgument(
"The groupNormNHWCSum of SkipGroupnormAct Plugin got wrong parameters"
"params.cPerBlock %% params.cPerGroup should be 0, but get %d.",
params.cPerBlock % params.cPerGroup));
dim3 grid;
// The number of blocks to compute all the channels.
grid.x = params.c / params.cPerBlock;
// The number of blocks to compute all the activations in a given instance.
grid.y = divUp(params.hw, params.hwPerBlock);
// The number of instances.
grid.z = params.n;
switch (params.cPerBlock) {
case 320:
skipGroupNormNHWCSumKernel<160><<<grid, 160, 0, stream>>>(params);
break;
case 480:
skipGroupNormNHWCSumKernel<256><<<grid, 256, 0, stream>>>(params);
break;
case 256:
skipGroupNormNHWCSumKernel<128><<<grid, 128, 0, stream>>>(params);
break;
case 128:
skipGroupNormNHWCSumKernel<64><<<grid, 64, 0, stream>>>(params);
break;
default:
PADDLE_THROW(platform::errors::Fatal(
"The function groupNormNHWCSum of SkipGroupnormAct TRT Plugin "
"encounter error"));
}
}
template <int32_t tTHREADS_PER_BLOCK>
__global__ void skipGroupNormNHWCScaleKernel(GroupNormNHWCParams params) {
// The instance in the batch.
int32_t ni = blockIdx.z;
// The channel loaded by that thread (2 channels per thread for F16x2).
int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2;
// The group that thread works on and the channel in the group (modulus).
int32_t gi = ci / params.cPerGroup;
// Load the sum and sum of squares for the group.
float sum = 0.F, sumSq = 0.F;
if (gi < params.groups) {
sum = params.redBuffer[(2 * ni + 0) * params.groups + gi];
sumSq = params.redBuffer[(2 * ni + 1) * params.groups + gi];
}
// Load gamma/beta.
float2 gammaF2, betaF2;
if (ci < params.c) {
gammaF2 = *reinterpret_cast<float2 const *>(
reinterpret_cast<float const *>(params.gamma) + ci);
betaF2 = *reinterpret_cast<float2 const *>(
reinterpret_cast<float const *>(params.beta) + ci);
}
// Compute the mean.
float mean = sum * params.invHWC;
// Compute the variance.
float var = sumSq * params.invHWC - (mean * mean);
// Compute the inverse of the stddev.
float invStdDev = rsqrtf(var + params.eps);
// The first activation loaded by that block.
int32_t hwBegin = blockIdx.y * params.hwPerBlock;
// The last activation loaded by that block.
int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw);
// Iterate over the activations to compute the sums.
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) {
// The src/dst offset.
int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci;
// Fetch two channels per thread.
__half2 h2(0, 0);
if (ci < params.c) {
h2 = *reinterpret_cast<__half2 const *>(&params.dst[offset]);
}
// Extract the two half values.
float2 f2 = __half22float2(h2);
// Normalize the channels.
f2.x = (f2.x - mean) * invStdDev;
f2.y = (f2.y - mean) * invStdDev;
// Scale by gamma and add beta.
f2.x = gammaF2.x * f2.x + betaF2.x;
f2.y = gammaF2.y * f2.y + betaF2.y;
// Apply Swish if needed.
if (params.withSwish) {
f2.x = f2.x * sigmoid(f2.x);
f2.y = f2.y * sigmoid(f2.y);
}
// Store the scaled values.
if (ci < params.c) {
*reinterpret_cast<__half2 *>(&params.dst[offset]) = __float22half2_rn(f2);
}
}
}
void skipGroupNormNHWCScale(GroupNormNHWCParams const &params,
cudaStream_t stream) {
// Make sure the dimensions are aligned with what we expect.
PADDLE_ENFORCE_EQ(params.c % params.cPerBlock,
0,
platform::errors::InvalidArgument(
"The groupNormNHWCScale of SkipGroupnormAct Plugin got "
"wrong parameters"
"params.c %% params.cPerBlock should be 0, but get %d.",
params.c % params.cPerBlock));
// Make sure a group does not span multiple blocks.
PADDLE_ENFORCE_EQ(
params.cPerBlock % params.cPerGroup,
0,
platform::errors::InvalidArgument(
"The groupNormNHWCScale of SkipGroupnormAct Plugin got wrong "
"parameters"
"params.cPerBlock %% params.cPerGroup should be 0, but get %d.",
params.cPerBlock % params.cPerGroup));
dim3 grid;
// The number of blocks to compute all the channels.
grid.x = params.c / params.cPerBlock;
// The number of blocks to compute all the activations in a given instance.
grid.y = divUp(params.hw, params.hwPerBlock);
// The number of instances.
grid.z = params.n;
switch (params.cPerBlock) {
case 320:
skipGroupNormNHWCScaleKernel<160><<<grid, 160, 0, stream>>>(params);
break;
case 480:
skipGroupNormNHWCScaleKernel<256><<<grid, 256, 0, stream>>>(params);
break;
case 256:
skipGroupNormNHWCScaleKernel<128><<<grid, 128, 0, stream>>>(params);
break;
case 128:
skipGroupNormNHWCScaleKernel<64><<<grid, 64, 0, stream>>>(params);
break;
default:
PADDLE_THROW(platform::errors::Fatal(
"The function groupNormNHWCSum of SkipGroupnormAct TRT Plugin "
"encounter error"));
}
}
int SkipGroupnormActPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc,
const void *const *inputs,
void *const *outputs,
void *workspace,
cudaStream_t stream) TRT_NOEXCEPT {
auto input_type = input_desc[0].type;
if (input_type == nvinfer1::DataType::kFLOAT) {
VLOG(1) << "TRT Plugin DataType selected. SkipGroupnormAct-->fp32";
PADDLE_THROW(platform::errors::Fatal(
"The SkipGroupnormAct TRT Plugin's only support fp16 input"));
} else if (input_type == nvinfer1::DataType::kHALF) {
VLOG(1) << "TRT Plugin DataType selected. SkipGroupnormAct-->fp16";
int32_t cPerBlock = 320;
int32_t maxBlocksPerHW = 1024;
switch (input_desc[0].dims.d[1]) {
case 960:
case 1920:
cPerBlock = 480;
break;
case 512:
case 256:
cPerBlock = 256;
break;
case 128:
cPerBlock = 128;
break;
default:
cPerBlock = 320;
}
params_.withSwish = true;
params_.dst = static_cast<half *>(outputs[0]);
params_.srcX = static_cast<half const *>(inputs[0]);
params_.srcY = static_cast<half const *>(inputs[1]);
params_.gamma = scale_gpu_.get();
params_.beta = bias_gpu_.get();
params_.redBuffer = static_cast<float *>(workspace);
params_.n = input_desc[0].dims.d[0];
params_.h = input_desc[0].dims.d[2];
params_.w = input_desc[0].dims.d[3];
params_.c = input_desc[0].dims.d[1];
params_.groups = groups_;
params_.hw = params_.h * params_.w;
const int32_t blocksPerHW = findMaxDivisor(params_.hw, maxBlocksPerHW);
params_.hwPerBlock = divUp(params_.hw, blocksPerHW);
params_.cPerBlock = cPerBlock;
params_.cPerGroup = params_.c / params_.groups;
params_.hwc = params_.hw * params_.c;
params_.invHWC = 1.F / static_cast<float>(params_.hw * params_.cPerGroup);
params_.groupsPerBlock = cPerBlock / params_.cPerGroup;
params_.eps = eps_;
cudaMemsetAsync(params_.redBuffer, 0, ws_, stream);
skipGroupNormNHWCSum(params_, stream);
skipGroupNormNHWCScale(params_, stream);
} else {
// input not fp16
PADDLE_THROW(platform::errors::Fatal(
"The SkipGroupnormAct TRT Plugin's only support fp16 input"));
}
return cudaGetLastError() != cudaSuccess;
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
class SkipGroupnormActPluginDynamic : public DynamicPluginTensorRT {
public:
SkipGroupnormActPluginDynamic(const float* scale,
const int scale_num,
const float* bias,
const int bias_num,
float eps,
int groups,
bool with_fp16,
std::shared_ptr<void> scale_gpu = nullptr,
std::shared_ptr<void> bias_gpu = nullptr)
: scale_gpu_(scale_gpu),
bias_gpu_(bias_gpu),
groups_(groups),
eps_(eps),
with_fp16_(with_fp16) {
scale_.resize(scale_num);
bias_.resize(bias_num);
std::copy(scale, scale + scale_num, scale_.data());
std::copy(bias, bias + bias_num, bias_.data());
if (scale_gpu_ == nullptr) {
void* p;
cudaMalloc(reinterpret_cast<void**>(&p), scale_num * sizeof(float));
scale_gpu_.reset(p, [](void* ptr) { cudaFree(ptr); });
cudaMemcpy(
p, scale_.data(), scale_num * sizeof(float), cudaMemcpyHostToDevice);
}
if (bias_gpu_ == nullptr) {
void* p;
cudaMalloc(reinterpret_cast<void**>(&p), bias_num * sizeof(float));
bias_gpu_.reset(p, [](void* ptr) { cudaFree(ptr); });
cudaMemcpy(
p, bias_.data(), bias_num * sizeof(float), cudaMemcpyHostToDevice);
}
}
SkipGroupnormActPluginDynamic(void const* serialData, size_t serialLength) {
DeserializeValue(&serialData, &serialLength, &scale_);
DeserializeValue(&serialData, &serialLength, &bias_);
DeserializeValue(&serialData, &serialLength, &eps_);
DeserializeValue(&serialData, &serialLength, &groups_);
DeserializeValue(&serialData, &serialLength, &with_fp16_);
{
void* p;
cudaMalloc(reinterpret_cast<void**>(&p), scale_.size() * sizeof(float));
scale_gpu_.reset(p, [](void* ptr) { cudaFree(ptr); });
cudaMemcpy(p,
scale_.data(),
scale_.size() * sizeof(float),
cudaMemcpyHostToDevice);
}
{
void* p;
cudaMalloc(reinterpret_cast<void**>(&p), bias_.size() * sizeof(float));
bias_gpu_.reset(p, [](void* ptr) { cudaFree(ptr); });
cudaMemcpy(p,
bias_.data(),
bias_.size() * sizeof(float),
cudaMemcpyHostToDevice);
}
}
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
auto* ptr = new SkipGroupnormActPluginDynamic(scale_.data(),
scale_.size(),
bias_.data(),
bias_.size(),
eps_,
groups_,
with_fp16_,
scale_gpu_,
bias_gpu_);
return ptr;
}
const char* getPluginType() const TRT_NOEXCEPT override {
return "skip_groupnorm_act_plugin_dynamic";
}
int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
int initialize() TRT_NOEXCEPT override;
size_t getSerializationSize() const TRT_NOEXCEPT override {
return SerializedSize(scale_) + SerializedSize(bias_) +
SerializedSize(eps_) + SerializedSize(groups_) +
SerializedSize(with_fp16_);
}
void serialize(void* buffer) const TRT_NOEXCEPT override {
SerializeValue(&buffer, scale_);
SerializeValue(&buffer, bias_);
SerializeValue(&buffer, eps_);
SerializeValue(&buffer, groups_);
SerializeValue(&buffer, with_fp16_);
}
nvinfer1::DimsExprs getOutputDimensions(
int output_index,
const nvinfer1::DimsExprs* inputs,
int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) // NOLINT
TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) TRT_NOEXCEPT override {
// sizeof(float2) * maxBatchSize * maxNumberOfGroup. float2
// contians two buffers for sum and squared sum;
ws_ = sizeof(float) * 2 * in[0].max.d[0] * groups_;
}
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const TRT_NOEXCEPT override {
return ws_;
}
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) TRT_NOEXCEPT override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* inputTypes,
int nbInputs) const
TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override { delete this; }
void terminate() TRT_NOEXCEPT override{};
private:
size_t ws_;
std::vector<float> scale_;
std::vector<float> bias_;
std::shared_ptr<void> scale_gpu_;
std::shared_ptr<void> bias_gpu_;
GroupNormNHWCParams params_;
int groups_;
float eps_;
bool with_fp16_;
};
class SkipGroupnormActPluginDynamicCreator : public TensorRTPluginCreator {
public:
const char* getPluginName() const TRT_NOEXCEPT override {
return "skip_groupnorm_act_plugin_dynamic";
}
const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length)
TRT_NOEXCEPT override {
return new SkipGroupnormActPluginDynamic(serial_data, serial_length);
}
};
REGISTER_TRT_PLUGIN_V2(SkipGroupnormActPluginDynamicCreator);
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
type: "group_norm"
def {
inputs {
name: "X"
}
inputs {
name: "Scale"
}
inputs {
name: "Bias"
}
outputs {
name: "Y"
}
outputs {
name: "Mean"
}
outputs {
name: "Variance"
}
attrs {
name: "epsilon"
type: FLOAT
}
attrs {
name: "groups"
type: INT
}
attrs {
name: "data_layout"
type: STRING
}
}
extra {
attrs {
name: "use_mkldnn"
type: BOOLEAN
}
attrs {
name: "mkldnn_data_type"
type: STRING
}
attrs {
name: "is_test"
type: BOOLEAN
}
attrs {
name: "op_role"
type: INT
}
attrs {
name: "op_role_var"
type: STRINGS
}
attrs {
name: "op_namescope"
type: STRING
}
attrs {
name: "op_callstack"
type: STRINGS
}
attrs {
name: "op_device"
type: STRING
}
}
type: "silu"
def {
inputs {
name: "X"
}
outputs {
name: "Out"
}
}
extra {
attrs {
name: "op_role"
type: INT
}
attrs {
name: "op_role_var"
type: STRINGS
}
attrs {
name: "op_namescope"
type: STRING
}
attrs {
name: "op_callstack"
type: STRINGS
}
attrs {
name: "op_device"
type: STRING
}
}
......@@ -35,6 +35,10 @@ endif()
if(WIN32)
list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES
"test_trt_convert_fused_token_prune")
list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES
"test_preln_groupnorm_act_fuse_pass")
list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES
"test_element_groupnorm_act_fuse_pass")
list(REMOVE_ITEM TEST_TRT_IR_PASSES "test_trt_convert_fused_token_prune")
list(REMOVE_ITEM TEST_TRT_CONVERTER "test_trt_convert_fused_token_prune")
endif()
......@@ -217,6 +221,10 @@ if(WITH_GPU AND TENSORRT_FOUND)
set_tests_properties(test_map_matmul_v2_to_mul_pass PROPERTIES TIMEOUT
120)
set_tests_properties(test_map_matmul_to_mul_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_element_groupnorm_act_fuse_pass
PROPERTIES TIMEOUT 120)
set_tests_properties(test_preln_groupnorm_act_fuse_pass PROPERTIES TIMEOUT
120)
endif()
endif()
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
import hypothesis.strategies as st
import numpy as np
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
import paddle.inference as paddle_infer
class TestElementGNActPass(PassAutoScanTest):
#
# | | | |
# other_op1 other_op2 other_op1 other_op2
# | | fuse \ /
# elementwise_add -> skip_groupnorm_act
# | |
# groupnorm
# |
# silu
def sample_predictor_configs(self, program_config):
# trt dynamic_shape
config = self.create_trt_inference_config()
config.enable_tensorrt_engine(
max_batch_size=1,
workspace_size=102400,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Half,
use_static=False,
use_calib_mode=False,
)
config.set_trt_dynamic_shape_info(
{
"input_data_x": [1, 160, 1, 1],
"input_data_y": [1, 160, 1, 1],
},
{
"input_data_x": [4, 1280, 64, 64],
"input_data_y": [4, 1280, 1, 1],
},
{
"input_data_x": [1, 320, 1, 1],
"input_data_y": [1, 320, 1, 1],
},
)
yield config, ['skip_groupnorm_act'], (3e-3, 1e-3)
def sample_program_config(self, draw):
axis = draw(st.sampled_from([0, -1]))
epsilon = draw(st.floats(min_value=0.0000001, max_value=0.001))
batch_size = draw(st.integers(min_value=1, max_value=4))
groups = draw(st.sampled_from([4, 8, 16, 32]))
hw = draw(st.sampled_from([1, 8, 16, 32]))
channel = draw(st.sampled_from([320, 1280]))
def generate_input_x(attrs):
return np.random.random(
[attrs[1]["batch_size"], *attrs[1]["input_dim_x"]]
).astype(np.float32)
def generate_input_y(attrs):
return np.random.random(
[attrs[1]["batch_size"], *attrs[1]["input_dim_y"]]
).astype(np.float32)
def generate_weight(attrs):
return np.random.random(attrs[1]['input_dim_x'][0]).astype(
np.float32
)
attrs = [
{
'axis': axis,
'epsilon': epsilon,
'groups': groups,
},
{
'batch_size': batch_size,
'input_dim_x': [channel, hw, hw],
'input_dim_y': [channel, 1, 1],
},
]
elementwise_add_op = OpConfig(
type="elementwise_add",
inputs={"X": ["input_data_x"], "Y": ["input_data_y"]},
outputs={"Out": ["ele_out"]},
attrs={"axis": attrs[0]['axis']},
)
group_norm_op = OpConfig(
type="group_norm",
inputs={
"X": ["ele_out"],
"Bias": ["group_norm_bias"],
"Scale": ["group_norm_scale"],
},
outputs={
"Y": ["group_norm_output1"],
"Mean": ["group_norm_output2"],
"Variance": ["group_norm_output3"],
},
attrs={
"data_layout": "NCHW",
"groups": attrs[0]["groups"],
"epsilon": attrs[0]["epsilon"],
},
)
silu_op = OpConfig(
type="silu",
inputs={
"X": ["group_norm_output1"],
},
outputs={
"Out": ["silu_output"],
},
)
program_config = ProgramConfig(
ops=[
elementwise_add_op,
group_norm_op,
silu_op,
],
weights={
"group_norm_bias": TensorConfig(
data_gen=partial(generate_weight, attrs)
),
"group_norm_scale": TensorConfig(
data_gen=partial(generate_weight, attrs)
),
},
inputs={
"input_data_x": TensorConfig(
data_gen=partial(generate_input_x, attrs)
),
"input_data_y": TensorConfig(
data_gen=partial(generate_input_y, attrs)
),
},
outputs=["silu_output"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=50,
passes=["elementwise_groupnorm_act_pass"],
max_duration=250,
min_success_num=50,
)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
import hypothesis.strategies as st
import numpy as np
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
import paddle.inference as paddle_infer
class TestElementGNActPass(PassAutoScanTest):
#
# | | | |
# other_op1 other_op2 other_op1 other_op2
# | | fuse \ /
# elementwise_add -> preln_groupnorm_act
# | | | |
# other_op3 groupnorm other_op3
# |
# silu
def sample_predictor_configs(self, program_config):
# trt dynamic_shape
config = self.create_trt_inference_config()
config.enable_tensorrt_engine(
max_batch_size=1,
workspace_size=102400,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Half,
use_static=False,
use_calib_mode=False,
)
config.set_trt_dynamic_shape_info(
{
"input_data_x": [1, 160, 1, 1],
"input_data_y": [1, 160, 1, 1],
},
{
"input_data_x": [4, 1280, 64, 64],
"input_data_y": [4, 1280, 64, 64],
},
{
"input_data_x": [1, 320, 32, 32],
"input_data_y": [1, 320, 32, 32],
},
)
yield config, ['preln_groupnorm_act'], (3e-3, 1e-3)
def sample_program_config(self, draw):
axis = draw(st.sampled_from([0, -1]))
epsilon = draw(st.floats(min_value=0.0000001, max_value=0.001))
batch_size = draw(st.integers(min_value=1, max_value=4))
groups = draw(st.sampled_from([4, 8, 16, 32]))
hw = draw(st.sampled_from([1, 8, 16, 32]))
channel = draw(st.sampled_from([320, 1280]))
def generate_input_x(attrs):
return np.random.random(
[attrs[1]["batch_size"], *attrs[1]["input_dim_x"]]
).astype(np.float32)
def generate_input_y(attrs):
return np.random.random(
[attrs[1]["batch_size"], *attrs[1]["input_dim_y"]]
).astype(np.float32)
def generate_weight(attrs):
return np.random.random(attrs[1]['input_dim_x'][0]).astype(
np.float32
)
attrs = [
{
'axis': axis,
'epsilon': epsilon,
'groups': groups,
},
{
'batch_size': batch_size,
'input_dim_x': [channel, hw, hw],
'input_dim_y': [channel, hw, hw],
},
]
elementwise_add_op = OpConfig(
type="elementwise_add",
inputs={"X": ["input_data_x"], "Y": ["input_data_y"]},
outputs={"Out": ["ele_out"]},
attrs={"axis": attrs[0]['axis']},
)
group_norm_op = OpConfig(
type="group_norm",
inputs={
"X": ["ele_out"],
"Bias": ["group_norm_bias"],
"Scale": ["group_norm_scale"],
},
outputs={
"Y": ["group_norm_output1"],
"Mean": ["group_norm_output2"],
"Variance": ["group_norm_output3"],
},
attrs={
"data_layout": "NCHW",
"groups": attrs[0]["groups"],
"epsilon": attrs[0]["epsilon"],
},
)
silu_op = OpConfig(
type="silu",
inputs={
"X": ["group_norm_output1"],
},
outputs={
"Out": ["silu_output"],
},
)
program_config = ProgramConfig(
ops=[
elementwise_add_op,
group_norm_op,
silu_op,
],
weights={
"group_norm_bias": TensorConfig(
data_gen=partial(generate_weight, attrs)
),
"group_norm_scale": TensorConfig(
data_gen=partial(generate_weight, attrs)
),
},
inputs={
"input_data_x": TensorConfig(
data_gen=partial(generate_input_x, attrs)
),
"input_data_y": TensorConfig(
data_gen=partial(generate_input_y, attrs)
),
},
outputs=["ele_out", "silu_output"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=50,
passes=["preln_elementwise_groupnorm_act_pass"],
max_duration=250,
min_success_num=50,
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册