未验证 提交 aa0e84e3 编写于 作者: W wenbin 提交者: GitHub

residual_no_bias (#46129)

* residual_no_bias

* comments

* more ut

* fix input
上级 3d59fee5
...@@ -33,11 +33,16 @@ namespace ir { ...@@ -33,11 +33,16 @@ namespace ir {
namespace patterns { namespace patterns {
struct PrelnResidualBias : public PatternBase { struct PrelnResidualBias : public PatternBase {
PrelnResidualBias(PDPattern *pattern, const std::string &name_scope) PrelnResidualBias(PDPattern *pattern,
: PatternBase(pattern, name_scope, "preln_residual_bias") {} const std::string &name_scope,
bool with_bias)
: PatternBase(pattern, name_scope, "preln_residual_bias") {
with_bias_ = with_bias;
}
void operator()(PDNode *x, PDNode *y); void operator()(PDNode *x, PDNode *y);
bool with_bias_;
// declare operator node's name // declare operator node's name
PATTERN_DECL_NODE(elementwise_bias); PATTERN_DECL_NODE(elementwise_bias);
PATTERN_DECL_NODE(elementwise0); PATTERN_DECL_NODE(elementwise0);
...@@ -55,15 +60,17 @@ struct PrelnResidualBias : public PatternBase { ...@@ -55,15 +60,17 @@ struct PrelnResidualBias : public PatternBase {
}; };
void PrelnResidualBias::operator()(PDNode *x, PDNode *y) { void PrelnResidualBias::operator()(PDNode *x, PDNode *y) {
PDNode *elementwise0 = nullptr;
PDNode *elementwise_bias_var = nullptr;
PDNode *elementwise0_out_var = nullptr;
// Create nodes for elementwise add op. // Create nodes for elementwise add op.
x->assert_is_op_input("elementwise_add"); if (with_bias_) {
y->assert_is_op_input("elementwise_add", "X"); elementwise0 =
auto *elementwise0 =
pattern->NewNode(elementwise0_repr())->assert_is_op("elementwise_add"); pattern->NewNode(elementwise0_repr())->assert_is_op("elementwise_add");
auto *elementwise_bias_var = pattern->NewNode(elementwise_bias_repr()) elementwise_bias_var = pattern->NewNode(elementwise_bias_repr())
->assert_is_op_input("elementwise_add", "Y") ->assert_is_op_input("elementwise_add", "Y")
->assert_is_persistable_var(); ->assert_is_persistable_var();
auto *elementwise0_out_var = pattern->NewNode(elementwise0_out_repr()) elementwise0_out_var = pattern->NewNode(elementwise0_out_repr())
->assert_is_op_output("elementwise_add") ->assert_is_op_output("elementwise_add")
->assert_is_op_input("elementwise_add") ->assert_is_op_input("elementwise_add")
->assert_more([](Node *x) { ->assert_more([](Node *x) {
...@@ -73,14 +80,21 @@ void PrelnResidualBias::operator()(PDNode *x, PDNode *y) { ...@@ -73,14 +80,21 @@ void PrelnResidualBias::operator()(PDNode *x, PDNode *y) {
return false; return false;
} }
}); });
} else {
elementwise0_out_var = y;
}
auto *elementwise1 = auto *elementwise1 =
pattern->NewNode(elementwise1_repr())->assert_is_op("elementwise_add"); pattern->NewNode(elementwise1_repr())->assert_is_op("elementwise_add");
auto *elementwise1_out_var = pattern->NewNode(elementwise1_out_repr()) auto *elementwise1_out_var = pattern->NewNode(elementwise1_out_repr())
->assert_is_op_output("elementwise_add")
->assert_is_op_input("layer_norm", "X"); ->assert_is_op_input("layer_norm", "X");
// Add links for elementwise_add op. // Add links for elementwise_add op.
if (with_bias_) {
elementwise0->LinksFrom({y, elementwise_bias_var}) elementwise0->LinksFrom({y, elementwise_bias_var})
.LinksTo({elementwise0_out_var}); .LinksTo({elementwise0_out_var});
elementwise1_out_var->assert_is_op_output("elementwise_add");
}
elementwise1->LinksFrom({x, elementwise0_out_var}) elementwise1->LinksFrom({x, elementwise0_out_var})
.LinksTo({elementwise1_out_var}); .LinksTo({elementwise1_out_var});
// Create nodes for layer_norm op. // Create nodes for layer_norm op.
...@@ -115,7 +129,8 @@ void PrelnResidualBias::operator()(PDNode *x, PDNode *y) { ...@@ -115,7 +129,8 @@ void PrelnResidualBias::operator()(PDNode *x, PDNode *y) {
} // namespace patterns } // namespace patterns
void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const { int PrelnResidualBiasFusePass::ApplyPattern(ir::Graph *graph,
bool with_bias) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null.")); graph, platform::errors::PreconditionNotMet("graph should not be null."));
FusePassBase::Init("preln_residual_bias_fuse", graph); FusePassBase::Init("preln_residual_bias_fuse", graph);
...@@ -123,18 +138,32 @@ void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const { ...@@ -123,18 +138,32 @@ void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const {
int found_subgraph_count = 0; int found_subgraph_count = 0;
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto *x = gpd.mutable_pattern() PDNode *x = nullptr;
PDNode *y = nullptr;
if (with_bias) {
x = gpd.mutable_pattern()
->NewNode("preln_residual_bias_fuse/x") ->NewNode("preln_residual_bias_fuse/x")
->AsInput() ->AsInput()
->assert_is_op_input("elementwise_add") ->assert_is_op_input("elementwise_add")
->assert_var_not_persistable(); ->assert_var_not_persistable();
auto *y = gpd.mutable_pattern() y = gpd.mutable_pattern()
->NewNode("preln_residual_bias_fuse/y") ->NewNode("preln_residual_bias_fuse/y")
->AsInput() ->AsInput()
->assert_is_op_input("elementwise_add", "X") ->assert_is_op_input("elementwise_add", "X")
->assert_var_not_persistable(); ->assert_var_not_persistable();
patterns::PrelnResidualBias fused_pattern(gpd.mutable_pattern(), } else {
"preln_residual_bias_fuse"); x = gpd.mutable_pattern()
->NewNode("preln_residual_bias_fuse/x")
->AsInput()
->assert_is_op_input("elementwise_add", "X");
y = gpd.mutable_pattern()
->NewNode("preln_residual_bias_fuse/y")
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
}
patterns::PrelnResidualBias fused_pattern(
gpd.mutable_pattern(), "preln_residual_bias_fuse", with_bias);
fused_pattern(x, y); fused_pattern(x, y);
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
...@@ -145,11 +174,19 @@ void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const { ...@@ -145,11 +174,19 @@ void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const {
} }
VLOG(4) << "handle PrelnResidualBias fuse"; VLOG(4) << "handle PrelnResidualBias fuse";
Node *elementwise_bias = nullptr;
Node *elementwise0 = nullptr;
Node *elementwise0_out = nullptr;
if (with_bias) {
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
elementwise_bias, elementwise_bias, fused_pattern); tmp_elementwise_bias, elementwise_bias, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise0, elementwise0, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(tmp_elementwise0, elementwise0, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
elementwise0_out, elementwise0_out, fused_pattern); tmp_elementwise0_out, elementwise0_out, fused_pattern);
elementwise_bias = tmp_elementwise_bias;
elementwise0 = tmp_elementwise0;
elementwise0_out = tmp_elementwise0_out;
}
GET_IR_NODE_FROM_SUBGRAPH(elementwise1, elementwise1, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(elementwise1, elementwise1, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
elementwise1_out, elementwise1_out, fused_pattern); elementwise1_out, elementwise1_out, fused_pattern);
...@@ -185,7 +222,9 @@ void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const { ...@@ -185,7 +222,9 @@ void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const {
new_desc.SetInput("Y", {subgraph.at(y)->Name()}); new_desc.SetInput("Y", {subgraph.at(y)->Name()});
new_desc.SetInput("Scale", {layer_norm_scale->Name()}); new_desc.SetInput("Scale", {layer_norm_scale->Name()});
new_desc.SetInput("Bias", {layer_norm_bias->Name()}); new_desc.SetInput("Bias", {layer_norm_bias->Name()});
if (with_bias) {
new_desc.SetInput("EleBias", {elementwise_bias->Name()}); new_desc.SetInput("EleBias", {elementwise_bias->Name()});
}
// outputs // outputs
new_desc.SetOutput("Out_0", {layer_norm_out->Name()}); new_desc.SetOutput("Out_0", {layer_norm_out->Name()});
new_desc.SetOutput("Out_1", {elementwise1_out->Name()}); new_desc.SetOutput("Out_1", {elementwise1_out->Name()});
...@@ -194,16 +233,20 @@ void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const { ...@@ -194,16 +233,20 @@ void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const {
new_desc.SetAttr("begin_norm_axis", new_desc.SetAttr("begin_norm_axis",
layer_norm->Op()->GetAttr("begin_norm_axis")); layer_norm->Op()->GetAttr("begin_norm_axis"));
auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied. auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied.
if (with_bias) {
del_node_set.insert(elementwise0); del_node_set.insert(elementwise0);
del_node_set.insert(elementwise1);
del_node_set.insert(elementwise0_out); del_node_set.insert(elementwise0_out);
}
del_node_set.insert(elementwise1);
del_node_set.insert(layer_norm); del_node_set.insert(layer_norm);
del_node_set.insert(layer_norm_mean); del_node_set.insert(layer_norm_mean);
del_node_set.insert(layer_norm_variance); del_node_set.insert(layer_norm_variance);
GraphSafeRemoveNodes(graph, del_node_set); GraphSafeRemoveNodes(graph, del_node_set);
IR_NODE_LINK_TO(subgraph.at(x), fused_node); IR_NODE_LINK_TO(subgraph.at(x), fused_node);
IR_NODE_LINK_TO(subgraph.at(y), fused_node); IR_NODE_LINK_TO(subgraph.at(y), fused_node);
if (with_bias) {
IR_NODE_LINK_TO(elementwise_bias, fused_node); IR_NODE_LINK_TO(elementwise_bias, fused_node);
}
IR_NODE_LINK_TO(layer_norm_scale, fused_node); IR_NODE_LINK_TO(layer_norm_scale, fused_node);
IR_NODE_LINK_TO(layer_norm_bias, fused_node); IR_NODE_LINK_TO(layer_norm_bias, fused_node);
IR_NODE_LINK_TO(fused_node, layer_norm_out); IR_NODE_LINK_TO(fused_node, layer_norm_out);
...@@ -212,6 +255,17 @@ void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const { ...@@ -212,6 +255,17 @@ void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const {
}; };
gpd(graph, handler); gpd(graph, handler);
return found_subgraph_count;
}
void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
FusePassBase::Init("preln_residual_bias_fuse", graph);
int found_subgraph_count = 0;
found_subgraph_count = ApplyPattern(graph, true);
found_subgraph_count += ApplyPattern(graph, false);
AddStatis(found_subgraph_count); AddStatis(found_subgraph_count);
} }
......
...@@ -29,6 +29,16 @@ namespace ir { ...@@ -29,6 +29,16 @@ namespace ir {
// other_op4 layer_norm other_op4 other_op3 // other_op4 layer_norm other_op4 other_op3
// | // |
// other_op3 // other_op3
// or
//
// | | | |
// other_op1 other_op2 other_op1 other_op2
// | | fuse \ /
// |------elementwise_add -> preln_residual_bias
// | | | |
// other_op4 layer_norm other_op4 other_op3
// |
// other_op3
class Graph; class Graph;
class PrelnResidualBiasFusePass : public FusePassBase { class PrelnResidualBiasFusePass : public FusePassBase {
...@@ -80,6 +90,7 @@ class PrelnResidualBiasFusePass : public FusePassBase { ...@@ -80,6 +90,7 @@ class PrelnResidualBiasFusePass : public FusePassBase {
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
int ApplyPattern(ir::Graph* graph, bool with_bias) const;
}; };
} // namespace ir } // namespace ir
......
...@@ -51,12 +51,15 @@ class PrelnResidualBiasOpConverter : public OpConverter { ...@@ -51,12 +51,15 @@ class PrelnResidualBiasOpConverter : public OpConverter {
framework::DDim bias_dims, scale_dims, ele_bias_dims; framework::DDim bias_dims, scale_dims, ele_bias_dims;
auto* bias = get_persistable_data("Bias", &bias_dims); auto* bias = get_persistable_data("Bias", &bias_dims);
auto* scale = get_persistable_data("Scale", &scale_dims); auto* scale = get_persistable_data("Scale", &scale_dims);
auto* ele_bias = get_persistable_data("EleBias", &ele_bias_dims); auto const& vars = op_desc.Inputs(false);
bool has_bias = vars.find("EleBias") != vars.end();
float* ele_bias =
has_bias ? get_persistable_data("EleBias", &ele_bias_dims) : nullptr;
int bias_size = phi::product(bias_dims); int bias_size = phi::product(bias_dims);
int scale_size = phi::product(scale_dims); int scale_size = phi::product(scale_dims);
int ele_bias_size = phi::product(ele_bias_dims); int ele_bias_size = has_bias ? phi::product(ele_bias_dims) : 0;
float epsilon = PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon")); float epsilon = PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon"));
bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
if (engine_->precision() == AnalysisConfig::Precision::kInt8) { if (engine_->precision() == AnalysisConfig::Precision::kInt8) {
...@@ -66,13 +69,17 @@ class PrelnResidualBiasOpConverter : public OpConverter { ...@@ -66,13 +69,17 @@ class PrelnResidualBiasOpConverter : public OpConverter {
nvinfer1::ILayer* layer = nullptr; nvinfer1::ILayer* layer = nullptr;
plugin::DynamicPluginTensorRT* plugin = nullptr; plugin::DynamicPluginTensorRT* plugin = nullptr;
if (with_fp16) { if (with_fp16) {
auto half_ele_bias_data = new half[ele_bias_size]; half* half_ele_bias_data = nullptr;
if (ele_bias_size > 0) {
half_ele_bias_data = new half[ele_bias_size];
for (int i = 0; i < ele_bias_size; i++) { for (int i = 0; i < ele_bias_size; i++) {
half_ele_bias_data[i] = static_cast<half>(ele_bias[i]); half_ele_bias_data[i] = static_cast<half>(ele_bias[i]);
} }
plugin = new plugin::PrelnResidualBiasPluginDynamic(bias, }
plugin = new plugin::PrelnResidualBiasPluginDynamic(
bias,
scale, scale,
half_ele_bias_data, ele_bias_size > 0 ? half_ele_bias_data : nullptr,
bias_size, bias_size,
scale_size, scale_size,
ele_bias_size, ele_bias_size,
......
...@@ -44,7 +44,7 @@ int PrelnResidualBiasPluginDynamic::initialize() TRT_NOEXCEPT { ...@@ -44,7 +44,7 @@ int PrelnResidualBiasPluginDynamic::initialize() TRT_NOEXCEPT {
scale_.data(), scale_.data(),
scale_size_ * sizeof(float), scale_size_ * sizeof(float),
cudaMemcpyHostToDevice); cudaMemcpyHostToDevice);
if (ele_bias_size_ > 0) {
if (with_fp16_) { if (with_fp16_) {
cudaMalloc(&ele_bias_gpu_, sizeof(half) * ele_bias_size_); cudaMalloc(&ele_bias_gpu_, sizeof(half) * ele_bias_size_);
cudaMemcpy(ele_bias_gpu_, cudaMemcpy(ele_bias_gpu_,
...@@ -58,6 +58,9 @@ int PrelnResidualBiasPluginDynamic::initialize() TRT_NOEXCEPT { ...@@ -58,6 +58,9 @@ int PrelnResidualBiasPluginDynamic::initialize() TRT_NOEXCEPT {
ele_bias_size_ * sizeof(float), ele_bias_size_ * sizeof(float),
cudaMemcpyHostToDevice); cudaMemcpyHostToDevice);
} }
} else {
ele_bias_gpu_ = nullptr;
}
return 0; return 0;
} }
......
...@@ -142,6 +142,7 @@ if(WIN32) ...@@ -142,6 +142,7 @@ if(WIN32)
list(REMOVE_ITEM TEST_OPS test_complex_matmul) list(REMOVE_ITEM TEST_OPS test_complex_matmul)
list(REMOVE_ITEM TEST_OPS test_ops_nms) list(REMOVE_ITEM TEST_OPS test_ops_nms)
list(REMOVE_ITEM TEST_OPS test_trt_convert_preln_residual_bias) list(REMOVE_ITEM TEST_OPS test_trt_convert_preln_residual_bias)
list(REMOVE_ITEM TEST_OPS test_trt_convert_preln_residual_no_bias)
list(REMOVE_ITEM TEST_OPS test_fused_multi_transformer_int8_op) list(REMOVE_ITEM TEST_OPS test_fused_multi_transformer_int8_op)
endif() endif()
list(REMOVE_ITEM TEST_OPS test_checkpoint_saver) list(REMOVE_ITEM TEST_OPS test_checkpoint_saver)
......
...@@ -22,6 +22,10 @@ if(NOT WITH_DISTRIBUTE) ...@@ -22,6 +22,10 @@ if(NOT WITH_DISTRIBUTE)
"test_trt_convert_preln_residual_bias") "test_trt_convert_preln_residual_bias")
list(REMOVE_ITEM TEST_TRT_IR_PASSES "test_trt_convert_preln_residual_bias") list(REMOVE_ITEM TEST_TRT_IR_PASSES "test_trt_convert_preln_residual_bias")
list(REMOVE_ITEM TEST_TRT_CONVERTER "test_trt_convert_preln_residual_bias") list(REMOVE_ITEM TEST_TRT_CONVERTER "test_trt_convert_preln_residual_bias")
list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES
"test_trt_convert_preln_residual_no_bias")
list(REMOVE_ITEM TEST_TRT_IR_PASSES "test_trt_convert_preln_residual_no_bias")
list(REMOVE_ITEM TEST_TRT_CONVERTER "test_trt_convert_preln_residual_no_bias")
list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES "test_trt_convert_c_allreduce") list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES "test_trt_convert_c_allreduce")
list(REMOVE_ITEM TEST_TRT_IR_PASSES "test_trt_convert_c_allreduce") list(REMOVE_ITEM TEST_TRT_IR_PASSES "test_trt_convert_c_allreduce")
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons
from program_config import TensorConfig, ProgramConfig
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
import unittest
class TrtConvertSkipLayernormTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
inputs = program_config.inputs
weights = program_config.weights
outputs = program_config.outputs
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
#The input dimension should be less than or equal to the set axis.
if 'begin_norm_axis' in attrs[0] and attrs[0]['begin_norm_axis'] >= 0:
if len(inputs['inputX_data'].shape) <= attrs[0]['begin_norm_axis']:
return False
return True
def sample_program_configs(self):
def generate_input1(attrs: List[Dict[str, Any]], batch):
return np.ones([batch, 128, 768]).astype(np.float32)
def generate_input2(attrs: List[Dict[str, Any]], batch):
return np.ones([batch, 128, 768]).astype(np.float32)
def generate_weight1(attrs: List[Dict[str, Any]]):
return np.random.random([768]).astype(np.float32)
def generate_weight2(attrs: List[Dict[str, Any]]):
return np.random.random([768]).astype(np.float32)
for batch in [4]:
for epsilon in [1e-5]:
for begin_norm_axis in [2]:
for enable_int8 in [False, True]:
dics = [{
"epsilon": epsilon,
"begin_norm_axis": begin_norm_axis,
}, {}]
ops_config = [{
"op_type": "elementwise_add",
"op_inputs": {
"X": ["inputX_data"],
"Y": ["inputY_data"]
},
"op_outputs": {
"Out": ["ele_out"]
},
"op_attrs": {
"axis": -1
}
}, {
"op_type": "layer_norm",
"op_inputs": {
"X": ["ele_out"],
"Bias": ["Bias"],
"Scale": ["Scale"]
},
"op_outputs": {
"Y": ["layernorm_out"],
"Mean": ["Mean"],
"Variance": ["Variance"]
},
"op_attrs": dics[0]
}]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={
"Bias":
TensorConfig(
data_gen=partial(generate_weight1, dics)),
"Scale":
TensorConfig(
data_gen=partial(generate_weight2, dics))
},
inputs={
"inputX_data":
TensorConfig(data_gen=partial(
generate_input1, dics, batch)),
"inputY_data":
TensorConfig(data_gen=partial(
generate_input2, dics, batch))
},
outputs=["ele_out", "layernorm_out"])
yield program_config
def sample_predictor_configs(
self, program_config) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {
"inputX_data": [4, 128, 768],
"inputY_data": [4, 128, 768],
"Bias": [768],
"Scale": [768]
}
self.dynamic_shape.max_input_shape = {
"inputX_data": [4, 128, 768],
"inputY_data": [4, 128, 768],
"Bias": [768],
"Scale": [768]
}
self.dynamic_shape.opt_input_shape = {
"inputX_data": [4, 128, 768],
"inputY_data": [4, 128, 768],
"Bias": [768],
"Scale": [768]
}
def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
return 1, 4
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
# just support dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True), 1e-2 # atol=1e-2 while rtol is 1e-8
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True), 1e-2 # atol=1e-2 while rtol is 1e-8
def add_skip_trt_case(self):
pass
def test(self):
self.add_skip_trt_case()
self.run_test()
if __name__ == "__main__":
unittest.main()
...@@ -57,5 +57,37 @@ class PrelnResidualBiasFusePassTest(PassTest): ...@@ -57,5 +57,37 @@ class PrelnResidualBiasFusePassTest(PassTest):
self.check_program(opt_program) self.check_program(opt_program)
class PrelnResidualBiasFusePassNoBiasTest(PassTest):
def setUp(self):
paddle.enable_static()
with paddle.static.program_guard(self.main_program,
self.startup_program):
x = paddle.static.data(name="x",
shape=[128, 768],
dtype="float32",
lod_level=0)
y = paddle.static.data(name="y",
shape=[128, 768],
dtype="float32",
lod_level=0)
elementwise_out = x + y
out = paddle.static.nn.layer_norm(input=elementwise_out)
self.fetch_list = [out, elementwise_out]
self.pass_names = "preln_residual_bias_fuse_pass"
self.fused_op_type = "preln_residual_bias"
self.num_fused_ops = 1
def test_check_program(self):
use_gpu_set = [False]
if paddle.device.is_compiled_with_cuda():
use_gpu_set.append(True)
for use_gpu in use_gpu_set:
place = paddle.CUDAPlace(0) if use_gpu else paddle.CPUPlace()
opt_program = self._apply_ir_passes()
self.check_program(opt_program)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册