未验证 提交 d03ef054 编写于 作者: S Sławomir Siwek 提交者: GitHub

Extend conv_concat_relu to support all activations (#45089)

* merge conv_concat_relu to conv_act

* fix typo

* extend unit test

* reuse existing gpd

* codestyle

* enforce mkldnn conv
上级 25d25b00
......@@ -199,7 +199,6 @@ if(WITH_MKLDNN)
pass_library(conv_affine_channel_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(conv_bias_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(conv_activation_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(conv_concat_relu_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(int8_scale_calculation_mkldnn_pass inference DIR mkldnn)
pass_library(params_quantization_mkldnn_pass inference DIR mkldnn)
......@@ -409,7 +408,7 @@ if(WITH_MKLDNN)
cc_test(
test_conv_concat_relu_mkldnn_fuse_pass
SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc
DEPS conv_concat_relu_mkldnn_fuse_pass)
DEPS conv_activation_mkldnn_fuse_pass)
cc_test(
test_conv_elementwise_add_mkldnn_fuse_pass
SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc
......
......@@ -2081,46 +2081,6 @@ PDNode *patterns::Concat::operator()() {
return output_var;
}
PDNode *patterns::ConcatReLU::operator()() {
auto concat_op = pattern->NewNode(concat_op_repr())->assert_is_op("concat");
auto relu_op = pattern->NewNode(relu_op_repr())->assert_is_op("relu");
auto concat_out =
pattern->NewNode(concat_out_repr())->assert_is_op_output("concat", "Out");
auto relu_out = pattern->NewNode(relu_out_repr())
->AsOutput()
->assert_is_op_output("relu", "Out");
concat_op->LinksTo({concat_out});
relu_op->LinksFrom({concat_out}).LinksTo({relu_out});
return relu_out;
}
PDNode *patterns::ConvConcatReLU::operator()() {
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
auto concat_op = pattern->NewNode(concat_op_repr())->assert_is_op("concat");
auto relu_op = pattern->NewNode(relu_op_repr())->assert_is_op("relu");
auto conv_out = pattern->NewNode(conv_out_repr())
->assert_is_op_output("conv2d", "Output");
auto concat_out = pattern->NewNode(concat_out_repr())
->assert_is_op_output("concat", "Out")
->assert_is_op_input("relu", "X");
auto relu_out = pattern->NewNode(relu_out_repr())
->AsOutput()
->assert_is_op_output("relu", "Out");
conv_op->LinksTo({conv_out});
concat_op->LinksFrom({conv_out}).LinksTo({concat_out});
relu_op->LinksFrom({concat_out}).LinksTo({relu_out});
return relu_out;
}
PDNode *patterns::OpRequant::operator()() {
auto any_op = pattern->NewNode(any_op_repr())
->assert_is_op()
......
......@@ -1228,39 +1228,6 @@ struct Concat : public PatternBase {
PATTERN_DECL_NODE(concat_out);
};
// Concat + ReLU
// named nodes:
// concat_op, concat_out, relu_op, relu_out
struct ConcatReLU : public PatternBase {
ConcatReLU(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "concat_relu") {}
PDNode* operator()();
PATTERN_DECL_NODE(concat_op);
PATTERN_DECL_NODE(concat_out);
PATTERN_DECL_NODE(relu_op);
PATTERN_DECL_NODE(relu_out);
};
// Conv + Concat + ReLU
// named nodes:
// conv_op, conv_out
// concat_op, concat_out, relu_op, relu_out
struct ConvConcatReLU : public PatternBase {
ConvConcatReLU(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_concat_relu") {}
PDNode* operator()();
PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_out);
PATTERN_DECL_NODE(concat_op);
PATTERN_DECL_NODE(concat_out);
PATTERN_DECL_NODE(relu_op);
PATTERN_DECL_NODE(relu_out);
};
// Op + Requant
// named nodes:
// any_op, any_out
......
......@@ -28,10 +28,12 @@ void ConvActivationMkldnnFusePass::ApplyImpl(Graph* graph) const {
auto act_types = paddle::platform::GetSupportedActivations();
std::vector<std::string> conv_types = {"conv2d"};
for (const auto& conv_type : conv_types)
for (auto& act_type : act_types) {
for (auto& act_type : act_types) {
FuseConvConcatAct(graph, act_type);
for (const auto& conv_type : conv_types) {
FuseConvAct(graph, conv_type, act_type);
}
}
}
void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph,
......@@ -49,8 +51,6 @@ void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph,
int found_conv_activation_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "handle " + conv_type + "+" + act_type + " fuse";
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "conv_activation_mkldnn_fuse_pass op compat failed.";
return;
......@@ -89,13 +89,95 @@ void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph,
gpd(graph, handler);
AddStatis(found_conv_activation_count);
if (!Has("disable_logs") || !Get<bool>("disable_logs")) {
if ((!Has("disable_logs") || !Get<bool>("disable_logs")) &&
found_conv_activation_count > 0) {
PrettyLogDetail("--- fused %d conv with %s activation",
found_conv_activation_count,
act_type);
}
}
void ConvActivationMkldnnFusePass::FuseConvConcatAct(
Graph* graph, std::string& act_type) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("conv2d_concat_" + act_type + "_mkldnn_fuse_pass", graph);
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::OperatorActivation conv_concat_act(
pattern, "conv2d_concat_" + act_type + "_mkldnn_fuse_pass");
conv_concat_act("concat", act_type);
int found_conv_concat_activation_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING)
<< "conv_concat_activation_mkldnn_fuse_pass op compat failed.";
return;
}
GET_IR_NODE_FROM_SUBGRAPH(concat_op, preceding_op, conv_concat_act);
GET_IR_NODE_FROM_SUBGRAPH(concat_out, preceding_op_out, conv_concat_act);
GET_IR_NODE_FROM_SUBGRAPH(activation_op, activation, conv_concat_act);
GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out, conv_concat_act);
auto concat_inputs = concat_op->inputs;
for (auto node : concat_inputs) {
auto prev_op_nodes = node->inputs;
if (prev_op_nodes.size() != 1) {
LOG(WARNING)
<< "Operator connected to concat can have only one output.";
return;
}
bool is_not_conv_mkldnn =
!(prev_op_nodes[0]->Op()->GetAttrIfExists<bool>("use_mkldnn"));
if (prev_op_nodes[0]->Op()->Type() != "conv2d" || is_not_conv_mkldnn) {
LOG(WARNING)
<< "This fuse pass supports only conv2d (mkldnn) + activation.";
return;
}
}
for (auto node : concat_inputs) {
OpDesc* conv_op = node->inputs[0]->Op();
OpDesc* act_op = activation_op->Op();
auto attr_map = paddle::platform::GetAttributeMap(act_type);
for (const auto& attrs : attr_map) {
if (act_op->HasAttr(attrs.first)) {
conv_op->SetAttr(attrs.second, act_op->GetAttr(attrs.first));
}
}
if (act_type == "gelu" && act_op->HasAttr("approximate")) {
act_type = PADDLE_GET_CONST(bool, act_op->GetAttr("approximate"))
? "gelu_tanh"
: "gelu_erf";
conv_op->SetAttr("fuse_alpha", 0.0f);
conv_op->SetAttr("fuse_beta", 0.0f);
}
conv_op->SetAttr("fuse_activation", act_type);
}
concat_op->Op()->SetOutput("Out", {activation_out->Name()});
GraphSafeRemoveNodes(graph, {activation_op, concat_out});
IR_NODE_LINK_TO(concat_op, activation_out);
found_conv_concat_activation_count++;
};
gpd(graph, handler);
AddStatis(found_conv_concat_activation_count);
if ((!Has("disable_logs") || !Get<bool>("disable_logs")) &&
found_conv_concat_activation_count > 0) {
PrettyLogDetail("--- fused %d conv_concat with %s activation",
found_conv_concat_activation_count,
act_type);
}
}
ConvActivationMkldnnFusePass::ConvActivationMkldnnFusePass() {
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
......@@ -136,6 +218,20 @@ ConvActivationMkldnnFusePass::ConvActivationMkldnnFusePass() {
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("concat"))
.AddInput("X")
.End()
.AddInput("AxisTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumGE(0)
.End();
AddOpCompat(OpCompat("relu"))
.AddInput("X")
.IsTensor()
......@@ -276,6 +372,7 @@ REGISTER_PASS_CAPABILITY(conv_activation_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1)
.EQ("concat", 0)
.EQ("abs", 0)
.LE("clip", 1)
.EQ("gelu", 0)
......
......@@ -34,6 +34,8 @@ class ConvActivationMkldnnFusePass : public FusePassBase {
void FuseConvAct(Graph *graph,
const std::string &conv_type,
std::string &act_type) const;
void FuseConvConcatAct(Graph *graph, std::string &act_type) const;
};
} // namespace ir
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.h"
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
ConvConcatReLUFusePass::ConvConcatReLUFusePass() {
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("concat"))
.AddInput("X") // Input("X"): vector<tensors>
.End()
.AddInput("AxisTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumGE(0)
.End();
AddOpCompat(OpCompat("relu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
}
void ConvConcatReLUFusePass::FindConcatWithConvs(
ir::Graph* graph,
std::unordered_map<const Node*, int>* concat_with_convs_counter) const {
GraphPatternDetector gpd;
patterns::ConcatReLU concat_relu_pattern{gpd.mutable_pattern(),
"concat_relu"};
concat_relu_pattern();
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "Find Concats with Convs";
GET_IR_NODE_FROM_SUBGRAPH(concat_op, concat_op, concat_relu_pattern);
GET_IR_NODE_FROM_SUBGRAPH(relu_op, relu_op, concat_relu_pattern);
auto concat_inputs = concat_op->inputs;
for (auto node : concat_inputs) {
auto prev_op_node = node->inputs;
PADDLE_ENFORCE_EQ(prev_op_node.size(),
1,
platform::errors::InvalidArgument(
"Node(%s) input size(%d) must be 1.",
node->Name(),
prev_op_node.size()));
auto* conv_op = prev_op_node[0];
if (conv_op->Op()->Type() != "conv2d") return;
FuseOptions fuse_option = FindFuseOption(*conv_op, *relu_op);
if (fuse_option == DO_NOT_FUSE) {
return;
}
}
(*concat_with_convs_counter)[concat_op] = concat_inputs.size();
found_count++;
};
gpd(graph, handler);
AddStatis(found_count);
}
void ConvConcatReLUFusePass::FuseConvConcatReLU(
ir::Graph* graph,
std::unordered_map<const Node*, int>* concat_with_convs_counter) const {
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::ConvConcatReLU conv_concat_relu(pattern, name_scope_);
conv_concat_relu();
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "handle ConvConcatReLU fuse";
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_concat_relu);
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_concat_relu);
GET_IR_NODE_FROM_SUBGRAPH(concat_op, concat_op, conv_concat_relu);
GET_IR_NODE_FROM_SUBGRAPH(concat_out, concat_out, conv_concat_relu);
GET_IR_NODE_FROM_SUBGRAPH(relu_op, relu_op, conv_concat_relu);
GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, conv_concat_relu);
if (!concat_with_convs_counter->count(concat_op)) {
VLOG(4) << "this concat has input from non-conv2d operator";
return;
}
// Transform Conv node into ConvReLU node.
OpDesc* conv_desc = conv_op->Op();
conv_desc->SetAttr("fuse_activation", std::string("relu"));
// Remove ReLU when all Convs were transformed.
auto number_of_unfused_convs_left =
--(*concat_with_convs_counter)[concat_op];
if (number_of_unfused_convs_left == 0) {
OpDesc* concat_desc = concat_op->Op();
concat_desc->SetOutput("Out",
std::vector<std::string>({relu_out->Name()}));
GraphSafeRemoveNodes(graph, {relu_op, concat_out});
IR_NODE_LINK_TO(concat_op, relu_out);
}
found_count++;
};
gpd(graph, handler);
AddStatis(found_count);
}
void ConvConcatReLUFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph);
std::unordered_map<const Node*, int> concat_with_convs_counter;
FindConcatWithConvs(graph, &concat_with_convs_counter);
FuseConvConcatReLU(graph, &concat_with_convs_counter);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(conv_concat_relu_mkldnn_fuse_pass,
paddle::framework::ir::ConvConcatReLUFusePass);
REGISTER_PASS_CAPABILITY(conv_concat_relu_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1)
.EQ("concat", 0)
.EQ("relu", 0));
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse the (multi conv) -> Concat -> ReLU -> next_op
* to a:
* (multi ConvReLU) -> Concat -> next_op.
*/
class ConvConcatReLUFusePass : public FusePassBase {
public:
ConvConcatReLUFusePass();
virtual ~ConvConcatReLUFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
void FindConcatWithConvs(
Graph* graph,
std::unordered_map<const Node*, int>* concat_with_convs_counter) const;
void FuseConvConcatReLU(
Graph* graph,
std::unordered_map<const Node*, int>* concat_with_convs_counter) const;
const std::string name_scope_{"conv_concat_relu_mkldnn_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -14,7 +14,7 @@
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
......@@ -47,6 +47,7 @@ void SetOp(ProgramDesc* prog,
op->SetOutput("Out", outputs);
} else if (type == "concat") {
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("axis", 0);
op->SetInput("X", inputs);
op->SetOutput("Out", outputs);
}
......@@ -103,7 +104,7 @@ void MainTest(const ProgramDesc& prog, bool fuse_relu) {
int original_nodes_num = graph->Nodes().size();
auto pass = PassRegistry::Instance().Get("conv_concat_relu_mkldnn_fuse_pass");
auto pass = PassRegistry::Instance().Get("conv_activation_mkldnn_fuse_pass");
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
......@@ -167,4 +168,4 @@ TEST(ConvConcatReLUFusePass, convs_and_pool_before_concat) {
} // namespace framework
} // namespace paddle
USE_PASS(conv_concat_relu_mkldnn_fuse_pass);
USE_PASS(conv_activation_mkldnn_fuse_pass);
......@@ -303,7 +303,6 @@ void CpuPassStrategy::EnableMKLDNN() {
// TODO(baoachun): Need to support 5-dimensional input.
// "conv3d_bias_mkldnn_fuse_pass", //
"conv_elementwise_add_mkldnn_fuse_pass",
"conv_concat_relu_mkldnn_fuse_pass",
"conv_activation_mkldnn_fuse_pass", //
"scale_matmul_fuse_pass", //
"reshape_transpose_matmul_mkldnn_fuse_pass", //
......@@ -396,7 +395,6 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("conv_bias_mkldnn_fuse_pass");
passes_.push_back("conv_transpose_bias_mkldnn_fuse_pass");
passes_.push_back("conv_elementwise_add_mkldnn_fuse_pass");
passes_.push_back("conv_concat_relu_mkldnn_fuse_pass");
passes_.push_back("conv_activation_mkldnn_fuse_pass");
passes_.push_back("fc_fuse_pass");
passes_.push_back("repeated_fc_relu_fuse_pass");
......
......@@ -439,7 +439,6 @@ class Quant2Int8MkldnnPass(object):
graph = self._apply_pass(graph, 'conv_bias_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_transpose_bias_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_elementwise_add_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_concat_relu_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_activation_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'fc_fuse_pass',
['use_gpu', 'use_fc_padding'], [False, False])
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,140 +12,136 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import PassAutoScanTest, SkipReasons
from program_config import TensorConfig, ProgramConfig
from auto_scan_test import PassAutoScanTest
from program_config import TensorConfig, ProgramConfig, OpConfig
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
import unittest
import hypothesis
from hypothesis import given, settings, seed, example, assume
import hypothesis.strategies as st
class TestConvConcatReluMkldnnFusePass(PassAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
class TestConvConcatActivationMkldnnFusePass(PassAutoScanTest):
def sample_program_config(self, draw):
data_format = draw(st.sampled_from(["NCHW", "NHWC"]))
dilations = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]]))
padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"]))
groups = draw(st.sampled_from([1, 2, 4]))
paddings = draw(st.sampled_from([[0, 3], [1, 2, 3, 4]]))
strides = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]]))
data_format = draw(st.sampled_from(['NCHW', 'NHWC']))
dilations = draw(st.sampled_from([[2, 2]]))
padding_algorithm = draw(st.sampled_from(['VALID']))
groups = draw(st.sampled_from([4]))
paddings = draw(st.sampled_from([[0, 3]]))
strides = draw(st.sampled_from([[1, 2]]))
axis = draw(st.sampled_from([0]))
batch_size = draw(st.integers(min_value=1, max_value=4))
def generate_input(attrs):
if attrs[0]['data_format'] == "NCHW":
return np.random.random([attrs[2]['batch_size'], 48, 64,
64]).astype(np.float32)
else:
return np.random.random([attrs[2]['batch_size'], 64, 64,
48]).astype(np.float32)
def generate_weight():
return np.random.random([16, int(48 / groups), 3,
3]).astype(np.float32)
attrs = [{
"data_format": data_format,
"dilations": dilations,
"padding_algorithm": padding_algorithm,
"groups": groups,
"paddings": paddings,
"strides": strides
}, {
"axis": axis
}, {
'batch_size': batch_size
}]
ops_config = [{
"op_type": "conv2d",
"op_inputs": {
"Input": ["input_data1"],
"Filter": ["input_weight"]
},
"op_outputs": {
"Output": ["conv1_output"]
},
"op_attrs": {
"data_format": attrs[0]['data_format'],
"dilations": attrs[0]['dilations'],
"padding_algorithm": attrs[0]['padding_algorithm'],
"groups": attrs[0]['groups'],
"paddings": attrs[0]['paddings'],
"strides": attrs[0]['strides']
}
}, {
"op_type": "conv2d",
"op_inputs": {
"Input": ["input_data2"],
"Filter": ["input_weight"]
},
"op_outputs": {
"Output": ["conv2_output"]
},
"op_attrs": {
"data_format": attrs[0]['data_format'],
"dilations": attrs[0]['dilations'],
"padding_algorithm": attrs[0]['padding_algorithm'],
"groups": attrs[0]['groups'],
"paddings": attrs[0]['paddings'],
"strides": attrs[0]['strides']
}
}, {
"op_type": "concat",
"op_inputs": {
"X": ["conv1_output", "conv2_output"]
},
"op_outputs": {
"Out": ["concat_output"]
},
"op_attrs": {
'axis': attrs[1]['axis']
}
}, {
"op_type": "relu",
"op_inputs": {
"X": ["concat_output"]
},
"op_outputs": {
"Out": ["relu_output"]
},
"op_attrs": {}
}]
ops = self.generate_op_config(ops_config)
activation_type = draw(
st.sampled_from([
'relu', 'gelu', 'swish', 'mish', 'sqrt', 'hard_swish',
'sigmoid', 'abs', 'relu6', 'clip', 'tanh', 'hard_sigmoid',
'leaky_relu'
]))
def generate_data(input_type):
if input_type == 'NCHW':
return np.random.random([16, 48, 64, 64]).astype(np.float32)
elif input_type == 'NHWC':
return np.random.random([16, 64, 64, 48]).astype(np.float32)
elif input_type == 'weights':
return np.random.random([16, int(48 / groups), 3,
3]).astype(np.float32)
conv2d_op1 = OpConfig(type='conv2d',
inputs={
'Input': ['conv_input_1'],
'Filter': ['conv_weights_1']
},
outputs={'Output': ['conv_output_1']},
attrs={
'data_format': data_format,
'dilations': dilations,
'padding_algorithm': padding_algorithm,
'groups': groups,
'paddings': paddings,
'strides': strides
})
conv2d_op2 = OpConfig(type='conv2d',
inputs={
'Input': ['conv_input_2'],
'Filter': ['conv_weights_2']
},
outputs={'Output': ['conv_output_2']},
attrs={
'data_format': data_format,
'dilations': dilations,
'padding_algorithm': padding_algorithm,
'groups': groups,
'paddings': paddings,
'strides': strides
})
concat_op = OpConfig(type='concat',
inputs={'X': ['conv_output_1', 'conv_output_2']},
outputs={'Out': ['concat_output']},
attrs={'axis': axis})
if activation_type == 'relu6':
activation_op = OpConfig(activation_type,
inputs={'X': ['concat_output']},
outputs={'Out': ['activation_output']},
threshold=draw(
st.floats(min_value=1.0,
max_value=10.0)))
elif activation_type == 'leaky_relu':
activation_op = OpConfig(activation_type,
inputs={'X': ['concat_output']},
outputs={'Out': ['activation_output']},
alpha=draw(
st.floats(min_value=0.1,
max_value=1.0)))
elif activation_type == 'swish':
activation_op = OpConfig(activation_type,
inputs={'X': ['concat_output']},
outputs={'Out': ['activation_output']},
beta=draw(
st.floats(min_value=0.1,
max_value=1.0)))
elif activation_type == 'clip':
activation_op = OpConfig(
activation_type,
inputs={'X': ['concat_output']},
outputs={'Out': ['activation_output']},
min=draw(st.floats(min_value=0.1, max_value=0.49)),
max=draw(st.floats(min_value=0.5, max_value=1.0)))
else:
activation_op = OpConfig(activation_type,
inputs={'X': ['concat_output']},
outputs={'Out': ['activation_output']})
model_net = [conv2d_op1, conv2d_op2, concat_op, activation_op]
program_config = ProgramConfig(
ops=ops,
weights={
"input_weight": TensorConfig(data_gen=partial(generate_weight))
},
ops=model_net,
inputs={
"input_data1":
TensorConfig(data_gen=partial(generate_input, attrs)),
"input_data2":
TensorConfig(data_gen=partial(generate_input, attrs))
'conv_input_1':
TensorConfig(data_gen=partial(generate_data, data_format)),
'conv_input_2':
TensorConfig(data_gen=partial(generate_data, data_format))
},
weights={
'conv_weights_1':
TensorConfig(data_gen=partial(generate_data, 'weights')),
'conv_weights_2':
TensorConfig(data_gen=partial(generate_data, 'weights'))
},
outputs=["relu_output"])
outputs=['activation_output'])
return program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True)
yield config, ["conv2d", "conv2d", "concat"], (1e-5, 1e-5)
yield config, ['conv2d', 'conv2d', 'concat'], (1e-5, 1e-5)
def test(self):
self.run_and_statis(quant=False,
passes=["conv_concat_relu_mkldnn_fuse_pass"])
passes=['conv_activation_mkldnn_fuse_pass'])
if __name__ == "__main__":
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册