未验证 提交 d9a134c3 编写于 作者: W Wang Bojun 提交者: GitHub

prefix (#50381)

上级 0e92adce
......@@ -115,6 +115,24 @@ void PrelnResidualBias::operator()(PDNode *x, PDNode *y) {
} // namespace patterns
void setIntermediateOut(OpDesc *desc,
const std::string &out_name,
const std::string &scope_name) {
std::string new_name = scope_name + "/at." + out_name + ".new";
desc->SetOutput(out_name, {new_name});
}
void addIntermediateOut(Node *op_node,
const std::string &out_name,
const std::string &scope_name,
Graph *graph) {
std::string new_name = scope_name + "/at." + out_name + ".new";
VarDesc out_var(new_name);
out_var.SetPersistable(false);
auto *node_var = graph->CreateVarNode(&out_var);
IR_NODE_LINK_TO(op_node, node_var);
}
void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
......@@ -168,7 +186,7 @@ void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const {
// on each other, so we make below check to ensure only one
// PrelnResidualBias pattern is delalted with.
for (auto op : elementwise1_out->inputs) {
if (op->Name() == "preln_residual_bias") return;
if (op->Name() == "fused_bias_dropout_residual_layer_norm") return;
}
if (!IsCompat(subgraph, graph)) {
......@@ -179,27 +197,32 @@ void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const {
std::unordered_set<const Node *> del_node_set;
// Create an PrelnResidualBias op node
OpDesc new_desc;
new_desc.SetType("preln_residual_bias");
new_desc.SetType("fused_bias_dropout_residual_layer_norm");
// inputs
new_desc.SetInput("X", {subgraph.at(x)->Name()});
new_desc.SetInput("Y", {subgraph.at(y)->Name()});
new_desc.SetInput("Scale", {layer_norm_scale->Name()});
new_desc.SetInput("Bias", {layer_norm_bias->Name()});
new_desc.SetInput("EleBias", {elementwise_bias->Name()});
new_desc.SetInput("Residual", {subgraph.at(y)->Name()});
new_desc.SetInput("LnScale", {layer_norm_scale->Name()});
new_desc.SetInput("LnBias", {layer_norm_bias->Name()});
new_desc.SetInput("Bias", {elementwise_bias->Name()});
// outputs
new_desc.SetOutput("Out_0", {layer_norm_out->Name()});
new_desc.SetOutput("Out_1", {elementwise1_out->Name()});
new_desc.SetOutput("Y", {layer_norm_out->Name()});
new_desc.SetOutput("BiasDropoutResidualOut", {elementwise1_out->Name()});
new_desc.SetOutput("LnMean", {layer_norm_mean->Name()});
new_desc.SetOutput("LnVariance", {layer_norm_variance->Name()});
setIntermediateOut(&new_desc, "DropoutMaskOut", "preln_residual_bias_fuse");
// attrs
new_desc.SetAttr("epsilon", layer_norm->Op()->GetAttr("epsilon"));
new_desc.SetAttr("ln_epsilon", layer_norm->Op()->GetAttr("epsilon"));
new_desc.SetAttr("dropout_rate", 0.0f);
new_desc.SetAttr("is_test", true);
new_desc.SetAttr("begin_norm_axis",
layer_norm->Op()->GetAttr("begin_norm_axis"));
auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied.
addIntermediateOut(
fused_node, "DropoutMaskOut", "preln_residual_bias_fuse", graph);
del_node_set.insert(elementwise0);
del_node_set.insert(elementwise1);
del_node_set.insert(elementwise0_out);
del_node_set.insert(layer_norm);
del_node_set.insert(layer_norm_mean);
del_node_set.insert(layer_norm_variance);
GraphSafeRemoveNodes(graph, del_node_set);
IR_NODE_LINK_TO(subgraph.at(x), fused_node);
IR_NODE_LINK_TO(subgraph.at(y), fused_node);
......@@ -208,6 +231,9 @@ void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const {
IR_NODE_LINK_TO(layer_norm_bias, fused_node);
IR_NODE_LINK_TO(fused_node, layer_norm_out);
IR_NODE_LINK_TO(fused_node, elementwise1_out);
IR_NODE_LINK_TO(fused_node, layer_norm_mean);
IR_NODE_LINK_TO(fused_node, layer_norm_variance);
found_subgraph_count++;
};
......
......@@ -169,8 +169,18 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
// attrs
new_desc.SetAttr("epsilon", layer_norm->Op()->GetAttr("epsilon"));
new_desc.SetAttr("begin_norm_axis",
layer_norm->Op()->GetAttr("begin_norm_axis"));
if (layer_norm->Op()->HasAttr("begin_norm_axis")) {
int32_t begin_norm_axis = PADDLE_GET_CONST(
int32_t, layer_norm->Op()->GetAttr("begin_norm_axis"));
int32_t input_rank =
static_cast<int32_t>(elementwise_out->Var()->GetShape().size());
if ((begin_norm_axis != -1) && (begin_norm_axis != input_rank - 1)) {
LOG(WARNING) << "skip_layernorm pass only support "
"layer_norm'begin_norm_axis == input_rank - 1.";
return;
}
new_desc.SetAttr("begin_norm_axis", begin_norm_axis);
}
auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied.
......
......@@ -2250,7 +2250,7 @@ USE_TRT_CONVERTER(deformable_conv);
USE_TRT_CONVERTER(pool3d)
USE_TRT_CONVERTER(fused_preln_embedding_eltwise_layernorm)
USE_TRT_CONVERTER(preln_skip_layernorm)
USE_TRT_CONVERTER(preln_residual_bias)
USE_TRT_CONVERTER(fused_bias_dropout_residual_layer_norm)
USE_TRT_CONVERTER(c_allreduce_sum)
USE_TRT_CONVERTER(roll)
USE_TRT_CONVERTER(strided_slice)
......
......@@ -26,15 +26,12 @@ class PrelnResidualBiasOpConverter : public OpConverter {
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
VLOG(4) << "convert fused preln_residual_bias op to tensorrt layer";
if (!engine_->with_dynamic_shape()) {
PADDLE_THROW(platform::errors::Fatal(
"Unsupported static mode. Please set dynamic shape of inputs."));
}
VLOG(4) << "convert fused_bias_dropout_residual_layer_norm op with "
"drop_rate = 0 to preln_residual_bias tensorrt layer";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]);
auto* input2 = engine_->GetITensor(op_desc.Input("Residual")[0]);
std::vector<nvinfer1::ITensor*> inputs;
inputs.push_back(input1);
inputs.push_back(input2);
......@@ -49,15 +46,15 @@ class PrelnResidualBiasOpConverter : public OpConverter {
return temp_data;
};
framework::DDim bias_dims, scale_dims, ele_bias_dims;
auto* bias = get_persistable_data("Bias", &bias_dims);
auto* scale = get_persistable_data("Scale", &scale_dims);
auto* ele_bias = get_persistable_data("EleBias", &ele_bias_dims);
auto* bias = get_persistable_data("LnBias", &bias_dims);
auto* scale = get_persistable_data("LnScale", &scale_dims);
auto* ele_bias = get_persistable_data("Bias", &ele_bias_dims);
int bias_size = phi::product(bias_dims);
int scale_size = phi::product(scale_dims);
int ele_bias_size = phi::product(ele_bias_dims);
float epsilon = PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon"));
float epsilon = PADDLE_GET_CONST(float, op_desc.GetAttr("ln_epsilon"));
bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
if (engine_->precision() == AnalysisConfig::Precision::kInt8) {
with_fp16 = true;
......@@ -94,8 +91,8 @@ class PrelnResidualBiasOpConverter : public OpConverter {
plugin_inputs.emplace_back(input2);
layer = engine_->AddDynamicPlugin(plugin_inputs.data(), 2, plugin);
std::vector<std::string> output_names;
output_names.push_back(op_desc.Output("Out_0")[0]);
output_names.push_back(op_desc.Output("Out_1")[0]);
output_names.push_back(op_desc.Output("Y")[0]);
output_names.push_back(op_desc.Output("BiasDropoutResidualOut")[0]);
RreplenishLayerAndOutput(
layer, "preln_residual_bias", output_names, test_mode);
}
......@@ -105,4 +102,5 @@ class PrelnResidualBiasOpConverter : public OpConverter {
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(preln_residual_bias, PrelnResidualBiasOpConverter);
REGISTER_TRT_OP_CONVERTER(fused_bias_dropout_residual_layer_norm,
PrelnResidualBiasOpConverter);
......@@ -1316,7 +1316,21 @@ struct SimpleOpTypeSetTeller : public Teller {
return false;
}
}
if (op_type == "fused_bias_dropout_residual_layer_norm") {
if (!with_dynamic_shape) {
VLOG(3) << "fused_bias_dropout_residual_layer_norm should run on "
"dynamic shape mode.";
return false;
}
float dropout_rate =
PADDLE_GET_CONST(float, desc.GetAttr("dropout_rate"));
if (dropout_rate != 0.0f) {
VLOG(4) << "preln_residual_bias trt layer can not work with "
"fused_bias_dropout_residual_layer_norm op in which the "
"dropout_rate != 0, stop convert";
return false;
}
}
if (op_type == "fused_preln_embedding_eltwise_layernorm") {
if (!with_dynamic_shape) {
VLOG(3) << "fused_preln_embedding_eltwise_layernorm should run on "
......@@ -2223,7 +2237,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"slice",
"strided_slice",
"fused_preln_embedding_eltwise_layernorm",
"preln_residual_bias",
"fused_bias_dropout_residual_layer_norm",
"c_allreduce_sum",
"c_allreduce_min",
"c_allreduce_max",
......@@ -2337,7 +2351,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"strided_slice",
"fused_preln_embedding_eltwise_layernorm",
"preln_skip_layernorm",
"preln_residual_bias",
"fused_bias_dropout_residual_layer_norm",
"c_allreduce_sum",
"c_allreduce_min",
"c_allreduce_max",
......
......@@ -37,16 +37,17 @@ class FusedBiasDropoutResidualLnOp : public framework::OperatorWithKernel {
"Output",
"LnVariance",
"FusedBiasDropoutResidualLnOp");
OP_INOUT_CHECK(ctx->HasOutput("BiasDropoutResidualOut"),
"Output",
"BiasDropoutResidualOut",
"FusedBiasDropoutResidualLnOp");
OP_INOUT_CHECK(ctx->HasOutput("DropoutMaskOut"),
"Output",
"DropoutMaskOut",
"FusedBiasDropoutResidualLnOp");
OP_INOUT_CHECK(ctx->HasOutput("BiasDropoutResidualOut"),
"Output",
"BiasDropoutResidualOut",
"FusedBiasDropoutResidualLnOp");
OP_INOUT_CHECK(
ctx->HasOutput("Y"), "Output", "Y", "FusedBiasDropoutResidualLnOp");
auto x_dim = ctx->GetInputDim("X");
int left = 1;
for (int i = 0; i < x_dim.size() - 1; i++) {
......
......@@ -56,8 +56,12 @@ class FusedBiasDropoutResidualLnOpKernel : public framework::OpKernel<T> {
auto *ln_mean_data =
dev_ctx.Alloc<U>(ln_mean, ln_mean->numel() * sizeof(U));
auto *ln_var_data = dev_ctx.Alloc<U>(ln_var, ln_var->numel() * sizeof(U));
auto *dropout_mask_out_data = dev_ctx.Alloc<uint8_t>(
dropout_mask_out, dropout_mask_out->numel() * sizeof(uint8_t));
auto *dropout_mask_out_data =
(dropout_mask_out == nullptr)
? nullptr
: dev_ctx.Alloc<uint8_t>(
dropout_mask_out,
dropout_mask_out->numel() * sizeof(uint8_t));
auto *y_data = dev_ctx.Alloc<T>(y, y->numel() * sizeof(T));
const auto input_x_dims = input_x->dims();
......
......@@ -767,9 +767,10 @@ void LaunchLayernormResidualDropoutBias(
residual,
rows * cols * sizeof(T),
ctx.stream());
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
mask_data, 0, rows * cols * sizeof(MaskType), ctx.stream()));
if (mask_data != nullptr) {
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
mask_data, 0, rows * cols * sizeof(MaskType), ctx.stream()));
}
// call layernorm forward
switch (GetDesiredBlockDim(cols)) {
FIXED_BLOCK_DIM_CASE(
......
......@@ -18,11 +18,6 @@ string(REPLACE ".py" "" TEST_TRT_CONVERTER "${TEST_TRT_CONVERTER}")
if(NOT WITH_DISTRIBUTE)
list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES "test_delete_c_identity_op_pass")
list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES
"test_trt_convert_preln_residual_bias")
list(REMOVE_ITEM TEST_TRT_IR_PASSES "test_trt_convert_preln_residual_bias")
list(REMOVE_ITEM TEST_TRT_CONVERTER "test_trt_convert_preln_residual_bias")
list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES "test_trt_convert_c_allreduce")
list(REMOVE_ITEM TEST_TRT_IR_PASSES "test_trt_convert_c_allreduce")
list(REMOVE_ITEM TEST_TRT_CONVERTER "test_trt_convert_c_allreduce")
......
......@@ -22,7 +22,6 @@ import unittest
class TrtConvertSkipLayernormTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
inputs = program_config.inputs
weights = program_config.weights
......@@ -32,14 +31,13 @@ class TrtConvertSkipLayernormTest(TrtLayerAutoScanTest):
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
#The input dimension should be less than or equal to the set axis.
# The input dimension should be less than or equal to the set axis.
if 'begin_norm_axis' in attrs[0] and attrs[0]['begin_norm_axis'] >= 0:
if len(inputs['inputX_data'].shape) <= attrs[0]['begin_norm_axis']:
return False
return True
def sample_program_configs(self):
def generate_input1(attrs: List[Dict[str, Any]], batch):
return np.ones([batch, 128, 768]).astype(np.float32)
......@@ -56,96 +54,100 @@ class TrtConvertSkipLayernormTest(TrtLayerAutoScanTest):
for epsilon in [1e-5]:
for begin_norm_axis in [2]:
for enable_int8 in [False, True]:
dics = [{
"epsilon": epsilon,
"begin_norm_axis": begin_norm_axis,
}, {}]
ops_config = [{
"op_type": "elementwise_add",
"op_inputs": {
"X": ["inputX_data"],
"Y": ["EleBias"]
},
"op_outputs": {
"Out": ["bias_out"]
dics = [
{
"epsilon": epsilon,
"begin_norm_axis": begin_norm_axis,
},
"op_attrs": {
"axis": -1
}
}, {
"op_type": "elementwise_add",
"op_inputs": {
"X": ["bias_out"],
"Y": ["inputY_data"]
{},
]
ops_config = [
{
"op_type": "elementwise_add",
"op_inputs": {
"X": ["inputX_data"],
"Y": ["EleBias"],
},
"op_outputs": {"Out": ["bias_out"]},
"op_attrs": {"axis": -1},
},
"op_outputs": {
"Out": ["ele_out"]
{
"op_type": "elementwise_add",
"op_inputs": {
"X": ["bias_out"],
"Y": ["inputY_data"],
},
"op_outputs": {"Out": ["ele_out"]},
"op_attrs": {"axis": -1},
},
"op_attrs": {
"axis": -1
}
}, {
"op_type": "layer_norm",
"op_inputs": {
"X": ["ele_out"],
"Bias": ["Bias"],
"Scale": ["Scale"]
{
"op_type": "layer_norm",
"op_inputs": {
"X": ["ele_out"],
"Bias": ["Bias"],
"Scale": ["Scale"],
},
"op_outputs": {
"Y": ["layernorm_out"],
"Mean": ["Mean"],
"Variance": ["Variance"],
},
"op_attrs": dics[0],
},
"op_outputs": {
"Y": ["layernorm_out"],
"Mean": ["Mean"],
"Variance": ["Variance"]
},
"op_attrs": dics[0]
}]
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={
"Bias":
TensorConfig(
data_gen=partial(generate_weight1, dics)),
"Scale":
TensorConfig(
data_gen=partial(generate_weight2, dics)),
"EleBias":
TensorConfig(
data_gen=partial(generate_weight2, dics))
"Bias": TensorConfig(
data_gen=partial(generate_weight1, dics)
),
"Scale": TensorConfig(
data_gen=partial(generate_weight2, dics)
),
"EleBias": TensorConfig(
data_gen=partial(generate_weight2, dics)
),
},
inputs={
"inputX_data":
TensorConfig(data_gen=partial(
generate_input1, dics, batch)),
"inputY_data":
TensorConfig(data_gen=partial(
generate_input2, dics, batch))
"inputX_data": TensorConfig(
data_gen=partial(
generate_input1, dics, batch
)
),
"inputY_data": TensorConfig(
data_gen=partial(
generate_input2, dics, batch
)
),
},
outputs=["ele_out", "layernorm_out"])
outputs=["ele_out", "layernorm_out"],
)
yield program_config
def sample_predictor_configs(
self, program_config) -> (paddle_infer.Config, List[int], float):
self, program_config
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {
"inputX_data": [4, 128, 768],
"inputY_data": [4, 128, 768],
"Bias": [768],
"Scale": [768]
"Scale": [768],
}
self.dynamic_shape.max_input_shape = {
"inputX_data": [4, 128, 768],
"inputY_data": [4, 128, 768],
"Bias": [768],
"Scale": [768]
"Scale": [768],
}
self.dynamic_shape.opt_input_shape = {
"inputX_data": [4, 128, 768],
"inputY_data": [4, 128, 768],
"Bias": [768],
"Scale": [768]
"Scale": [768],
}
def clear_dynamic_shape():
......@@ -154,20 +156,35 @@ class TrtConvertSkipLayernormTest(TrtLayerAutoScanTest):
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
return 1, 4
if dynamic_shape:
return 1, 4
else:
return 0, 5
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
# for static_shape, fall back to fluid fused op
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False
), 1e-2 # atol=1e-2 while rtol is 1e-8
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False
), 1e-2 # atol=1e-2 while rtol is 1e-8
# just support dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True), 1e-2 # atol=1e-2 while rtol is 1e-8
attrs, True
), 1e-2 # atol=1e-2 while rtol is 1e-8
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True), 1e-2 # atol=1e-2 while rtol is 1e-8
attrs, True
), 1e-2 # atol=1e-2 while rtol is 1e-8
def add_skip_trt_case(self):
pass
......
......@@ -20,27 +20,25 @@ import paddle
class PrelnResidualBiasFusePassTest(PassTest):
def setUp(self):
paddle.enable_static()
with paddle.static.program_guard(self.main_program,
self.startup_program):
x = paddle.static.data(name="x",
shape=[128, 768],
dtype="float32",
lod_level=0)
with paddle.static.program_guard(
self.main_program, self.startup_program
):
x = paddle.static.data(
name="x", shape=[128, 768], dtype="float32", lod_level=0
)
bias = paddle.static.create_parameter(shape=[768], dtype='float32')
y = paddle.static.data(name="y",
shape=[128, 768],
dtype="float32",
lod_level=0)
y = paddle.static.data(
name="y", shape=[128, 768], dtype="float32", lod_level=0
)
x = x + bias
elementwise_out = x + y
out = paddle.static.nn.layer_norm(input=elementwise_out)
self.fetch_list = [out, elementwise_out]
self.pass_names = "preln_residual_bias_fuse_pass"
self.fused_op_type = "preln_residual_bias"
self.fused_op_type = "fused_bias_dropout_residual_layer_norm"
self.num_fused_ops = 1
# self.graph_attrs = {
# "embedding_eltwise_layernorm_fuse_pass_flag": True,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册