未验证 提交 aaa25222 编写于 作者: H Hulek 提交者: GitHub

Rewrite batch norm act fuse pass tester (#49277)

* Rewritten

* change mkldnn to onednn

* fix cmake name
上级 e9df6fcd
......@@ -404,10 +404,6 @@ if(WITH_MKLDNN)
test_params_quantization_mkldnn_pass SRCS
mkldnn/params_quantization_mkldnn_pass_tester.cc DEPS
params_quantization_mkldnn_pass)
cc_test_old(
test_batch_norm_act_fuse_pass SRCS
mkldnn/batch_norm_act_fuse_pass_tester.cc DEPS batch_norm_act_fuse_pass
pass_test_util)
set(TEST_CONV_BN_PASS_DEPS
conv_bn_fuse_pass
graph_to_program_pass
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.h"
#include "paddle/fluid/framework/ir/pass_test_util.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
namespace ir {
namespace {
void SetBatchNormAttrs(OpDesc* bn_op,
bool is_test = true,
bool trainable_stats = true) {
bn_op->SetAttr("is_test", is_test);
bn_op->SetAttr("trainable_statistics", trainable_stats);
bn_op->SetAttr("fuse_with_relu", false);
bn_op->SetAttr("epsilon", 0.001f);
}
} // namespace
// ------------------------------ Test cases -----------------------------------
// The below test cases are distinguished by whether following attributes have
// true or false value:
// - is_test
// - trainable_statistics
// The test case name would have only attributes with true value in its name.
TEST(FuseBatchNormActOneDNNPass, ThrowIsTestTrainableStats) {
auto prog = test::BuildProgramDesc(
{"x", "m", "v", "bn_y", "act_y", "m_out", "var_out", "sm", "sv"},
{"scale", "bias"});
auto* bn_op = test::CreateOp(&prog,
"batch_norm",
{{"X", "x"},
{"Scale", "scale"},
{"Bias", "bias"},
{"Mean", "m"},
{"Variance", "v"}},
{{"Y", "bn_y"},
{"MeanOut", "m_out"},
{"VarianceOut", "var_out"},
{"SavedMean", "sm"},
{"SavedVariance", "sv"}});
SetBatchNormAttrs(bn_op, true, true);
test::CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
// No fusion in this attribute configuration
constexpr int removed_nodes_count = 0;
EXPECT_THROW(test::RunPassAndAssert(&graph,
"batch_norm_act_fuse_pass",
"x",
"act_y",
removed_nodes_count),
paddle::platform::EnforceNotMet);
}
TEST(FuseBatchNormActOneDNNPass, FuseIsTest) {
auto prog = test::BuildProgramDesc({"x", "m", "v", "bn_y", "act_y"},
{"scale", "bias"});
auto* bn_op = test::CreateOp(&prog,
"batch_norm",
{{"X", "x"},
{"Scale", "scale"},
{"Bias", "bias"},
{"Mean", "m"},
{"Variance", "v"}},
{{"Y", "bn_y"}});
SetBatchNormAttrs(bn_op, true, false);
test::CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
constexpr int removed_nodes_count = 2;
EXPECT_TRUE(test::RunPassAndAssert(
&graph, "batch_norm_act_fuse_pass", "x", "act_y", removed_nodes_count));
EXPECT_TRUE(test::AssertOpsCount(graph, {{"batch_norm", 1}, {"relu", 0}}));
for (const auto* node : graph.Nodes()) {
if (node->IsOp() && node->Op()->Type() == "batch_norm") {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("fuse_with_relu"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("fuse_with_relu")));
ASSERT_TRUE(op->HasAttr("trainable_statistics"));
EXPECT_FALSE(PADDLE_GET_CONST(bool, op->GetAttr("trainable_statistics")));
}
}
}
TEST(FuseBatchNormActOneDNNPass, ThrowTrainableStats) {
auto prog = test::BuildProgramDesc(
{"x", "m", "v", "bn_y", "act_y", "m_out", "var_out", "sm", "sv"},
{"scale", "bias"});
auto* bn_op = test::CreateOp(&prog,
"batch_norm",
{{"X", "x"},
{"Scale", "scale"},
{"Bias", "bias"},
{"Mean", "m"},
{"Variance", "v"}},
{{"Y", "bn_y"},
{"MeanOut", "m_out"},
{"VarianceOut", "var_out"},
{"SavedMean", "sm"},
{"SavedVariance", "sv"}});
SetBatchNormAttrs(bn_op, false, true);
test::CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
// No fusion in this attribute configuration
constexpr int removed_nodes_count = 0;
EXPECT_THROW(test::RunPassAndAssert(&graph,
"batch_norm_act_fuse_pass",
"x",
"act_y",
removed_nodes_count),
paddle::platform::EnforceNotMet);
}
TEST(FuseBatchNormActOneDNNPass, AllAttrsFalse) {
auto prog = test::BuildProgramDesc(
{"x", "m", "v", "bn_y", "act_y", "m_out", "var_out", "sm", "sv"},
{"scale", "bias"});
auto* bn_op = test::CreateOp(&prog,
"batch_norm",
{{"X", "x"},
{"Scale", "scale"},
{"Bias", "bias"},
{"Mean", "m"},
{"Variance", "v"}},
{{"Y", "bn_y"},
{"MeanOut", "m_out"},
{"VarianceOut", "var_out"},
{"SavedMean", "sm"},
{"SavedVariance", "sv"}});
SetBatchNormAttrs(bn_op, false, false);
test::CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
// No fusion in this attribute configuration
constexpr int removed_nodes_count = 0;
EXPECT_THROW(test::RunPassAndAssert(&graph,
"batch_norm_act_fuse_pass",
"x",
"act_y",
removed_nodes_count),
paddle::platform::EnforceNotMet);
}
TEST(FuseBatchNormActOneDNNPass, ThrowUseMkldnn) {
auto prog = test::BuildProgramDesc(
{"x", "m", "v", "bn_y", "act_y", "m_out", "var_out", "sm", "sv"},
{"scale", "bias"});
auto* bn_op = test::CreateOp(&prog,
"batch_norm",
{{"X", "x"},
{"Scale", "scale"},
{"Bias", "bias"},
{"Mean", "m"},
{"Variance", "v"}},
{{"Y", "bn_y"},
{"MeanOut", "m_out"},
{"VarianceOut", "var_out"},
{"SavedMean", "sm"},
{"SavedVariance", "sv"}},
false);
SetBatchNormAttrs(bn_op, false, false);
test::CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
// No fusion in this attribute configuration
constexpr int removed_nodes_count = 0;
EXPECT_THROW(test::RunPassAndAssert(&graph,
"batch_norm_act_fuse_pass",
"x",
"act_y",
removed_nodes_count),
paddle::platform::EnforceNotMet);
}
TEST(FuseBatchNormActOneDNNPass, pass_op_version_check) {
ASSERT_TRUE(
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
.IsPassCompatible("batch_norm_act_fuse_pass"));
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(batch_norm_act_fuse_pass);
......@@ -240,7 +240,7 @@ if(WITH_GPU AND TENSORRT_FOUND)
PROPERTIES TIMEOUT 300)
set_tests_properties(test_mkldnn_conv_hard_swish_fuse_pass
PROPERTIES TIMEOUT 300)
set_tests_properties(test_mkldnn_batch_norm_act_fuse_pass PROPERTIES TIMEOUT
set_tests_properties(test_onednn_batch_norm_act_fuse_pass PROPERTIES TIMEOUT
100)
set_tests_properties(test_mkldnn_matmul_v2_transpose_reshape_fuse_pass
PROPERTIES TIMEOUT 100)
......
......@@ -21,12 +21,12 @@ from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
class TestScaleMatmulMkldnnFusePass(PassAutoScanTest):
class TestScaleOneDNNFusePass(PassAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def sample_program_config(self, draw):
data_layout = draw(st.sampled_from(["NCHW", "NHWC"]))
data_layout = draw(st.sampled_from(['NCHW', 'NHWC']))
epsilon = draw(st.floats(min_value=0.0, max_value=0.001))
fuse_with_relu = draw(st.booleans())
is_test = draw(st.sampled_from([True]))
......@@ -43,7 +43,7 @@ class TestScaleMatmulMkldnnFusePass(PassAutoScanTest):
def generate_input():
shape = [input_dim1, input_dim2]
if data_layout == "NCHW":
if data_layout == 'NCHW':
shape.insert(0, channel)
shape.insert(0, batch_size)
else:
......@@ -55,38 +55,38 @@ class TestScaleMatmulMkldnnFusePass(PassAutoScanTest):
return np.random.random(channel).astype(np.float32)
batch_norm_op = OpConfig(
type="batch_norm",
type='batch_norm',
inputs={
"X": ["input_data"],
"Bias": ["Bias"],
"Mean": ["Mean"],
"Scale": ["Scale"],
"Variance": ["Variance"],
'X': ['input_data'],
'Bias': ['Bias'],
'Mean': ['Mean'],
'Scale': ['Scale'],
'Variance': ['Variance'],
},
outputs={
"Y": ["norm_output"],
"MeanOut": ["Mean"],
"VarianceOut": ["Variance"],
"SavedMean": ["SavedMean"],
"SavedVariance": ["SavedVariance"],
'Y': ['norm_output'],
'MeanOut': ['Mean'],
'VarianceOut': ['Variance'],
'SavedMean': ['SavedMean'],
'SavedVariance': ['SavedVariance'],
},
attrs={
"data_layout": data_layout,
"epsilon": epsilon,
"fuse_with_relu": fuse_with_relu,
"is_test": is_test,
"momentum": momentum,
"trainable_statistics": trainable_statistics,
"use_global_stats": use_global_stats,
"use_mkldnn": use_mkldnn1,
'data_layout': data_layout,
'epsilon': epsilon,
'fuse_with_relu': fuse_with_relu,
'is_test': is_test,
'momentum': momentum,
'trainable_statistics': trainable_statistics,
'use_global_stats': use_global_stats,
'use_mkldnn': use_mkldnn1,
},
)
relu_op = OpConfig(
type="relu",
inputs={"X": ["norm_output"]},
outputs={"Out": ["relu_output"]},
attrs={"use_cudnn": use_cudnn, "use_mkldnn": use_mkldnn2},
type='relu',
inputs={'X': ['norm_output']},
outputs={'Out': ['relu_output']},
attrs={'use_cudnn': use_cudnn, 'use_mkldnn': use_mkldnn2},
)
model_net = [batch_norm_op, relu_op]
......@@ -94,26 +94,26 @@ class TestScaleMatmulMkldnnFusePass(PassAutoScanTest):
program_config = ProgramConfig(
ops=model_net,
weights={
"Bias": TensorConfig(data_gen=partial(generate_weight)),
"Mean": TensorConfig(data_gen=partial(generate_weight)),
"Scale": TensorConfig(data_gen=partial(generate_weight)),
"Variance": TensorConfig(data_gen=partial(generate_weight)),
'Bias': TensorConfig(data_gen=partial(generate_weight)),
'Mean': TensorConfig(data_gen=partial(generate_weight)),
'Scale': TensorConfig(data_gen=partial(generate_weight)),
'Variance': TensorConfig(data_gen=partial(generate_weight)),
},
inputs={
"input_data": TensorConfig(data_gen=partial(generate_input))
'input_data': TensorConfig(data_gen=partial(generate_input))
},
outputs=["relu_output"],
outputs=['relu_output'],
)
return program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True)
yield config, ["batch_norm"], (1e-5, 1e-5)
yield config, ['batch_norm'], (1e-5, 1e-5)
def test(self):
self.run_and_statis(quant=False, passes=["batch_norm_act_fuse_pass"])
self.run_and_statis(quant=False, passes=['batch_norm_act_fuse_pass'])
if __name__ == "__main__":
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册