From 804271cff9f43cd06409962b3bef80827374fa25 Mon Sep 17 00:00:00 2001 From: lidanqing Date: Mon, 16 Nov 2020 11:42:01 +0100 Subject: [PATCH] Op version python mkldnn_inplace test (#28354) * add mkldnn inplace op version test * update mkldnn_inplace fuse pass * update the inplace test --- .../ir/mkldnn/mkldnn_inplace_pass.cc | 8 +++ .../test_mkldnn_inplace_fuse_pass.py | 56 +++++++++++++++++++ tools/static_mode_white_list.py | 1 + 3 files changed, 65 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_inplace_fuse_pass.py diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.cc b/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.cc index 7bd94bf55e..d655837f74 100644 --- a/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.cc @@ -17,10 +17,12 @@ #include #include #include +#include #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_info.h" +#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { @@ -215,3 +217,9 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { } // namespace paddle REGISTER_PASS(mkldnn_inplace_pass, paddle::framework::ir::MKLDNNInPlacePass); +REGISTER_PASS_CAPABILITY(mkldnn_inplace_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("softmax", 0) + .EQ("elementwise_add", 0) + .EQ("tanh", 0)); diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_inplace_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_inplace_fuse_pass.py new file mode 100644 index 0000000000..4215e56de2 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_inplace_fuse_pass.py @@ -0,0 +1,56 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from inference_pass_test import InferencePassTest +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.core import AnalysisConfig +from paddle.fluid.core import PassVersionChecker + + +class MkldnnInplacePassTest(InferencePassTest): + def setUp(self): + with fluid.program_guard(self.main_program, self.startup_program): + paddle.enable_static() + data = fluid.data( + name="data", shape=[-1, 3, 100, 100], dtype="float32") + conv_out_1 = fluid.layers.conv2d( + data, num_filters=3, filter_size=3, bias_attr=False) + softmax_out = fluid.layers.softmax(conv_out_1) + relu_out = fluid.layers.relu(conv_out_1) + eltwise_out = fluid.layers.elementwise_add( + softmax_out, relu_out, axis=-1) + + self.pass_name = 'mkldnn_inplace_pass' + self.feeds = { + "data": np.random.random((1, 3, 100, 100)).astype("float32") + } + self.fetch_list = [softmax_out, relu_out, eltwise_out] + self.enable_mkldnn = True + + def test_check_output(self): + use_gpu = False + self.check_output_with_option(use_gpu) + + def test_pass_compatible(self): + self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 6a2a121cd6..1f153442af 100644 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -611,6 +611,7 @@ STATIC_MODE_TESTING_LIST = [ 'test_mkldnn_matmul_op_output_fuse_pass', 'test_mkldnn_matmul_transpose_reshape_fuse_pass', 'test_mkldnn_scale_matmul_fuse_pass', + 'test_mkldnn_inplace_fuse_pass', 'test_batch_fc_op', 'test_c_comm_init_all_op', 'test_conv2d_fusion_op', -- GitLab