diff --git a/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc b/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc index 79d27948954278227f07ba044bf955426bf75862..4c0457ac55915aa4f32cd7d2cfdc2ff832b44ce7 100644 --- a/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc +++ b/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc @@ -115,6 +115,24 @@ void PrelnResidualBias::operator()(PDNode *x, PDNode *y) { } // namespace patterns +void setIntermediateOut(OpDesc *desc, + const std::string &out_name, + const std::string &scope_name) { + std::string new_name = scope_name + "/at." + out_name + ".new"; + desc->SetOutput(out_name, {new_name}); +} + +void addIntermediateOut(Node *op_node, + const std::string &out_name, + const std::string &scope_name, + Graph *graph) { + std::string new_name = scope_name + "/at." + out_name + ".new"; + VarDesc out_var(new_name); + out_var.SetPersistable(false); + auto *node_var = graph->CreateVarNode(&out_var); + IR_NODE_LINK_TO(op_node, node_var); +} + void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::PreconditionNotMet("graph should not be null.")); @@ -168,7 +186,7 @@ void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const { // on each other, so we make below check to ensure only one // PrelnResidualBias pattern is delalted with. for (auto op : elementwise1_out->inputs) { - if (op->Name() == "preln_residual_bias") return; + if (op->Name() == "fused_bias_dropout_residual_layer_norm") return; } if (!IsCompat(subgraph, graph)) { @@ -179,27 +197,32 @@ void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const { std::unordered_set del_node_set; // Create an PrelnResidualBias op node OpDesc new_desc; - new_desc.SetType("preln_residual_bias"); + new_desc.SetType("fused_bias_dropout_residual_layer_norm"); // inputs new_desc.SetInput("X", {subgraph.at(x)->Name()}); - new_desc.SetInput("Y", {subgraph.at(y)->Name()}); - new_desc.SetInput("Scale", {layer_norm_scale->Name()}); - new_desc.SetInput("Bias", {layer_norm_bias->Name()}); - new_desc.SetInput("EleBias", {elementwise_bias->Name()}); + new_desc.SetInput("Residual", {subgraph.at(y)->Name()}); + new_desc.SetInput("LnScale", {layer_norm_scale->Name()}); + new_desc.SetInput("LnBias", {layer_norm_bias->Name()}); + new_desc.SetInput("Bias", {elementwise_bias->Name()}); // outputs - new_desc.SetOutput("Out_0", {layer_norm_out->Name()}); - new_desc.SetOutput("Out_1", {elementwise1_out->Name()}); + new_desc.SetOutput("Y", {layer_norm_out->Name()}); + new_desc.SetOutput("BiasDropoutResidualOut", {elementwise1_out->Name()}); + new_desc.SetOutput("LnMean", {layer_norm_mean->Name()}); + new_desc.SetOutput("LnVariance", {layer_norm_variance->Name()}); + setIntermediateOut(&new_desc, "DropoutMaskOut", "preln_residual_bias_fuse"); // attrs - new_desc.SetAttr("epsilon", layer_norm->Op()->GetAttr("epsilon")); + new_desc.SetAttr("ln_epsilon", layer_norm->Op()->GetAttr("epsilon")); + new_desc.SetAttr("dropout_rate", 0.0f); + new_desc.SetAttr("is_test", true); new_desc.SetAttr("begin_norm_axis", layer_norm->Op()->GetAttr("begin_norm_axis")); auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied. + addIntermediateOut( + fused_node, "DropoutMaskOut", "preln_residual_bias_fuse", graph); del_node_set.insert(elementwise0); del_node_set.insert(elementwise1); del_node_set.insert(elementwise0_out); del_node_set.insert(layer_norm); - del_node_set.insert(layer_norm_mean); - del_node_set.insert(layer_norm_variance); GraphSafeRemoveNodes(graph, del_node_set); IR_NODE_LINK_TO(subgraph.at(x), fused_node); IR_NODE_LINK_TO(subgraph.at(y), fused_node); @@ -208,6 +231,9 @@ void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const { IR_NODE_LINK_TO(layer_norm_bias, fused_node); IR_NODE_LINK_TO(fused_node, layer_norm_out); IR_NODE_LINK_TO(fused_node, elementwise1_out); + IR_NODE_LINK_TO(fused_node, layer_norm_mean); + IR_NODE_LINK_TO(fused_node, layer_norm_variance); + found_subgraph_count++; }; diff --git a/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc index d33adab8b3ea7846eec22260d032001f5652e4f6..dd3d113f34fa33f8fd0f7370cb96c8186ca523a7 100644 --- a/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc @@ -169,8 +169,18 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { // attrs new_desc.SetAttr("epsilon", layer_norm->Op()->GetAttr("epsilon")); - new_desc.SetAttr("begin_norm_axis", - layer_norm->Op()->GetAttr("begin_norm_axis")); + if (layer_norm->Op()->HasAttr("begin_norm_axis")) { + int32_t begin_norm_axis = PADDLE_GET_CONST( + int32_t, layer_norm->Op()->GetAttr("begin_norm_axis")); + int32_t input_rank = + static_cast(elementwise_out->Var()->GetShape().size()); + if ((begin_norm_axis != -1) && (begin_norm_axis != input_rank - 1)) { + LOG(WARNING) << "skip_layernorm pass only support " + "layer_norm'begin_norm_axis == input_rank - 1."; + return; + } + new_desc.SetAttr("begin_norm_axis", begin_norm_axis); + } auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied. diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 7bd14ca05ecdde9e72c9f72042e0b5dacc806b80..6d2b8fa7eff62e334747b0f644b8af36578e2c7f 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2250,7 +2250,7 @@ USE_TRT_CONVERTER(deformable_conv); USE_TRT_CONVERTER(pool3d) USE_TRT_CONVERTER(fused_preln_embedding_eltwise_layernorm) USE_TRT_CONVERTER(preln_skip_layernorm) -USE_TRT_CONVERTER(preln_residual_bias) +USE_TRT_CONVERTER(fused_bias_dropout_residual_layer_norm) USE_TRT_CONVERTER(c_allreduce_sum) USE_TRT_CONVERTER(roll) USE_TRT_CONVERTER(strided_slice) diff --git a/paddle/fluid/inference/tensorrt/convert/preln_residual_bias.cc b/paddle/fluid/inference/tensorrt/convert/preln_residual_bias.cc index bdcb54cfe2e8e2b288402816376ced61855747a9..3c9a905f9f680dafd44e1fcda40ccbf73c8f7232 100644 --- a/paddle/fluid/inference/tensorrt/convert/preln_residual_bias.cc +++ b/paddle/fluid/inference/tensorrt/convert/preln_residual_bias.cc @@ -26,15 +26,12 @@ class PrelnResidualBiasOpConverter : public OpConverter { void operator()(const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode) override { - VLOG(4) << "convert fused preln_residual_bias op to tensorrt layer"; - if (!engine_->with_dynamic_shape()) { - PADDLE_THROW(platform::errors::Fatal( - "Unsupported static mode. Please set dynamic shape of inputs.")); - } + VLOG(4) << "convert fused_bias_dropout_residual_layer_norm op with " + "drop_rate = 0 to preln_residual_bias tensorrt layer"; framework::OpDesc op_desc(op, nullptr); // Declare inputs auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]); - auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]); + auto* input2 = engine_->GetITensor(op_desc.Input("Residual")[0]); std::vector inputs; inputs.push_back(input1); inputs.push_back(input2); @@ -49,15 +46,15 @@ class PrelnResidualBiasOpConverter : public OpConverter { return temp_data; }; framework::DDim bias_dims, scale_dims, ele_bias_dims; - auto* bias = get_persistable_data("Bias", &bias_dims); - auto* scale = get_persistable_data("Scale", &scale_dims); - auto* ele_bias = get_persistable_data("EleBias", &ele_bias_dims); + auto* bias = get_persistable_data("LnBias", &bias_dims); + auto* scale = get_persistable_data("LnScale", &scale_dims); + auto* ele_bias = get_persistable_data("Bias", &ele_bias_dims); int bias_size = phi::product(bias_dims); int scale_size = phi::product(scale_dims); int ele_bias_size = phi::product(ele_bias_dims); - float epsilon = PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon")); + float epsilon = PADDLE_GET_CONST(float, op_desc.GetAttr("ln_epsilon")); bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); if (engine_->precision() == AnalysisConfig::Precision::kInt8) { with_fp16 = true; @@ -94,8 +91,8 @@ class PrelnResidualBiasOpConverter : public OpConverter { plugin_inputs.emplace_back(input2); layer = engine_->AddDynamicPlugin(plugin_inputs.data(), 2, plugin); std::vector output_names; - output_names.push_back(op_desc.Output("Out_0")[0]); - output_names.push_back(op_desc.Output("Out_1")[0]); + output_names.push_back(op_desc.Output("Y")[0]); + output_names.push_back(op_desc.Output("BiasDropoutResidualOut")[0]); RreplenishLayerAndOutput( layer, "preln_residual_bias", output_names, test_mode); } @@ -105,4 +102,5 @@ class PrelnResidualBiasOpConverter : public OpConverter { } // namespace inference } // namespace paddle -REGISTER_TRT_OP_CONVERTER(preln_residual_bias, PrelnResidualBiasOpConverter); +REGISTER_TRT_OP_CONVERTER(fused_bias_dropout_residual_layer_norm, + PrelnResidualBiasOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index ecc089e134a3ff93546c57a57ba84c29c95a515e..8e9854f77252cf78ba376e968a643f05ea40c028 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -1316,7 +1316,21 @@ struct SimpleOpTypeSetTeller : public Teller { return false; } } - + if (op_type == "fused_bias_dropout_residual_layer_norm") { + if (!with_dynamic_shape) { + VLOG(3) << "fused_bias_dropout_residual_layer_norm should run on " + "dynamic shape mode."; + return false; + } + float dropout_rate = + PADDLE_GET_CONST(float, desc.GetAttr("dropout_rate")); + if (dropout_rate != 0.0f) { + VLOG(4) << "preln_residual_bias trt layer can not work with " + "fused_bias_dropout_residual_layer_norm op in which the " + "dropout_rate != 0, stop convert"; + return false; + } + } if (op_type == "fused_preln_embedding_eltwise_layernorm") { if (!with_dynamic_shape) { VLOG(3) << "fused_preln_embedding_eltwise_layernorm should run on " @@ -2223,7 +2237,7 @@ struct SimpleOpTypeSetTeller : public Teller { "slice", "strided_slice", "fused_preln_embedding_eltwise_layernorm", - "preln_residual_bias", + "fused_bias_dropout_residual_layer_norm", "c_allreduce_sum", "c_allreduce_min", "c_allreduce_max", @@ -2337,7 +2351,7 @@ struct SimpleOpTypeSetTeller : public Teller { "strided_slice", "fused_preln_embedding_eltwise_layernorm", "preln_skip_layernorm", - "preln_residual_bias", + "fused_bias_dropout_residual_layer_norm", "c_allreduce_sum", "c_allreduce_min", "c_allreduce_max", diff --git a/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cc b/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cc index 081a1ab0a0d2ab6b1ea413ffbb2dba877f729d7a..2512fe3504165b7672c23dadcc47670c2cbe2462 100644 --- a/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cc +++ b/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cc @@ -37,16 +37,17 @@ class FusedBiasDropoutResidualLnOp : public framework::OperatorWithKernel { "Output", "LnVariance", "FusedBiasDropoutResidualLnOp"); - OP_INOUT_CHECK(ctx->HasOutput("BiasDropoutResidualOut"), - "Output", - "BiasDropoutResidualOut", - "FusedBiasDropoutResidualLnOp"); OP_INOUT_CHECK(ctx->HasOutput("DropoutMaskOut"), "Output", "DropoutMaskOut", "FusedBiasDropoutResidualLnOp"); + OP_INOUT_CHECK(ctx->HasOutput("BiasDropoutResidualOut"), + "Output", + "BiasDropoutResidualOut", + "FusedBiasDropoutResidualLnOp"); OP_INOUT_CHECK( ctx->HasOutput("Y"), "Output", "Y", "FusedBiasDropoutResidualLnOp"); + auto x_dim = ctx->GetInputDim("X"); int left = 1; for (int i = 0; i < x_dim.size() - 1; i++) { diff --git a/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cu b/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cu index b194f07c848da8f1b23317e4a0677add6aacb4a9..df4b5720892e67b208a7732f185042ca78f67f27 100644 --- a/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cu +++ b/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cu @@ -56,8 +56,12 @@ class FusedBiasDropoutResidualLnOpKernel : public framework::OpKernel { auto *ln_mean_data = dev_ctx.Alloc(ln_mean, ln_mean->numel() * sizeof(U)); auto *ln_var_data = dev_ctx.Alloc(ln_var, ln_var->numel() * sizeof(U)); - auto *dropout_mask_out_data = dev_ctx.Alloc( - dropout_mask_out, dropout_mask_out->numel() * sizeof(uint8_t)); + auto *dropout_mask_out_data = + (dropout_mask_out == nullptr) + ? nullptr + : dev_ctx.Alloc( + dropout_mask_out, + dropout_mask_out->numel() * sizeof(uint8_t)); auto *y_data = dev_ctx.Alloc(y, y->numel() * sizeof(T)); const auto input_x_dims = input_x->dims(); diff --git a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h index 137943afbfb94df80c8b13593493f5a6810b3366..eaf2d40a3db88a0f5fe29896b60b521a18d1089f 100644 --- a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h @@ -767,9 +767,10 @@ void LaunchLayernormResidualDropoutBias( residual, rows * cols * sizeof(T), ctx.stream()); - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( - mask_data, 0, rows * cols * sizeof(MaskType), ctx.stream())); - + if (mask_data != nullptr) { + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( + mask_data, 0, rows * cols * sizeof(MaskType), ctx.stream())); + } // call layernorm forward switch (GetDesiredBlockDim(cols)) { FIXED_BLOCK_DIM_CASE( diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt index 5f3bfa62ebc1a64683a7b99efb65a6db89929d38..07671cbd3bab95579b5fdbc1f74f44970925b1ba 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -18,11 +18,6 @@ string(REPLACE ".py" "" TEST_TRT_CONVERTER "${TEST_TRT_CONVERTER}") if(NOT WITH_DISTRIBUTE) list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES "test_delete_c_identity_op_pass") - list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES - "test_trt_convert_preln_residual_bias") - list(REMOVE_ITEM TEST_TRT_IR_PASSES "test_trt_convert_preln_residual_bias") - list(REMOVE_ITEM TEST_TRT_CONVERTER "test_trt_convert_preln_residual_bias") - list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES "test_trt_convert_c_allreduce") list(REMOVE_ITEM TEST_TRT_IR_PASSES "test_trt_convert_c_allreduce") list(REMOVE_ITEM TEST_TRT_CONVERTER "test_trt_convert_c_allreduce") diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_preln_residual_bias.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_preln_residual_bias.py index 8e6d1fb58d75eecb20d428c36387fe99694178dc..993e53b1c0ff96bf3e8d02b2c2bedae43a596776 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_preln_residual_bias.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_preln_residual_bias.py @@ -22,7 +22,6 @@ import unittest class TrtConvertSkipLayernormTest(TrtLayerAutoScanTest): - def is_program_valid(self, program_config: ProgramConfig) -> bool: inputs = program_config.inputs weights = program_config.weights @@ -32,14 +31,13 @@ class TrtConvertSkipLayernormTest(TrtLayerAutoScanTest): program_config.ops[i].attrs for i in range(len(program_config.ops)) ] - #The input dimension should be less than or equal to the set axis. + # The input dimension should be less than or equal to the set axis. if 'begin_norm_axis' in attrs[0] and attrs[0]['begin_norm_axis'] >= 0: if len(inputs['inputX_data'].shape) <= attrs[0]['begin_norm_axis']: return False return True def sample_program_configs(self): - def generate_input1(attrs: List[Dict[str, Any]], batch): return np.ones([batch, 128, 768]).astype(np.float32) @@ -56,96 +54,100 @@ class TrtConvertSkipLayernormTest(TrtLayerAutoScanTest): for epsilon in [1e-5]: for begin_norm_axis in [2]: for enable_int8 in [False, True]: - dics = [{ - "epsilon": epsilon, - "begin_norm_axis": begin_norm_axis, - }, {}] - - ops_config = [{ - "op_type": "elementwise_add", - "op_inputs": { - "X": ["inputX_data"], - "Y": ["EleBias"] - }, - "op_outputs": { - "Out": ["bias_out"] + dics = [ + { + "epsilon": epsilon, + "begin_norm_axis": begin_norm_axis, }, - "op_attrs": { - "axis": -1 - } - }, { - "op_type": "elementwise_add", - "op_inputs": { - "X": ["bias_out"], - "Y": ["inputY_data"] + {}, + ] + + ops_config = [ + { + "op_type": "elementwise_add", + "op_inputs": { + "X": ["inputX_data"], + "Y": ["EleBias"], + }, + "op_outputs": {"Out": ["bias_out"]}, + "op_attrs": {"axis": -1}, }, - "op_outputs": { - "Out": ["ele_out"] + { + "op_type": "elementwise_add", + "op_inputs": { + "X": ["bias_out"], + "Y": ["inputY_data"], + }, + "op_outputs": {"Out": ["ele_out"]}, + "op_attrs": {"axis": -1}, }, - "op_attrs": { - "axis": -1 - } - }, { - "op_type": "layer_norm", - "op_inputs": { - "X": ["ele_out"], - "Bias": ["Bias"], - "Scale": ["Scale"] + { + "op_type": "layer_norm", + "op_inputs": { + "X": ["ele_out"], + "Bias": ["Bias"], + "Scale": ["Scale"], + }, + "op_outputs": { + "Y": ["layernorm_out"], + "Mean": ["Mean"], + "Variance": ["Variance"], + }, + "op_attrs": dics[0], }, - "op_outputs": { - "Y": ["layernorm_out"], - "Mean": ["Mean"], - "Variance": ["Variance"] - }, - "op_attrs": dics[0] - }] + ] ops = self.generate_op_config(ops_config) program_config = ProgramConfig( ops=ops, weights={ - "Bias": - TensorConfig( - data_gen=partial(generate_weight1, dics)), - "Scale": - TensorConfig( - data_gen=partial(generate_weight2, dics)), - "EleBias": - TensorConfig( - data_gen=partial(generate_weight2, dics)) + "Bias": TensorConfig( + data_gen=partial(generate_weight1, dics) + ), + "Scale": TensorConfig( + data_gen=partial(generate_weight2, dics) + ), + "EleBias": TensorConfig( + data_gen=partial(generate_weight2, dics) + ), }, inputs={ - "inputX_data": - TensorConfig(data_gen=partial( - generate_input1, dics, batch)), - "inputY_data": - TensorConfig(data_gen=partial( - generate_input2, dics, batch)) + "inputX_data": TensorConfig( + data_gen=partial( + generate_input1, dics, batch + ) + ), + "inputY_data": TensorConfig( + data_gen=partial( + generate_input2, dics, batch + ) + ), }, - outputs=["ele_out", "layernorm_out"]) + outputs=["ele_out", "layernorm_out"], + ) yield program_config def sample_predictor_configs( - self, program_config) -> (paddle_infer.Config, List[int], float): - + self, program_config + ) -> (paddle_infer.Config, List[int], float): def generate_dynamic_shape(attrs): self.dynamic_shape.min_input_shape = { "inputX_data": [4, 128, 768], "inputY_data": [4, 128, 768], "Bias": [768], - "Scale": [768] + "Scale": [768], } self.dynamic_shape.max_input_shape = { "inputX_data": [4, 128, 768], "inputY_data": [4, 128, 768], "Bias": [768], - "Scale": [768] + "Scale": [768], } self.dynamic_shape.opt_input_shape = { "inputX_data": [4, 128, 768], "inputY_data": [4, 128, 768], "Bias": [768], - "Scale": [768] + "Scale": [768], } def clear_dynamic_shape(): @@ -154,20 +156,35 @@ class TrtConvertSkipLayernormTest(TrtLayerAutoScanTest): self.dynamic_shape.opt_input_shape = {} def generate_trt_nodes_num(attrs, dynamic_shape): - return 1, 4 + if dynamic_shape: + return 1, 4 + else: + return 0, 5 attrs = [ program_config.ops[i].attrs for i in range(len(program_config.ops)) ] + # for static_shape, fall back to fluid fused op + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False + ), 1e-2 # atol=1e-2 while rtol is 1e-8 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False + ), 1e-2 # atol=1e-2 while rtol is 1e-8 # just support dynamic_shape generate_dynamic_shape(attrs) self.trt_param.precision = paddle_infer.PrecisionType.Float32 yield self.create_inference_config(), generate_trt_nodes_num( - attrs, True), 1e-2 # atol=1e-2 while rtol is 1e-8 + attrs, True + ), 1e-2 # atol=1e-2 while rtol is 1e-8 self.trt_param.precision = paddle_infer.PrecisionType.Half yield self.create_inference_config(), generate_trt_nodes_num( - attrs, True), 1e-2 # atol=1e-2 while rtol is 1e-8 + attrs, True + ), 1e-2 # atol=1e-2 while rtol is 1e-8 def add_skip_trt_case(self): pass diff --git a/python/paddle/fluid/tests/unittests/ir/test_ir_preln_residual_bias_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/test_ir_preln_residual_bias_fuse_pass.py index 071c0803a4910996696d4445d81e7017e850b4e5..068b0073dd5ecb4f348dc1d1c5cb171b4856a866 100644 --- a/python/paddle/fluid/tests/unittests/ir/test_ir_preln_residual_bias_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/test_ir_preln_residual_bias_fuse_pass.py @@ -20,27 +20,25 @@ import paddle class PrelnResidualBiasFusePassTest(PassTest): - def setUp(self): paddle.enable_static() - with paddle.static.program_guard(self.main_program, - self.startup_program): - x = paddle.static.data(name="x", - shape=[128, 768], - dtype="float32", - lod_level=0) + with paddle.static.program_guard( + self.main_program, self.startup_program + ): + x = paddle.static.data( + name="x", shape=[128, 768], dtype="float32", lod_level=0 + ) bias = paddle.static.create_parameter(shape=[768], dtype='float32') - y = paddle.static.data(name="y", - shape=[128, 768], - dtype="float32", - lod_level=0) + y = paddle.static.data( + name="y", shape=[128, 768], dtype="float32", lod_level=0 + ) x = x + bias elementwise_out = x + y out = paddle.static.nn.layer_norm(input=elementwise_out) self.fetch_list = [out, elementwise_out] self.pass_names = "preln_residual_bias_fuse_pass" - self.fused_op_type = "preln_residual_bias" + self.fused_op_type = "fused_bias_dropout_residual_layer_norm" self.num_fused_ops = 1 # self.graph_attrs = { # "embedding_eltwise_layernorm_fuse_pass_flag": True,