From aa0e84e3648c60258527b4ed18e3ff500e895603 Mon Sep 17 00:00:00 2001 From: wenbin Date: Wed, 21 Sep 2022 14:13:56 +0800 Subject: [PATCH] residual_no_bias (#46129) * residual_no_bias * comments * more ut * fix input --- .../ir/preln_residual_bias_fuse_pass.cc | 142 ++++++++++----- .../ir/preln_residual_bias_fuse_pass.h | 11 ++ .../tensorrt/convert/preln_residual_bias.cc | 33 ++-- .../plugin/preln_residual_bias_plugin.cu | 27 +-- .../fluid/tests/unittests/CMakeLists.txt | 1 + .../unittests/ir/inference/CMakeLists.txt | 4 + ...test_trt_convert_preln_residual_no_bias.py | 166 ++++++++++++++++++ .../test_ir_preln_residual_bias_fuse_pass.py | 32 ++++ 8 files changed, 347 insertions(+), 69 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_preln_residual_no_bias.py diff --git a/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc b/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc index 79d27948954..8af2dd2427f 100644 --- a/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc +++ b/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc @@ -33,11 +33,16 @@ namespace ir { namespace patterns { struct PrelnResidualBias : public PatternBase { - PrelnResidualBias(PDPattern *pattern, const std::string &name_scope) - : PatternBase(pattern, name_scope, "preln_residual_bias") {} + PrelnResidualBias(PDPattern *pattern, + const std::string &name_scope, + bool with_bias) + : PatternBase(pattern, name_scope, "preln_residual_bias") { + with_bias_ = with_bias; + } void operator()(PDNode *x, PDNode *y); + bool with_bias_; // declare operator node's name PATTERN_DECL_NODE(elementwise_bias); PATTERN_DECL_NODE(elementwise0); @@ -55,32 +60,41 @@ struct PrelnResidualBias : public PatternBase { }; void PrelnResidualBias::operator()(PDNode *x, PDNode *y) { + PDNode *elementwise0 = nullptr; + PDNode *elementwise_bias_var = nullptr; + PDNode *elementwise0_out_var = nullptr; // Create nodes for elementwise add op. - x->assert_is_op_input("elementwise_add"); - y->assert_is_op_input("elementwise_add", "X"); - auto *elementwise0 = - pattern->NewNode(elementwise0_repr())->assert_is_op("elementwise_add"); - auto *elementwise_bias_var = pattern->NewNode(elementwise_bias_repr()) - ->assert_is_op_input("elementwise_add", "Y") - ->assert_is_persistable_var(); - auto *elementwise0_out_var = pattern->NewNode(elementwise0_out_repr()) - ->assert_is_op_output("elementwise_add") - ->assert_is_op_input("elementwise_add") - ->assert_more([](Node *x) { - if (x->outputs.size() == 1) { - return true; - } else { - return false; - } - }); + if (with_bias_) { + elementwise0 = + pattern->NewNode(elementwise0_repr())->assert_is_op("elementwise_add"); + elementwise_bias_var = pattern->NewNode(elementwise_bias_repr()) + ->assert_is_op_input("elementwise_add", "Y") + ->assert_is_persistable_var(); + elementwise0_out_var = pattern->NewNode(elementwise0_out_repr()) + ->assert_is_op_output("elementwise_add") + ->assert_is_op_input("elementwise_add") + ->assert_more([](Node *x) { + if (x->outputs.size() == 1) { + return true; + } else { + return false; + } + }); + } else { + elementwise0_out_var = y; + } + auto *elementwise1 = pattern->NewNode(elementwise1_repr())->assert_is_op("elementwise_add"); auto *elementwise1_out_var = pattern->NewNode(elementwise1_out_repr()) - ->assert_is_op_output("elementwise_add") ->assert_is_op_input("layer_norm", "X"); // Add links for elementwise_add op. - elementwise0->LinksFrom({y, elementwise_bias_var}) - .LinksTo({elementwise0_out_var}); + if (with_bias_) { + elementwise0->LinksFrom({y, elementwise_bias_var}) + .LinksTo({elementwise0_out_var}); + elementwise1_out_var->assert_is_op_output("elementwise_add"); + } + elementwise1->LinksFrom({x, elementwise0_out_var}) .LinksTo({elementwise1_out_var}); // Create nodes for layer_norm op. @@ -115,7 +129,8 @@ void PrelnResidualBias::operator()(PDNode *x, PDNode *y) { } // namespace patterns -void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const { +int PrelnResidualBiasFusePass::ApplyPattern(ir::Graph *graph, + bool with_bias) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::PreconditionNotMet("graph should not be null.")); FusePassBase::Init("preln_residual_bias_fuse", graph); @@ -123,18 +138,32 @@ void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const { int found_subgraph_count = 0; GraphPatternDetector gpd; - auto *x = gpd.mutable_pattern() - ->NewNode("preln_residual_bias_fuse/x") - ->AsInput() - ->assert_is_op_input("elementwise_add") - ->assert_var_not_persistable(); - auto *y = gpd.mutable_pattern() - ->NewNode("preln_residual_bias_fuse/y") - ->AsInput() - ->assert_is_op_input("elementwise_add", "X") - ->assert_var_not_persistable(); - patterns::PrelnResidualBias fused_pattern(gpd.mutable_pattern(), - "preln_residual_bias_fuse"); + PDNode *x = nullptr; + PDNode *y = nullptr; + if (with_bias) { + x = gpd.mutable_pattern() + ->NewNode("preln_residual_bias_fuse/x") + ->AsInput() + ->assert_is_op_input("elementwise_add") + ->assert_var_not_persistable(); + y = gpd.mutable_pattern() + ->NewNode("preln_residual_bias_fuse/y") + ->AsInput() + ->assert_is_op_input("elementwise_add", "X") + ->assert_var_not_persistable(); + } else { + x = gpd.mutable_pattern() + ->NewNode("preln_residual_bias_fuse/x") + ->AsInput() + ->assert_is_op_input("elementwise_add", "X"); + + y = gpd.mutable_pattern() + ->NewNode("preln_residual_bias_fuse/y") + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + } + patterns::PrelnResidualBias fused_pattern( + gpd.mutable_pattern(), "preln_residual_bias_fuse", with_bias); fused_pattern(x, y); auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, @@ -145,11 +174,19 @@ void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const { } VLOG(4) << "handle PrelnResidualBias fuse"; - GET_IR_NODE_FROM_SUBGRAPH( - elementwise_bias, elementwise_bias, fused_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise0, elementwise0, fused_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - elementwise0_out, elementwise0_out, fused_pattern); + Node *elementwise_bias = nullptr; + Node *elementwise0 = nullptr; + Node *elementwise0_out = nullptr; + if (with_bias) { + GET_IR_NODE_FROM_SUBGRAPH( + tmp_elementwise_bias, elementwise_bias, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(tmp_elementwise0, elementwise0, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + tmp_elementwise0_out, elementwise0_out, fused_pattern); + elementwise_bias = tmp_elementwise_bias; + elementwise0 = tmp_elementwise0; + elementwise0_out = tmp_elementwise0_out; + } GET_IR_NODE_FROM_SUBGRAPH(elementwise1, elementwise1, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH( elementwise1_out, elementwise1_out, fused_pattern); @@ -185,7 +222,9 @@ void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const { new_desc.SetInput("Y", {subgraph.at(y)->Name()}); new_desc.SetInput("Scale", {layer_norm_scale->Name()}); new_desc.SetInput("Bias", {layer_norm_bias->Name()}); - new_desc.SetInput("EleBias", {elementwise_bias->Name()}); + if (with_bias) { + new_desc.SetInput("EleBias", {elementwise_bias->Name()}); + } // outputs new_desc.SetOutput("Out_0", {layer_norm_out->Name()}); new_desc.SetOutput("Out_1", {elementwise1_out->Name()}); @@ -194,16 +233,20 @@ void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const { new_desc.SetAttr("begin_norm_axis", layer_norm->Op()->GetAttr("begin_norm_axis")); auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied. - del_node_set.insert(elementwise0); + if (with_bias) { + del_node_set.insert(elementwise0); + del_node_set.insert(elementwise0_out); + } del_node_set.insert(elementwise1); - del_node_set.insert(elementwise0_out); del_node_set.insert(layer_norm); del_node_set.insert(layer_norm_mean); del_node_set.insert(layer_norm_variance); GraphSafeRemoveNodes(graph, del_node_set); IR_NODE_LINK_TO(subgraph.at(x), fused_node); IR_NODE_LINK_TO(subgraph.at(y), fused_node); - IR_NODE_LINK_TO(elementwise_bias, fused_node); + if (with_bias) { + IR_NODE_LINK_TO(elementwise_bias, fused_node); + } IR_NODE_LINK_TO(layer_norm_scale, fused_node); IR_NODE_LINK_TO(layer_norm_bias, fused_node); IR_NODE_LINK_TO(fused_node, layer_norm_out); @@ -212,6 +255,17 @@ void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const { }; gpd(graph, handler); + return found_subgraph_count; +} + +void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + FusePassBase::Init("preln_residual_bias_fuse", graph); + + int found_subgraph_count = 0; + found_subgraph_count = ApplyPattern(graph, true); + found_subgraph_count += ApplyPattern(graph, false); AddStatis(found_subgraph_count); } diff --git a/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.h b/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.h index a22bc6d517a..6c8bf8b1496 100644 --- a/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.h +++ b/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.h @@ -29,6 +29,16 @@ namespace ir { // other_op4 layer_norm other_op4 other_op3 // | // other_op3 +// or +// +// | | | | +// other_op1 other_op2 other_op1 other_op2 +// | | fuse \ / +// |------elementwise_add -> preln_residual_bias +// | | | | +// other_op4 layer_norm other_op4 other_op3 +// | +// other_op3 class Graph; class PrelnResidualBiasFusePass : public FusePassBase { @@ -80,6 +90,7 @@ class PrelnResidualBiasFusePass : public FusePassBase { protected: void ApplyImpl(ir::Graph* graph) const override; + int ApplyPattern(ir::Graph* graph, bool with_bias) const; }; } // namespace ir diff --git a/paddle/fluid/inference/tensorrt/convert/preln_residual_bias.cc b/paddle/fluid/inference/tensorrt/convert/preln_residual_bias.cc index bdcb54cfe2e..722cc336ec1 100644 --- a/paddle/fluid/inference/tensorrt/convert/preln_residual_bias.cc +++ b/paddle/fluid/inference/tensorrt/convert/preln_residual_bias.cc @@ -51,12 +51,15 @@ class PrelnResidualBiasOpConverter : public OpConverter { framework::DDim bias_dims, scale_dims, ele_bias_dims; auto* bias = get_persistable_data("Bias", &bias_dims); auto* scale = get_persistable_data("Scale", &scale_dims); - auto* ele_bias = get_persistable_data("EleBias", &ele_bias_dims); + auto const& vars = op_desc.Inputs(false); + bool has_bias = vars.find("EleBias") != vars.end(); + float* ele_bias = + has_bias ? get_persistable_data("EleBias", &ele_bias_dims) : nullptr; int bias_size = phi::product(bias_dims); int scale_size = phi::product(scale_dims); - int ele_bias_size = phi::product(ele_bias_dims); + int ele_bias_size = has_bias ? phi::product(ele_bias_dims) : 0; float epsilon = PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon")); bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); if (engine_->precision() == AnalysisConfig::Precision::kInt8) { @@ -66,18 +69,22 @@ class PrelnResidualBiasOpConverter : public OpConverter { nvinfer1::ILayer* layer = nullptr; plugin::DynamicPluginTensorRT* plugin = nullptr; if (with_fp16) { - auto half_ele_bias_data = new half[ele_bias_size]; - for (int i = 0; i < ele_bias_size; i++) { - half_ele_bias_data[i] = static_cast(ele_bias[i]); + half* half_ele_bias_data = nullptr; + if (ele_bias_size > 0) { + half_ele_bias_data = new half[ele_bias_size]; + for (int i = 0; i < ele_bias_size; i++) { + half_ele_bias_data[i] = static_cast(ele_bias[i]); + } } - plugin = new plugin::PrelnResidualBiasPluginDynamic(bias, - scale, - half_ele_bias_data, - bias_size, - scale_size, - ele_bias_size, - epsilon, - with_fp16); + plugin = new plugin::PrelnResidualBiasPluginDynamic( + bias, + scale, + ele_bias_size > 0 ? half_ele_bias_data : nullptr, + bias_size, + scale_size, + ele_bias_size, + epsilon, + with_fp16); } else { plugin = new plugin::PrelnResidualBiasPluginDynamic(bias, scale, diff --git a/paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.cu index 18b0124698e..f8df772d5e4 100644 --- a/paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.cu @@ -44,19 +44,22 @@ int PrelnResidualBiasPluginDynamic::initialize() TRT_NOEXCEPT { scale_.data(), scale_size_ * sizeof(float), cudaMemcpyHostToDevice); - - if (with_fp16_) { - cudaMalloc(&ele_bias_gpu_, sizeof(half) * ele_bias_size_); - cudaMemcpy(ele_bias_gpu_, - fp16_ele_bias_.data(), - ele_bias_size_ * sizeof(half), - cudaMemcpyHostToDevice); + if (ele_bias_size_ > 0) { + if (with_fp16_) { + cudaMalloc(&ele_bias_gpu_, sizeof(half) * ele_bias_size_); + cudaMemcpy(ele_bias_gpu_, + fp16_ele_bias_.data(), + ele_bias_size_ * sizeof(half), + cudaMemcpyHostToDevice); + } else { + cudaMalloc(&ele_bias_gpu_, sizeof(float) * ele_bias_size_); + cudaMemcpy(ele_bias_gpu_, + fp32_ele_bias_.data(), + ele_bias_size_ * sizeof(float), + cudaMemcpyHostToDevice); + } } else { - cudaMalloc(&ele_bias_gpu_, sizeof(float) * ele_bias_size_); - cudaMemcpy(ele_bias_gpu_, - fp32_ele_bias_.data(), - ele_bias_size_ * sizeof(float), - cudaMemcpyHostToDevice); + ele_bias_gpu_ = nullptr; } return 0; diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 1daf55e630c..a67187869b8 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -142,6 +142,7 @@ if(WIN32) list(REMOVE_ITEM TEST_OPS test_complex_matmul) list(REMOVE_ITEM TEST_OPS test_ops_nms) list(REMOVE_ITEM TEST_OPS test_trt_convert_preln_residual_bias) + list(REMOVE_ITEM TEST_OPS test_trt_convert_preln_residual_no_bias) list(REMOVE_ITEM TEST_OPS test_fused_multi_transformer_int8_op) endif() list(REMOVE_ITEM TEST_OPS test_checkpoint_saver) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt index 5f3bfa62ebc..fbe6fbf26f8 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -22,6 +22,10 @@ if(NOT WITH_DISTRIBUTE) "test_trt_convert_preln_residual_bias") list(REMOVE_ITEM TEST_TRT_IR_PASSES "test_trt_convert_preln_residual_bias") list(REMOVE_ITEM TEST_TRT_CONVERTER "test_trt_convert_preln_residual_bias") + list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES + "test_trt_convert_preln_residual_no_bias") + list(REMOVE_ITEM TEST_TRT_IR_PASSES "test_trt_convert_preln_residual_no_bias") + list(REMOVE_ITEM TEST_TRT_CONVERTER "test_trt_convert_preln_residual_no_bias") list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES "test_trt_convert_c_allreduce") list(REMOVE_ITEM TEST_TRT_IR_PASSES "test_trt_convert_c_allreduce") diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_preln_residual_no_bias.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_preln_residual_no_bias.py new file mode 100644 index 00000000000..de6f10bf6cb --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_preln_residual_no_bias.py @@ -0,0 +1,166 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest + + +class TrtConvertSkipLayernormTest(TrtLayerAutoScanTest): + + def is_program_valid(self, program_config: ProgramConfig) -> bool: + inputs = program_config.inputs + weights = program_config.weights + outputs = program_config.outputs + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + + #The input dimension should be less than or equal to the set axis. + if 'begin_norm_axis' in attrs[0] and attrs[0]['begin_norm_axis'] >= 0: + if len(inputs['inputX_data'].shape) <= attrs[0]['begin_norm_axis']: + return False + return True + + def sample_program_configs(self): + + def generate_input1(attrs: List[Dict[str, Any]], batch): + return np.ones([batch, 128, 768]).astype(np.float32) + + def generate_input2(attrs: List[Dict[str, Any]], batch): + return np.ones([batch, 128, 768]).astype(np.float32) + + def generate_weight1(attrs: List[Dict[str, Any]]): + return np.random.random([768]).astype(np.float32) + + def generate_weight2(attrs: List[Dict[str, Any]]): + return np.random.random([768]).astype(np.float32) + + for batch in [4]: + for epsilon in [1e-5]: + for begin_norm_axis in [2]: + for enable_int8 in [False, True]: + dics = [{ + "epsilon": epsilon, + "begin_norm_axis": begin_norm_axis, + }, {}] + + ops_config = [{ + "op_type": "elementwise_add", + "op_inputs": { + "X": ["inputX_data"], + "Y": ["inputY_data"] + }, + "op_outputs": { + "Out": ["ele_out"] + }, + "op_attrs": { + "axis": -1 + } + }, { + "op_type": "layer_norm", + "op_inputs": { + "X": ["ele_out"], + "Bias": ["Bias"], + "Scale": ["Scale"] + }, + "op_outputs": { + "Y": ["layernorm_out"], + "Mean": ["Mean"], + "Variance": ["Variance"] + }, + "op_attrs": dics[0] + }] + ops = self.generate_op_config(ops_config) + program_config = ProgramConfig( + ops=ops, + weights={ + "Bias": + TensorConfig( + data_gen=partial(generate_weight1, dics)), + "Scale": + TensorConfig( + data_gen=partial(generate_weight2, dics)) + }, + inputs={ + "inputX_data": + TensorConfig(data_gen=partial( + generate_input1, dics, batch)), + "inputY_data": + TensorConfig(data_gen=partial( + generate_input2, dics, batch)) + }, + outputs=["ele_out", "layernorm_out"]) + + yield program_config + + def sample_predictor_configs( + self, program_config) -> (paddle_infer.Config, List[int], float): + + def generate_dynamic_shape(attrs): + self.dynamic_shape.min_input_shape = { + "inputX_data": [4, 128, 768], + "inputY_data": [4, 128, 768], + "Bias": [768], + "Scale": [768] + } + self.dynamic_shape.max_input_shape = { + "inputX_data": [4, 128, 768], + "inputY_data": [4, 128, 768], + "Bias": [768], + "Scale": [768] + } + self.dynamic_shape.opt_input_shape = { + "inputX_data": [4, 128, 768], + "inputY_data": [4, 128, 768], + "Bias": [768], + "Scale": [768] + } + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + return 1, 4 + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + + # just support dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True), 1e-2 # atol=1e-2 while rtol is 1e-8 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True), 1e-2 # atol=1e-2 while rtol is 1e-8 + + def add_skip_trt_case(self): + pass + + def test(self): + self.add_skip_trt_case() + self.run_test() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/test_ir_preln_residual_bias_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/test_ir_preln_residual_bias_fuse_pass.py index 071c0803a49..efa5da7da3b 100644 --- a/python/paddle/fluid/tests/unittests/ir/test_ir_preln_residual_bias_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/test_ir_preln_residual_bias_fuse_pass.py @@ -57,5 +57,37 @@ class PrelnResidualBiasFusePassTest(PassTest): self.check_program(opt_program) +class PrelnResidualBiasFusePassNoBiasTest(PassTest): + + def setUp(self): + paddle.enable_static() + with paddle.static.program_guard(self.main_program, + self.startup_program): + x = paddle.static.data(name="x", + shape=[128, 768], + dtype="float32", + lod_level=0) + y = paddle.static.data(name="y", + shape=[128, 768], + dtype="float32", + lod_level=0) + elementwise_out = x + y + out = paddle.static.nn.layer_norm(input=elementwise_out) + + self.fetch_list = [out, elementwise_out] + self.pass_names = "preln_residual_bias_fuse_pass" + self.fused_op_type = "preln_residual_bias" + self.num_fused_ops = 1 + + def test_check_program(self): + use_gpu_set = [False] + if paddle.device.is_compiled_with_cuda(): + use_gpu_set.append(True) + for use_gpu in use_gpu_set: + place = paddle.CUDAPlace(0) if use_gpu else paddle.CPUPlace() + opt_program = self._apply_ir_passes() + self.check_program(opt_program) + + if __name__ == "__main__": unittest.main() -- GitLab