未验证 提交 aaa25222 编写于 作者: H Hulek 提交者: GitHub

Rewrite batch norm act fuse pass tester (#49277)

* Rewritten

* change mkldnn to onednn

* fix cmake name
上级 e9df6fcd
...@@ -404,10 +404,6 @@ if(WITH_MKLDNN) ...@@ -404,10 +404,6 @@ if(WITH_MKLDNN)
test_params_quantization_mkldnn_pass SRCS test_params_quantization_mkldnn_pass SRCS
mkldnn/params_quantization_mkldnn_pass_tester.cc DEPS mkldnn/params_quantization_mkldnn_pass_tester.cc DEPS
params_quantization_mkldnn_pass) params_quantization_mkldnn_pass)
cc_test_old(
test_batch_norm_act_fuse_pass SRCS
mkldnn/batch_norm_act_fuse_pass_tester.cc DEPS batch_norm_act_fuse_pass
pass_test_util)
set(TEST_CONV_BN_PASS_DEPS set(TEST_CONV_BN_PASS_DEPS
conv_bn_fuse_pass conv_bn_fuse_pass
graph_to_program_pass graph_to_program_pass
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.h"
#include "paddle/fluid/framework/ir/pass_test_util.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
namespace ir {
namespace {
void SetBatchNormAttrs(OpDesc* bn_op,
bool is_test = true,
bool trainable_stats = true) {
bn_op->SetAttr("is_test", is_test);
bn_op->SetAttr("trainable_statistics", trainable_stats);
bn_op->SetAttr("fuse_with_relu", false);
bn_op->SetAttr("epsilon", 0.001f);
}
} // namespace
// ------------------------------ Test cases -----------------------------------
// The below test cases are distinguished by whether following attributes have
// true or false value:
// - is_test
// - trainable_statistics
// The test case name would have only attributes with true value in its name.
TEST(FuseBatchNormActOneDNNPass, ThrowIsTestTrainableStats) {
auto prog = test::BuildProgramDesc(
{"x", "m", "v", "bn_y", "act_y", "m_out", "var_out", "sm", "sv"},
{"scale", "bias"});
auto* bn_op = test::CreateOp(&prog,
"batch_norm",
{{"X", "x"},
{"Scale", "scale"},
{"Bias", "bias"},
{"Mean", "m"},
{"Variance", "v"}},
{{"Y", "bn_y"},
{"MeanOut", "m_out"},
{"VarianceOut", "var_out"},
{"SavedMean", "sm"},
{"SavedVariance", "sv"}});
SetBatchNormAttrs(bn_op, true, true);
test::CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
// No fusion in this attribute configuration
constexpr int removed_nodes_count = 0;
EXPECT_THROW(test::RunPassAndAssert(&graph,
"batch_norm_act_fuse_pass",
"x",
"act_y",
removed_nodes_count),
paddle::platform::EnforceNotMet);
}
TEST(FuseBatchNormActOneDNNPass, FuseIsTest) {
auto prog = test::BuildProgramDesc({"x", "m", "v", "bn_y", "act_y"},
{"scale", "bias"});
auto* bn_op = test::CreateOp(&prog,
"batch_norm",
{{"X", "x"},
{"Scale", "scale"},
{"Bias", "bias"},
{"Mean", "m"},
{"Variance", "v"}},
{{"Y", "bn_y"}});
SetBatchNormAttrs(bn_op, true, false);
test::CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
constexpr int removed_nodes_count = 2;
EXPECT_TRUE(test::RunPassAndAssert(
&graph, "batch_norm_act_fuse_pass", "x", "act_y", removed_nodes_count));
EXPECT_TRUE(test::AssertOpsCount(graph, {{"batch_norm", 1}, {"relu", 0}}));
for (const auto* node : graph.Nodes()) {
if (node->IsOp() && node->Op()->Type() == "batch_norm") {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("fuse_with_relu"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("fuse_with_relu")));
ASSERT_TRUE(op->HasAttr("trainable_statistics"));
EXPECT_FALSE(PADDLE_GET_CONST(bool, op->GetAttr("trainable_statistics")));
}
}
}
TEST(FuseBatchNormActOneDNNPass, ThrowTrainableStats) {
auto prog = test::BuildProgramDesc(
{"x", "m", "v", "bn_y", "act_y", "m_out", "var_out", "sm", "sv"},
{"scale", "bias"});
auto* bn_op = test::CreateOp(&prog,
"batch_norm",
{{"X", "x"},
{"Scale", "scale"},
{"Bias", "bias"},
{"Mean", "m"},
{"Variance", "v"}},
{{"Y", "bn_y"},
{"MeanOut", "m_out"},
{"VarianceOut", "var_out"},
{"SavedMean", "sm"},
{"SavedVariance", "sv"}});
SetBatchNormAttrs(bn_op, false, true);
test::CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
// No fusion in this attribute configuration
constexpr int removed_nodes_count = 0;
EXPECT_THROW(test::RunPassAndAssert(&graph,
"batch_norm_act_fuse_pass",
"x",
"act_y",
removed_nodes_count),
paddle::platform::EnforceNotMet);
}
TEST(FuseBatchNormActOneDNNPass, AllAttrsFalse) {
auto prog = test::BuildProgramDesc(
{"x", "m", "v", "bn_y", "act_y", "m_out", "var_out", "sm", "sv"},
{"scale", "bias"});
auto* bn_op = test::CreateOp(&prog,
"batch_norm",
{{"X", "x"},
{"Scale", "scale"},
{"Bias", "bias"},
{"Mean", "m"},
{"Variance", "v"}},
{{"Y", "bn_y"},
{"MeanOut", "m_out"},
{"VarianceOut", "var_out"},
{"SavedMean", "sm"},
{"SavedVariance", "sv"}});
SetBatchNormAttrs(bn_op, false, false);
test::CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
// No fusion in this attribute configuration
constexpr int removed_nodes_count = 0;
EXPECT_THROW(test::RunPassAndAssert(&graph,
"batch_norm_act_fuse_pass",
"x",
"act_y",
removed_nodes_count),
paddle::platform::EnforceNotMet);
}
TEST(FuseBatchNormActOneDNNPass, ThrowUseMkldnn) {
auto prog = test::BuildProgramDesc(
{"x", "m", "v", "bn_y", "act_y", "m_out", "var_out", "sm", "sv"},
{"scale", "bias"});
auto* bn_op = test::CreateOp(&prog,
"batch_norm",
{{"X", "x"},
{"Scale", "scale"},
{"Bias", "bias"},
{"Mean", "m"},
{"Variance", "v"}},
{{"Y", "bn_y"},
{"MeanOut", "m_out"},
{"VarianceOut", "var_out"},
{"SavedMean", "sm"},
{"SavedVariance", "sv"}},
false);
SetBatchNormAttrs(bn_op, false, false);
test::CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
// No fusion in this attribute configuration
constexpr int removed_nodes_count = 0;
EXPECT_THROW(test::RunPassAndAssert(&graph,
"batch_norm_act_fuse_pass",
"x",
"act_y",
removed_nodes_count),
paddle::platform::EnforceNotMet);
}
TEST(FuseBatchNormActOneDNNPass, pass_op_version_check) {
ASSERT_TRUE(
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
.IsPassCompatible("batch_norm_act_fuse_pass"));
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(batch_norm_act_fuse_pass);
...@@ -240,7 +240,7 @@ if(WITH_GPU AND TENSORRT_FOUND) ...@@ -240,7 +240,7 @@ if(WITH_GPU AND TENSORRT_FOUND)
PROPERTIES TIMEOUT 300) PROPERTIES TIMEOUT 300)
set_tests_properties(test_mkldnn_conv_hard_swish_fuse_pass set_tests_properties(test_mkldnn_conv_hard_swish_fuse_pass
PROPERTIES TIMEOUT 300) PROPERTIES TIMEOUT 300)
set_tests_properties(test_mkldnn_batch_norm_act_fuse_pass PROPERTIES TIMEOUT set_tests_properties(test_onednn_batch_norm_act_fuse_pass PROPERTIES TIMEOUT
100) 100)
set_tests_properties(test_mkldnn_matmul_v2_transpose_reshape_fuse_pass set_tests_properties(test_mkldnn_matmul_v2_transpose_reshape_fuse_pass
PROPERTIES TIMEOUT 100) PROPERTIES TIMEOUT 100)
......
...@@ -21,12 +21,12 @@ from auto_scan_test import PassAutoScanTest ...@@ -21,12 +21,12 @@ from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig from program_config import OpConfig, ProgramConfig, TensorConfig
class TestScaleMatmulMkldnnFusePass(PassAutoScanTest): class TestScaleOneDNNFusePass(PassAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool: def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True return True
def sample_program_config(self, draw): def sample_program_config(self, draw):
data_layout = draw(st.sampled_from(["NCHW", "NHWC"])) data_layout = draw(st.sampled_from(['NCHW', 'NHWC']))
epsilon = draw(st.floats(min_value=0.0, max_value=0.001)) epsilon = draw(st.floats(min_value=0.0, max_value=0.001))
fuse_with_relu = draw(st.booleans()) fuse_with_relu = draw(st.booleans())
is_test = draw(st.sampled_from([True])) is_test = draw(st.sampled_from([True]))
...@@ -43,7 +43,7 @@ class TestScaleMatmulMkldnnFusePass(PassAutoScanTest): ...@@ -43,7 +43,7 @@ class TestScaleMatmulMkldnnFusePass(PassAutoScanTest):
def generate_input(): def generate_input():
shape = [input_dim1, input_dim2] shape = [input_dim1, input_dim2]
if data_layout == "NCHW": if data_layout == 'NCHW':
shape.insert(0, channel) shape.insert(0, channel)
shape.insert(0, batch_size) shape.insert(0, batch_size)
else: else:
...@@ -55,38 +55,38 @@ class TestScaleMatmulMkldnnFusePass(PassAutoScanTest): ...@@ -55,38 +55,38 @@ class TestScaleMatmulMkldnnFusePass(PassAutoScanTest):
return np.random.random(channel).astype(np.float32) return np.random.random(channel).astype(np.float32)
batch_norm_op = OpConfig( batch_norm_op = OpConfig(
type="batch_norm", type='batch_norm',
inputs={ inputs={
"X": ["input_data"], 'X': ['input_data'],
"Bias": ["Bias"], 'Bias': ['Bias'],
"Mean": ["Mean"], 'Mean': ['Mean'],
"Scale": ["Scale"], 'Scale': ['Scale'],
"Variance": ["Variance"], 'Variance': ['Variance'],
}, },
outputs={ outputs={
"Y": ["norm_output"], 'Y': ['norm_output'],
"MeanOut": ["Mean"], 'MeanOut': ['Mean'],
"VarianceOut": ["Variance"], 'VarianceOut': ['Variance'],
"SavedMean": ["SavedMean"], 'SavedMean': ['SavedMean'],
"SavedVariance": ["SavedVariance"], 'SavedVariance': ['SavedVariance'],
}, },
attrs={ attrs={
"data_layout": data_layout, 'data_layout': data_layout,
"epsilon": epsilon, 'epsilon': epsilon,
"fuse_with_relu": fuse_with_relu, 'fuse_with_relu': fuse_with_relu,
"is_test": is_test, 'is_test': is_test,
"momentum": momentum, 'momentum': momentum,
"trainable_statistics": trainable_statistics, 'trainable_statistics': trainable_statistics,
"use_global_stats": use_global_stats, 'use_global_stats': use_global_stats,
"use_mkldnn": use_mkldnn1, 'use_mkldnn': use_mkldnn1,
}, },
) )
relu_op = OpConfig( relu_op = OpConfig(
type="relu", type='relu',
inputs={"X": ["norm_output"]}, inputs={'X': ['norm_output']},
outputs={"Out": ["relu_output"]}, outputs={'Out': ['relu_output']},
attrs={"use_cudnn": use_cudnn, "use_mkldnn": use_mkldnn2}, attrs={'use_cudnn': use_cudnn, 'use_mkldnn': use_mkldnn2},
) )
model_net = [batch_norm_op, relu_op] model_net = [batch_norm_op, relu_op]
...@@ -94,26 +94,26 @@ class TestScaleMatmulMkldnnFusePass(PassAutoScanTest): ...@@ -94,26 +94,26 @@ class TestScaleMatmulMkldnnFusePass(PassAutoScanTest):
program_config = ProgramConfig( program_config = ProgramConfig(
ops=model_net, ops=model_net,
weights={ weights={
"Bias": TensorConfig(data_gen=partial(generate_weight)), 'Bias': TensorConfig(data_gen=partial(generate_weight)),
"Mean": TensorConfig(data_gen=partial(generate_weight)), 'Mean': TensorConfig(data_gen=partial(generate_weight)),
"Scale": TensorConfig(data_gen=partial(generate_weight)), 'Scale': TensorConfig(data_gen=partial(generate_weight)),
"Variance": TensorConfig(data_gen=partial(generate_weight)), 'Variance': TensorConfig(data_gen=partial(generate_weight)),
}, },
inputs={ inputs={
"input_data": TensorConfig(data_gen=partial(generate_input)) 'input_data': TensorConfig(data_gen=partial(generate_input))
}, },
outputs=["relu_output"], outputs=['relu_output'],
) )
return program_config return program_config
def sample_predictor_configs(self, program_config): def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True) config = self.create_inference_config(use_mkldnn=True)
yield config, ["batch_norm"], (1e-5, 1e-5) yield config, ['batch_norm'], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis(quant=False, passes=["batch_norm_act_fuse_pass"]) self.run_and_statis(quant=False, passes=['batch_norm_act_fuse_pass'])
if __name__ == "__main__": if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册