From eb22391c7031d8d6bcbc1e79ced4dcdf6ff0eddd Mon Sep 17 00:00:00 2001 From: Hulek Date: Tue, 28 Feb 2023 11:20:53 +0100 Subject: [PATCH] Rewrite mkldnn fc rnn fuse pass tester (#50265) * Added file * Tests separated and rewritten, fixed fc_lstm_fuse_pass * Resolve conflicts --- paddle/fluid/framework/ir/CMakeLists.txt | 6 - .../fluid/framework/ir/fc_lstm_fuse_pass.cc | 2 +- .../mkldnn/mkldnn_fc_rnn_fuse_pass_tester.cc | 94 ------------- .../ir/mkldnn/mkldnn_placement_pass.cc | 6 + .../inference/test_onednn_fc_gru_fuse_pass.py | 126 +++++++++++++++++ .../test_onednn_fc_lstm_fuse_pass.py | 131 ++++++++++++++++++ 6 files changed, 264 insertions(+), 101 deletions(-) delete mode 100644 paddle/fluid/framework/ir/mkldnn/mkldnn_fc_rnn_fuse_pass_tester.cc create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_onednn_fc_gru_fuse_pass.py create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_onednn_fc_lstm_fuse_pass.py diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 536b0b2f21b..d32f13e68e5 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -483,10 +483,4 @@ if(WITH_MKLDNN) test_multi_gru_seq_fuse_pass SRCS mkldnn/multi_gru_seq_fuse_pass_tester.cc DEPS multi_gru_seq_fuse_pass) - set(TEST_FC_RNN_PASS_DEPS fc_gru_fuse_pass fc_lstm_fuse_pass - mkldnn_placement_pass) - cc_test( - test_fc_rnn_mkldnn_fuse_pass - SRCS mkldnn/mkldnn_fc_rnn_fuse_pass_tester.cc - DEPS ${TEST_FC_RNN_PASS_DEPS}) endif() diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc index 5fdfb48cd36..78e6ea14e43 100644 --- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc @@ -227,7 +227,7 @@ int FCLstmFusePass::BuildFusion(Graph* graph, lstm_bias_tensor->mutable_data(platform::CPUPlace()); auto* fc_bias_data = fc_bias_tensor.data(); - for (int i = 0; i < lstm_bias_tensor->numel(); i++) { + for (int i = 0; i < fc_bias_tensor.numel(); i++) { lstm_bias_data[i] += fc_bias_data[i]; } } diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_fc_rnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/mkldnn_fc_rnn_fuse_pass_tester.cc deleted file mode 100644 index 05e46db50af..00000000000 --- a/paddle/fluid/framework/ir/mkldnn/mkldnn_fc_rnn_fuse_pass_tester.cc +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include "paddle/fluid/framework/ir/fc_gru_fuse_pass_tester.h" -#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass_tester.h" -#include "paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h" -#include "paddle/fluid/framework/ir/pass_tester_helper.h" - -namespace paddle { -namespace framework { -namespace ir { - -void TestFcRNNFusePass(const std::string& pass_name, - std::string activation = "tanh", - std::string gate_activation = "sigmoid", - std::string candidate_activation = "tanh") { - std::unique_ptr graph = - (pass_name == "fc_gru_fuse_pass" - ? fc_gru_test::PrepareGraph(activation, gate_activation) - : fc_lstm_test::PrepareGraph( - gate_activation, activation, candidate_activation)); - auto mkldnn_placement_pass_ = - PassRegistry::Instance().Get("mkldnn_placement_pass"); - mkldnn_placement_pass_->Set("mkldnn_enabled_op_types", - new std::unordered_set({})); - graph->Set( - "__param_scope__", - (pass_name == "fc_gru_fuse_pass" ? fc_gru_test::CreateParamScope() - : fc_lstm_test::CreateParamScope())); - RegisterOpKernel({"mul", "elementwise_add"}); - graph.reset(mkldnn_placement_pass_->Apply(graph.release())); - - auto check_num_mkldnn_nodes = [&](const std::unique_ptr& graph) { - int nodes_cout = 0; - for (auto* node : graph->Nodes()) { - if (node->IsOp()) { - auto* op = node->Op(); - if (op->GetAttrIfExists("use_mkldnn")) nodes_cout++; - } - } - return nodes_cout; - }; - int num_mkldnn_nodes_before = check_num_mkldnn_nodes(graph); - int removed_mkldnn_nodes = 2; - - // OneDNN fusion_gru and fusion_lstm supports only sigmoid as a gate - // activation and tanh as an activation and candidate_activation - if (activation != "tanh" || gate_activation != "sigmoid" || - candidate_activation != "tanh") - removed_mkldnn_nodes += 2; - - auto fc_rnn_fuse_pass_ = PassRegistry::Instance().Get(pass_name); - graph.reset(fc_rnn_fuse_pass_->Apply(graph.release())); - int num_mkldnn_nodes_after = check_num_mkldnn_nodes(graph); - - PADDLE_ENFORCE_EQ(num_mkldnn_nodes_before - removed_mkldnn_nodes, - num_mkldnn_nodes_after, - platform::errors::PreconditionNotMet( - "The number of nodes with \"use_mkldnn\" attr after " - "passes is not as expected")); -} - -TEST(FcGruFusePass, use_mkldnn) { TestFcRNNFusePass("fc_gru_fuse_pass"); } - -TEST(FcGruFusePass, gru_unsupported_activations) { - TestFcRNNFusePass("fc_gru_fuse_pass", "relu", "sigmoid"); -} - -TEST(FcLstmFusePass, use_mkldnn) { TestFcRNNFusePass("fc_lstm_fuse_pass"); } - -TEST(FcLstmFusePass, lstm_unsupported_activations) { - TestFcRNNFusePass("fc_lstm_fuse_pass", "tanh", "relu", "tanh"); -} - -} // namespace ir -} // namespace framework -} // namespace paddle - -USE_PASS(mkldnn_placement_pass); -USE_PASS(fc_gru_fuse_pass); -USE_PASS(fc_lstm_fuse_pass); diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.cc b/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.cc index 83b06102d21..8d8d7feacfb 100644 --- a/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h" +#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/operator.h" namespace paddle { @@ -79,3 +80,8 @@ bool MKLDNNPlacementPass::IsSupport(const Node* op) const { REGISTER_PASS(mkldnn_placement_pass, paddle::framework::ir::MKLDNNPlacementPass) .RequirePassAttr("mkldnn_enabled_op_types"); + +REGISTER_PASS_CAPABILITY(mkldnn_placement_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().LE( + "fusion_gru", 1)); diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_onednn_fc_gru_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_onednn_fc_gru_fuse_pass.py new file mode 100644 index 00000000000..12069aac2de --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_onednn_fc_gru_fuse_pass.py @@ -0,0 +1,126 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial + +import hypothesis.strategies as st +import numpy as np +from auto_scan_test import PassAutoScanTest +from program_config import OpConfig, ProgramConfig, TensorConfig + + +class TestOneDNNFCGruFusePass(PassAutoScanTest): + def sample_program_config(self, draw): + def generate_input(shape): + return np.random.random(shape).astype(np.float32) + + batch_size = draw(st.integers(min_value=1, max_value=16)) + fc_input_shape = [batch_size, 64] + fc_weight_shape = [64, 192] + fc_bias_shape = [1, 192] + lod = [[0, batch_size]] + + gru_weight_shape = [64, 192] + gru_bias_shape = [1, 192] + activation = draw(st.sampled_from(['tanh'])) + is_reverse = draw(st.booleans()) + gate_activation = draw(st.sampled_from(['sigmoid'])) + + mul_op = OpConfig( + type='mul', + inputs={'X': ['fc_input'], 'Y': ['fc_weight']}, + outputs={'Out': ['mul_out']}, + attrs={'x_num_col_dims': 1, 'y_num_col_dims': 1}, + ) + + elt_op = OpConfig( + type='elementwise_add', + inputs={'X': ['mul_out'], 'Y': ['fc_bias']}, + outputs={'Out': ['fc_output']}, + attrs={'axis': -1}, + ) + + gru_op = OpConfig( + type='gru', + inputs={ + 'Input': ['fc_output'], + 'Weight': ['gru_weight'], + 'Bias': ['gru_bias'], + }, + outputs={ + 'BatchGate': ['batch_gate'], + 'BatchHidden': ['batch_hidden'], + 'BatchResetHiddenPrev': ['batch_reset'], + 'Hidden': ['gru_hidden'], + }, + attrs={ + 'activation': activation, + 'is_reverse': is_reverse, + 'gate_activation': gate_activation, + 'is_test': True, + }, + ) + + model_net = [mul_op, elt_op, gru_op] + + program_config = ProgramConfig( + ops=model_net, + inputs={ + 'fc_input': TensorConfig( + lod=lod, data_gen=partial(generate_input, fc_input_shape) + ) + }, + weights={ + 'fc_weight': TensorConfig( + data_gen=partial(generate_input, fc_weight_shape) + ), + 'fc_bias': TensorConfig( + data_gen=partial(generate_input, fc_bias_shape) + ), + 'gru_weight': TensorConfig( + data_gen=partial(generate_input, gru_weight_shape) + ), + 'gru_bias': TensorConfig( + data_gen=partial(generate_input, gru_bias_shape) + ), + }, + outputs=['gru_hidden'], + ) + + return program_config + + def sample_predictor_configs(self, program_config): + config = self.create_inference_config( + use_mkldnn=True, + passes=[ + 'mkldnn_placement_pass', + 'fc_gru_fuse_pass', + ], + ) + yield config, ['fusion_gru'], (1e-5, 1e-5) + + def test(self): + self.run_and_statis( + quant=False, + passes=[ + 'mkldnn_placement_pass', + 'fc_gru_fuse_pass', + ], + max_examples=100, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_onednn_fc_lstm_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_onednn_fc_lstm_fuse_pass.py new file mode 100644 index 00000000000..c919fa118a9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_onednn_fc_lstm_fuse_pass.py @@ -0,0 +1,131 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial + +import hypothesis.strategies as st +import numpy as np +from auto_scan_test import PassAutoScanTest +from program_config import OpConfig, ProgramConfig, TensorConfig + + +class TestOneDNNFCLstmFusePass(PassAutoScanTest): + def sample_program_config(self, draw): + + batch_size = draw(st.integers(min_value=1, max_value=16)) + fc_input_shape = [batch_size, 64] + fc_weight_shape = [64, 256] + fc_bias_shape = [1, 256] + lod = [[0, batch_size]] + + use_peepholes = draw(st.booleans()) + is_reverse = draw(st.booleans()) + gate_activation = draw(st.sampled_from(['sigmoid'])) + cell_activation = draw(st.sampled_from(['tanh'])) + candidate_activation = draw(st.sampled_from(['tanh'])) + lstm_weight_shape = [64, 256] + lstm_bias_shape = [1, 448] if use_peepholes else [1, 256] + + mul_op = OpConfig( + type='mul', + inputs={'X': ['fc_input'], 'Y': ['fc_weight']}, + outputs={'Out': ['mul_out']}, + attrs={'x_num_col_dims': 1, 'y_num_col_dims': 1}, + ) + + elt_op = OpConfig( + type='elementwise_add', + inputs={'X': ['mul_out'], 'Y': ['fc_bias']}, + outputs={'Out': ['fc_output']}, + attrs={'axis': -1}, + ) + + lstm_op = OpConfig( + type='lstm', + inputs={ + 'Input': ['fc_output'], + 'Weight': ['lstm_weight'], + 'Bias': ['lstm_bias'], + }, + outputs={ + 'Hidden': ['lstm_hidden'], + 'Cell': ['lstm_cell'], + 'BatchGate': ['lstm_gate'], + 'BatchCellPreAct': ['lstm_batch_cell'], + }, + attrs={ + 'use_peepholes': use_peepholes, + 'is_reverse': is_reverse, + 'gate_activation': gate_activation, + 'cell_activation': cell_activation, + 'candidate_activation': candidate_activation, + 'is_test': True, + }, + ) + + model_net = [mul_op, elt_op, lstm_op] + + def generate_data(shape): + return np.random.random(shape).astype(np.float32) + + program_config = ProgramConfig( + ops=model_net, + inputs={ + 'fc_input': TensorConfig( + lod=lod, data_gen=partial(generate_data, fc_input_shape) + ) + }, + weights={ + 'fc_weight': TensorConfig( + data_gen=partial(generate_data, fc_weight_shape) + ), + 'fc_bias': TensorConfig( + data_gen=partial(generate_data, fc_bias_shape) + ), + 'lstm_weight': TensorConfig( + data_gen=partial(generate_data, lstm_weight_shape) + ), + 'lstm_bias': TensorConfig( + data_gen=partial(generate_data, lstm_bias_shape) + ), + }, + outputs=['lstm_hidden'], + ) + + return program_config + + def sample_predictor_configs(self, program_config): + config = self.create_inference_config( + use_mkldnn=True, + passes=[ + 'mkldnn_placement_pass', + 'fc_lstm_fuse_pass', + ], + ) + yield config, ['fusion_lstm'], (1e-5, 1e-5) + + def test(self): + self.run_and_statis( + quant=False, + passes=[ + 'mkldnn_placement_pass', + 'fc_lstm_fuse_pass', + ], + max_examples=50, + ) + + +if __name__ == '__main__': + unittest.main() -- GitLab