未验证 提交 007f3614 编写于 作者: W whs 提交者: GitHub

Add passes and plugins for distributed inference of NLU (#43049)

上级 ec3e0a13
......@@ -143,6 +143,8 @@ pass_library(delete_quant_dequant_filter_op_pass inference)
pass_library(delete_weight_dequant_linear_op_pass inference)
pass_library(delete_quant_dequant_linear_op_pass inference)
pass_library(delete_dropout_op_pass inference)
pass_library(delete_c_identity_op_pass inference)
pass_library(preln_residual_bias_fuse_pass inference)
pass_library(delete_fill_constant_op_pass inference)
pass_library(simplify_with_basic_ops_pass base)
pass_library(fc_elementwise_layernorm_fuse_pass base)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/delete_c_identity_op_pass.h"
#include <string>
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
void DeleteCIdentityOpPattern::operator()() {
auto any_op_out = pattern->NewNode(any_op_out_repr())
->assert_is_op_input("c_identity", "X")
->AsInput();
auto c_identity_op =
pattern->NewNode(c_identity_op_repr())->assert_is_op("c_identity");
auto c_identity_op_out = pattern->NewNode(c_identity_op_out_repr())
->assert_is_op_output("c_identity", "Out")
->AsIntermediate();
auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput();
c_identity_op->LinksFrom({any_op_out});
c_identity_op_out->LinksFrom({c_identity_op});
any_op2->LinksFrom({c_identity_op_out});
}
} // namespace patterns
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(any_op_out); \
GET_IR_NODE(c_identity_op); \
GET_IR_NODE(c_identity_op_out); \
GET_IR_NODE(any_op2);
void DeleteCIdentityOpPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "delete_c_identity_op_pattern";
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
patterns::DeleteCIdentityOpPattern pattern(gpd.mutable_pattern(),
pattern_name);
pattern();
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
IR_NODE_LINK_TO(any_op_out, any_op2);
std::string any_op_out_name = any_op_out->Var()->Name();
std::string c_identity_op_out_name = c_identity_op_out->Var()->Name();
auto* any_op2_desc = any_op2->Op();
auto var_map = any_op2_desc->Inputs();
std::string arg_name = "";
for (auto& name_m : var_map) {
if (std::find(name_m.second.begin(), name_m.second.end(),
c_identity_op_out_name) != name_m.second.end()) {
arg_name = name_m.first;
}
}
if (arg_name.size() == 0) {
LOG(INFO) << "Delete c_identity op pass: can not find the input "
<< c_identity_op_out_name;
return;
}
// modify the any_op2's inputs
for (auto& name_m : var_map) {
if (std::find(name_m.second.begin(), name_m.second.end(),
c_identity_op_out_name) != name_m.second.end()) {
std::vector<std::string> new_inputs;
for (auto& i_n : name_m.second) {
if (i_n != c_identity_op_out_name) {
new_inputs.push_back(i_n);
}
}
new_inputs.push_back(any_op_out_name);
any_op2_desc->SetInput(name_m.first, new_inputs);
any_op2_desc->Flush();
}
}
any_op2_desc->Flush();
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph, {c_identity_op, c_identity_op_out});
};
gpd(graph, handler);
}
DeleteCIdentityOpPass::DeleteCIdentityOpPass() {
AddOpCompat(OpCompat("c_identity"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(delete_c_identity_op_pass,
paddle::framework::ir::DeleteCIdentityOpPass);
REGISTER_PASS_CAPABILITY(delete_c_identity_op_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().LE(
"c_identity", 1));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct DeleteCIdentityOpPattern : public PatternBase {
DeleteCIdentityOpPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "delete_c_identity_op_pattern") {}
void operator()();
PATTERN_DECL_NODE(any_op_out);
PATTERN_DECL_NODE(c_identity_op);
PATTERN_DECL_NODE(c_identity_op_out);
PATTERN_DECL_NODE(any_op2);
};
} // namespace patterns
class Graph;
class DeleteCIdentityOpPass : public FusePassBase {
public:
DeleteCIdentityOpPass();
virtual ~DeleteCIdentityOpPass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
class Node;
} // namespace ir
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct PrelnResidualBias : public PatternBase {
PrelnResidualBias(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "preln_residual_bias") {}
void operator()(PDNode *x, PDNode *y);
// declare operator node's name
PATTERN_DECL_NODE(elementwise_bias);
PATTERN_DECL_NODE(elementwise0);
PATTERN_DECL_NODE(elementwise1);
PATTERN_DECL_NODE(layer_norm);
// declare variable node's name
PATTERN_DECL_NODE(elementwise0_out);
PATTERN_DECL_NODE(elementwise1_out);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
};
void PrelnResidualBias::operator()(PDNode *x, PDNode *y) {
// Create nodes for elementwise add op.
x->assert_is_op_input("elementwise_add");
y->assert_is_op_input("elementwise_add", "X");
auto *elementwise0 =
pattern->NewNode(elementwise0_repr())->assert_is_op("elementwise_add");
auto *elementwise_bias_var = pattern->NewNode(elementwise_bias_repr())
->assert_is_op_input("elementwise_add", "Y");
auto *elementwise0_out_var = pattern->NewNode(elementwise0_out_repr())
->assert_is_op_output("elementwise_add")
->assert_is_op_input("elementwise_add")
->assert_more([](Node *x) {
if (x->outputs.size() == 1) {
return true;
} else {
return false;
}
});
auto *elementwise1 =
pattern->NewNode(elementwise1_repr())->assert_is_op("elementwise_add");
auto *elementwise1_out_var = pattern->NewNode(elementwise1_out_repr())
->assert_is_op_output("elementwise_add")
->assert_is_op_input("layer_norm", "X");
// Add links for elementwise_add op.
elementwise0->LinksFrom({y, elementwise_bias_var})
.LinksTo({elementwise0_out_var});
elementwise1->LinksFrom({x, elementwise0_out_var})
.LinksTo({elementwise1_out_var});
// Create nodes for layer_norm op.
auto *layer_norm =
pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm");
auto *layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto *layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto *layer_norm_out_var = pattern->NewNode(layer_norm_out_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Y");
auto *layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Mean");
auto *layer_norm_variance_var =
pattern->NewNode(layer_norm_variance_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Variance");
// Add links for layer_norm op.
layer_norm
->LinksFrom(
{elementwise1_out_var, layer_norm_bias_var, layer_norm_scale_var})
.LinksTo(
{layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var});
}
} // namespace patterns
void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
FusePassBase::Init("preln_residual_bias_fuse", graph);
int found_subgraph_count = 0;
GraphPatternDetector gpd;
auto *x = gpd.mutable_pattern()
->NewNode("preln_residual_bias_fuse/x")
->AsInput()
->assert_is_op_input("elementwise_add")
->assert_var_not_persistable();
auto *y = gpd.mutable_pattern()
->NewNode("preln_residual_bias_fuse/y")
->AsInput()
->assert_is_op_input("elementwise_add", "X")
->assert_var_not_persistable();
patterns::PrelnResidualBias fused_pattern(gpd.mutable_pattern(),
"preln_residual_bias_fuse");
fused_pattern(x, y);
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *graph) {
if (subgraph.count(x) <= 0 || subgraph.count(y) <= 0) {
LOG(WARNING) << "The subgraph is empty.";
return;
}
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "preln_residual_bias pass in op compat failed.";
return;
}
VLOG(4) << "handle PrelnResidualBias fuse";
GET_IR_NODE_FROM_SUBGRAPH(elementwise_bias, elementwise_bias,
fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise0, elementwise0, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise0_out, elementwise0_out,
fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise1, elementwise1, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise1_out, elementwise1_out,
fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, layer_norm_scale,
fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance,
fused_pattern);
std::unordered_set<const Node *> del_node_set;
// Create an PrelnResidualBias op node
OpDesc new_desc;
new_desc.SetType("preln_residual_bias");
// inputs
new_desc.SetInput("X", {subgraph.at(x)->Name()});
new_desc.SetInput("Y", {subgraph.at(y)->Name()});
new_desc.SetInput("Scale", {layer_norm_scale->Name()});
new_desc.SetInput("Bias", {layer_norm_bias->Name()});
new_desc.SetInput("EleBias", {elementwise_bias->Name()});
// outputs
new_desc.SetOutput("Out_0", {layer_norm_out->Name()});
new_desc.SetOutput("Out_1", {elementwise1_out->Name()});
// attrs
new_desc.SetAttr("epsilon", layer_norm->Op()->GetAttr("epsilon"));
new_desc.SetAttr("begin_norm_axis",
layer_norm->Op()->GetAttr("begin_norm_axis"));
auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied.
del_node_set.insert(elementwise0);
del_node_set.insert(elementwise1);
del_node_set.insert(elementwise0_out);
del_node_set.insert(layer_norm);
del_node_set.insert(layer_norm_mean);
del_node_set.insert(layer_norm_variance);
GraphSafeRemoveNodes(graph, del_node_set);
IR_NODE_LINK_TO(subgraph.at(x), fused_node);
IR_NODE_LINK_TO(subgraph.at(y), fused_node);
IR_NODE_LINK_TO(elementwise_bias, fused_node);
IR_NODE_LINK_TO(layer_norm_scale, fused_node);
IR_NODE_LINK_TO(layer_norm_bias, fused_node);
IR_NODE_LINK_TO(fused_node, layer_norm_out);
IR_NODE_LINK_TO(fused_node, elementwise1_out);
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(preln_residual_bias_fuse_pass,
paddle::framework::ir::PrelnResidualBiasFusePass);
REGISTER_PASS_CAPABILITY(preln_residual_bias_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("elementwise_add", 1)
.EQ("layer_norm", 0));
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
// other_op2
// | | | |
// other_op1 elementwise_add other_op1 other_op2
// | | fuse \ /
// |------elementwise_add -> preln_residual_bias
// | | | |
// other_op4 layer_norm other_op4 other_op3
// |
// other_op3
class Graph;
class PrelnResidualBiasFusePass : public FusePassBase {
public:
PrelnResidualBiasFusePass() {
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({0, -1})
.End();
AddOpCompat(OpCompat("layer_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsTensor()
.End()
.AddOutput("Variance")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(0.001f)
.End()
.AddAttr("begin_norm_axis")
.IsNumGT(0)
.End();
}
virtual ~PrelnResidualBiasFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -1897,6 +1897,7 @@ USE_TRT_CONVERTER(elementwise_max_tensor);
USE_TRT_CONVERTER(elementwise_min_tensor);
USE_TRT_CONVERTER(elementwise_pow_tensor);
USE_TRT_CONVERTER(transpose);
USE_TRT_CONVERTER(transpose2);
USE_TRT_CONVERTER(flatten);
USE_TRT_CONVERTER(flatten_contiguous_range);
USE_TRT_CONVERTER(matmul);
......@@ -1945,6 +1946,7 @@ USE_TRT_CONVERTER(nearest_interp);
USE_TRT_CONVERTER(nearest_interp_v2);
USE_TRT_CONVERTER(bilinear_interp_v2);
USE_TRT_CONVERTER(reshape);
USE_TRT_CONVERTER(reshape2);
USE_TRT_CONVERTER(reduce_sum);
USE_TRT_CONVERTER(gather_nd);
USE_TRT_CONVERTER(reduce_mean);
......@@ -1956,6 +1958,8 @@ USE_TRT_CONVERTER(deformable_conv);
USE_TRT_CONVERTER(pool3d)
USE_TRT_CONVERTER(fused_preln_embedding_eltwise_layernorm)
USE_TRT_CONVERTER(preln_skip_layernorm)
USE_TRT_CONVERTER(preln_residual_bias)
USE_TRT_CONVERTER(c_allreduce_sum)
USE_TRT_CONVERTER(roll)
USE_TRT_CONVERTER(strided_slice)
USE_TRT_CONVERTER(transformer_input_convert)
......
......@@ -97,10 +97,12 @@ const std::vector<std::string> kTRTSubgraphPasses({
"simplify_with_basic_ops_pass", //
"trt_embedding_eltwise_layernorm_fuse_pass", //
"preln_embedding_eltwise_layernorm_fuse_pass", //
"delete_c_identity_op_pass", //
"trt_multihead_matmul_fuse_pass_v2", //
"trt_multihead_matmul_fuse_pass_v3", //
"trt_skip_layernorm_fuse_pass", //
"preln_skip_layernorm_fuse_pass", //
"preln_residual_bias_fuse_pass", //
// "set_transformer_input_convert_pass", //
"conv_bn_fuse_pass", //
"unsqueeze2_eltwise_fuse_pass", //
......
......@@ -62,6 +62,8 @@ list(
transformer_input_convert_op.cc
remove_padding_op.cc
recover_padding_op.cc
preln_residual_bias.cc
c_allreduce_op.cc
top_k_op.cc
squeeze2_op.cc
unsqueeze2_op.cc)
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
using ReduceType = paddle::inference::tensorrt::plugin::ReduceType;
std::map<std::string, ReduceType> op_to_reduce_type = {
{"c_allreduce_sum", paddle::inference::tensorrt::plugin::kRedSum},
{"c_allreduce_max", paddle::inference::tensorrt::plugin::kRedMax},
{"c_allreduce_min", paddle::inference::tensorrt::plugin::kRedMin},
{"c_allreduce_prod", paddle::inference::tensorrt::plugin::kRedProd}};
class CAllReduceOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(4) << "convert fluid callreduce op to tensorrt layer";
if (!engine_->with_dynamic_shape()) {
PADDLE_THROW(platform::errors::Fatal(
"Unsupported static mode. Please set dynamic shape of inputs."));
}
ReduceType red_type = op_to_reduce_type[op.type()];
std::string name = op.type();
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
int input_num = op_desc.Input("X").size();
PADDLE_ENFORCE_EQ(
input_num, 1,
platform::errors::InvalidArgument(
"The input X's size must equal to 1 in TRT c_allreduce op."
" But received X's size %d.",
input_num));
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
// Get output
size_t output_num = op_desc.Output("Out").size();
PADDLE_ENFORCE_EQ(
output_num, 1UL,
platform::errors::InvalidArgument(
"The ouput Out's size must equal to 1 in TRT c_allreduce op. "
"But received Out's size %u.",
output_num));
// Get attrs
int ring_id = BOOST_GET_CONST(int, op_desc.GetAttr("ring_id"));
bool use_calc_stream =
BOOST_GET_CONST(bool, op_desc.GetAttr("use_calc_stream"));
nvinfer1::ILayer* layer = nullptr;
#if IS_TRT_VERSION_GE(6000)
bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
if (engine_->precision() == AnalysisConfig::Precision::kInt8) {
with_fp16 = true;
}
plugin::CAllReducePluginDynamic* plugin =
new plugin::CAllReducePluginDynamic(ring_id, use_calc_stream, red_type,
with_fp16);
layer = engine_->AddDynamicPlugin(&input, input_num, plugin);
#else
PADDLE_THROW(platform::errors::Fatal(
"You are running the TRT Dynamic Shape mode, need to confirm that "
"your TRT version is no less than 6.0"));
#endif
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, name, {output_name}, test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(c_allreduce_sum, CAllReduceOpConverter);
REGISTER_TRT_OP_CONVERTER(c_allreduce_max, CAllReduceOpConverter);
REGISTER_TRT_OP_CONVERTER(c_allreduce_min, CAllReduceOpConverter);
REGISTER_TRT_OP_CONVERTER(c_allreduce_prod, CAllReduceOpConverter);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
using half = paddle::platform::float16;
class PrelnResidualBiasOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(4) << "convert fused preln_residual_bias op to tensorrt layer";
if (!engine_->with_dynamic_shape()) {
PADDLE_THROW(platform::errors::Fatal(
"Unsupported static mode. Please set dynamic shape of inputs."));
}
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]);
std::vector<nvinfer1::ITensor*> inputs;
inputs.push_back(input1);
inputs.push_back(input2);
auto get_persistable_data = [&](const std::string& arg_name,
framework::DDim* dims) -> float* {
std::string var_name = op_desc.Input(arg_name).front();
auto* temp_var = scope.FindVar(var_name);
auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>();
(*dims) = temp_tensor->dims();
auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor);
return temp_data;
};
framework::DDim bias_dims, scale_dims, ele_bias_dims;
auto* bias = get_persistable_data("Bias", &bias_dims);
auto* scale = get_persistable_data("Scale", &scale_dims);
auto* ele_bias = get_persistable_data("EleBias", &ele_bias_dims);
int bias_size = phi::product(bias_dims);
int scale_size = phi::product(scale_dims);
int ele_bias_size = phi::product(ele_bias_dims);
float epsilon = BOOST_GET_CONST(float, op_desc.GetAttr("epsilon"));
bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
if (engine_->precision() == AnalysisConfig::Precision::kInt8) {
with_fp16 = true;
}
nvinfer1::ILayer* layer = nullptr;
plugin::DynamicPluginTensorRT* plugin = nullptr;
if (with_fp16) {
auto half_ele_bias_data = new half[bias_size];
for (int i = 0; i < bias_size; i++) {
half_ele_bias_data[i] = static_cast<half>(ele_bias[i]);
}
plugin = new plugin::PrelnResidualBiasPluginDynamic(
bias, scale, half_ele_bias_data, bias_size, scale_size, ele_bias_size,
epsilon, with_fp16);
} else {
plugin = new plugin::PrelnResidualBiasPluginDynamic(
bias, scale, ele_bias, bias_size, scale_size, ele_bias_size, epsilon,
with_fp16);
}
std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.emplace_back(input1);
plugin_inputs.emplace_back(input2);
layer = engine_->AddDynamicPlugin(plugin_inputs.data(), 2, plugin);
std::vector<std::string> output_names;
output_names.push_back(op_desc.Output("Out_0")[0]);
output_names.push_back(op_desc.Output("Out_1")[0]);
RreplenishLayerAndOutput(layer, "preln_residual_bias", output_names,
test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(preln_residual_bias, PrelnResidualBiasOpConverter);
......@@ -61,3 +61,4 @@ class ReshapeOpConverter : public OpConverter {
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(reshape, ReshapeOpConverter);
REGISTER_TRT_OP_CONVERTER(reshape2, ReshapeOpConverter);
......@@ -60,3 +60,4 @@ class TransposeOpConverter : public OpConverter {
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(transpose, TransposeOpConverter);
REGISTER_TRT_OP_CONVERTER(transpose2, TransposeOpConverter);
......@@ -156,6 +156,11 @@ struct SimpleOpTypeSetTeller : public Teller {
"slice",
"strided_slice",
"fused_preln_embedding_eltwise_layernorm",
"preln_residual_bias",
"c_allreduce_sum",
"c_allreduce_min",
"c_allreduce_max",
"c_allreduce_prod",
"roll",
"preln_skip_layernorm",
"transformer_input_convert",
......@@ -254,6 +259,11 @@ struct SimpleOpTypeSetTeller : public Teller {
"strided_slice",
"fused_preln_embedding_eltwise_layernorm",
"preln_skip_layernorm",
"preln_residual_bias",
"c_allreduce_sum",
"c_allreduce_min",
"c_allreduce_max",
"c_allreduce_prod",
"roll",
"multiclass_nms3",
"transformer_input_convert",
......@@ -1994,9 +2004,7 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return false;
}
OpTeller::OpTeller() { tellers_.emplace_back(new SimpleOpTypeSetTeller); }
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -27,7 +27,9 @@ list(
matmul_op_int8_plugin.cu
transformer_input_convert_plugin.cu
remove_padding_plugin.cu
recover_padding_plugin.cu)
recover_padding_plugin.cu
c_allreduce_op_plugin.cu
preln_residual_bias_plugin.cu)
if(CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8)
list(APPEND TRT_FILES spmm_plugin.cu)
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cstring>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.h"
#include "paddle/fluid/platform/collective_helper.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
#if defined(PADDLE_WITH_NCCL)
inline ncclDataType_t NvInferDtypeToNCCLDType(nvinfer1::DataType type) {
if (type == nvinfer1::DataType::kFLOAT) {
return ncclFloat;
} else if (type == nvinfer1::DataType::kHALF) {
return ncclFloat16;
} else if (type == nvinfer1::DataType::kINT8) {
return ncclInt8;
} else if (type == nvinfer1::DataType::kINT32) {
return ncclInt32;
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"This datatype in nccl is not supported."));
}
}
#endif
CAllReducePluginDynamic::CAllReducePluginDynamic(void const* serialData,
size_t serialLength) {
DeserializeValue(&serialData, &serialLength, &ring_id_);
DeserializeValue(&serialData, &serialLength, &use_calc_stream_);
DeserializeValue(&serialData, &serialLength, &red_type_);
DeserializeValue(&serialData, &serialLength, &with_fp16_);
}
nvinfer1::IPluginV2DynamicExt* CAllReducePluginDynamic::clone() const
TRT_NOEXCEPT {
return new CAllReducePluginDynamic(ring_id_, use_calc_stream_, red_type_,
with_fp16_);
}
const char* CAllReducePluginDynamic::getPluginType() const TRT_NOEXCEPT {
return "c_allreduce_plugin_dynamic";
}
int CAllReducePluginDynamic::getNbOutputs() const TRT_NOEXCEPT { return 1; }
int CAllReducePluginDynamic::initialize() TRT_NOEXCEPT { return 0; };
size_t CAllReducePluginDynamic::getSerializationSize() const TRT_NOEXCEPT {
return SerializedSize(ring_id_) + SerializedSize(use_calc_stream_) +
SerializedSize(red_type_);
+SerializedSize(with_fp16_);
}
void CAllReducePluginDynamic::serialize(void* buffer) const TRT_NOEXCEPT {
SerializeValue(&buffer, ring_id_);
SerializeValue(&buffer, use_calc_stream_);
SerializeValue(&buffer, red_type_);
SerializeValue(&buffer, with_fp16_);
}
nvinfer1::DimsExprs CAllReducePluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT {
return inputs[0];
}
bool CAllReducePluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc* in_out, int nb_inputs,
int nb_outputs) TRT_NOEXCEPT {
PADDLE_ENFORCE_NOT_NULL(
in_out, platform::errors::InvalidArgument(
"The input of CAllReduce plugin shoule not be nullptr."));
PADDLE_ENFORCE_LT(
pos, nb_inputs + nb_outputs,
platform::errors::InvalidArgument("The pos(%d) should be less than the "
"num(%d) of the input and the output.",
pos, nb_inputs + nb_outputs));
const nvinfer1::PluginTensorDesc& in = in_out[pos];
if (pos == 0 || pos == 1) {
if (with_fp16_) {
return (in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
} else {
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
}
}
}
void CAllReducePluginDynamic::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) TRT_NOEXCEPT {}
size_t CAllReducePluginDynamic::getWorkspaceSize(
const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const TRT_NOEXCEPT {
return 0;
}
void CAllReducePluginDynamic::destroy() TRT_NOEXCEPT { delete this; }
nvinfer1::DataType CAllReducePluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType* input_types,
int nb_inputs) const TRT_NOEXCEPT {
PADDLE_ENFORCE_EQ(index, 0,
platform::errors::InvalidArgument(
"The CAllReduce Plugin only has one input, so the "
"index value should be 0, but get %d.",
index));
return input_types[0];
}
int CAllReducePluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc, const void* const* inputs,
void* const* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT {
#if defined(PADDLE_WITH_NCCL)
auto input_dims = input_desc[0].dims;
size_t numel = ProductDim(input_dims);
auto input_type = input_desc[0].type;
void* sendbuff = const_cast<void*>(inputs[0]);
void* recvbuff = outputs[0];
ncclDataType_t dtype = NvInferDtypeToNCCLDType(input_type);
ncclRedOp_t nccl_red_type = ncclSum;
switch (red_type_) {
case kRedSum:
nccl_red_type = ncclSum;
break;
case kRedMax:
nccl_red_type = ncclMax;
break;
case kRedMin:
nccl_red_type = ncclMin;
break;
case kRedProd:
nccl_red_type = ncclProd;
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument("Invalid reduce type: %d",
red_type_));
}
auto comm = platform::NCCLCommContext::Instance().Get(ring_id_);
cudaStream_t custream = use_calc_stream_ ? stream : comm->stream();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, numel, dtype, nccl_red_type, comm->comm(), stream));
#endif
return (cudaGetLastError() != cudaSuccess);
}
const char* CAllReducePluginDynamicCreator::getPluginName() const TRT_NOEXCEPT {
return "c_allreduce_plugin_dynamic";
}
const char* CAllReducePluginDynamicCreator::getPluginVersion() const
TRT_NOEXCEPT {
return "1";
}
nvinfer1::IPluginV2* CAllReducePluginDynamicCreator::deserializePlugin(
const char* name, const void* serial_data,
size_t serial_length) TRT_NOEXCEPT {
auto plugin = new CAllReducePluginDynamic(serial_data, serial_length);
return plugin;
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdio.h>
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
enum ReduceType { kRedSum, kRedMax, kRedMin, kRedProd };
class CAllReducePluginDynamic : public DynamicPluginTensorRT {
private:
int ring_id_;
bool use_calc_stream_;
ReduceType red_type_;
public:
explicit CAllReducePluginDynamic(const int ring_id,
const bool use_calc_stream,
const ReduceType red_type,
const bool with_fp16) {
ring_id_ = ring_id;
use_calc_stream_ = use_calc_stream;
red_type_ = red_type;
with_fp16_ = with_fp16;
}
CAllReducePluginDynamic(void const* serialData, size_t serialLength);
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override;
const char* getPluginType() const TRT_NOEXCEPT override;
int getNbOutputs() const TRT_NOEXCEPT override;
int initialize() TRT_NOEXCEPT override;
size_t getSerializationSize() const TRT_NOEXCEPT override;
void serialize(void* buffer) const TRT_NOEXCEPT override;
nvinfer1::DimsExprs getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) TRT_NOEXCEPT override;
nvinfer1::DataType getOutputDataType(
int index, const nvinfer1::DataType* inputTypes,
int nbInputs) const TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override;
};
class CAllReducePluginDynamicCreator : public TensorRTPluginCreator {
public:
const char* getPluginName() const TRT_NOEXCEPT override;
const char* getPluginVersion() const TRT_NOEXCEPT override;
nvinfer1::IPluginV2* deserializePlugin(
const char* name, const void* serial_data,
size_t serial_length) TRT_NOEXCEPT override;
};
REGISTER_TRT_PLUGIN_V2(CAllReducePluginDynamicCreator);
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda_runtime.h>
#include <stdio.h>
#include <cassert>
#include <cub/cub.cuh> // NOLINT
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.h"
#include "paddle/fluid/operators/fused/fused_dropout_common.h"
#include "paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
using half = phi::dtype::float16;
#if IS_TRT_VERSION_GE(6000)
int PrelnResidualBiasPluginDynamic::initialize() TRT_NOEXCEPT {
cudaMalloc(&bias_gpu_, sizeof(float) * bias_size_);
cudaMemcpy(bias_gpu_, bias_.data(), bias_size_ * sizeof(float),
cudaMemcpyHostToDevice);
cudaMalloc(&scale_gpu_, sizeof(float) * scale_size_);
cudaMemcpy(scale_gpu_, scale_.data(), scale_size_ * sizeof(float),
cudaMemcpyHostToDevice);
if (with_fp16_) {
cudaMalloc(&ele_bias_gpu_, sizeof(half) * ele_bias_size_);
cudaMemcpy(ele_bias_gpu_, fp16_ele_bias_.data(),
ele_bias_size_ * sizeof(half), cudaMemcpyHostToDevice);
} else {
cudaMalloc(&ele_bias_gpu_, sizeof(float) * ele_bias_size_);
cudaMemcpy(ele_bias_gpu_, fp32_ele_bias_.data(),
ele_bias_size_ * sizeof(float), cudaMemcpyHostToDevice);
}
return 0;
}
void PrelnResidualBiasPluginDynamic::terminate() TRT_NOEXCEPT {
if (bias_gpu_) {
cudaFree(bias_gpu_);
bias_gpu_ = nullptr;
}
if (scale_gpu_) {
cudaFree(scale_gpu_);
scale_gpu_ = nullptr;
}
if (ele_bias_gpu_) {
cudaFree(ele_bias_gpu_);
ele_bias_gpu_ = nullptr;
}
}
nvinfer1::IPluginV2DynamicExt *PrelnResidualBiasPluginDynamic::clone() const
TRT_NOEXCEPT {
PrelnResidualBiasPluginDynamic *ptr = nullptr;
if (with_fp16_) {
ptr = new PrelnResidualBiasPluginDynamic(
bias_.data(), scale_.data(), fp16_ele_bias_.data(), bias_size_,
scale_size_, ele_bias_size_, eps_, with_fp16_);
} else {
ptr = new PrelnResidualBiasPluginDynamic(
bias_.data(), scale_.data(), fp32_ele_bias_.data(), bias_size_,
scale_size_, ele_bias_size_, eps_, with_fp16_);
}
ptr->bias_gpu_ = bias_gpu_;
ptr->scale_gpu_ = scale_gpu_;
ptr->ele_bias_gpu_ = ele_bias_gpu_;
return ptr;
}
const char *PrelnResidualBiasPluginDynamic::getPluginType() const TRT_NOEXCEPT {
return "preln_residual_bias_plugin_dynamic";
}
int PrelnResidualBiasPluginDynamic::getNbOutputs() const TRT_NOEXCEPT {
return 2;
}
size_t PrelnResidualBiasPluginDynamic::getSerializationSize() const
TRT_NOEXCEPT {
size_t ser_size = SerializedSize(bias_) + SerializedSize(scale_) +
SerializedSize(fp32_ele_bias_) +
SerializedSize(fp16_ele_bias_) +
SerializedSize(bias_size_) + SerializedSize(scale_size_) +
SerializedSize(ele_bias_size_) + SerializedSize(eps_) +
SerializedSize(with_fp16_);
return ser_size;
}
void PrelnResidualBiasPluginDynamic::serialize(void *buffer) const
TRT_NOEXCEPT {
SerializeValue(&buffer, bias_);
SerializeValue(&buffer, scale_);
SerializeValue(&buffer, fp32_ele_bias_);
SerializeValue(&buffer, fp16_ele_bias_);
SerializeValue(&buffer, bias_size_);
SerializeValue(&buffer, scale_size_);
SerializeValue(&buffer, ele_bias_size_);
SerializeValue(&buffer, eps_);
SerializeValue(&buffer, with_fp16_);
}
nvinfer1::DimsExprs PrelnResidualBiasPluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs,
nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT {
if (output_index < 2) {
return inputs[0];
} else { // moving mean and var
nvinfer1::DimsExprs ret;
ret.nbDims = 1;
ret.d[0] = inputs[0].d[2];
return ret;
}
}
bool PrelnResidualBiasPluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs,
int nb_outputs) TRT_NOEXCEPT {
PADDLE_ENFORCE_NOT_NULL(
in_out, platform::errors::InvalidArgument(
"The input of swish plugin shoule not be nullptr."));
PADDLE_ENFORCE_LT(
pos, nb_inputs + nb_outputs,
platform::errors::InvalidArgument("The pos(%d) should be less than the "
"num(%d) of the input and the output.",
pos, nb_inputs + nb_outputs));
const nvinfer1::PluginTensorDesc &in = in_out[pos];
if (pos == 0) {
if (with_fp16_) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE
return (in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
#else
PADDLE_THROW(
platform::errors::Fatal("TRT plugin supported FP16 is not available "
"while with_fp16 is set true."));
#endif
} else {
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
}
}
const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
if (pos == 1) {
return in.type == prev.type && in.format == prev.format;
}
// output
return in.type == prev.type && in.format == prev.format;
}
void PrelnResidualBiasPluginDynamic::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc *in, int nb_inputs,
const nvinfer1::DynamicPluginTensorDesc *out, int nb_outputs) TRT_NOEXCEPT {
}
size_t PrelnResidualBiasPluginDynamic::getWorkspaceSize(
const nvinfer1::PluginTensorDesc *inputs, int nb_inputs,
const nvinfer1::PluginTensorDesc *outputs,
int nb_outputs) const TRT_NOEXCEPT {
return 0;
}
nvinfer1::DataType PrelnResidualBiasPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType *input_types,
int nb_inputs) const TRT_NOEXCEPT {
return input_types[0];
}
void PrelnResidualBiasPluginDynamic::destroy() TRT_NOEXCEPT { delete this; }
int PrelnResidualBiasPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs,
void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT {
auto input_dims = input_desc[0].dims;
int hidden = input_dims.d[2];
const size_t rows = static_cast<size_t>(
input_dims.d[0] * input_dims.d[1]); // batch * seq_length
const size_t cols = static_cast<size_t>(input_dims.d[2]);
auto input_type = input_desc[0].type;
if (input_type == nvinfer1::DataType::kFLOAT) {
VLOG(1) << "TRT Plugin DataType selected. PrelnResidualBias-->fp32";
const float *input1 = static_cast<const float *>(inputs[0]);
const float *input2 = static_cast<const float *>(inputs[1]);
uint64_t seed = 0;
const float dropout_prob = 0.;
const bool is_upscale_in_train = false;
const bool is_test = true;
const uint64_t increment = 0;
const float epsilon = eps_;
const float *src = input2;
const float *residual = input1;
const float *bias = static_cast<float *>(ele_bias_gpu_);
const float *scale = scale_gpu_;
const float *layernorm_bias = bias_gpu_;
uint8_t *mask_data = nullptr;
float *dst = static_cast<float *>(outputs[1]);
float *layernorm_dst = static_cast<float *>(outputs[0]);
float *mean = nullptr;
float *var = nullptr;
const int VecSize = 8;
paddle::operators::FusedLayernormResidualDropoutBiasFunctor<
float, uint8_t, VecSize, float, false>()(
rows, cols, seed, dropout_prob, is_upscale_in_train, is_test, increment,
epsilon, src, residual, bias, scale, layernorm_bias, mask_data, dst,
layernorm_dst, mean, var, stream);
} else if (input_type == nvinfer1::DataType::kHALF) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE
VLOG(1) << "TRT Plugin DataType selected. PrelnResidualBias-->fp16";
const half *input1 = static_cast<const half *>(inputs[0]);
const half *input2 = static_cast<const half *>(inputs[1]);
uint64_t seed = 0;
const float dropout_prob = 0.;
const bool is_upscale_in_train = false;
const bool is_test = true;
const uint64_t increment = 0;
const float epsilon = eps_;
const half *src = input2;
const half *residual = input1;
const half *bias = static_cast<half *>(ele_bias_gpu_);
const float *scale = scale_gpu_;
const float *layernorm_bias = bias_gpu_;
uint8_t *mask_data = nullptr;
half *dst = static_cast<half *>(outputs[1]);
half *layernorm_dst = static_cast<half *>(outputs[0]);
float *mean = nullptr;
float *var = nullptr;
const int VecSize = 8;
paddle::operators::FusedLayernormResidualDropoutBiasFunctor<
half, uint8_t, VecSize, float, false>()(
rows, cols, seed, dropout_prob, is_upscale_in_train, is_test, increment,
epsilon, src, residual, bias, scale, layernorm_bias, mask_data, dst,
layernorm_dst, mean, var, stream);
#else
PADDLE_THROW(platform::errors::Fatal(
"The Ernie(Bert) tensorRT plugin should be "
"complied with CUDA version >= 10.0 when running with fp16. "
"Please recomplie it or try to use fp32 by set "
"config.SetTRTDynamicShapeInfo(min_input_shape, "
"max_input_shape, opt_input_shape, true"));
#endif
} else {
PADDLE_THROW(
platform::errors::Fatal("The PrelnResidualBias TRT Plugin's input type "
"should be float or half."));
}
return cudaGetLastError() != cudaSuccess;
}
const char *PrelnResidualBiasPluginDynamicCreator::getPluginName() const
TRT_NOEXCEPT {
return "preln_residual_bias_plugin_dynamic";
}
const char *PrelnResidualBiasPluginDynamicCreator::getPluginVersion() const
TRT_NOEXCEPT {
return "1";
}
nvinfer1::IPluginV2 *PrelnResidualBiasPluginDynamicCreator::deserializePlugin(
const char *name, const void *serial_data,
size_t serial_length) TRT_NOEXCEPT {
return new PrelnResidualBiasPluginDynamic(serial_data, serial_length);
}
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
using half = phi::dtype::float16;
#if IS_TRT_VERSION_GE(6000)
class PrelnResidualBiasPluginDynamic : public DynamicPluginTensorRT {
public:
explicit PrelnResidualBiasPluginDynamic(const float* bias, const float* scale,
const half* ele_bias, int bias_size,
int scale_size, int ele_bias_size,
const float eps, bool with_fp16)
: bias_size_(bias_size),
scale_size_(scale_size),
ele_bias_size_(ele_bias_size),
eps_(eps) {
with_fp16_ = with_fp16;
bias_.resize(bias_size);
scale_.resize(scale_size);
fp16_ele_bias_.resize(ele_bias_size);
std::copy(ele_bias, ele_bias + ele_bias_size, fp16_ele_bias_.data());
std::copy(bias, bias + bias_size, bias_.data());
std::copy(scale, scale + scale_size, scale_.data());
}
explicit PrelnResidualBiasPluginDynamic(const float* bias, const float* scale,
const float* ele_bias, int bias_size,
int scale_size, int ele_bias_size,
const float eps, bool with_fp16)
: bias_size_(bias_size),
scale_size_(scale_size),
ele_bias_size_(ele_bias_size),
eps_(eps) {
with_fp16_ = with_fp16;
bias_.resize(bias_size);
scale_.resize(scale_size);
fp32_ele_bias_.resize(ele_bias_size);
std::copy(ele_bias, ele_bias + ele_bias_size, fp32_ele_bias_.data());
std::copy(bias, bias + bias_size, bias_.data());
std::copy(scale, scale + scale_size, scale_.data());
}
PrelnResidualBiasPluginDynamic(void const* serial_data,
size_t serial_length) {
DeserializeValue(&serial_data, &serial_length, &bias_);
DeserializeValue(&serial_data, &serial_length, &scale_);
DeserializeValue(&serial_data, &serial_length, &fp32_ele_bias_);
DeserializeValue(&serial_data, &serial_length, &fp16_ele_bias_);
DeserializeValue(&serial_data, &serial_length, &bias_size_);
DeserializeValue(&serial_data, &serial_length, &scale_size_);
DeserializeValue(&serial_data, &serial_length, &ele_bias_size_);
DeserializeValue(&serial_data, &serial_length, &eps_);
DeserializeValue(&serial_data, &serial_length, &with_fp16_);
}
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override;
const char* getPluginType() const TRT_NOEXCEPT override;
int getNbOutputs() const TRT_NOEXCEPT override;
int initialize() TRT_NOEXCEPT override;
size_t getSerializationSize() const TRT_NOEXCEPT override;
void serialize(void* buffer) const TRT_NOEXCEPT override;
nvinfer1::DimsExprs getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* in_out,
int nb_inputs,
int nb_outputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nb_inputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nb_outputs) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nb_inputs,
const nvinfer1::PluginTensorDesc* outputs,
int nb_outputs) const TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc,
const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) TRT_NOEXCEPT override;
nvinfer1::DataType getOutputDataType(
int index, const nvinfer1::DataType* input_types,
int nb_inputs) const TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override;
void terminate() TRT_NOEXCEPT override;
private:
std::vector<float> bias_;
std::vector<float> scale_;
std::vector<float> fp32_ele_bias_;
std::vector<half> fp16_ele_bias_;
float* bias_gpu_{nullptr};
float* scale_gpu_{nullptr};
void* ele_bias_gpu_{nullptr};
int bias_size_;
int scale_size_;
int ele_bias_size_;
float eps_;
bool with_fp16_;
};
class PrelnResidualBiasPluginDynamicCreator : public TensorRTPluginCreator {
public:
const char* getPluginName() const TRT_NOEXCEPT override;
const char* getPluginVersion() const TRT_NOEXCEPT override;
nvinfer1::IPluginV2* deserializePlugin(
const char* name, const void* serial_data,
size_t serial_length) TRT_NOEXCEPT override;
};
REGISTER_TRT_PLUGIN_V2(PrelnResidualBiasPluginDynamicCreator);
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -155,6 +155,109 @@ __global__ void FusedLayernormResidualDropoutBias(
invvar);
}
/**
* @brief layernorm(residual + dropout(src + bias));
* @param
* rows: batch_size * seq_len
* cols: feature_size or hidden_size
* src: [rows, cols], inputs
* bias: [cols], linear bias, can be null
* residual:[rows, cols]
* mask: [rows, cols], dropout result
* dst: [rows, cols], residual + dropout(src+bias)
* layernorm_dst: [rows, cols], layernorm result
* layernorm_bias: [cols], layernorm bias, can be null
* scale: [cols]: layernorm scale, can be null
*/
template <typename T, typename MaskType, int VecSize, typename U,
bool ScaleBiasWithSameTypeX = false>
__global__ void FusedLayernormResidualDropoutBiasInfer(
const size_t rows, const size_t cols, uint64_t seed,
const float dropout_prob, const bool is_upscale_in_train,
const bool is_test, const uint64_t increment, const float epsilon,
const T *src, const T *residual, const T *bias,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *layernorm_bias,
MaskType *mask, T *dst, T *layernorm_dst) {
int col_id = threadIdx.x;
int row_id = blockIdx.x;
int idx = row_id * cols + col_id;
curandStatePhilox4_32_10_t state;
curand_init(seed, idx, increment, &state);
T factor = GetFactor<T>(dropout_prob, is_upscale_in_train, is_test);
__shared__ U mean_share;
__shared__ U var_share;
__shared__ U shared_mean[32];
__shared__ U shared_var[32];
phi::funcs::ReluFunctor<T> relu;
U mean_val = 0;
U var_val = 0;
for (int i = col_id * VecSize; i < cols; i += blockDim.x * VecSize) {
FusedResidualDropoutBiasOneThread<T, MaskType, VecSize, true, false,
phi::funcs::ReluFunctor<T>>(
row_id, i, cols, &state, dropout_prob, factor, src, residual, bias, dst,
mask, is_test, &mean_val, &var_val, relu);
}
mean_val = BlockReduceSum<U>(mean_val, shared_mean);
var_val = BlockReduceSum<U>(var_val, shared_var);
if (threadIdx.x == 0) {
auto scale = static_cast<LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>>(
static_cast<float>(1.) / static_cast<float>(cols));
auto tmp = mean_val * static_cast<U>(scale);
mean_share = static_cast<U>(tmp);
var_share = static_cast<U>(var_val * static_cast<U>(scale) -
mean_share * mean_share);
var_share = var_share > U(0) ? var_share : U(0);
}
__syncthreads();
mean_val = mean_share;
U invvar = rsqrt_<U>(var_share + static_cast<U>(epsilon));
// calculate layernorm_dst
CalcLayernormY<T, VecSize, U, ScaleBiasWithSameTypeX>(
scale, layernorm_bias, dst, layernorm_dst, row_id, col_id, cols, mean_val,
invvar);
}
template <typename T, typename MaskType, int VecSize, typename U,
bool ScaleBiasWithSameTypeX = false>
struct FusedLayernormResidualDropoutBiasFunctor {
void operator()(
const size_t rows, const size_t cols, uint64_t seed,
const float dropout_prob, const bool is_upscale_in_train,
const bool is_test, const uint64_t increment, const float epsilon,
const T *src, const T *residual, const T *bias,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *layernorm_bias,
MaskType *mask, T *dst, T *layernorm_dst, LayerNormParamType<T> *mean,
LayerNormParamType<T> *var, cudaStream_t stream) {
int blockDim = GetDesiredBlockDim(cols / VecSize);
if (mean != nullptr && var != nullptr) {
FusedLayernormResidualDropoutBias<T, MaskType, VecSize, U,
ScaleBiasWithSameTypeX>
<<<rows, blockDim, 0, stream>>>(
rows, cols, seed, dropout_prob, is_upscale_in_train, is_test,
increment, epsilon, src, residual, bias, scale, layernorm_bias,
mask, dst, layernorm_dst, mean, var);
} else {
FusedLayernormResidualDropoutBiasInfer<T, MaskType, VecSize, U,
ScaleBiasWithSameTypeX>
<<<rows, blockDim, 0, stream>>>(
rows, cols, seed, dropout_prob, is_upscale_in_train, is_test,
increment, epsilon, src, residual, bias, scale, layernorm_bias,
mask, dst, layernorm_dst);
}
}
};
template struct FusedLayernormResidualDropoutBiasFunctor<
paddle::platform::float16, uint8_t, 8, float, false>;
/*
* @brief layernorm(residual + dropout(x));
* Conditions:
......
......@@ -231,7 +231,7 @@ if(NOT WITH_DISTRIBUTE OR WIN32)
list(REMOVE_ITEM TEST_OPS test_fleet_rolemaker_2)
list(REMOVE_ITEM TEST_OPS test_fleet_utils)
list(REMOVE_ITEM TEST_OPS test_collective_cpu_barrier_with_gloo)
list(REMOVE_ITEM TEST_OPS test_delete_c_identity_op_pass)
# TODO: Fix these unittests failed on Windows
list(REMOVE_ITEM TEST_OPS test_fake_init_op)
endif()
......@@ -244,6 +244,7 @@ endif()
if(WIN32)
list(REMOVE_ITEM TEST_OPS test_complex_matmul)
list(REMOVE_ITEM TEST_OPS test_ops_nms)
list(REMOVE_ITEM TEST_OPS test_trt_convert_preln_residual_bias)
endif()
list(REMOVE_ITEM TEST_OPS test_fleet_checkpoint)
......
......@@ -16,6 +16,18 @@ file(
"test_trt_convert_*.py")
string(REPLACE ".py" "" TEST_TRT_CONVERTER "${TEST_TRT_CONVERTER}")
if(NOT WITH_DISTRIBUTE)
list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES "test_delete_c_identity_op_pass")
list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES
"test_trt_convert_preln_residual_bias")
list(REMOVE_ITEM TEST_TRT_IR_PASSES "test_trt_convert_preln_residual_bias")
list(REMOVE_ITEM TEST_TRT_CONVERTER "test_trt_convert_preln_residual_bias")
list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES "test_trt_convert_c_allreduce")
list(REMOVE_ITEM TEST_TRT_IR_PASSES "test_trt_convert_c_allreduce")
list(REMOVE_ITEM TEST_TRT_CONVERTER "test_trt_convert_c_allreduce")
endif()
# Only for cpu(mkl + openblas)
set(TEST_INFERENCE_CPU_UT "test_mul_lstm_fuse_pass" "test_mul_gru_fuse_pass")
......
......@@ -605,7 +605,7 @@ class TrtLayerAutoScanTest(AutoScanTest):
dic['use_trt'] = False
return str(dic)
def run_test(self, quant=False, *args, **kwargs):
def run_test(self, quant=False, skip_baseline=False, *args, **kwargs):
status = True
run_flags = []
for prog_config in self.sample_program_configs(*args, **kwargs):
......@@ -636,14 +636,14 @@ class TrtLayerAutoScanTest(AutoScanTest):
}
results: List[Dict[str, np.ndarray]] = []
# baseline: gpu run
logging.info('RUN program_config: ' + str(prog_config))
gpu_config = self.create_inference_config(use_trt=False)
results.append(
self.run_test_config(model, params, prog_config, gpu_config,
feed_data))
self.success_log('RUN_GPU_BASELINE done')
if not skip_baseline:
#baseline: gpu run
logging.info('RUN program_config: ' + str(prog_config))
gpu_config = self.create_inference_config(use_trt=False)
results.append(
self.run_test_config(model, params, prog_config, gpu_config,
feed_data))
self.success_log('RUN_GPU_BASELINE done')
for pred_config, nodes_num, threshold in self.sample_predictor_configs(
prog_config):
......
......@@ -226,6 +226,7 @@ def create_fake_model(program_config):
var_desc.set_type(core.VarDesc.VarType.LOD_TENSOR)
var_desc.set_dtype(convert_np_dtype_to_dtype_(tensor_config.dtype))
var_desc.set_shape(tensor_config.shape)
print(f"name: {name}; shape: {tensor_config.shape}")
var_desc.set_need_check_feed(True)
if tensor_config.lod is not None:
var_desc.set_lod_level(len(tensor_config.lod))
......@@ -323,6 +324,7 @@ def create_fake_model(program_config):
with fluid.scope_guard(scope):
executor.run(util_program)
params = scope.find_var("out_var_0").get_bytes()
return model, params
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from auto_scan_test import PassAutoScanTest
from program_config import TensorConfig, ProgramConfig, OpConfig
import paddle.inference as paddle_infer
import unittest
import hypothesis.strategies as st
class TestDeleteCIdentityPass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_trt_inference_config()
config.enable_tensorrt_engine(
max_batch_size=8,
workspace_size=0,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Float32,
use_static=False,
use_calib_mode=False)
yield config, ['relu'], (1e-5, 1e-5)
def sample_program_config(self, draw):
n = draw(st.integers(min_value=1, max_value=2))
relu_op = OpConfig("relu",
inputs={"X": ["relu_x"]},
outputs={"Out": ["relu_out"]})
c_identity_op = OpConfig("c_identity",
inputs={"X": ["relu_out"]},
outputs={"Out": ["id_out"]})
program_config = ProgramConfig(
ops=[relu_op, c_identity_op],
weights={},
inputs={"relu_x": TensorConfig(shape=[n])},
outputs=["id_out"])
return program_config
def test(self):
self.run_and_statis(max_examples=2,
min_success_num=2,
passes=["delete_c_identity_op_pass"])
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import numpy as np
import tempfile
import paddle
import paddle.distributed.fleet as fleet
from paddle.distributed import ReduceOp
from paddle.distributed import init_parallel_env
from paddle.inference import Config
from paddle.inference import create_predictor
from paddle.inference import PrecisionType
from paddle.fluid import core
def run(op_type, precision):
fleet.init(is_collective=True)
paddle.enable_static()
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
block = main_program.blocks[0]
with paddle.static.program_guard(main_program, startup_program):
data = paddle.static.data(name='data', shape=[3, 4], dtype='float32')
c_data = block.create_var(
shape=data.shape,
dtype=data.dtype,
type=data.type,
lod_level=data.lod_level,
persistable=False,
is_data=False,
initializer=paddle.nn.initializer.Constant(value=1.0))
block.append_op(type=op_type,
inputs={'X': data},
outputs={'Out': c_data},
attrs={
'ring_id': 0,
'use_calc_stream': True,
'use_model_parallel': True
})
out = paddle.static.nn.fc(
x=c_data,
size=1,
weight_attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=0.5)))
mean = paddle.mean(out)
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(startup_program)
nranks = 2
current_endpoint = "127.0.0.1:600" + str(fleet.worker_index())
trainer_endpoints = ["127.0.0.1:6000", "127.0.0.1:6001"]
dist_config = core.DistConfig()
dist_config.set_carrier_id("inference")
dist_config.set_endpoints(trainer_endpoints, current_endpoint)
dist_config.set_ranks(nranks, fleet.worker_index())
dist_config.enable_dist_model(True)
with tempfile.TemporaryDirectory(prefix="allreduce_") as tmpdir:
paddle.static.save_inference_model(os.path.join(tmpdir, "model"),
[data], [mean],
exe,
program=main_program)
config = Config(os.path.join(tmpdir, "model.pdmodel"),
os.path.join(tmpdir, "model.pdiparams"))
config.enable_memory_optim()
config.enable_use_gpu(1000, fleet.worker_index())
config.set_dist_config(dist_config)
config.enable_tensorrt_engine(
workspace_size=1 << 30,
max_batch_size=1,
min_subgraph_size=1,
precision_mode=PrecisionType.Half
if precision == "fp16" else PrecisionType.Int8,
use_static=False,
use_calib_mode=False)
config.set_trt_dynamic_shape_info({"data": [3, 4]}, {"data": [3, 4]},
{"data": [3, 4]})
predictor = create_predictor(config)
input_names = predictor.get_input_names()
input_tensor = predictor.get_input_handle("data")
input_tensor.reshape([3, 4])
input_tensor.copy_from_cpu(np.ones([3, 4]).astype(np.float32))
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
output_data = output_handle.copy_to_cpu() # numpy.ndarray类型
print(f"c_allreduce_out={output_data[0]}")
if __name__ == "__main__":
if len(sys.argv) < 2:
# This script just be called by test_trt_convert_c_allreduce.py
sys.exit(0)
op_type = sys.argv[1]
precision = sys.argv[2]
run(op_type, precision)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import subprocess
import sys
import pickle
import os
import unittest
import paddle
class TestDistTRT(unittest.TestCase):
def setUp(self):
self.init_case()
self.script = "test_trt_c_allreduce_infer_script.py"
def init_case(self):
self.op_type = "c_allreduce_sum"
self.target_value = 4.
self.precision = "fp16"
def test_run(self):
env = dict(os.environ)
env["CUDA_VISIBLE_DEVICES"] = "0,1"
cmd = f"python -u -m paddle.distributed.fleet.launch --gpus 0,1 {self.script} {self.op_type} {self.precision}"
cmd = cmd.split(" ")
local_proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env)
local_out, local_err = local_proc.communicate()
for line in local_out.decode("utf-8").split("\n"):
results = line.split("=")
if len(results) == 2 and results[0] == "c_allreduce_out":
self.assertEqual(float(results[1]), self.target_value)
class TestMin(TestDistTRT):
def init_case(self):
self.op_type = "c_allreduce_min"
self.target_value = 2.
self.precision = "int8"
#class TestMax(TestDistTRT):
#
# def init_case(self):
# self.op_type = "c_allreduce_max"
# self.target_value = 2.
# self.precision = "fp16"
#
#
#class TestProd(TestDistTRT):
#
# def init_case(self):
# self.op_type = "c_allreduce_prod"
# self.target_value = 2.
# self.precision = "fp16"
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons
from program_config import TensorConfig, ProgramConfig
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
import unittest
class TrtConvertSkipLayernormTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
inputs = program_config.inputs
weights = program_config.weights
outputs = program_config.outputs
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
#The input dimension should be less than or equal to the set axis.
if 'begin_norm_axis' in attrs[0] and attrs[0]['begin_norm_axis'] >= 0:
if len(inputs['inputX_data'].shape) <= attrs[0]['begin_norm_axis']:
return False
return True
def sample_program_configs(self):
def generate_input1(attrs: List[Dict[str, Any]], batch):
return np.ones([batch, 128, 768]).astype(np.float32)
def generate_input2(attrs: List[Dict[str, Any]], batch):
return np.ones([batch, 128, 768]).astype(np.float32)
def generate_weight1(attrs: List[Dict[str, Any]]):
return np.random.random([768]).astype(np.float32)
def generate_weight2(attrs: List[Dict[str, Any]]):
return np.random.random([768]).astype(np.float32)
for batch in [4]:
for epsilon in [1e-5]:
for begin_norm_axis in [2]:
for enable_int8 in [False, True]:
dics = [{
"epsilon": epsilon,
"begin_norm_axis": begin_norm_axis,
}, {}]
ops_config = [{
"op_type": "elementwise_add",
"op_inputs": {
"X": ["inputX_data"],
"Y": ["EleBias"]
},
"op_outputs": {
"Out": ["bias_out"]
},
"op_attrs": {
"axis": -1
}
}, {
"op_type": "elementwise_add",
"op_inputs": {
"X": ["bias_out"],
"Y": ["inputY_data"]
},
"op_outputs": {
"Out": ["ele_out"]
},
"op_attrs": {
"axis": -1
}
}, {
"op_type": "layer_norm",
"op_inputs": {
"X": ["ele_out"],
"Bias": ["Bias"],
"Scale": ["Scale"]
},
"op_outputs": {
"Y": ["layernorm_out"],
"Mean": ["Mean"],
"Variance": ["Variance"]
},
"op_attrs": dics[0]
}]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={
"Bias":
TensorConfig(
data_gen=partial(generate_weight1, dics)),
"Scale":
TensorConfig(
data_gen=partial(generate_weight2, dics)),
"EleBias":
TensorConfig(
data_gen=partial(generate_weight2, dics))
},
inputs={
"inputX_data":
TensorConfig(data_gen=partial(
generate_input1, dics, batch)),
"inputY_data":
TensorConfig(data_gen=partial(
generate_input2, dics, batch))
},
outputs=["ele_out", "layernorm_out"])
yield program_config
def sample_predictor_configs(
self, program_config) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {
"inputX_data": [4, 128, 768],
"inputY_data": [4, 128, 768],
"Bias": [768],
"Scale": [768]
}
self.dynamic_shape.max_input_shape = {
"inputX_data": [4, 128, 768],
"inputY_data": [4, 128, 768],
"Bias": [768],
"Scale": [768]
}
self.dynamic_shape.opt_input_shape = {
"inputX_data": [4, 128, 768],
"inputY_data": [4, 128, 768],
"Bias": [768],
"Scale": [768]
}
def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
return 1, 4
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
# just support dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True), 1e-2 # atol=1e-2 while rtol is 1e-8
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True), 1e-2 # atol=1e-2 while rtol is 1e-8
def add_skip_trt_case(self):
pass
def test(self):
self.add_skip_trt_case()
self.run_test()
if __name__ == "__main__":
unittest.main()
......@@ -190,15 +190,15 @@ class TrtConvertSkipLayernormTest(TrtLayerAutoScanTest):
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
# for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5
# # for static_shape
# clear_dynamic_shape()
# self.trt_param.precision = paddle_infer.PrecisionType.Float32
# yield self.create_inference_config(), generate_trt_nodes_num(
# attrs, False), 1e-5
# self.trt_param.precision = paddle_infer.PrecisionType.Half
# yield self.create_inference_config(), generate_trt_nodes_num(
# attrs, False), 1e-5
# for dynamic_shape
generate_dynamic_shape(attrs)
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from pass_test import PassTest
import paddle
class PrelnResidualBiasFusePassTest(PassTest):
def setUp(self):
paddle.enable_static()
with paddle.static.program_guard(self.main_program,
self.startup_program):
x = paddle.static.data(name="x",
shape=[128, 768],
dtype="float32",
lod_level=0)
bias = paddle.static.create_parameter(shape=[768], dtype='float32')
y = paddle.static.data(name="y",
shape=[128, 768],
dtype="float32",
lod_level=0)
x = x + bias
elementwise_out = x + y
out = paddle.static.nn.layer_norm(input=elementwise_out)
self.fetch_list = [out, elementwise_out]
self.pass_names = "preln_residual_bias_fuse_pass"
self.fused_op_type = "preln_residual_bias"
self.num_fused_ops = 1
# self.graph_attrs = {
# "embedding_eltwise_layernorm_fuse_pass_flag": True,
# "multihead_matmul_fuse_pass_flag": True
# }
def test_check_program(self):
use_gpu_set = [False]
if paddle.device.is_compiled_with_cuda():
use_gpu_set.append(True)
for use_gpu in use_gpu_set:
place = paddle.CUDAPlace(0) if use_gpu else paddle.CPUPlace()
opt_program = self._apply_ir_passes()
self.check_program(opt_program)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册