未验证 提交 3540d33b 编写于 作者: P piotrekobi 提交者: GitHub

Rea-dd conv_affine_channel fuse pass as oneDNN only pass (#41998)

* Readd conv_affine_channel fuse pass as mkldnn pass

* Fix formatting

* Add new test to parallel_UT_rule.py

* Fix Coverage and Windows CI issues

* Revert "Fix Coverage and Windows CI issues"

This reverts commit f33459846385c9fd51c07f9f44e7ff283a652637.

* Fix CI errors

* Remove unnecessary conv_eltwise_add_affine_channel fuse pass

* Remove test from parallel_UT_rule.py
上级 754edf6e
......@@ -118,6 +118,7 @@ if(WITH_MKLDNN)
pass_library(mkldnn_placement_pass base DEPS placement_pass_base DIR mkldnn)
pass_library(mkldnn_inplace_pass inference DEPS mkldnn_placement_pass op_registry elementwise_add_op gelu_op activation_op softmax_op softmax DIR mkldnn)
pass_library(depthwise_conv_mkldnn_pass base DIR mkldnn)
pass_library(conv_affine_channel_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(conv_bias_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(conv_activation_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(conv_concat_relu_mkldnn_fuse_pass inference DIR mkldnn)
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/conv_affine_channel_mkldnn_fuse_pass.h"
#include <cmath>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
class Node;
#define GET_CONV_BN_NODES(pattern_name) \
/* OPERATORS */ \
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, pattern_name); \
GET_IR_NODE_FROM_SUBGRAPH(affine_channel, affine_channel, pattern_name); \
/* CONV inputs */ \
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, pattern_name); \
/* CONV outputs */ \
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, pattern_name); \
/* Affine Channel inputs */ \
GET_IR_NODE_FROM_SUBGRAPH(ac_scale, ac_scale, pattern_name); \
GET_IR_NODE_FROM_SUBGRAPH(ac_bias, ac_bias, pattern_name); \
/* Affine channel outputs */ \
GET_IR_NODE_FROM_SUBGRAPH(ac_out, ac_out, pattern_name); /* Out */
void recompute_bias_and_weights(const Scope* scope, ir::Node* conv_weight,
const ir::Node& ac_scale,
const LoDTensor& ac_bias_tensor,
LoDTensor* eltwise_y_in_tensor) {
using EigenVectorArrayMap =
Eigen::Map<Eigen::Array<float, Eigen::Dynamic, 1>>;
using ConstEigenVectorArrayMap =
Eigen::Map<const Eigen::Array<float, Eigen::Dynamic, 1>>;
using EigenMatrixArrayMap = Eigen::Map<
Eigen::Array<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
// Re-compute bias of conv2d from AffineChannel
PADDLE_ENFORCE_EQ(
eltwise_y_in_tensor->dims(), ac_bias_tensor.dims(),
platform::errors::InvalidArgument(
"Tensor elementwise y(%d) and activation bias(%d) must have same "
"dimension.",
eltwise_y_in_tensor->dims().size(), ac_bias_tensor.dims().size()));
auto* scale_tensor = scope->FindVar(ac_scale.Name())->GetMutable<LoDTensor>();
ConstEigenVectorArrayMap scale_array(scale_tensor->data<float>(),
scale_tensor->numel(), 1);
ConstEigenVectorArrayMap ac_bias_array(ac_bias_tensor.data<float>(),
ac_bias_tensor.numel(), 1);
EigenVectorArrayMap eltwise_y_in_array(
eltwise_y_in_tensor->mutable_data<float>(platform::CPUPlace()),
eltwise_y_in_tensor->numel(), 1);
eltwise_y_in_array = (eltwise_y_in_array * scale_array) + ac_bias_array;
// Re-compute weight of conv2d from AffineChannel
auto* weights = scope->FindVar(conv_weight->Name())->GetMutable<LoDTensor>();
auto weights_shape = weights->dims();
auto weights_shape_2d = phi::flatten_to_2d(weights_shape, 1);
auto* weights_data = weights->mutable_data<float>(platform::CPUPlace());
EigenMatrixArrayMap weights_array_2d(weights_data, weights_shape_2d[0],
weights_shape_2d[1]);
weights_array_2d.colwise() *= scale_array;
// Check for subnormal values that slows down convolution execution
for (int i = 0; i < weights->numel(); ++i) {
if (std::fpclassify(weights_data[i]) == FP_SUBNORMAL) weights_data[i] = 0;
}
}
ConvAffineChannelFusePass::ConvAffineChannelFusePass() {
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "AnyLayout"})
.End();
AddOpCompat(OpCompat("affine_channel"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("data_layout")
.IsStringIn({"NCHW", "AnyLayout"})
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(1)
.End();
}
void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
GraphPatternDetector gpd;
auto* conv_input =
gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
->AsInput()
->assert_is_op_input("conv2d", "Input");
patterns::ConvAffineChannel conv_ac_pattern(gpd.mutable_pattern(),
name_scope_);
conv_ac_pattern(conv_input, false /*with_eltwise_add*/);
int found_conv_ac_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "ConvAffineChannelFusePass in op compat failed.";
return;
}
VLOG(4) << "handle ConvAffineChannel fuse";
GET_CONV_BN_NODES(conv_ac_pattern);
auto data_format = conv->Op()->GetAttrIfExists<std::string>("data_format");
if (data_format == "AnyLayout") {
LOG_FIRST_N(WARNING, 1) << "conv_affine_channel_fuse_pass is enabled, "
"it's wrong if data_format of conv is not "
"NCHW.";
}
// Get affine_channel bias for resizing eltwise_y!
auto* ac_bias_tensor =
scope->FindVar(ac_bias->Name())->GetMutable<LoDTensor>();
// Create eltwise_y (conv bias) variable
VarDesc eltwise_y_in_desc(
patterns::PDNodeName(name_scope_, "eltwise_y_in"));
// Set shape && datatype manually
eltwise_y_in_desc.SetShape(phi::vectorize(ac_bias_tensor->dims()));
eltwise_y_in_desc.SetDataType(
framework::TransToProtoVarType(ac_bias_tensor->dtype()));
eltwise_y_in_desc.SetLoDLevel(ac_bias->Var()->GetLoDLevel());
eltwise_y_in_desc.SetPersistable(true);
// Initialize eltwise_y
auto* eltwise_y_in_node = g->CreateVarNode(&eltwise_y_in_desc);
auto* eltwise_y_in_tensor =
scope->Var(eltwise_y_in_node->Name())->GetMutable<LoDTensor>();
eltwise_y_in_tensor->Resize(ac_bias_tensor->dims());
std::fill_n(eltwise_y_in_tensor->mutable_data<float>(platform::CPUPlace()),
eltwise_y_in_tensor->numel(), 0.0f);
// update weights and biases
recompute_bias_and_weights(scope, conv_weight, *ac_scale, *ac_bias_tensor,
eltwise_y_in_tensor);
// create an elementwise add node.
OpDesc desc;
desc.SetInput("X", std::vector<std::string>({conv_out->Name()}));
desc.SetInput("Y", std::vector<std::string>({eltwise_y_in_node->Name()}));
desc.SetOutput("Out", std::vector<std::string>({ac_out->Name()}));
desc.SetType("elementwise_add");
desc.SetAttr("axis", 1);
desc.SetAttr("use_mkldnn", conv->Op()->GetAttrIfExists<bool>("use_mkldnn"));
auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph, {ac_scale, ac_bias, affine_channel});
IR_NODE_LINK_TO(conv_out, eltwise_op);
IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op);
IR_NODE_LINK_TO(eltwise_op, ac_out);
found_conv_ac_count++;
};
gpd(graph, handler);
AddStatis(found_conv_ac_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(conv_affine_channel_mkldnn_fuse_pass,
paddle::framework::ir::ConvAffineChannelFusePass);
REGISTER_PASS_CAPABILITY(conv_affine_channel_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1)
.EQ("affine_channel", 0));
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse the Conv and ConvAffineChannel.
*/
class Graph;
class ConvAffineChannelFusePass : public FusePassBase {
public:
ConvAffineChannelFusePass();
virtual ~ConvAffineChannelFusePass() {}
protected:
void ApplyImpl(ir::Graph*) const override;
const std::string name_scope_{"conv_affine_channel_mkldnn_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -282,7 +282,8 @@ void CpuPassStrategy::EnableMKLDNN() {
"depthwise_conv_mkldnn_pass", //
"conv_bn_fuse_pass", // Execute BN passes again to
"conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order
"conv_transpose_bn_fuse_pass", //
"conv_affine_channel_mkldnn_fuse_pass", //
"conv_transpose_bn_fuse_pass", //
"conv_transpose_eltwiseadd_bn_fuse_pass", //
"conv_bias_mkldnn_fuse_pass", //
"conv_transpose_bias_mkldnn_fuse_pass",
......
......@@ -426,6 +426,7 @@ class Quant2Int8MkldnnPass(object):
graph = self._apply_pass(graph, 'depthwise_conv_mkldnn_pass')
graph = self._apply_pass(graph, 'conv_bn_fuse_pass')
graph = self._apply_pass(graph, 'conv_eltwiseadd_bn_fuse_pass')
graph = self._apply_pass(graph, 'conv_affine_channel_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_transpose_bn_fuse_pass')
graph = self._apply_pass(graph,
'conv_transpose_eltwiseadd_bn_fuse_pass')
......
......@@ -143,5 +143,6 @@ if (WITH_MKLDNN)
set_tests_properties(test_mkldnn_conv_mish_fuse_pass PROPERTIES TIMEOUT 300)
set_tests_properties(test_mkldnn_fc_mish_fuse_pass PROPERTIES TIMEOUT 300)
set_tests_properties(test_mkldnn_fc_elementwise_add_fuse_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_mkldnn_conv_affine_channel_fuse_pass PROPERTIES TIMEOUT 60)
endif()
endif()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import PassAutoScanTest, IgnoreReasons
from program_config import TensorConfig, ProgramConfig, OpConfig
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
import unittest
import hypothesis
from hypothesis import given, settings, seed, example, assume, reproduce_failure
import hypothesis.strategies as st
class TestConvAffineChannelFusePass(PassAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def sample_program_config(self, draw):
padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"]))
groups = draw(st.integers(min_value=1, max_value=3))
data_format = draw(st.sampled_from(["NCHW", "NHWC"]))
axis = draw(st.sampled_from([1]))
filter_channel = draw(st.integers(min_value=1, max_value=16)) * 4
filter_size = draw(st.integers(min_value=1, max_value=4))
in_channel = groups * filter_channel
out_channel_factor = draw(st.integers(min_value=1, max_value=16)) * 4
out_channel = groups * out_channel_factor
batch_size = draw(st.integers(min_value=1, max_value=4))
dilations = draw(
st.lists(
st.integers(
min_value=1, max_value=2), min_size=2, max_size=2))
paddings = draw(
st.lists(
st.integers(
min_value=0, max_value=2), min_size=2, max_size=2))
strides = draw(
st.lists(
st.integers(
min_value=1, max_value=2), min_size=2, max_size=2))
has_bias = draw(st.booleans())
x_shape = [
batch_size, in_channel, 64, 64
] if data_format == "NCHW" else [batch_size, 64, 64, in_channel]
w_shape = [out_channel, filter_channel, filter_size, filter_size]
scale_shape = [out_channel]
bias_shape = [out_channel]
def generate_input():
return np.random.random(x_shape).astype(np.float32)
def generate_weight():
return np.random.random(w_shape).astype(np.float32)
def generate_bias():
return np.random.random(bias_shape).astype(np.float32)
def generate_scale_bias():
return np.random.random(bias_shape).astype(np.float32)
conv2d_op = OpConfig(
"conv2d",
inputs={
"Input": ["input_data"],
"Filter": ["conv2d_weight"],
},
outputs={"Output": ["conv_output"]},
data_format=data_format,
dilations=dilations,
padding_algorithm=padding_algorithm,
groups=groups,
paddings=paddings,
strides=strides,
has_bias=has_bias,
is_test=True)
ac_op = OpConfig(
"affine_channel",
inputs={
"X": ["conv_output"],
"Scale": ["affine_channel_scale"],
"Bias": ["affine_channel_bias"]
},
outputs={"Out": ["affine_channel_ouput"]},
data_layout=data_format)
if has_bias == True:
conv2d_op.inputs["Bias"] = ["conv2d_bias"]
ops = [conv2d_op, ac_op]
program_config = ProgramConfig(
ops=ops,
inputs={
"input_data": TensorConfig(data_gen=partial(generate_input)),
},
weights={
"conv2d_weight":
TensorConfig(data_gen=partial(generate_weight)),
"conv2d_bias": TensorConfig(data_gen=partial(generate_bias)),
"affine_channel_scale":
TensorConfig(data_gen=partial(generate_scale_bias)),
"affine_channel_bias":
TensorConfig(data_gen=partial(generate_scale_bias)),
},
outputs=["affine_channel_ouput"])
if has_bias == True:
program_config.weights["conv2d_bias"] = TensorConfig(
data_gen=partial(generate_bias))
return program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True)
yield config, ['conv2d', 'elementwise_add'], (1e-4, 1e-4)
def add_ignore_pass_case(self):
# If the problem has been fixed, the judgment
# in is_program_valid needs to be deleted!!!
def teller1(program_config, predictor_config):
if program_config.ops[0].attrs['data_format'] == "NHWC":
return True
return False
# mkldnn Output has diff with bias!
def teller2(program_config, predictor_config):
return predictor_config.mkldnn_enabled() and program_config.ops[
0].attrs['has_bias'] == True
self.add_ignore_check_case(
teller1, IgnoreReasons.PASS_ACCURACY_ERROR,
"The output format of conv2d is wrong when data_format attribute is NHWC, \
because currently its fused op (Conv2DFusion) only supports data format of channel first (NCHW)."
)
self.add_ignore_check_case(
teller2, IgnoreReasons.PASS_ACCURACY_ERROR,
"Currently mkldnn Output has diff with bias!")
def test(self):
self.run_and_statis(
quant=False,
passes=["conv_affine_channel_mkldnn_fuse_pass"], )
if __name__ == "__main__":
unittest.main()
......@@ -659,6 +659,7 @@ STATIC_MODE_TESTING_LIST = [
'test_mkldnn_matmul_transpose_reshape_fuse_pass',
'test_mkldnn_scale_matmul_fuse_pass',
'test_mkldnn_inplace_fuse_pass',
'test_mkldnn_conv_affine_channel_fuse_pass',
'test_batch_fc_op',
'test_c_comm_init_all_op',
'test_conv2d_fusion_op',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册