未验证 提交 0127e92d 编写于 作者: H heliqi 提交者: GitHub

add fc_elementwise_layernorm_fuse_pass (#37771)

* add fc_elementwise_layernorm_fuse_pass

* fix name conflictn

* rebuild CI

* fix Ran Programs=0 bug
上级 26c44a86
......@@ -156,7 +156,7 @@ cc_test(test_seqpool_cvm_concat_fuse_pass SRCS seqpool_cvm_concat_fuse_pass_test
cc_test(test_repeated_fc_relu_fuse_pass_cc SRCS repeated_fc_relu_fuse_pass_tester.cc DEPS repeated_fc_relu_fuse_pass framework_proto)
cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass)
cc_test(test_simplify_with_basic_ops_pass SRCS simplify_with_basic_ops_pass_tester.cc DEPS simplify_with_basic_ops_pass)
cc_test(test_fc_elementwise_layernorm_fuse_pass SRCS fc_elementwise_layernorm_fuse_pass_tester.cc DEPS fc_elementwise_layernorm_fuse_pass)
cc_test(test_fc_elementwise_layernorm_fuse_pass_cc SRCS fc_elementwise_layernorm_fuse_pass_tester.cc DEPS fc_elementwise_layernorm_fuse_pass)
cc_test(test_skip_layernorm_fuse_pass SRCS skip_layernorm_fuse_pass_tester.cc DEPS skip_layernorm_fuse_pass)
cc_test(test_multihead_matmul_fuse_pass SRCS multihead_matmul_fuse_pass_tester.cc DEPS multihead_matmul_fuse_pass)
cc_test(test_conv_bn_fuse_pass_cc SRCS conv_bn_fuse_pass_tester.cc DEPS conv_bn_fuse_pass)
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
......@@ -338,3 +339,9 @@ void FCElementwiseLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
REGISTER_PASS(fc_elementwise_layernorm_fuse_pass,
paddle::framework::ir::FCElementwiseLayerNormFusePass);
REGISTER_PASS_CAPABILITY(fc_elementwise_layernorm_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("fc", 0)
.LE("elementwise_add", 1)
.EQ("layer_norm", 0));
......@@ -322,14 +322,14 @@ class PassAutoScanTest(AutoScanTest):
"Expected operator list after fusion is {}, but now it's {}".format(
op_list_after_fusion, after_op_list), )
def run_and_statis(
self,
quant=False,
max_examples=100,
reproduce=None,
min_success_num=25,
max_duration=180,
passes=None, ):
def run_and_statis(self,
quant=False,
max_examples=100,
reproduce=None,
min_success_num=25,
max_duration=180,
passes=None,
use_gpu_run_baseline=False):
if os.getenv('HYPOTHESIS_TEST_PROFILE', 'ci') == "dev":
max_examples *= 10
min_success_num *= 10
......@@ -354,7 +354,10 @@ class PassAutoScanTest(AutoScanTest):
return self.sample_program_config(draw)
def run_test(prog_config):
return self.run_test(quant=quant, prog_configs=[prog_config])
return self.run_test(
quant=quant,
prog_configs=[prog_config],
use_gpu_run_baseline=use_gpu_run_baseline)
generator = st.composite(program_generator)
loop_func = given(generator())(run_test)
......@@ -371,8 +374,8 @@ class PassAutoScanTest(AutoScanTest):
logging.info("Number of Ran Programs: {}".format(self.num_ran_programs))
logging.info("Number of Ignore Tests: {}".format(self.num_ignore_tests))
successful_ran_programs = int(self.num_ran_programs -
self.num_ignore_tests /
self.num_predictor_kinds)
self.num_ignore_tests / max(
self.num_predictor_kinds, 1))
logging.info(
"Number of successfully ran programs approximately equal to {}".
format(successful_ran_programs))
......@@ -391,7 +394,10 @@ class PassAutoScanTest(AutoScanTest):
format(max_duration))
assert False
def run_test(self, quant=False, prog_configs=None):
def run_test(self,
quant=False,
prog_configs=None,
use_gpu_run_baseline=False):
status = True
for prog_config in prog_configs:
......@@ -413,7 +419,9 @@ class PassAutoScanTest(AutoScanTest):
results: List[Dict[str, np.ndarray]] = []
# baseline: cpu no ir_optim run
base_config = self.create_inference_config(ir_optim=False)
base_config = self.create_inference_config(
ir_optim=False, use_gpu=use_gpu_run_baseline)
logging.info('RUN program_config: ' + str(prog_config))
results.append(
self.run_test_config(model, params, prog_config, base_config,
......
......@@ -109,7 +109,7 @@ class TestAdaptivePool2dConvertGlobalPass(PassAutoScanTest):
def test(self):
self.run_and_statis(
quant=False,
max_examples=100,
max_examples=300,
passes=["adaptive_pool2d_convert_global_pass"],
min_success_num=40)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import PassAutoScanTest, IgnoreReasons
from program_config import TensorConfig, ProgramConfig, OpConfig
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
import unittest
import hypothesis
from hypothesis import given, settings, seed, example, assume, reproduce_failure
import hypothesis.strategies as st
class TestFCElementwiseLayerNormFusePass(PassAutoScanTest):
"""
x_var w(persistable) bias_var(persistable)
\ | /
fc
|
fc_out_var bias_var(persistable)
\ /
elementwise_add bias_var(persistable) scale_var(persistable)
\ | /
layer_norm
/ | \
Y mean_var variance_var
"""
def sample_predictor_configs(self, program_config):
# for gpu
config = self.create_inference_config(use_gpu=True)
yield config, ["fused_fc_elementwise_layernorm"], (1e-5, 1e-5)
def sample_program_config(self, draw):
# 1. Generate shape of input:X of fc
x_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=8), min_size=2, max_size=5))
x_shape = [2, 1]
x_rank = len(x_shape)
# 2. Generate attr:in_num_col_dims of fc
in_num_col_dims = draw(st.integers(min_value=1, max_value=x_rank - 1))
# 3. Generate legal shape of input:W/bias of fc
w_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=8), min_size=2, max_size=2))
w_shape[0] = int(np.prod(x_shape[in_num_col_dims:]))
w_shape = [1, 2]
fc_bias_shape = [w_shape[1], ]
if draw(st.booleans()):
fc_bias_shape.insert(0, 1)
fc_bias_shape = [2, ]
fc_out_shape = x_shape[:in_num_col_dims] + w_shape[1:]
# 4. Generate legal attr:axis/shape of elementwise_add
add_bias_shape = fc_out_shape[:]
axis = draw(st.integers(min_value=-1, max_value=0))
# 5. Generate legal shape of layer_norm
begin_norm_axis = draw(
st.integers(
min_value=1, max_value=len(fc_out_shape) - 1))
layer_norm_shape = [int(np.prod(fc_out_shape[begin_norm_axis:]))]
epsilon = 1e-5
fc_op = OpConfig(
"fc",
inputs={"Input": ["fc_x"],
"W": ["fc_w"],
"Bias": ["fc_bias"]},
outputs={"Out": ["fc_out"]},
in_num_col_dims=in_num_col_dims,
padding_weights=False,
activation_type="",
use_quantizer=False,
use_mkldnn=False, )
add_op = OpConfig(
"elementwise_add",
inputs={"X": ["fc_out"],
"Y": ["add_bias"]},
outputs={"Out": ["add_out"]},
axis=axis, )
layer_norm_op = OpConfig(
"layer_norm",
inputs={
"X": ["add_out"],
"Scale": ["scale"],
"Bias": ["layer_norm_bias"]
},
outputs={
"Y": ["layer_norm_out"],
"Mean": ["layer_norm_mean"],
"Variance": ["layer_norm_var"]
},
begin_norm_axis=begin_norm_axis,
epsilon=epsilon)
ops = [fc_op, add_op, layer_norm_op]
program_config = ProgramConfig(
ops=ops,
weights={
"fc_w": TensorConfig(shape=w_shape),
"fc_bias": TensorConfig(shape=fc_bias_shape),
"add_bias": TensorConfig(shape=add_bias_shape),
"scale": TensorConfig(shape=layer_norm_shape),
"layer_norm_bias": TensorConfig(shape=layer_norm_shape),
},
inputs={"fc_x": TensorConfig(shape=x_shape), },
outputs=ops[-1].outputs["Y"], )
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=300,
passes=["fc_elementwise_layernorm_fuse_pass"],
use_gpu_run_baseline=True)
if __name__ == "__main__":
unittest.main()
......@@ -202,7 +202,7 @@ HIGH_PARALLEL_JOB_NEW = [
'test_fleet_runtime',
'test_rnn_cudnn_params_packing',
'test_mkldnn_placement_pass',
'test_fc_elementwise_layernorm_fuse_pass',
'test_fc_elementwise_layernorm_fuse_pass_cc',
'program_desc_test',
'test_simplify_with_basic_ops_pass',
'test_dygraph_mode_of_unittest',
......@@ -1417,7 +1417,7 @@ CPU_PARALLEL_JOB = [
'test_fc_mkldnn_op',
'test_fc_lstm_fuse_pass',
'test_fc_gru_fuse_pass',
'test_fc_elementwise_layernorm_fuse_pass',
'test_fc_elementwise_layernorm_fuse_pass_cc',
'test_fc_bf16_mkldnn_op',
'test_executor_feed_non_tensor',
'test_executor_check_feed',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册