未验证 提交 fd7ab4e6 编写于 作者: P Pei Yang 提交者: GitHub

register pass compatibility (#27357)

* pass compatibility

* add compatibility registry

* add unittests for different padding

* add assert

* drop errmsg
上级 7e6dfcf9
......@@ -18,6 +18,7 @@
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -225,3 +226,14 @@ REGISTER_PASS(conv_affine_channel_fuse_pass,
paddle::framework::ir::ConvAffineChannelFusePass);
REGISTER_PASS(conv_eltwiseadd_affine_channel_fuse_pass,
paddle::framework::ir::ConvEltwiseAddAffineChannelFusePass);
REGISTER_PASS_CAPABILITY(conv_affine_channel_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.EQ("affine_channel", 0));
REGISTER_PASS_CAPABILITY(conv_eltwiseadd_affine_channel_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.EQ("elementwise_add", 0)
.EQ("affine_channel", 0));
......@@ -18,6 +18,7 @@
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -372,3 +373,14 @@ REGISTER_PASS(depthwise_conv_bn_fuse_pass,
paddle::framework::ir::DepthwiseConvBNFusePass);
REGISTER_PASS(depthwise_conv_eltwiseadd_bn_fuse_pass,
paddle::framework::ir::DepthwiseConvEltwiseAddBNFusePass);
REGISTER_PASS_CAPABILITY(conv_bn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.EQ("batch_norm", 0));
REGISTER_PASS_CAPABILITY(conv_eltwiseadd_bn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.EQ("elementwise_add", 0)
.EQ("batch_norm", 0));
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
#define MAX_NUM_FC 10
......@@ -174,6 +175,10 @@ void BuildRepeatedFCReluPattern(PDPattern* pattern,
if (x->outputs.size() <= 0 || x->inputs.size() <= 0U) {
return false;
}
if (x->IsVar() && x->Var() && x->Var()->GetShape().size() > 2) {
LOG(WARNING) << "repeated fc relu only supports input dims = 2";
return false;
}
int fc_idx = FindFCIdx(x);
if (fc_idx < 0) {
return false;
......@@ -384,3 +389,8 @@ void RepeatedFCReluFusePass::ApplyImpl(ir::Graph* graph) const {
REGISTER_PASS(repeated_fc_relu_fuse_pass,
paddle::framework::ir::RepeatedFCReluFusePass);
REGISTER_PASS_CAPABILITY(repeated_fc_relu_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("fc", 0)
.EQ("relu", 0));
......@@ -16,6 +16,7 @@
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/shuffle_channel_detect_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
......@@ -34,6 +35,8 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "shufflechannel_pattern";
FusePassBase::Init(pattern_name, graph);
LOG(WARNING) << "There is fluid.layers.shuffle_channel API already, you can "
"use it instead of (reshape + transpose +reshape)";
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
->NewNode("x")
......@@ -93,3 +96,8 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const {
REGISTER_PASS(shuffle_channel_detect_pass,
paddle::framework::ir::ShuffleChannelDetectPass);
REGISTER_PASS_CAPABILITY(shuffle_channel_detect_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("reshape2", 0)
.EQ("transpose2", 0));
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from inference_pass_test import InferencePassTest
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.core import PassVersionChecker
class ConvAffineChannelFusePassExplicitPaddingTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 3, 64, 64], dtype="float32")
conv_out = fluid.layers.conv2d(
input=data,
num_filters=3,
filter_size=3,
groups=3,
padding=[1, 1, 1, 1],
bias_attr=False,
act=None)
input_scale = fluid.layers.create_parameter(
shape=[3], dtype="float32")
input_bias = fluid.layers.create_parameter(
shape=[3], dtype="float32")
ac_out = fluid.layers.affine_channel(
x=conv_out, scale=input_scale, bias=input_bias)
self.feeds = {
"data": np.random.random([1, 3, 64, 64]).astype("float32"),
}
self.fetch_list = [ac_out]
def test_check_output(self):
self.check_output()
self.assertTrue(
PassVersionChecker.IsCompatible('conv_affine_channel_fuse_pass'))
class ConvAffineChannelFusePassValidPaddingTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 3, 64, 64], dtype="float32")
conv_out = fluid.layers.conv2d(
input=data,
num_filters=3,
filter_size=3,
groups=3,
padding='VALID',
bias_attr=False,
act=None)
input_scale = fluid.layers.create_parameter(
shape=[3], dtype="float32")
input_bias = fluid.layers.create_parameter(
shape=[3], dtype="float32")
ac_out = fluid.layers.affine_channel(
x=conv_out, scale=input_scale, bias=input_bias)
self.feeds = {
"data": np.random.random([1, 3, 64, 64]).astype("float32"),
}
self.fetch_list = [ac_out]
def test_check_output(self):
self.check_output()
self.assertTrue(
PassVersionChecker.IsCompatible('conv_affine_channel_fuse_pass'))
class ConvAffineChannelFusePassSamePaddingTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 3, 64, 64], dtype="float32")
conv_out = fluid.layers.conv2d(
input=data,
num_filters=3,
filter_size=3,
groups=3,
padding='SAME',
bias_attr=False,
act=None)
input_scale = fluid.layers.create_parameter(
shape=[3], dtype="float32")
input_bias = fluid.layers.create_parameter(
shape=[3], dtype="float32")
ac_out = fluid.layers.affine_channel(
x=conv_out, scale=input_scale, bias=input_bias)
self.feeds = {
"data": np.random.random([1, 3, 64, 64]).astype("float32"),
}
self.fetch_list = [ac_out]
def test_check_output(self):
self.check_output()
self.assertTrue(
PassVersionChecker.IsCompatible('conv_affine_channel_fuse_pass'))
class ConvEltwiseAddAffineChannelFusePassExplicitPaddingTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 3, 64, 64], dtype="float32")
param_attr = fluid.ParamAttr(
initializer=fluid.initializer.Xavier(uniform=False),
learning_rate=0.001)
conv_out = fluid.layers.conv2d(
input=data,
num_filters=3,
filter_size=3,
groups=3,
padding=[1, 1, 1, 1],
bias_attr=param_attr,
act=None)
input_scale = fluid.layers.create_parameter(
shape=[3], dtype="float32")
input_bias = fluid.layers.create_parameter(
shape=[3], dtype="float32")
ac_out = fluid.layers.affine_channel(
x=conv_out, scale=input_scale, bias=input_bias)
self.feeds = {
"data": np.random.random([1, 3, 64, 64]).astype("float32"),
}
self.fetch_list = [ac_out]
def test_check_output(self):
self.check_output()
self.assertTrue(
PassVersionChecker.IsCompatible(
'conv_eltwiseadd_affine_channel_fuse_pass'))
class ConvEltwiseAddAffineChannelFusePassValidPaddingTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 3, 64, 64], dtype="float32")
param_attr = fluid.ParamAttr(
initializer=fluid.initializer.Xavier(uniform=False),
learning_rate=0.001)
conv_out = fluid.layers.conv2d(
input=data,
num_filters=3,
filter_size=3,
groups=3,
padding='VALID',
bias_attr=param_attr,
act=None)
input_scale = fluid.layers.create_parameter(
shape=[3], dtype="float32")
input_bias = fluid.layers.create_parameter(
shape=[3], dtype="float32")
ac_out = fluid.layers.affine_channel(
x=conv_out, scale=input_scale, bias=input_bias)
self.feeds = {
"data": np.random.random([1, 3, 64, 64]).astype("float32"),
}
self.fetch_list = [ac_out]
def test_check_output(self):
self.check_output()
self.assertTrue(
PassVersionChecker.IsCompatible(
'conv_eltwiseadd_affine_channel_fuse_pass'))
class ConvEltwiseAddAffineChannelFusePassSamePaddingTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 3, 64, 64], dtype="float32")
param_attr = fluid.ParamAttr(
initializer=fluid.initializer.Xavier(uniform=False),
learning_rate=0.001)
conv_out = fluid.layers.conv2d(
input=data,
num_filters=3,
filter_size=3,
groups=3,
padding='Same',
bias_attr=param_attr,
act=None)
input_scale = fluid.layers.create_parameter(
shape=[3], dtype="float32")
input_bias = fluid.layers.create_parameter(
shape=[3], dtype="float32")
ac_out = fluid.layers.affine_channel(
x=conv_out, scale=input_scale, bias=input_bias)
self.feeds = {
"data": np.random.random([1, 3, 64, 64]).astype("float32"),
}
self.fetch_list = [ac_out]
def test_check_output(self):
self.check_output()
self.assertTrue(
PassVersionChecker.IsCompatible(
'conv_eltwiseadd_affine_channel_fuse_pass'))
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from inference_pass_test import InferencePassTest
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.core import PassVersionChecker
class ConvBnFusePassExplicitPaddingTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 3, 64, 64], dtype="float32")
conv_out = fluid.layers.conv2d(
input=data,
num_filters=6,
filter_size=6,
groups=3,
padding=[1, 1, 1, 1],
bias_attr=False,
act=None)
bn_out = fluid.layers.batch_norm(conv_out, is_test=True)
self.feeds = {
"data": np.random.random([1, 3, 64, 64]).astype("float32"),
}
self.fetch_list = [bn_out]
def test_check_output(self):
self.check_output()
self.assertTrue(PassVersionChecker.IsCompatible('conv_bn_fuse_pass'))
class ConvBnFusePassValidPaddingTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 3, 64, 64], dtype="float32")
conv_out = fluid.layers.conv2d(
input=data,
num_filters=6,
filter_size=6,
groups=3,
padding='VALID',
bias_attr=False,
act=None)
bn_out = fluid.layers.batch_norm(conv_out, is_test=True)
self.feeds = {
"data": np.random.random([1, 3, 64, 64]).astype("float32"),
}
self.fetch_list = [bn_out]
def test_check_output(self):
self.check_output()
self.assertTrue(PassVersionChecker.IsCompatible('conv_bn_fuse_pass'))
class ConvBnFusePassSamePaddingTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 3, 64, 64], dtype="float32")
conv_out = fluid.layers.conv2d(
input=data,
num_filters=6,
filter_size=6,
groups=3,
padding='SAME',
bias_attr=False,
act=None)
bn_out = fluid.layers.batch_norm(conv_out, is_test=True)
self.feeds = {
"data": np.random.random([1, 3, 64, 64]).astype("float32"),
}
self.fetch_list = [bn_out]
def test_check_output(self):
self.check_output()
self.assertTrue(PassVersionChecker.IsCompatible('conv_bn_fuse_pass'))
class ConvEltwiseAddBnFuseExplicitPaddingPass(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 3, 64, 64], dtype="float32")
conv_out = fluid.layers.conv2d(
input=data,
num_filters=6,
filter_size=6,
groups=3,
padding=[1, 1, 1, 1],
bias_attr=None,
act=None)
bn_out = fluid.layers.batch_norm(conv_out, is_test=True)
self.feeds = {
"data": np.random.random([1, 3, 64, 64]).astype("float32"),
}
self.fetch_list = [bn_out]
def test_check_output(self):
self.check_output()
self.assertTrue(
PassVersionChecker.IsCompatible('conv_eltwiseadd_bn_fuse_pass'))
class ConvEltwiseAddBnFuseValidPaddingPass(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 3, 64, 64], dtype="float32")
conv_out = fluid.layers.conv2d(
input=data,
num_filters=6,
filter_size=6,
groups=3,
padding='VALID',
bias_attr=None,
act=None)
bn_out = fluid.layers.batch_norm(conv_out, is_test=True)
self.feeds = {
"data": np.random.random([1, 3, 64, 64]).astype("float32"),
}
self.fetch_list = [bn_out]
def test_check_output(self):
self.check_output()
self.assertTrue(
PassVersionChecker.IsCompatible('conv_eltwiseadd_bn_fuse_pass'))
class ConvEltwiseAddBnFuseSamePaddingPass(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 3, 64, 64], dtype="float32")
conv_out = fluid.layers.conv2d(
input=data,
num_filters=6,
filter_size=6,
groups=3,
padding='SAME',
bias_attr=None,
act=None)
bn_out = fluid.layers.batch_norm(conv_out, is_test=True)
self.feeds = {
"data": np.random.random([1, 3, 64, 64]).astype("float32"),
}
self.fetch_list = [bn_out]
def test_check_output(self):
self.check_output()
self.assertTrue(
PassVersionChecker.IsCompatible('conv_eltwiseadd_bn_fuse_pass'))
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from inference_pass_test import InferencePassTest
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.core import PassVersionChecker
class RepeatedFcReluFusePass3Test(InferencePassTest):
def setUp(self):
fc_num = 3
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 3, 64, 64], dtype="float32")
param_attr = fluid.ParamAttr(
initializer=fluid.initializer.Xavier(uniform=False),
learning_rate=0.001)
conv_out = fluid.layers.conv2d(
input=data,
num_filters=3,
filter_size=3,
bias_attr=param_attr,
act=None)
fc_outs = []
fc_outs.append(
fluid.layers.fc(input=[conv_out], act="relu", size=1000))
for i in range(1, fc_num):
fc_outs.append(
fluid.layers.fc(
input=[fc_outs[i - 1]], act="relu", size=1000))
self.feeds = {
"data": np.random.random([1, 3, 64, 64]).astype("float32"),
}
self.fetch_list = [fc_outs[fc_num - 1]]
def test_check_output(self):
use_gpu = False
self.check_output_with_option(use_gpu)
self.assertTrue(
PassVersionChecker.IsCompatible('repeated_fc_relu_fuse_pass'))
class RepeatedFcReluFusePass9Test(InferencePassTest):
def setUp(self):
fc_num = 9
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 3, 64, 64], dtype="float32")
param_attr = fluid.ParamAttr(
initializer=fluid.initializer.Xavier(uniform=False),
learning_rate=0.001)
conv_out = fluid.layers.conv2d(
input=data,
num_filters=3,
filter_size=3,
bias_attr=param_attr,
act=None)
fc_outs = []
fc_outs.append(
fluid.layers.fc(input=[conv_out], act="relu", size=1000))
for i in range(1, fc_num):
fc_outs.append(
fluid.layers.fc(
input=[fc_outs[i - 1]], act="relu", size=1000))
self.feeds = {
"data": np.random.random([1, 3, 64, 64]).astype("float32"),
}
self.fetch_list = [fc_outs[fc_num - 1]]
def test_check_output(self):
use_gpu = False
self.check_output_with_option(use_gpu)
self.assertTrue(
PassVersionChecker.IsCompatible('repeated_fc_relu_fuse_pass'))
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from inference_pass_test import InferencePassTest
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.core import PassVersionChecker
from paddle.fluid.core import AnalysisConfig
class ShuffleChannelFuseTRTPassTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 6, 64, 64], dtype="float32")
reshape1 = fluid.layers.reshape(x=data, shape=[-1, 2, 3, 64, 64])
trans = fluid.layers.transpose(x=reshape1, perm=[0, 2, 1, 3, 4])
reshape2 = fluid.layers.reshape(x=trans, shape=[-1, 6, 64, 64])
out = fluid.layers.batch_norm(reshape2, is_test=True)
self.feeds = {
"data": np.random.random([1, 6, 64, 64]).astype("float32"),
}
self.enable_trt = True
self.trt_parameters = ShuffleChannelFuseTRTPassTest.TensorRTParam(
1 << 30, 32, 1, AnalysisConfig.Precision.Float32, False, False)
self.fetch_list = [out]
def test_check_output(self):
self.check_output()
self.assertTrue(
PassVersionChecker.IsCompatible('shuffle_channel_detect_pass'))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册