From 5134f1100554e5d3074220ac1c05f94cbe9956d0 Mon Sep 17 00:00:00 2001 From: jakpiase Date: Wed, 27 Apr 2022 17:57:01 +0200 Subject: [PATCH] Added missing test for shuffle_channel_mkldnn_detect_pass (#42001) * added test for shuffle_channel_mkldnn_detect_pass * added UT using new framework * CI fix --- paddle/fluid/framework/ir/CMakeLists.txt | 1 + ...uffle_channel_mkldnn_detect_pass_tester.cc | 83 +++++++++++ ...test_mkldnn_shuffle_channel_detect_pass.py | 138 ++++++++++++++++++ 3 files changed, 222 insertions(+) create mode 100644 paddle/fluid/framework/ir/mkldnn/shuffle_channel_mkldnn_detect_pass_tester.cc create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_shuffle_channel_detect_pass.py diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 207ee713bf..a2f3b8dc79 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -226,6 +226,7 @@ endif() cc_test(test_cpu_quantize_squash_pass SRCS mkldnn/cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor) cc_test(test_reshape_transpose_matmul_mkldnn_fuse_pass SRCS mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc DEPS reshape_transpose_matmul_mkldnn_fuse_pass reshape_transpose_matmul_v2_mkldnn_fuse_pass) cc_test(test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass matmul_v2_transpose_reshape_fuse_pass) + cc_test(test_shuffle_channel_mkldnn_detect_pass SRCS mkldnn/shuffle_channel_mkldnn_detect_pass_tester.cc DEPS shuffle_channel_mkldnn_detect_pass) cc_test(test_cpu_bfloat16_placement_pass SRCS mkldnn/cpu_bfloat16_placement_pass_tester.cc DEPS cpu_bfloat16_placement_pass) cc_test(test_cpu_bfloat16_pass SRCS mkldnn/cpu_bfloat16_pass_tester.cc DEPS cpu_bfloat16_pass) cc_test(test_multi_gru_fuse_pass SRCS mkldnn/multi_gru_fuse_pass_tester.cc DEPS multi_gru_fuse_pass) diff --git a/paddle/fluid/framework/ir/mkldnn/shuffle_channel_mkldnn_detect_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/shuffle_channel_mkldnn_detect_pass_tester.cc new file mode 100644 index 0000000000..fe42e8f96f --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/shuffle_channel_mkldnn_detect_pass_tester.cc @@ -0,0 +1,83 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/fluid/framework/ir/mkldnn/shuffle_channel_mkldnn_detect_pass.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +void AddVarToScope(Scope* param_scope, const std::string& name, + const DDim& dims) { + auto* tensor = param_scope->Var(name)->GetMutable(); + tensor->Resize(dims); + tensor->mutable_data(platform::CPUPlace()); +} + +Scope* CreateParamScope() { + auto param_scope = new Scope(); + AddVarToScope(param_scope, "prog_x", {1, 128, 52, 52}); + return param_scope; +} + +void MainTest() { + Layers layers; + auto prog_x = layers.data("prog_x", {1, 128, 52, 52}); + auto first_reshape2 = layers.reshape2(prog_x, {-1, 2, 64, 52, 52}, true); + first_reshape2->SetShape({-1, 2, 64, 52, 52}); + auto transpose2 = layers.transpose2(first_reshape2, {0, 2, 1, 3, 4}, true); + transpose2->SetShape({-1, 64, 2, 52, 52}); + auto second_reshape2 = layers.reshape2(transpose2, {-1, 128, 52, 52}, true); + second_reshape2->SetShape({-1, 128, 52, 52}); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + graph->Set("__param_scope__", CreateParamScope()); + + int added_nodes = 1; // shuffle_channel + int removed_nodes = 5; // 2 * reshape, reshape_out, transpose, transpose_out + + int original_nodes_num = graph->Nodes().size(); + auto pass = + PassRegistry::Instance().Get("shuffle_channel_mkldnn_detect_pass"); + graph.reset(pass->Apply(graph.release())); + int current_nodes_num = graph->Nodes().size(); + + EXPECT_EQ(current_nodes_num, + original_nodes_num + added_nodes - removed_nodes); + EXPECT_EQ(GetNumOpNodes(graph, "reshape2"), 0); + EXPECT_EQ(GetNumOpNodes(graph, "transpose2"), 0); + EXPECT_EQ(GetNumOpNodes(graph, "shuffle_channel"), 1); + + for (const auto* node : graph->Nodes()) { + if (node->IsOp() && node->Op()->Type() == "shuffle_channel") { + const auto* op = node->Op(); + ASSERT_TRUE(op->HasAttr("use_mkldnn")); + EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("use_mkldnn"))); + } + } +} + +TEST(ShuffleChannelOneDNNDetectPass, ShuffleChannelOneDNNDetectPassTest) { + MainTest(); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(shuffle_channel_mkldnn_detect_pass); diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_shuffle_channel_detect_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_shuffle_channel_detect_pass.py new file mode 100644 index 0000000000..828e92dc03 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_shuffle_channel_detect_pass.py @@ -0,0 +1,138 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import PassAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig +import numpy as np +from functools import partial +import unittest + +from hypothesis import given, settings, seed, example, assume +import hypothesis.strategies as st + + +def product(input): + result = 1 + + for value in input: + result = result * value + + return result + + +class TestShuffleChannelMKLDNNDetectPass(PassAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + input_shape = program_config.inputs['input_data'].shape + first_reshape2_shape = program_config.ops[0].attrs['shape'] + transpose2_axis = program_config.ops[1].attrs['axis'] + second_reshape2_shape = program_config.ops[2].attrs['shape'] + + shape_prod = product(input_shape) + img_h = input_shape[-2] + img_w = input_shape[-1] + + if shape_prod != product(first_reshape2_shape) or shape_prod != product( + second_reshape2_shape): + return False + if len(input_shape) != 4 or len(first_reshape2_shape) != 5 or len( + second_reshape2_shape) != 4: + return False + if transpose2_axis != [0, 2, 1, 3, 4]: + return False + if first_reshape2_shape[-1] != img_w or first_reshape2_shape[ + -2] != img_h: + return False + if second_reshape2_shape[-1] != img_w or second_reshape2_shape[ + -2] != img_h: + return False + + return True + + def sample_program_config(self, draw): + input_shape = draw(st.sampled_from([[128, 32, 32]])) + first_reshape2_shape = draw( + st.sampled_from([[2, 64, 32, 32], [8, 16, 32, 32]])) + transpose2_axis = draw(st.sampled_from([[0, 2, 1, 3, 4], [0, 2, 1, 3]])) + second_reshape2_shape = draw( + st.sampled_from([[128, 32, 32], [128, 31, 32]])) + batch_size = draw(st.integers(min_value=1, max_value=10)) + + input_shape.insert(0, batch_size) + first_reshape2_shape.insert(0, batch_size) + second_reshape2_shape.insert(0, batch_size) + + def generate_input(): + return np.random.random(input_shape).astype(np.float32) + + ops_config = [{ + "op_type": "reshape2", + "op_inputs": { + "X": ["input_data"] + }, + "op_outputs": { + "Out": ["first_reshape2_output"], + "XShape": ["first_reshape2_xshape"] + }, + "op_attrs": { + 'shape': first_reshape2_shape + }, + }, { + "op_type": "transpose2", + "op_inputs": { + "X": ["first_reshape2_output"] + }, + "op_outputs": { + "Out": ["transpose2_output"], + "XShape": ["transpose2_xshape"] + }, + "op_attrs": { + 'axis': transpose2_axis + }, + }, { + "op_type": "reshape2", + "op_inputs": { + "X": ["transpose2_output"], + }, + "op_outputs": { + "Out": ["output_data"], + "XShape": ["second_reshape2_xshape"] + }, + "op_attrs": { + 'shape': second_reshape2_shape + } + }] + + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "input_data": TensorConfig(data_gen=partial(generate_input)) + }, + outputs=["output_data"]) + + return program_config + + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_mkldnn=True) + yield config, ["shuffle_channel"], (1e-5, 1e-5) + + def test(self): + self.run_and_statis( + quant=False, passes=["shuffle_channel_mkldnn_detect_pass"]) + + +if __name__ == "__main__": + unittest.main() -- GitLab