From 39546aa2f32788e1b55394739d46e47cd37fc232 Mon Sep 17 00:00:00 2001 From: Wilber Date: Mon, 21 Sep 2020 13:39:17 +0800 Subject: [PATCH] Add pass compatible and unit test. (#27377) --- .../ir/embedding_fc_lstm_fuse_pass.cc | 12 ++- paddle/fluid/framework/ir/fc_fuse_pass.cc | 8 ++ paddle/fluid/framework/ir/fc_gru_fuse_pass.cc | 22 ++++- .../fluid/framework/ir/fc_lstm_fuse_pass.cc | 15 ++++ .../framework/ir/squared_mat_sub_fuse_pass.cc | 30 +++++-- .../framework/ir/squared_mat_sub_fuse_pass.h | 2 +- .../inference/api/paddle_pass_builder.cc | 3 +- python/paddle/fluid/layers/tensor.py | 2 + .../ir/inference/test_fc_fuse_pass.py | 54 ++++++++++++ .../ir/inference/test_fc_gru_fuse_pass.py | 86 +++++++++++++++++++ .../ir/inference/test_fc_lstm_fuse_pass.py | 52 +++++++++++ .../test_squared_mat_sub_fuse_pass.py | 63 ++++++++++++++ ...test_transpose_flatten_concat_fuse_pass.py | 4 +- 13 files changed, 342 insertions(+), 11 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_fc_fuse_pass.py create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_fc_gru_fuse_pass.py create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_fc_lstm_fuse_pass.py create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_squared_mat_sub_fuse_pass.py diff --git a/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc index c50b7476c6..02e3e2542f 100644 --- a/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc @@ -23,6 +23,8 @@ #include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/platform/cpu_info.h" +#include "paddle/fluid/framework/op_version_registry.h" + namespace paddle { namespace framework { namespace ir { @@ -34,7 +36,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, // Build pattern PDNode* x = pattern->NewNode(patterns::PDNodeName(name_scope, "x")) - ->assert_is_op_input("lookup_table") + ->assert_is_op_input("lookup_table_v2") ->assert_var_not_persistable(); patterns::Embedding embedding_pattern(pattern, name_scope); // TODO(jczaja): Intermediate can only be for val that are not used anywhere @@ -256,3 +258,11 @@ void EmbeddingFCLSTMFusePass::ApplyImpl(ir::Graph* graph) const { REGISTER_PASS(embedding_fc_lstm_fuse_pass, paddle::framework::ir::EmbeddingFCLSTMFusePass); +REGISTER_PASS_CAPABILITY(embedding_fc_lstm_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("lookup_table_v2", 0) + .EQ("mul", 0) + .EQ("elementwise_add", 0) + .EQ("lstm", 0) + .EQ("fused_embedding_fc_lstm", 0)); diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.cc b/paddle/fluid/framework/ir/fc_fuse_pass.cc index 066a8fb975..d60510a407 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass.cc @@ -18,6 +18,7 @@ #include #include #include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { @@ -182,3 +183,10 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const { REGISTER_PASS(fc_fuse_pass, paddle::framework::ir::FCFusePass) .RequirePassAttr("use_gpu"); +REGISTER_PASS_CAPABILITY(fc_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("mul", 0) + .EQ("elementwise_add", 0) + .EQ("relu", 0) + .EQ("fc", 0)); diff --git a/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc b/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc index a2185cdc55..f5fea90ac2 100644 --- a/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc @@ -16,6 +16,7 @@ #include #include #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_version_registry.h" namespace paddle { namespace framework { @@ -125,7 +126,6 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, auto* x_n = subgraph.at(x); GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern); - GET_IR_NODE_FROM_SUBGRAPH(fc_out, elementwise_add_out, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, gru_pattern); GET_IR_NODE_FROM_SUBGRAPH(gru, gru, gru_pattern); GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, gru_pattern); @@ -136,10 +136,17 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, gru_pattern); GET_IR_NODE_FROM_SUBGRAPH(BatchHidden, BatchHidden, gru_pattern); + // TODO(wilber): Support origin_mode=True. + if (gru->Op()->GetAttrIfExists("origin_mode") == true) { + LOG(INFO) << "fc_gru_fuse_pass not supported when origin_mode=True."; + return; + } + if (with_fc_bias) { GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fc_out, elementwise_add_out, fc_pattern); gru_creater(gru, x_n, w, Weight, Bias, Hidden, fc_bias); // Remove unneeded nodes. @@ -188,3 +195,16 @@ void FCGRUFusePass::ApplyImpl(ir::Graph* graph) const { REGISTER_PASS(mul_gru_fuse_pass, paddle::framework::ir::MulGRUFusePass); REGISTER_PASS(fc_gru_fuse_pass, paddle::framework::ir::FCGRUFusePass); +REGISTER_PASS_CAPABILITY(mul_gru_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("mul", 0) + .EQ("gru", 0) + .EQ("fusion_gru", 0)); +REGISTER_PASS_CAPABILITY(fc_gru_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("mul", 0) + .EQ("elementwise_add", 0) + .EQ("gru", 0) + .EQ("fusion_gru", 0)); diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc index 12c7fc051e..a3c57e14e1 100644 --- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc @@ -16,6 +16,7 @@ #include #include #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_version_registry.h" namespace paddle { namespace framework { @@ -196,3 +197,17 @@ void FCLstmFusePass::ApplyImpl(ir::Graph* graph) const { REGISTER_PASS(mul_lstm_fuse_pass, paddle::framework::ir::MulLstmFusePass); REGISTER_PASS(fc_lstm_fuse_pass, paddle::framework::ir::FCLstmFusePass); + +REGISTER_PASS_CAPABILITY(fc_lstm_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("mul", 0) + .EQ("elementwise_add", 0) + .EQ("lstm", 0) + .EQ("fusion_lstm", 0)); +REGISTER_PASS_CAPABILITY(mul_lstm_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("mul", 0) + .EQ("lstm", 0) + .EQ("fusion_lstm", 0)); diff --git a/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc b/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc index 035b198bdc..d74843611c 100644 --- a/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc +++ b/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc @@ -17,6 +17,7 @@ #include #include #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_version_registry.h" namespace paddle { namespace framework { @@ -77,7 +78,8 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern, }; auto is_fusion_input_var = [=](Node* x, const std::string& arg_name) { - bool basic = var_is_op_input(x, "matmul", arg_name) && + bool basic = (var_is_op_input(x, "matmul_v2", arg_name) || + var_is_op_input(x, "matmul", arg_name)) && var_is_op_input(x, "square", "X"); if (!basic) { return false; @@ -88,7 +90,8 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern, } auto* squared_x = squared_x_op->outputs[0]; bool next_is_matmul_from_arg = - var_is_op_input(squared_x, "matmul", arg_name) && + (var_is_op_input(squared_x, "matmul_v2", arg_name) || + var_is_op_input(squared_x, "matmul", arg_name)) && squared_x->outputs.size() == 1 && squared_x->outputs[0]->outputs.size() == 1; if (!next_is_matmul_from_arg) { @@ -103,7 +106,8 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern, auto is_fusion_first_mul_out = [=](Node* x) -> bool { bool input_is_matmul_op = x && x->inputs.size() == 1 && x->inputs[0]->IsOp() && - x->inputs[0]->Op()->Type() == "matmul"; + (x->inputs[0]->Op()->Type() == "matmul_v2" || + x->inputs[0]->Op()->Type() == "matmul"); if (!input_is_matmul_op) { return false; } @@ -167,7 +171,8 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern, auto* matmul_xy_op = pattern->NewNode( [=](Node* x) { - return x && x->IsOp() && x->Op()->Type() == "matmul" && + return x && x->IsOp() && (x->Op()->Type() == "matmul_v2" || + x->Op()->Type() == "matmul") && is_fusion_first_mul_out(x->outputs[0]); }, name_scope + "/matmul_xy_op"); @@ -189,7 +194,9 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern, auto is_fusion_mat_squared_x_y_op_out = [=](Node* x) -> bool { bool basic = x && x->IsVar() && x->inputs.size() == 1 && - x->inputs[0]->IsOp() && x->inputs[0]->Op()->Type() == "matmul"; + x->inputs[0]->IsOp() && + (x->inputs[0]->Op()->Type() == "matmul_v2" || + x->inputs[0]->Op()->Type() == "matmul"); if (!basic) { return false; } @@ -206,7 +213,8 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern, auto* matmul_squared_x_y_op = pattern->NewNode( [=](Node* x) { - return x && x->IsOp() && x->Op()->Type() == "matmul" && + return x && x->IsOp() && (x->Op()->Type() == "matmul_v2" || + x->Op()->Type() == "matmul") && is_fusion_mat_squared_x_y_op_out(x->outputs[0]); }, name_scope + "/matmul_squared_x_y_op"); @@ -378,3 +386,13 @@ void SquaredMatSubFusePass::ApplyImpl(ir::Graph* graph) const { REGISTER_PASS(squared_mat_sub_fuse_pass, paddle::framework::ir::SquaredMatSubFusePass); +REGISTER_PASS_CAPABILITY(squared_mat_sub_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("matmul", 0) + .EQ("matmul_v2", 0) + .EQ("square", 0) + .EQ("elementwise_mul", 0) + .EQ("elementwise_sub", 0) + .EQ("fill_constant", 0) + .EQ("fusion_squared_mat_sub", 0)); diff --git a/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.h b/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.h index b6165a512a..56b7ec9b84 100644 --- a/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.h +++ b/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.h @@ -24,7 +24,7 @@ namespace framework { namespace ir { /** - * Fuse ( (A.^2 * B.^2) - (A * B).^2 ) .* scalar + * Fuse ( (A * B).^2 - (A.^2 * B.^2) ) .* scalar */ class SquaredMatSubFusePass : public FusePassBase { public: diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index c19e77d271..19f52422b4 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -156,7 +156,8 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { // "seqpool_concat_fuse_pass", // "seqpool_cvm_concat_fuse_pass", // // "embedding_fc_lstm_fuse_pass", // - "fc_lstm_fuse_pass", // + // TODO(wilber): fix correctness problem. + // "fc_lstm_fuse_pass", // "mul_lstm_fuse_pass", // "fc_gru_fuse_pass", // "mul_gru_fuse_pass", // diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 89acfc6075..0ce7c098e2 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -680,8 +680,10 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): if not isinstance(value, Variable): if dtype in ['int64', 'int32']: attrs['str_value'] = str(int(value)) + attrs['value'] = int(value) else: attrs['str_value'] = str(float(value)) + attrs['value'] = float(value) if in_dygraph_mode(): shape = utils.convert_shape_to_list(shape) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_fc_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_fc_fuse_pass.py new file mode 100644 index 0000000000..a62adcea3f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_fc_fuse_pass.py @@ -0,0 +1,54 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from inference_pass_test import InferencePassTest +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.core import AnalysisConfig +from paddle.fluid.core import PassVersionChecker + + +class FcFusePassTest(InferencePassTest): + def setUp(self): + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=[-1, 128, 768], dtype="float32") + data_y = fluid.data(name="y", shape=[-1, 128, 768], dtype="float32") + fc_out1 = fluid.layers.fc(input=data, + size=3072, + num_flatten_dims=2, + act="relu") + fc_out2 = fluid.layers.fc(input=fc_out1, + size=768, + num_flatten_dims=2) + + self.feeds = {"data": np.random.random((4, 128, 768)).astype("float32")} + self.fetch_list = [fc_out2] + + def test_check_output(self): + use_gpu = [False] + if core.is_compiled_with_cuda(): + use_gpu.append(True) + for i in range(len(use_gpu)): + self.check_output_with_option(use_gpu[i]) + + self.assertTrue(PassVersionChecker.IsCompatible('fc_fuse_pass')) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_fc_gru_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_fc_gru_fuse_pass.py new file mode 100644 index 0000000000..f7b43470d4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_fc_gru_fuse_pass.py @@ -0,0 +1,86 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from inference_pass_test import InferencePassTest +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.core import PassVersionChecker + + +class FcGruFusePassTest(InferencePassTest): + def setUp(self): + with fluid.program_guard(self.main_program, self.startup_program): + dict_dim, emb_dim = 128, 64 + data = fluid.data( + name='step_data', shape=[None], dtype='int64', lod_level=1) + emb = fluid.embedding(input=data, size=[dict_dim, emb_dim]) + hidden_dim = 512 + x = fluid.layers.fc(input=emb, size=hidden_dim * 3) + hidden = fluid.layers.dynamic_gru( + input=x, + size=hidden_dim, + bias_attr=True, + origin_mode=False, + is_reverse=True) + + batch = 16 + lod_tensor = fluid.LoDTensor() + lod_tensor.set(np.random.randint( + 0, dict_dim, size=[batch]).astype("int64"), + fluid.CPUPlace()) + lod_tensor.set_lod([[0, batch]]) + self.feeds = {"step_data": lod_tensor} + self.fetch_list = [hidden] + + def test_check_output(self): + use_gpu = False + self.check_output_with_option(use_gpu) + self.assertTrue(PassVersionChecker.IsCompatible('fc_gru_fuse_pass')) + + +class MulGruFusePassTest(InferencePassTest): + def setUp(self): + with fluid.program_guard(self.main_program, self.startup_program): + dict_dim, emb_dim = 128, 64 + data = fluid.data( + name='step_data', shape=[None], dtype='int64', lod_level=1) + emb = fluid.embedding(input=data, size=[dict_dim, emb_dim]) + hidden_dim = 512 + x = fluid.layers.fc(input=emb, size=hidden_dim * 3, bias_attr=False) + hidden = fluid.layers.dynamic_gru( + input=x, + size=hidden_dim, + bias_attr=True, + origin_mode=False, + is_reverse=True) + + batch = 16 + lod_tensor = fluid.LoDTensor() + lod_tensor.set(np.random.randint( + 0, dict_dim, size=[batch]).astype("int64"), + fluid.CPUPlace()) + lod_tensor.set_lod([[0, batch]]) + self.feeds = {"step_data": lod_tensor} + self.fetch_list = [hidden] + + def test_check_output(self): + use_gpu = False + self.check_output_with_option(use_gpu) + self.assertTrue(PassVersionChecker.IsCompatible('mul_gru_fuse_pass')) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_fc_lstm_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_fc_lstm_fuse_pass.py new file mode 100644 index 0000000000..fbb4373dae --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_fc_lstm_fuse_pass.py @@ -0,0 +1,52 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from inference_pass_test import InferencePassTest +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.core import PassVersionChecker + + +class MulLstmFusePassTest(InferencePassTest): + def setUp(self): + with fluid.program_guard(self.main_program, self.startup_program): + dict_dim, emb_dim = 128, 64 + hidden_dim = 512 + + data = fluid.data( + name='data', shape=[1], dtype='int64', lod_level=1) + emb = fluid.embedding(input=data, size=[dict_dim, emb_dim]) + x = fluid.layers.fc(input=emb, size=hidden_dim * 4, bias_attr=False) + forward, cell = fluid.layers.dynamic_lstm( + input=x, size=hidden_dim * 4) + + batch = 16 + lod_tensor = fluid.LoDTensor() + lod_tensor.set(np.random.randint( + 0, dict_dim, size=[batch]).astype("int64"), + fluid.CPUPlace()) + lod_tensor.set_lod([[0, batch]]) + self.feeds = {"data": lod_tensor} + self.fetch_list = [forward, cell] + + def test_check_output(self): + use_gpu = False + self.check_output_with_option(use_gpu) + self.assertTrue(PassVersionChecker.IsCompatible('mul_lstm_fuse_pass')) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_squared_mat_sub_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_squared_mat_sub_fuse_pass.py new file mode 100644 index 0000000000..5fa242df4e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_squared_mat_sub_fuse_pass.py @@ -0,0 +1,63 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from inference_pass_test import InferencePassTest +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.core import AnalysisConfig +from paddle.fluid.core import PassVersionChecker + + +class SquaredMatSubFusePassTest(InferencePassTest): + def setUp(self): + with fluid.program_guard(self.main_program, self.startup_program): + data_a = fluid.data(name="data_a", shape=[128, 1], dtype="float32") + data_b = fluid.data(name="data_b", shape=[256, 1], dtype="float32") + + fc_a = fluid.layers.fc(data_a, size=256) + fc_b = fluid.layers.fc(data_b, size=64) + + data_a_square = paddle.square(fc_a) + data_b_square = paddle.square(fc_b) + + matmul_ab = paddle.matmul(fc_a, fc_b) + matmul_ab_square = paddle.square(matmul_ab) + matmul_square_ab = paddle.matmul(data_a_square, data_b_square) + + scale = paddle.fill_constant(shape=[1], value=0.5, dtype='float32') + + sub_val = paddle.elementwise_sub(matmul_ab_square, matmul_square_ab) + squared_mat_sub_out = fluid.layers.elementwise_mul(sub_val, scale) + + self.feeds = { + "data_a": np.random.random((128, 1)).astype("float32"), + "data_b": np.random.random((256, 1)).astype("float32") + } + self.fetch_list = [squared_mat_sub_out] + + def test_check_output(self): + use_gpu = False + self.check_output_with_option(use_gpu) + + self.assertTrue( + PassVersionChecker.IsCompatible('squared_mat_sub_fuse_pass')) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_transpose_flatten_concat_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_transpose_flatten_concat_fuse_pass.py index 34a52e7aed..83d4b7091c 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_transpose_flatten_concat_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_transpose_flatten_concat_fuse_pass.py @@ -75,7 +75,9 @@ class TransposeFlattenConcatFusePassWithAxisTest(InferencePassTest): use_gpu = True self.check_output_with_option(use_gpu) - PassVersionChecker.IsCompatible('transpose_flatten_concat_fuse_pass') + self.assertTrue( + PassVersionChecker.IsCompatible( + 'transpose_flatten_concat_fuse_pass')) if __name__ == "__main__": -- GitLab