未验证 提交 fc06be9d 编写于 作者: W wenbin 提交者: GitHub

remove conv_affine_channel_fuse_pass (#39817)

* remove

* pass

* more pass
上级 25650774
......@@ -78,7 +78,6 @@ pass_library(is_test_pass base)
pass_library(conv_elementwise_add_act_fuse_pass inference)
pass_library(conv_elementwise_add2_act_fuse_pass inference)
pass_library(conv_elementwise_add_fuse_pass inference)
pass_library(conv_affine_channel_fuse_pass inference)
pass_library(transpose_flatten_concat_fuse_pass inference)
pass_library(identity_scale_op_clean_pass base)
pass_library(sync_batch_norm_pass base)
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.h"
#include <cmath>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
class Node;
#define GET_CONV_BN_NODES(pattern_name) \
/* OPERATORS */ \
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, pattern_name); \
GET_IR_NODE_FROM_SUBGRAPH(affine_channel, affine_channel, pattern_name); \
/* CONV inputs */ \
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, pattern_name); \
/* CONV outputs */ \
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, pattern_name); \
/* Affine Channel inputs */ \
GET_IR_NODE_FROM_SUBGRAPH(ac_scale, ac_scale, pattern_name); \
GET_IR_NODE_FROM_SUBGRAPH(ac_bias, ac_bias, pattern_name); \
/* Affine channel outputs */ \
GET_IR_NODE_FROM_SUBGRAPH(ac_out, ac_out, pattern_name); /* Out */
void recompute_bias_and_weights(const Scope* scope, ir::Node* conv_weight,
const ir::Node& ac_scale,
const LoDTensor& ac_bias_tensor,
LoDTensor* eltwise_y_in_tensor) {
using EigenVectorArrayMap =
Eigen::Map<Eigen::Array<float, Eigen::Dynamic, 1>>;
using ConstEigenVectorArrayMap =
Eigen::Map<const Eigen::Array<float, Eigen::Dynamic, 1>>;
using EigenMatrixArrayMap = Eigen::Map<
Eigen::Array<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
// Re-compute bias of conv2d from AffineChannel
PADDLE_ENFORCE_EQ(
eltwise_y_in_tensor->dims(), ac_bias_tensor.dims(),
platform::errors::InvalidArgument(
"Tensor elementwise y(%d) and activation bias(%d) must have same "
"dimension.",
eltwise_y_in_tensor->dims().size(), ac_bias_tensor.dims().size()));
auto* scale_tensor = scope->FindVar(ac_scale.Name())->GetMutable<LoDTensor>();
ConstEigenVectorArrayMap scale_array(scale_tensor->data<float>(),
scale_tensor->numel(), 1);
ConstEigenVectorArrayMap ac_bias_array(ac_bias_tensor.data<float>(),
ac_bias_tensor.numel(), 1);
EigenVectorArrayMap eltwise_y_in_array(
eltwise_y_in_tensor->mutable_data<float>(platform::CPUPlace()),
eltwise_y_in_tensor->numel(), 1);
eltwise_y_in_array = (eltwise_y_in_array * scale_array) + ac_bias_array;
// Re-compute weight of conv2d from AffineChannel
auto* weights = scope->FindVar(conv_weight->Name())->GetMutable<LoDTensor>();
auto weights_shape = weights->dims();
auto weights_shape_2d = phi::flatten_to_2d(weights_shape, 1);
auto* weights_data = weights->mutable_data<float>(platform::CPUPlace());
EigenMatrixArrayMap weights_array_2d(weights_data, weights_shape_2d[0],
weights_shape_2d[1]);
weights_array_2d.colwise() *= scale_array;
// Check for subnormal values that slows down convolution execution
for (int i = 0; i < weights->numel(); ++i) {
if (std::fpclassify(weights_data[i]) == FP_SUBNORMAL) weights_data[i] = 0;
}
}
ConvAffineChannelFusePass::ConvAffineChannelFusePass() {
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "AnyLayout"})
.End();
AddOpCompat(OpCompat("affine_channel"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("data_layout")
.IsStringIn({"NCHW", "AnyLayout"})
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(1)
.End();
}
void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
GraphPatternDetector gpd;
auto* conv_input =
gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
->AsInput()
->assert_is_op_input("conv2d", "Input");
patterns::ConvAffineChannel conv_ac_pattern(gpd.mutable_pattern(),
name_scope_);
conv_ac_pattern(conv_input, false /*with_eltwise_add*/);
int found_conv_ac_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "ConvAffineChannelFusePass in op compat failed.";
return;
}
VLOG(4) << "handle ConvAffineChannel fuse";
GET_CONV_BN_NODES(conv_ac_pattern);
auto data_format = conv->Op()->GetAttrIfExists<std::string>("data_format");
if (data_format == "AnyLayout") {
LOG_FIRST_N(WARNING, 1) << "conv_affine_channel_fuse_pass is enabled, "
"it's wrong if data_format of conv is not "
"NCHW.";
}
// Get affine_channel bias for resizing eltwise_y!
auto* ac_bias_tensor =
scope->FindVar(ac_bias->Name())->GetMutable<LoDTensor>();
// Create eltwise_y (conv bias) variable
VarDesc eltwise_y_in_desc(
patterns::PDNodeName(name_scope_, "eltwise_y_in"));
// Set shape && datatype manually
eltwise_y_in_desc.SetShape(phi::vectorize(ac_bias_tensor->dims()));
eltwise_y_in_desc.SetDataType(
framework::TransToProtoVarType(ac_bias_tensor->dtype()));
eltwise_y_in_desc.SetLoDLevel(ac_bias->Var()->GetLoDLevel());
eltwise_y_in_desc.SetPersistable(true);
// Initialize eltwise_y
auto* eltwise_y_in_node = g->CreateVarNode(&eltwise_y_in_desc);
auto* eltwise_y_in_tensor =
scope->Var(eltwise_y_in_node->Name())->GetMutable<LoDTensor>();
eltwise_y_in_tensor->Resize(ac_bias_tensor->dims());
std::fill_n(eltwise_y_in_tensor->mutable_data<float>(platform::CPUPlace()),
eltwise_y_in_tensor->numel(), 0.0f);
// update weights and biases
recompute_bias_and_weights(scope, conv_weight, *ac_scale, *ac_bias_tensor,
eltwise_y_in_tensor);
// create an elementwise add node.
OpDesc desc;
desc.SetInput("X", std::vector<std::string>({conv_out->Name()}));
desc.SetInput("Y", std::vector<std::string>({eltwise_y_in_node->Name()}));
desc.SetOutput("Out", std::vector<std::string>({ac_out->Name()}));
desc.SetType("elementwise_add");
desc.SetAttr("axis", 1);
desc.SetAttr("use_mkldnn", conv->Op()->GetAttrIfExists<bool>("use_mkldnn"));
auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph, {ac_scale, ac_bias, affine_channel});
IR_NODE_LINK_TO(conv_out, eltwise_op);
IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op);
IR_NODE_LINK_TO(eltwise_op, ac_out);
found_conv_ac_count++;
};
gpd(graph, handler);
AddStatis(found_conv_ac_count);
}
ConvEltwiseAddAffineChannelFusePass::ConvEltwiseAddAffineChannelFusePass() {
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "AnyLayout"})
.End();
AddOpCompat(OpCompat("affine_channel"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("data_layout")
.IsStringIn({"NCHW", "AnyLayout"})
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(1)
.End();
}
void ConvEltwiseAddAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
GraphPatternDetector gpd;
auto* conv_input =
gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
->AsInput()
->assert_is_op_input("conv2d", "Input");
patterns::ConvAffineChannel conv_ac_pattern(gpd.mutable_pattern(),
name_scope_);
conv_ac_pattern(conv_input, true /*with_eltwise_add*/);
int found_conv_ac_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING)
<< "ConvEltwiseAddAffineChannelFusePass in op compat failed.";
return;
}
VLOG(4) << "handle ConvBN fuse";
GET_CONV_BN_NODES(conv_ac_pattern);
auto data_format = conv->Op()->GetAttrIfExists<std::string>("data_format");
if (data_format == "AnyLayout") {
LOG_FIRST_N(WARNING, 1) << "conv_eltwiseadd_affine_channel_fuse_pass is "
"enabled, it's wrong if data_format of conv "
"is not NCHW.";
}
// OPERATORS
GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_ac_pattern);
// BIAS inputs
GET_IR_NODE_FROM_SUBGRAPH(eltwise_y_in, eltwise_y_in, conv_ac_pattern);
// BIAS outputs
GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, conv_ac_pattern);
// Get eltwise_y (conv bias) variable
auto* eltwise_y_in_tensor =
scope->FindVar(eltwise_y_in->Name())->GetMutable<LoDTensor>();
// Get batch norm bias
auto* ac_bias_tensor =
scope->FindVar(ac_bias->Name())->GetMutable<LoDTensor>();
recompute_bias_and_weights(scope, conv_weight, *ac_scale, *ac_bias_tensor,
eltwise_y_in_tensor);
// Update the elementwise_add node
eltwise->Op()->SetAttr("axis", 1);
eltwise->Op()->SetOutput("Out", std::vector<std::string>({ac_out->Name()}));
GraphSafeRemoveNodes(graph,
{ac_scale, ac_bias, affine_channel, eltwise_out});
IR_NODE_LINK_TO(eltwise, ac_out);
found_conv_ac_count++;
};
gpd(graph, handler);
AddStatis(found_conv_ac_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(conv_affine_channel_fuse_pass,
paddle::framework::ir::ConvAffineChannelFusePass);
REGISTER_PASS(conv_eltwiseadd_affine_channel_fuse_pass,
paddle::framework::ir::ConvEltwiseAddAffineChannelFusePass);
REGISTER_PASS_CAPABILITY(conv_affine_channel_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1)
.EQ("affine_channel", 0));
REGISTER_PASS_CAPABILITY(conv_eltwiseadd_affine_channel_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1)
.LE("elementwise_add", 1)
.EQ("affine_channel", 0));
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse the Conv and ConvAffineChannel.
*/
class Graph;
class ConvAffineChannelFusePass : public FusePassBase {
public:
ConvAffineChannelFusePass();
virtual ~ConvAffineChannelFusePass() {}
protected:
void ApplyImpl(ir::Graph*) const override;
const std::string name_scope_{"conv_affine_channel_fuse"};
};
class ConvEltwiseAddAffineChannelFusePass : public FusePassBase {
public:
ConvEltwiseAddAffineChannelFusePass();
virtual ~ConvEltwiseAddAffineChannelFusePass() {}
protected:
void ApplyImpl(ir::Graph*) const override;
const std::string name_scope_{"conv_eltwiseadd_affine_channel_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -75,13 +75,11 @@ void PaddlePassBuilder::AppendAnalysisPass(const std::string &pass) {
void PaddlePassBuilder::ClearPasses() { passes_.clear(); }
const std::vector<std::string> kTRTSubgraphPasses({
"conv_affine_channel_fuse_pass", //
"adaptive_pool2d_convert_global_pass",
"conv_eltwiseadd_affine_channel_fuse_pass", //
"shuffle_channel_detect_pass", //
"quant_conv2d_dequant_fuse_pass", //
"delete_quant_dequant_op_pass", //
"delete_quant_dequant_filter_op_pass", //
"adaptive_pool2d_convert_global_pass",
"shuffle_channel_detect_pass", //
"quant_conv2d_dequant_fuse_pass", //
"delete_quant_dequant_op_pass", //
"delete_quant_dequant_filter_op_pass", //
// "fc_fuse_pass", //
"simplify_with_basic_ops_pass", //
"embedding_eltwise_layernorm_fuse_pass", //
......@@ -134,22 +132,20 @@ const std::vector<std::string> kLiteSubgraphPasses({
GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
passes_.assign({
// "identity_scale_op_clean_pass", //
"is_test_pass", //
"simplify_with_basic_ops_pass", //
"conv_affine_channel_fuse_pass", //
"conv_eltwiseadd_affine_channel_fuse_pass", //
"conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", //
"embedding_eltwise_layernorm_fuse_pass", //
"multihead_matmul_fuse_pass_v2", //
"gpu_cpu_squeeze2_matmul_fuse_pass", //
"gpu_cpu_reshape2_matmul_fuse_pass", //
"gpu_cpu_flatten2_matmul_fuse_pass", //
"gpu_cpu_map_matmul_v2_to_mul_pass", //
"gpu_cpu_map_matmul_v2_to_matmul_pass", //
"gpu_cpu_map_matmul_to_mul_pass", //
"fc_fuse_pass", //
"fc_elementwise_layernorm_fuse_pass", //
"is_test_pass", //
"simplify_with_basic_ops_pass", //
"conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", //
"embedding_eltwise_layernorm_fuse_pass", //
"multihead_matmul_fuse_pass_v2", //
"gpu_cpu_squeeze2_matmul_fuse_pass", //
"gpu_cpu_reshape2_matmul_fuse_pass", //
"gpu_cpu_flatten2_matmul_fuse_pass", //
"gpu_cpu_map_matmul_v2_to_mul_pass", //
"gpu_cpu_map_matmul_v2_to_matmul_pass", //
"gpu_cpu_map_matmul_to_mul_pass", //
"fc_fuse_pass", //
"fc_elementwise_layernorm_fuse_pass", //
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
// guaranteed at least v7
// cudnn8.0 has memory leak problem in conv + eltwise + act, so we
......@@ -236,14 +232,12 @@ void CpuPassStrategy::EnableMKLDNN() {
passes_.insert(passes_.begin(), "mkldnn_placement_pass");
for (auto &pass : std::vector<std::string>({
"depthwise_conv_mkldnn_pass", //
"conv_bn_fuse_pass", // Execute BN passes again to
"conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order
"conv_affine_channel_fuse_pass", //
"conv_eltwiseadd_affine_channel_fuse_pass", //
"conv_transpose_bn_fuse_pass", //
"conv_transpose_eltwiseadd_bn_fuse_pass", //
"conv_bias_mkldnn_fuse_pass", //
"depthwise_conv_mkldnn_pass", //
"conv_bn_fuse_pass", // Execute BN passes again to
"conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order
"conv_transpose_bn_fuse_pass", //
"conv_transpose_eltwiseadd_bn_fuse_pass", //
"conv_bias_mkldnn_fuse_pass", //
"conv_transpose_bias_mkldnn_fuse_pass",
// TODO(baoachun): Need to support 5-dimensional input.
// "conv3d_bias_mkldnn_fuse_pass", //
......
......@@ -426,9 +426,6 @@ class Quant2Int8MkldnnPass(object):
graph = self._apply_pass(graph, 'depthwise_conv_mkldnn_pass')
graph = self._apply_pass(graph, 'conv_bn_fuse_pass')
graph = self._apply_pass(graph, 'conv_eltwiseadd_bn_fuse_pass')
graph = self._apply_pass(graph, 'conv_affine_channel_fuse_pass')
graph = self._apply_pass(graph,
'conv_eltwiseadd_affine_channel_fuse_pass')
graph = self._apply_pass(graph, 'conv_transpose_bn_fuse_pass')
graph = self._apply_pass(graph,
'conv_transpose_eltwiseadd_bn_fuse_pass')
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import PassAutoScanTest, IgnoreReasons
from program_config import TensorConfig, ProgramConfig, OpConfig
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
import unittest
import hypothesis
from hypothesis import given, settings, seed, example, assume, reproduce_failure
import hypothesis.strategies as st
class TestConvAffineChannelFusePass(PassAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def sample_program_config(self, draw):
padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"]))
groups = draw(st.integers(min_value=1, max_value=3))
data_format = draw(st.sampled_from(["NCHW", "NHWC"]))
axis = draw(st.sampled_from([1]))
filter_channel = draw(st.integers(min_value=1, max_value=16)) * 4
filter_size = draw(st.integers(min_value=1, max_value=4))
in_channel = groups * filter_channel
out_channel_factor = draw(st.integers(min_value=1, max_value=16)) * 4
out_channel = groups * out_channel_factor
batch_size = draw(st.integers(min_value=1, max_value=4))
dilations = draw(
st.lists(
st.integers(
min_value=1, max_value=2), min_size=2, max_size=2))
paddings = draw(
st.lists(
st.integers(
min_value=0, max_value=2), min_size=2, max_size=2))
strides = draw(
st.lists(
st.integers(
min_value=1, max_value=2), min_size=2, max_size=2))
has_bias = draw(st.booleans())
x_shape = [
batch_size, in_channel, 64, 64
] if data_format == "NCHW" else [batch_size, 64, 64, in_channel]
w_shape = [out_channel, filter_channel, filter_size, filter_size]
scale_shape = [out_channel]
bias_shape = [out_channel]
def generate_input():
return np.random.random(x_shape).astype(np.float32)
def generate_weight():
return np.random.random(w_shape).astype(np.float32)
def generate_bias():
return np.random.random(bias_shape).astype(np.float32)
def generate_scale_bias():
return np.random.random(bias_shape).astype(np.float32)
conv2d_op = OpConfig(
"conv2d",
inputs={
"Input": ["input_data"],
"Filter": ["conv2d_weight"],
},
outputs={"Output": ["conv_output"]},
data_format=data_format,
dilations=dilations,
padding_algorithm=padding_algorithm,
groups=groups,
paddings=paddings,
strides=strides,
has_bias=has_bias,
is_test=True)
ac_op = OpConfig(
"affine_channel",
inputs={
"X": ["conv_output"],
"Scale": ["affine_channel_scale"],
"Bias": ["affine_channel_bias"]
},
outputs={"Out": ["affine_channel_ouput"]},
data_layout=data_format)
if has_bias == True:
conv2d_op.inputs["Bias"] = ["conv2d_bias"]
ops = [conv2d_op, ac_op]
program_config = ProgramConfig(
ops=ops,
inputs={
"input_data": TensorConfig(data_gen=partial(generate_input)),
},
weights={
"conv2d_weight":
TensorConfig(data_gen=partial(generate_weight)),
"affine_channel_scale":
TensorConfig(data_gen=partial(generate_scale_bias)),
"affine_channel_bias":
TensorConfig(data_gen=partial(generate_scale_bias)),
},
outputs=["affine_channel_ouput"])
if has_bias == True:
program_config.weights["conv2d_bias"] = TensorConfig(
data_gen=partial(generate_bias))
return program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_gpu=True)
yield config, ['conv2d', 'elementwise_add'], (1e-4, 1e-4)
config = self.create_inference_config(use_mkldnn=True)
yield config, ['conv2d', 'elementwise_add'], (1e-4, 1e-4)
def add_ignore_pass_case(self):
# If the problem has been fixed, the judgment
# in is_program_valid needs to be deleted!!!
def teller1(program_config, predictor_config):
if program_config.ops[0].attrs['data_format'] == "NHWC":
return True
return False
# mkldnn Output has diff with bias!
def teller2(program_config, predictor_config):
return predictor_config.mkldnn_enabled() and program_config.ops[
0].attrs['has_bias'] == True
self.add_ignore_check_case(
teller1, IgnoreReasons.PASS_ACCURACY_ERROR,
"The output format of conv2d is wrong when data_format attribute is NHWC, \
because currently its fused op (Conv2DFusion) only supports data format of channel first (NCHW)."
)
self.add_ignore_check_case(
teller2, IgnoreReasons.PASS_ACCURACY_ERROR,
"Currently mkldnn Output has diff with bias!")
def test(self):
self.run_and_statis(
quant=False,
passes=["conv_affine_channel_fuse_pass"], )
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import PassAutoScanTest, IgnoreReasons
from program_config import TensorConfig, ProgramConfig, OpConfig
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
import unittest
import hypothesis
from hypothesis import given, settings, seed, example, assume
import hypothesis.strategies as st
class TestConvEltwiseAddAffineChannelFusePass(PassAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
attrs = [
program_config.ops[i].attrs
for i in range(len(program_config.ops))
]
if attrs[0]['data_format'] == "NHWC" and attrs[1]['axis'] != 3:
return False
return True
def sample_program_config(self, draw):
padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"]))
groups = draw(st.integers(min_value=1, max_value=3))
data_format = draw(st.sampled_from(["NCHW", "NHWC"]))
axis = draw(st.sampled_from([1]))
filter_channel = draw(st.integers(min_value=1, max_value=16)) * 4
filter_size = draw(st.integers(min_value=1, max_value=4))
in_channel = groups * filter_channel
out_channel_factor = draw(st.integers(min_value=1, max_value=16)) * 4
out_channel = groups * out_channel_factor
batch_size = draw(st.integers(min_value=1, max_value=4))
dilations = draw(
st.lists(
st.integers(
min_value=1, max_value=2), min_size=2, max_size=2))
paddings = draw(
st.lists(
st.integers(
min_value=0, max_value=2), min_size=2, max_size=2))
strides = draw(
st.lists(
st.integers(
min_value=1, max_value=2), min_size=2, max_size=2))
has_bias = draw(st.booleans())
x_shape = [
batch_size, in_channel, 64, 64
] if data_format == "NCHW" else [batch_size, 64, 64, in_channel]
w_shape = [out_channel, filter_channel, filter_size, filter_size]
scale_shape = [out_channel]
bias_shape = [out_channel]
def generate_input():
return np.random.random(x_shape).astype(np.float32)
def generate_weight():
return np.random.random(w_shape).astype(np.float32)
def generate_bias():
return np.random.random(bias_shape).astype(np.float32)
def generate_scale_bias():
return np.random.random(bias_shape).astype(np.float32)
conv2d_op = OpConfig(
"conv2d",
inputs={
"Input": ["input_data"],
"Filter": ["conv2d_weight"],
},
outputs={"Output": ["conv_output"]},
data_format=data_format,
dilations=dilations,
padding_algorithm=padding_algorithm,
groups=groups,
paddings=paddings,
strides=strides,
has_bias=has_bias,
is_test=True)
eltwise_op = OpConfig(
"elementwise_add",
inputs={"X": ["conv_output"],
"Y": ["conv2d_bias"]},
outputs={"Out": ["elementwise_output"]},
axis=axis)
ac_op = OpConfig(
"affine_channel",
inputs={
"X": ["elementwise_output"],
"Scale": ["affine_channel_scale"],
"Bias": ["affine_channel_bias"]
},
outputs={"Out": ["affine_channel_ouput"]},
data_layout=data_format)
if has_bias == True:
conv2d_op.inputs["Bias"] = ["conv2d_bias"]
ops = [conv2d_op, eltwise_op, ac_op]
program_config = ProgramConfig(
ops=ops,
inputs={
"input_data": TensorConfig(data_gen=partial(generate_input)),
},
weights={
"conv2d_weight":
TensorConfig(data_gen=partial(generate_weight)),
"conv2d_bias": TensorConfig(data_gen=partial(generate_bias)),
"affine_channel_scale":
TensorConfig(data_gen=partial(generate_scale_bias)),
"affine_channel_bias":
TensorConfig(data_gen=partial(generate_scale_bias)),
},
outputs=["affine_channel_ouput"])
return program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_gpu=True)
yield config, ['conv2d', 'elementwise_add'], (1e-4, 1e-4)
config = self.create_inference_config(use_mkldnn=True)
yield config, ['conv2d', 'elementwise_add'], (1e-4, 1e-4)
# TRT
config = self.create_trt_inference_config()
config.enable_tensorrt_engine(
workspace_size=1 << 20,
max_batch_size=4,
min_subgraph_size=1,
precision_mode=paddle_infer.PrecisionType.Float32,
use_static=False,
use_calib_mode=False)
yield config, ['conv2d', 'elementwise_add'], (1e-4, 1e-4)
def add_ignore_pass_case(self):
# If the problem has been fixed, the judgment
# in is_program_valid needs to be deleted!!!
def teller1(program_config, predictor_config):
if program_config.ops[0].attrs['data_format'] == "NHWC":
return True
return False
# mkldnn Output has diff with bias!
def teller2(program_config, predictor_config):
return predictor_config.mkldnn_enabled() and program_config.ops[
0].attrs['has_bias'] == True
self.add_ignore_check_case(
teller1, IgnoreReasons.PASS_ACCURACY_ERROR,
"The output format of conv2d is wrong when data_format attribute is NHWC, \
it will trigger Broadcast dimension mismatch bug \
when data_format attribute is NHWC and axis of eltwise op is 1 for this pass."
)
self.add_ignore_check_case(
teller2, IgnoreReasons.PASS_ACCURACY_ERROR,
"Currently mkldnn Output has diff with bias!")
def test(self):
self.run_and_statis(
quant=False,
passes=["conv_eltwiseadd_affine_channel_fuse_pass"], )
if __name__ == "__main__":
unittest.main()
......@@ -958,7 +958,6 @@ FOURTH_HIGH_PARALLEL_JOB_NEW = [
'test_dynamic_rnn_stop_gradient', 'test_raw_program_optimizer', 'test_pow',
'test_inplace_softmax_with_cross_entropy', 'test_transforms',
'test_unfold_op', 'test_assign_op', 'test_isinstance',
'test_conv_affine_channel_fuse_pass',
'auto_growth_best_fit_allocator_facade_test', 'test_cholesky_op',
'test_adaptive_avg_pool3d', 'test_paddle_save_load_binary',
'test_fused_fc_elementwise_layernorm_op', 'test_sequence_enumerate_op',
......@@ -1873,7 +1872,6 @@ TETRAD_PARALLEL_JOB = [
'test_dataloader_unkeep_order',
'test_parallel_executor_profiler',
'test_correlation',
'test_conv_affine_channel_fuse_pass',
'test_ir_inplace_pass',
'test_moving_average_abs_max_scale_op',
'test_flatten_contiguous_range_op',
......
......@@ -578,7 +578,6 @@ STATIC_MODE_TESTING_LIST = [
'test_ir_embedding_eltwise_layernorm_fuse_pass',
'test_ir_fc_fuse_pass',
'test_ir_skip_layernorm_pass',
'test_conv_affine_channel_fuse_pass',
'test_conv_bias_mkldnn_fuse_pass',
'test_conv_bn_fuse_pass',
'test_conv_elementwise_add2_act_fuse_pass',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册