未验证 提交 fd0d4fa4 编写于 作者: W Wang Bojun 提交者: GitHub

[TRT] elementwise_add+transpose fusion (#50081)

* eleadd_trans first version

log fix

* refine code for linear format, add pass check

* linear format refine and ut fix

* fix ut

* windows ut

* windows ut 2

* move tensorMeta and alloc to configure
上级 22bcb75a
......@@ -140,6 +140,7 @@ if(WITH_TENSORRT)
pass_library(delete_remove_padding_recover_padding_pass inference)
pass_library(layernorm_shift_partition_fuse_pass inference)
pass_library(reverse_roll_fuse_pass inference)
pass_library(elementwiseadd_transpose_pass inference)
pass_library(preln_layernorm_x_fuse_pass inference)
pass_library(trt_support_nhwc_pass inference)
pass_library(elementwise_groupnorm_act_pass inference)
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/elementwiseadd_transpose_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
class Node;
} // namespace ir
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
// input_x input_y, input_x and input_y are both (n,h*w,c)
// | |
// elementwise_add (n,h*w,c)
// |
// reshape (n,h, w,c)
// |
// transpose ((n,c,h,w))
// |
//
// fuse ->
//
// |
// elementwiseadd_transpose
// |
struct ElementwiseAddTransposePattern : public PatternBase {
ElementwiseAddTransposePattern(PDPattern *pattern,
const std::string &name_scope)
: PatternBase(pattern, name_scope, "elementwiseadd_transpose") {}
void operator()(PDNode *x, PDNode *y);
PATTERN_DECL_NODE(elementwise);
PATTERN_DECL_NODE(elementwise_out);
PATTERN_DECL_NODE(reshape);
PATTERN_DECL_NODE(reshape_out);
PATTERN_DECL_NODE(transpose);
PATTERN_DECL_NODE(transpose_out);
};
void ElementwiseAddTransposePattern::operator()(PDNode *x, PDNode *y) {
auto *elementwise = pattern->NewNode(elementwise_repr())
->assert_is_op("elementwise_add")
->assert_has_n_outputs(1);
auto *elementwise_out = pattern->NewNode(elementwise_out_repr())
->assert_is_op_output("elementwise_add")
->assert_is_op_input("reshape2")
->AsIntermediate();
elementwise->LinksFrom({x, y}).LinksTo({elementwise_out});
auto *reshape = pattern->NewNode(reshape_repr())->assert_is_op("reshape2");
auto *reshape_out = pattern->NewNode(reshape_out_repr())
->assert_is_op_output("reshape2")
->assert_is_op_input("transpose2")
->AsIntermediate();
reshape->LinksFrom({elementwise_out}).LinksTo({reshape_out});
auto *transpose =
pattern->NewNode(transpose_repr())->assert_is_op("transpose2");
auto *transpose_out = pattern->NewNode(transpose_out_repr())
->assert_is_op_output("transpose2")
->AsOutput();
transpose->LinksFrom({reshape_out}).LinksTo({transpose_out});
}
} // namespace patterns
int ElementwiseAddTransposeFusePass::ApplyEleTransPattern(
ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
FusePassBase::Init("eleadd_transpose_fuse", graph);
int found_subgraph_count = 0;
GraphPatternDetector gpd;
PDNode *x = nullptr;
PDNode *y = nullptr;
x = gpd.mutable_pattern()
->NewNode("eleadd_transpose/x")
->AsInput()
->assert_var_not_persistable()
->assert_is_op_input("elementwise_add", "X");
y = gpd.mutable_pattern()
->NewNode("eleadd_transpose/y")
->AsInput()
->assert_var_not_persistable()
->assert_is_op_input("elementwise_add", "Y");
patterns::ElementwiseAddTransposePattern fused_pattern(
gpd.mutable_pattern(), "eleadd_transpose_fuse");
fused_pattern(x, y);
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *graph) {
if (subgraph.count(x) <= 0 || subgraph.count(y) <= 0) {
LOG(WARNING) << "The subgraph is empty.";
return;
}
VLOG(4) << "handle elementwiseadd transpose fuse";
GET_IR_NODE_FROM_SUBGRAPH(elementwise, elementwise, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape, reshape, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape_out, reshape_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose, transpose, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose_out, transpose_out, fused_pattern);
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "elementwiseadd transpose pass in op compat failed.";
return;
}
// elementwiseadd_trans is suit for
// nhwc-to-nchw transpose after elementwise add
// check for it
std::vector<int> trans_axis =
PADDLE_GET_CONST(std::vector<int>, transpose->Op()->GetAttr("axis"));
if (trans_axis != std::vector<int>{0, 3, 1, 2}) {
VLOG(1)
<< "elementwiseadd transpose fuse pass, transpose axis check fail, "
"stop fusion";
return;
}
if (!reshape->Op()->HasAttr("shape")) {
VLOG(1) << "reshape op in elementwise_add_transpose fusion do not found "
"shape attr, the fusion will be stoped.";
return;
}
std::vector<int> shape_attr =
PADDLE_GET_CONST(std::vector<int>, reshape->Op()->GetAttr("shape"));
VLOG(4) << "Fuse elementwiseadd transpose, with reshape attr:"
<< shape_attr[0] << ", " << shape_attr[1] << ", " << shape_attr[2]
<< ", " << shape_attr[3];
if (shape_attr[1] <= 0 || shape_attr[2] <= 0 || shape_attr.size() != 4) {
VLOG(1) << "found that shape_attr[1] and shape_attr[2]<=0 for reshape op "
"in elementwise_add_transpose, "
"currently, the elementwiseadd transpose pass only support "
"reshape bay shape attr rather than shape tensor."
"Therefore, the fusion will be stoped.";
return;
}
if (shape_attr[3] % 8 != 0) {
VLOG(1)
<< "found that shape_attr[3](channel size) mod 8 !=0 for reshape op "
"in elementwise_add_transpose, "
"currently, the elementwiseadd transpose pass only support "
"channel size mod 8 == 0 for khwc8 trt format"
"Therefore, the fusion will be stoped.";
return;
}
std::unordered_set<const Node *> del_node_set;
OpDesc new_desc;
new_desc.SetType("fuse_eleadd_transpose");
new_desc.SetInput("X", {subgraph.at(x)->Name()});
new_desc.SetInput("Y", {subgraph.at(y)->Name()});
new_desc.SetOutput("Out", {transpose_out->Name()});
new_desc.SetAttr("axis", elementwise->Op()->GetAttr("axis"));
new_desc.SetAttr("output_shape", shape_attr);
new_desc.Flush();
auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied.
del_node_set.insert(elementwise);
del_node_set.insert(elementwise_out);
del_node_set.insert(reshape);
del_node_set.insert(reshape_out);
del_node_set.insert(transpose);
GraphSafeRemoveNodes(graph, del_node_set);
IR_NODE_LINK_TO(subgraph.at(x), fused_node);
IR_NODE_LINK_TO(subgraph.at(y), fused_node);
IR_NODE_LINK_TO(fused_node, transpose_out);
found_subgraph_count++;
};
gpd(graph, handler);
return found_subgraph_count;
}
void ElementwiseAddTransposeFusePass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::InvalidArgument(
"Pointer to graph argument should not be NULL."));
FusePassBase::Init("elementwiseadd_transpose_fuse_pass", graph);
int found_subgraph_count = ApplyEleTransPattern(graph);
AddStatis(found_subgraph_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(elementwiseadd_transpose_pass,
paddle::framework::ir::ElementwiseAddTransposeFusePass);
REGISTER_PASS_CAPABILITY(elementwiseadd_transpose_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("elementwise_add", 1)
.EQ("reshape2", 0)
.EQ("transpose2", 0));
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
// input_x input_y, input_x and input_y are both (n,h*w,c)
// | |
// elementwise_add (n,h*w,c)
// |
// reshape (n,h, w,c)
// |
// transpose ((n,c,h,w))
// |
//
// fuse ->
//
// |
// elementwiseadd_transpose
// |
class Graph;
class ElementwiseAddTransposeFusePass : public FusePassBase {
public:
ElementwiseAddTransposeFusePass() {
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({-1})
.End();
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("ShapeTensor")
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis")
.IsType<std::vector<int>>() // 0,3,2,1 nchw->nhwc
.End();
}
virtual ~ElementwiseAddTransposeFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
int ApplyEleTransPattern(ir::Graph* graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -2514,6 +2514,7 @@ USE_TRT_CONVERTER(trans_layernorm)
USE_TRT_CONVERTER(skip_merge_layernorm)
USE_TRT_CONVERTER(generic_plugin_creater)
USE_TRT_CONVERTER(custom_plugin_creater)
USE_TRT_CONVERTER(fuse_eleadd_transpose)
USE_TRT_CONVERTER(tanh_shrink)
USE_TRT_CONVERTER(logsigmoid)
USE_TRT_CONVERTER(lookup_table)
......
......@@ -141,6 +141,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
"elementwise_groupnorm_act_pass", //
"preln_elementwise_groupnorm_act_pass", //
"groupnorm_act_pass", //
"elementwiseadd_transpose_pass", //
#endif
"tensorrt_subgraph_pass", //
"conv_bn_fuse_pass", //
......
......@@ -97,6 +97,7 @@ list(
skip_merge_layernorm_op.cc
generic_and_custom_plugin_creater.cc
fused_lookup_tables_op.cc
elementwiseadd_transpose_op.cc
skip_groupnorm_act_op.cc
preln_groupnorm_act_op.cc
expand_v2_op.cc)
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/elementwiseadd_transpose_op_plugin.h"
#include <vector>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
namespace paddle {
namespace framework {
class Scope;
namespace proto {
class OpDesc;
} // namespace proto
} // namespace framework
} // namespace paddle
namespace paddle {
namespace inference {
namespace tensorrt {
class ElementwiseaddTransposeOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
VLOG(4) << "convert a fuse_elementwiseadd_transpose op to tensorrt "
"elementwiseadd_transpose plugin";
framework::OpDesc op_desc(op, nullptr);
auto* input_x = engine_->GetITensor(op_desc.Input("X").front());
auto* input_y = engine_->GetITensor(op_desc.Input("Y").front());
std::vector<nvinfer1::ITensor*> inputs{input_x, input_y};
int axis = PADDLE_GET_CONST(int, op_desc.GetAttr("axis"));
std::vector<int> output_shape =
PADDLE_GET_CONST(std::vector<int>, op_desc.GetAttr("output_shape"));
if (engine_->with_dynamic_shape()) {
plugin::ElementwiseAddTransposePluginDynamic* plugin =
new plugin::ElementwiseAddTransposePluginDynamic(axis, output_shape);
nvinfer1::ILayer* elementwise_layer =
engine_->AddDynamicPlugin(inputs.data(), 2, plugin);
std::vector<std::string> output_names;
output_names.emplace_back(op_desc.Output("Out").front());
RreplenishLayerAndOutput(elementwise_layer,
"fuse_elementwiseadd_transpose",
output_names,
test_mode);
}
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(fuse_eleadd_transpose,
ElementwiseaddTransposeOpConverter);
......@@ -2496,6 +2496,13 @@ struct SimpleOpTypeSetTeller : public Teller {
return false;
}
}
if (op_type == "fuse_eleadd_transpose") {
if (!with_dynamic_shape) {
VLOG(3) << "The fuse_eleadd_transpose op does not support "
"static shape yet";
return false;
}
}
if (op_type == "lookup_table") {
if (!with_dynamic_shape) {
VLOG(3) << "the lookup_table does not support "
......@@ -2670,6 +2677,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"skip_merge_layernorm",
"lookup_table_v2",
"expand_v2",
"fuse_eleadd_transpose",
"skip_groupnorm_act",
"preln_groupnorm_act"};
......@@ -2821,6 +2829,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"lookup_table",
"lookup_table_v2",
"expand_v2",
"fuse_eleadd_transpose",
"skip_groupnorm_act",
"preln_groupnorm_act"};
};
......
......@@ -38,6 +38,7 @@ list(
skip_merge_layernorm_op_plugin.cu
skip_groupnorm_act_op_plugin.cu
preln_groupnorm_act_op_plugin.cu
elementwiseadd_transpose_op_plugin.cu
generic_plugin.cu
lookup_table.cu
many_emb_layernorm_plugin.cu
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/elementwiseadd_transpose_op_plugin.h"
#include <glog/logging.h>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/elementwise_add_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
int ElementwiseAddTransposePluginDynamic::initialize() TRT_NOEXCEPT {
return 0;
}
size_t ElementwiseAddTransposePluginDynamic::getSerializationSize() const
TRT_NOEXCEPT {
return SerializedSize(axis_) + SerializedSize(output_shape_);
}
void ElementwiseAddTransposePluginDynamic::serialize(void *buffer) const
TRT_NOEXCEPT {
SerializeValue(&buffer, axis_);
SerializeValue(&buffer, output_shape_);
}
nvinfer1::DimsExprs ElementwiseAddTransposePluginDynamic::getOutputDimensions(
int output_index,
const nvinfer1::DimsExprs *inputs,
int nb_inputs,
nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT {
nvinfer1::DimsExprs ret;
ret.nbDims = 4;
ret.d[0] = inputs[0].d[0];
ret.d[1] = inputs[0].d[2];
ret.d[2] = expr_builder.constant(output_shape_[1]);
ret.d[3] = expr_builder.constant(output_shape_[2]);
return ret;
}
bool ElementwiseAddTransposePluginDynamic::supportsFormatCombination(
int pos,
const nvinfer1::PluginTensorDesc *in_out,
int nb_inputs,
int nb_outputs) TRT_NOEXCEPT {
PADDLE_ENFORCE_NOT_NULL(
in_out,
platform::errors::InvalidArgument("The input of elementwiseadd_transpose "
"plugin shoule not be nullptr."));
PADDLE_ENFORCE_LT(
pos,
nb_inputs + nb_outputs,
platform::errors::InvalidArgument("The pos(%d) should be less than the "
"num(%d) of the input and the output.",
pos,
nb_inputs + nb_outputs));
// (in_out && pos < (nb_inputs + nb_outputs));
const nvinfer1::PluginTensorDesc &in = in_out[pos];
// input 0
if (pos == 0) {
return (in.type == nvinfer1::DataType::kHALF ||
in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
}
// input 1
if (pos == 1) {
return (in.type == in_out[0].type) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
}
// output 0
if (pos == 2) {
return (in.type == in_out[0].type) &&
(in.format == nvinfer1::TensorFormat::kLINEAR ||
in.format == nvinfer1::TensorFormat::kHWC8);
}
}
void ElementwiseAddTransposePluginDynamic::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc *input_desc,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *output_desc,
int nbOutputs) TRT_NOEXCEPT {
const auto &x_dims = input_desc[0].desc.dims;
const auto &y_dims = input_desc[1].desc.dims;
const auto &out_dims = output_desc[0].desc.dims;
const auto &x_type = input_desc[0].desc.type;
std::vector<int> x_shape;
int x_numel = 1;
for (int i = 0; i < x_dims.nbDims; i++) {
x_shape.push_back(x_dims.d[i]);
x_numel *= x_dims.d[i];
}
std::vector<int> y_shape;
int y_numel = 1;
for (int i = 0; i < y_dims.nbDims; i++) {
y_shape.push_back(y_dims.d[i]);
y_numel *= y_dims.d[i];
}
std::vector<int> out_shape;
int out_numel = 1;
for (int i = 0; i < out_dims.nbDims; i++) {
out_shape.push_back(out_dims.d[i]);
out_numel *= out_dims.d[i];
}
x_numel_ = x_numel;
y_numel_ = y_numel;
out_numel_ = out_numel;
if (x_numel <= 0) {
return;
}
ele_out_tensor_.Resize(phi::make_ddim(x_shape));
paddle::platform::DeviceContextPool &pool =
paddle::platform::DeviceContextPool::Instance();
platform::CUDAPlace place(platform::GetCurrentDeviceId());
auto *device_context = static_cast<phi::GPUContext *>(pool.Get(place));
const phi::GPUContext &dev_ctx = *device_context;
if (x_type == nvinfer1::DataType::kFLOAT) {
x_meta_ =
phi::DenseTensorMeta(phi::DataType::FLOAT32, phi::make_ddim(x_shape));
y_meta_ =
phi::DenseTensorMeta(phi::DataType::FLOAT32, phi::make_ddim(y_shape));
out_meta_ =
phi::DenseTensorMeta(phi::DataType::FLOAT32, phi::make_ddim(out_shape));
dev_ctx.template Alloc<float>(&ele_out_tensor_, x_numel * sizeof(float));
} else if (x_type == nvinfer1::DataType::kHALF) {
x_meta_ =
phi::DenseTensorMeta(phi::DataType::FLOAT16, phi::make_ddim(x_shape));
y_meta_ =
phi::DenseTensorMeta(phi::DataType::FLOAT16, phi::make_ddim(y_shape));
out_meta_ =
phi::DenseTensorMeta(phi::DataType::FLOAT16, phi::make_ddim(out_shape));
dev_ctx.template Alloc<phi::dtype::float16>(
&ele_out_tensor_, x_numel * sizeof(phi::dtype::float16));
}
}
nvinfer1::DataType ElementwiseAddTransposePluginDynamic::getOutputDataType(
int index,
const nvinfer1::DataType *input_types,
int nb_inputs) const TRT_NOEXCEPT {
PADDLE_ENFORCE_EQ(index,
0,
platform::errors::InvalidArgument(
"The Elementwise Plugin only has one input, so the "
"index value should be 0, but get %d.",
index));
return input_types[0];
}
int ElementwiseAddTransposePluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc,
const void *const *inputs,
void *const *outputs,
void *workspace,
cudaStream_t stream) TRT_NOEXCEPT {
paddle::platform::DeviceContextPool &pool =
paddle::platform::DeviceContextPool::Instance();
platform::CUDAPlace place(platform::GetCurrentDeviceId());
auto *device_context = static_cast<phi::GPUContext *>(pool.Get(place));
const phi::GPUContext &dev_ctx = *device_context;
auto input_type = input_desc[0].type;
auto output_format = output_desc[0].format;
if (input_type == nvinfer1::DataType::kFLOAT) {
VLOG(1) << "TRT Plugin DataType selected. elementwiseadd_transpose-->fp32";
const float *x = static_cast<const float *>(inputs[0]);
const float *y = static_cast<const float *>(inputs[1]);
float *out = static_cast<float *>(outputs[0]);
VLOG(1) << "TRT Plugin format selected. elementwiseadd_transpose-->kLINEAR";
std::shared_ptr<phi::Allocation> x_alloc(new phi::Allocation(
static_cast<void *>(const_cast<float *>(x)), // NOLINT
x_numel_ * sizeof(float),
place));
std::shared_ptr<phi::Allocation> y_alloc(new phi::Allocation(
static_cast<void *>(const_cast<float *>(y)), // NOLINT
y_numel_ * sizeof(float),
place));
std::shared_ptr<phi::Allocation> out_alloc(
new phi::Allocation(static_cast<void *>(out), // NOLINT
out_numel_ * sizeof(float),
place));
const phi::DenseTensor x_tensor = phi::DenseTensor(x_alloc, x_meta_);
const phi::DenseTensor y_tensor = phi::DenseTensor(y_alloc, y_meta_);
phi::DenseTensor out_tensor = phi::DenseTensor(out_alloc, out_meta_);
phi::AddKernel<float, phi::GPUContext>(
dev_ctx, x_tensor, y_tensor, &ele_out_tensor_);
phi::TransposeKernel<float, phi::GPUContext>(
dev_ctx, ele_out_tensor_, std::vector<int>{0, 2, 1}, &out_tensor);
} else if (input_type == nvinfer1::DataType::kHALF) {
VLOG(1) << "TRT Plugin DataType selected. elementwiseadd_transpose-->fp16";
const half *x = static_cast<const half *>(inputs[0]);
const half *y = static_cast<const half *>(inputs[1]);
half *out = static_cast<half *>(outputs[0]);
if (output_format == nvinfer1::PluginFormat::kLINEAR) {
VLOG(1)
<< "TRT Plugin format selected. elementwiseadd_transpose-->kLINEAR";
std::shared_ptr<phi::Allocation> x_alloc(new phi::Allocation(
static_cast<void *>(const_cast<half *>(x)), // NOLINT
x_numel_ * sizeof(half),
place));
std::shared_ptr<phi::Allocation> y_alloc(new phi::Allocation(
static_cast<void *>(const_cast<half *>(y)), // NOLINT
y_numel_ * sizeof(half),
place));
std::shared_ptr<phi::Allocation> out_alloc(
new phi::Allocation(static_cast<void *>(out), // NOLINT
out_numel_ * sizeof(half),
place));
const phi::DenseTensor x_tensor = phi::DenseTensor(x_alloc, x_meta_);
const phi::DenseTensor y_tensor = phi::DenseTensor(y_alloc, y_meta_);
phi::DenseTensor out_tensor = phi::DenseTensor(out_alloc, out_meta_);
phi::AddKernel<phi::dtype::float16, phi::GPUContext>(
dev_ctx, x_tensor, y_tensor, &ele_out_tensor_);
phi::TransposeKernel<phi::dtype::float16, phi::GPUContext>(
dev_ctx, ele_out_tensor_, std::vector<int>{0, 2, 1}, &out_tensor);
} else if (output_format == nvinfer1::PluginFormat::kHWC8) {
VLOG(1) << "TRT Plugin format selected. elementwiseadd_transpose-->kHWC8";
std::shared_ptr<phi::Allocation> x_alloc(new phi::Allocation(
static_cast<void *>(const_cast<half *>(x)), // NOLINT
x_numel_ * sizeof(half),
place));
std::shared_ptr<phi::Allocation> y_alloc(new phi::Allocation(
static_cast<void *>(const_cast<half *>(y)), // NOLINT
y_numel_ * sizeof(half),
place));
std::shared_ptr<phi::Allocation> out_alloc(
new phi::Allocation(static_cast<void *>(out), // NOLINT
out_numel_ * sizeof(half),
place));
const phi::DenseTensor x_tensor = phi::DenseTensor(x_alloc, x_meta_);
const phi::DenseTensor y_tensor = phi::DenseTensor(y_alloc, y_meta_);
phi::DenseTensor out_tensor = phi::DenseTensor(out_alloc, out_meta_);
phi::AddKernel<phi::dtype::float16, phi::GPUContext>(
dev_ctx, x_tensor, y_tensor, &out_tensor);
}
}
return cudaGetLastError() != cudaSuccess;
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
class ElementwiseAddTransposePluginDynamic : public DynamicPluginTensorRT {
public:
explicit ElementwiseAddTransposePluginDynamic(int axis,
std::vector<int> output_shape)
: axis_(axis), output_shape_(output_shape) {}
ElementwiseAddTransposePluginDynamic(void const* serialData,
size_t serialLength) {
DeserializeValue(&serialData, &serialLength, &axis_);
DeserializeValue(&serialData, &serialLength, &output_shape_);
}
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
return new ElementwiseAddTransposePluginDynamic(axis_, output_shape_);
}
const char* getPluginType() const TRT_NOEXCEPT override {
return "elementwise_add_transpose_plugin_dynamic";
}
int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
int initialize() TRT_NOEXCEPT override;
size_t getSerializationSize() const TRT_NOEXCEPT override;
void serialize(void* buffer) const TRT_NOEXCEPT override;
nvinfer1::DimsExprs getOutputDimensions(
int output_index,
const nvinfer1::DimsExprs* inputs,
int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) // NOLINT
TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* input_desc,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* output_desc,
int nbOutputs) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const TRT_NOEXCEPT override {
return 0;
}
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) TRT_NOEXCEPT override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* inputTypes,
int nbInputs) const
TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override { delete this; }
private:
int axis_;
std::vector<int> output_shape_;
phi::DenseTensorMeta x_meta_;
phi::DenseTensorMeta y_meta_;
phi::DenseTensorMeta out_meta_;
phi::DenseTensor ele_out_tensor_;
int x_numel_ = -1;
int y_numel_ = -1;
int out_numel_ = -1;
};
class ElementwiseAddTransposePluginDynamicCreator
: public nvinfer1::IPluginCreator {
public:
ElementwiseAddTransposePluginDynamicCreator() {}
const char* getPluginName() const TRT_NOEXCEPT override {
return "elementwise_add_transpose_plugin_dynamic";
}
const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override {
return &field_collection_;
}
nvinfer1::IPluginV2* createPlugin(const char* name,
const nvinfer1::PluginFieldCollection* fc)
TRT_NOEXCEPT override {
return nullptr;
}
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length)
TRT_NOEXCEPT override {
auto plugin =
new ElementwiseAddTransposePluginDynamic(serial_data, serial_length);
return plugin;
}
void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override {
plugin_namespace_ = lib_namespace;
}
const char* getPluginNamespace() const TRT_NOEXCEPT override {
return plugin_namespace_.c_str();
}
private:
std::string plugin_namespace_;
std::string plugin_name_;
nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
std::vector<nvinfer1::PluginField> plugin_attributes_;
};
REGISTER_TRT_PLUGIN_V2(ElementwiseAddTransposePluginDynamicCreator);
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -36,6 +36,12 @@ if(WIN32)
list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES "test_trt_convert_trans_layernorm")
list(REMOVE_ITEM TEST_TRT_IR_PASSES "test_trt_convert_fused_token_prune")
list(REMOVE_ITEM TEST_TRT_CONVERTER "test_trt_convert_fused_token_prune")
list(REMOVE_ITEM TEST_TRT_IR_PASSES
"test_trt_convert_elementwiseadd_transpose")
list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES
"test_trt_convert_elementwiseadd_transpose")
list(REMOVE_ITEM TEST_TRT_CONVERTER
"test_trt_convert_elementwiseadd_transpose")
endif()
# Only for cpu(mkl + openblas)
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
from typing import List
import numpy as np
from program_config import ProgramConfig, TensorConfig
from trt_layer_auto_scan_test import TrtLayerAutoScanTest
import paddle.inference as paddle_infer
class TrtConvertElementwiseaddTransposeTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def sample_program_configs(self):
def conv_filter_datagen(dics):
c = dics["c"]
x = (np.random.randn(c, c, 1, 1)) / np.sqrt(c)
return x.astype(np.float32)
def conv_elementwise_bias_datagen(dics):
c = dics["c"]
x = np.random.random([dics["c"]]) * 0.1
return x.astype(np.float32)
def ele1_input_datagen(dics):
x = np.random.random(
[dics["batch"], dics["h"] * dics["w"], dics["c"]]
)
x = (x - np.mean(x)) / (np.std(x))
return x.astype(np.float32)
def ele2_input_datagen(dics):
x = np.random.random(
[dics["batch"], dics["h"] * dics["w"], dics["c"]]
)
x = (x - np.mean(x)) / (np.std(x))
return x.astype(np.float32)
for batch in [2]:
for h in [32, 64]:
for w in [32, 64]:
for c in [128, 320, 255, 133]:
dics = {"batch": batch, "h": h, "w": w, "c": c}
ops_config = [
{
"op_type": "elementwise_add",
"op_inputs": {
"X": ["ele_input_1"],
"Y": ["ele_input_2"],
},
"op_outputs": {"Out": ["elementwise_out"]},
"op_attrs": {"axis": -1},
},
{
"op_type": "reshape",
"op_inputs": {"X": ["elementwise_out"]},
"op_outputs": {
"Out": ["reshape_out"],
},
"op_attrs": {"shape": [-1, h, w, c]},
},
{
"op_type": "transpose2",
"op_inputs": {
"X": ["reshape_out"],
},
"op_outputs": {
"Out": ["transpose2_out"],
},
"op_attrs": {"axis": [0, 3, 1, 2]},
},
{
"op_type": "conv2d",
"op_inputs": {
"Input": ["transpose2_out"],
"Filter": ["conv2d_filter"],
},
"op_outputs": {
"Output": ["conv2d_output"],
},
"op_attrs": {
"dilations": [1, 1],
"padding_algorithm": "EXPLICIT",
"groups": 1,
"paddings": [0, 0],
"strides": [1, 1],
"data_format": "NCHW",
},
},
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={
"conv2d_filter": TensorConfig(
data_gen=partial(conv_filter_datagen, dics)
),
"elementwise_bias": TensorConfig(
data_gen=partial(
conv_elementwise_bias_datagen, dics
)
),
},
inputs={
"ele_input_1": TensorConfig(
data_gen=partial(ele1_input_datagen, dics)
),
"ele_input_2": TensorConfig(
data_gen=partial(ele2_input_datagen, dics)
),
},
outputs=["conv2d_output"],
)
yield program_config
def sample_predictor_configs(
self, program_config
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs, inputs):
channel = inputs['ele_input_1'].shape[2]
self.dynamic_shape.min_input_shape = {
"ele_input_1": [1, 32 * 32, channel],
"ele_input_2": [1, 32 * 32, channel],
}
self.dynamic_shape.max_input_shape = {
"ele_input_1": [4, 64 * 64, channel],
"ele_input_2": [4, 64 * 64, channel],
}
self.dynamic_shape.opt_input_shape = {
"ele_input_1": [4, 64 * 64, channel],
"ele_input_2": [4, 64 * 64, channel],
}
def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
return 1, 3
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
inputs = program_config.inputs
# just support dynamic_shape
generate_dynamic_shape(attrs, inputs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), (
1e-2,
1e-2,
) # tol 1e-2 for half
def add_skip_trt_case(self):
pass
def test(self):
self.add_skip_trt_case()
self.run_test()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册