未验证 提交 e2aacd21 编写于 作者: H Hulek 提交者: GitHub

Rewrite mkldnn conv bn fuse pass tester (#50034)

* New onednn test

* checkopoint

* added new test, fixed issue with onednn bias

* fix bias check

* remove prints, refactor code

* delete old test

* update python tests cmake

* Delete depracated conv bias

* Delete outdated bias from convolution test
上级 dd1410d7
......@@ -432,10 +432,6 @@ if(WITH_MKLDNN)
if(WITH_GPU OR WITH_ROCM)
set(TEST_CONV_BN_PASS_DEPS ${TEST_CONV_BN_PASS_DEPS} depthwise_conv)
endif()
cc_test(
test_conv_batch_norm_mkldnn_fuse_pass
SRCS mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc
DEPS ${TEST_CONV_BN_PASS_DEPS})
cc_test(
test_mkldnn_placement_pass
SRCS mkldnn/mkldnn_placement_pass_tester.cc
......
......@@ -359,73 +359,95 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
auto* bn_bias_tensor =
scope->FindVar(bn_bias->Name())->GetMutable<phi::DenseTensor>();
float epsilon =
PADDLE_GET_CONST(float, batch_norm->Op()->GetAttr("epsilon"));
bool is_mkldnn = fuse_option == FUSE_MKLDNN;
auto input_names = conv->Op()->InputNames();
bool has_bias = std::find(input_names.begin(), input_names.end(), "Bias") !=
input_names.end() &&
conv->Op()->Input("Bias").size() > 0;
bool mkldnn_with_bias = is_mkldnn && has_bias;
// Create eltwise_y (conv bias) variable
VarDesc eltwise_y_in_desc(
patterns::PDNodeName("fuse_conv_bn", conv_type() + "_eltwise_y_in"));
eltwise_y_in_desc.SetShape(phi::vectorize(bn_bias_tensor->dims()));
eltwise_y_in_desc.SetDataType(
framework::TransToProtoVarType(bn_bias_tensor->dtype()));
eltwise_y_in_desc.SetLoDLevel(bn_bias->Var()->GetLoDLevel());
eltwise_y_in_desc.SetPersistable(true);
auto* eltwise_y_in_node = g->CreateVarNode(&eltwise_y_in_desc);
auto* eltwise_y_in_tensor =
scope->Var(eltwise_y_in_node->Name())->GetMutable<phi::DenseTensor>();
phi::DenseTensor* eltwise_y_in_tensor;
Node* eltwise_y_in_node;
if (!mkldnn_with_bias) {
VarDesc eltwise_y_in_desc(
patterns::PDNodeName("fuse_conv_bn", conv_type() + "_eltwise_y_in"));
eltwise_y_in_desc.SetShape(phi::vectorize(bn_bias_tensor->dims()));
eltwise_y_in_desc.SetDataType(
framework::TransToProtoVarType(bn_bias_tensor->dtype()));
eltwise_y_in_desc.SetLoDLevel(bn_bias->Var()->GetLoDLevel());
eltwise_y_in_desc.SetPersistable(true);
eltwise_y_in_node = g->CreateVarNode(&eltwise_y_in_desc);
eltwise_y_in_tensor =
scope->Var(eltwise_y_in_node->Name())->GetMutable<phi::DenseTensor>();
// Initialize eltwise_y
eltwise_y_in_tensor->Resize(bn_bias_tensor->dims());
std::fill_n(eltwise_y_in_tensor->mutable_data<float>(platform::CPUPlace()),
eltwise_y_in_tensor->numel(),
0.0f);
// Initialize eltwise_y
eltwise_y_in_tensor->Resize(bn_bias_tensor->dims());
std::fill_n(
eltwise_y_in_tensor->mutable_data<float>(platform::CPUPlace()),
eltwise_y_in_tensor->numel(),
0.0f);
// update weights and biases
float epsilon =
PADDLE_GET_CONST(float, batch_norm->Op()->GetAttr("epsilon"));
recompute_bias_and_weights(scope,
conv_weight,
*bn_scale,
*bn_bias_tensor,
*bn_mean,
*bn_variance,
eltwise_y_in_tensor,
epsilon,
conv_type());
// update weights and biases
recompute_bias_and_weights(scope,
conv_weight,
*bn_scale,
*bn_bias_tensor,
*bn_mean,
*bn_variance,
eltwise_y_in_tensor,
epsilon,
conv_type());
if (tensor_type == paddle::experimental::DataType::FLOAT16) {
ConvertTensorType<float, float16>(conv_weight_tensor);
ConvertTensorType<float, float16>(eltwise_y_in_tensor);
if (tensor_type == paddle::experimental::DataType::FLOAT16) {
ConvertTensorType<float, float16>(conv_weight_tensor);
ConvertTensorType<float, float16>(eltwise_y_in_tensor);
}
}
// with MKL-DNN fuse conv+bn into conv with bias
// without MKL-DNN fuse conv+bn into conv+elementwise_add
if (fuse_option == FUSE_MKLDNN) {
if (is_mkldnn) {
if (conv->Op()->Type() == "conv2d" ||
conv->Op()->Type() == "depthwise_conv2d") {
conv->Op()->SetType("fused_conv2d");
}
auto input_names = conv->Op()->InputNames();
bool has_bias =
std::find(input_names.begin(), input_names.end(), "Bias") !=
input_names.end();
if (has_bias && conv->Op()->Input("Bias").size() > 0) {
if (mkldnn_with_bias) {
// reuse existing conv bias node
auto conv_bias_names = conv->Op()->Input("Bias");
PADDLE_ENFORCE_EQ(
conv_bias_names.size(),
1UL,
platform::errors::InvalidArgument("Find input var Bais error."));
phi::errors::InvalidArgument("Find input var Bias error."));
auto* conv_bias_var = scope->FindVar(conv_bias_names[0]);
auto* conv_bias_tensor = conv_bias_var->GetMutable<phi::DenseTensor>();
PADDLE_ENFORCE_EQ(
conv_bias_tensor->dims(),
eltwise_y_in_tensor->dims(),
platform::errors::InvalidArgument(
"phi::DenseTensor convolution bias(%d) and elementwise y(%d) "
"must have same dims.",
conv_bias_tensor->dims().size(),
eltwise_y_in_tensor->dims().size()));
auto eigen_conv_bias = EigenVector<float>::From(*conv_bias_tensor);
eigen_conv_bias += EigenVector<float>::From(*eltwise_y_in_tensor);
PADDLE_ENFORCE_EQ(conv_bias_tensor->dims(),
bn_bias_tensor->dims(),
phi::errors::InvalidArgument(
"phi::DenseTensor convolution bias(%d) and batch "
"normalization bias (%d) "
"must have same dims.",
conv_bias_tensor->dims().size(),
bn_bias_tensor->dims().size()));
recompute_bias_and_weights(scope,
conv_weight,
*bn_scale,
*bn_bias_tensor,
*bn_mean,
*bn_variance,
conv_bias_tensor,
epsilon,
conv_type());
if (tensor_type == paddle::experimental::DataType::FLOAT16) {
ConvertTensorType<float, float16>(conv_weight_tensor);
ConvertTensorType<float, float16>(conv_bias_tensor);
}
} else {
// add new conv_bias node
conv->Op()->SetInput(
......@@ -453,7 +475,7 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
IR_NODE_LINK_TO(conv, bn_out);
found_conv_bn_count++;
} else { // fuse_option == FUSE_NATIVE
// create an elementwise add node.
// create an elementwise add node.
OpDesc desc;
desc.SetInput("X", std::vector<std::string>({conv_out->Name()}));
desc.SetInput("Y", std::vector<std::string>({eltwise_y_in_node->Name()}));
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <random>
#include <string>
#include <unordered_set>
#include "paddle/utils/tribool.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
PD_DECLARE_KERNEL(conv2d_transpose, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(batch_norm, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(gelu, CPU, ALL_LAYOUT);
USE_OP_ITSELF(batch_norm);
PD_DECLARE_KERNEL(batch_norm, OneDNN, ONEDNN);
USE_OP_ITSELF(conv2d_transpose);
PD_DECLARE_KERNEL(conv2d_transpose, OneDNN, ONEDNN);
USE_OP_ITSELF(elementwise_add);
PD_DECLARE_KERNEL(add_raw, OneDNN, ONEDNN);
USE_OP_ITSELF(gelu);
PD_DECLARE_KERNEL(gelu, OneDNN, ONEDNN);
PD_DECLARE_ARG_MAPPING_FN(gelu);
namespace paddle {
namespace framework {
namespace ir {
class MKLDNNConvBatchNormPassTest {
private:
void SetOp(ProgramDesc* prog,
const std::string& type,
const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
paddle::tribool use_mkldnn) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
if (!paddle::indeterminate(use_mkldnn))
op->SetAttr("use_mkldnn", use_mkldnn);
if (type == "conv2d_transpose") {
op->SetAttr("name", name);
op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]});
op->SetOutput("Output", {outputs[0]});
op->SetAttr("is_test", true);
op->SetAttr("strides", std::vector<int>(2, 2));
} else if (std::unordered_set<std::string>{
"gelu", "leaky_relu", "relu", "tanh"}
.count(type)) {
op->SetInput("X", inputs);
op->SetOutput("Out", {outputs[0]});
} else if (type == "elementwise_add") {
op->SetAttr("axis", static_cast<int>(1));
op->SetInput("X", {inputs[0]});
op->SetInput("Y", {inputs[1]});
op->SetOutput("Out", {outputs[0]});
} else if (type == "batch_norm") {
op->SetAttr("is_test", true);
op->SetAttr("epsilon", static_cast<float>(1e-5));
op->SetInput("X", {inputs[0]});
op->SetInput("Scale", {inputs[1]});
op->SetInput("Bias", {inputs[2]});
op->SetInput("Mean", {inputs[3]});
op->SetInput("Variance", {inputs[4]});
op->SetOutput("Y", {outputs[0]});
op->SetOutput("MeanOut", {outputs[1]});
op->SetOutput("VarianceOut", {outputs[2]});
op->SetOutput("SavedMean", {outputs[3]});
op->SetOutput("SavedVariance", {outputs[4]});
} else {
FAIL() << "Unexpected operator type.";
}
}
ProgramDesc BuildProgramDesc(bool is_elementwise_add) {
ProgramDesc prog;
// params
for (auto& v : std::vector<std::string>({"weights",
"weights2",
"bias_bn",
"scale",
"mean",
"variance",
"saved_mean",
"saved_variance",
"bias_bn2",
"scale2",
"mean2",
"variance2",
"saved_mean2",
"saved_variance2"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::LOD_TENSOR);
var->SetPersistable(true);
}
// inputs and non-persistant holders
for (auto& v : std::vector<std::string>(
{"a", "b", "e", "f", "g", "h", "i", "j", "k", "l", "m"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::LOD_TENSOR);
}
SetOp(&prog,
"conv2d_transpose",
"conv1",
std::vector<std::string>({"a", "weights"}),
std::vector<std::string>({"f"}),
true);
if (is_elementwise_add == true) {
SetOp(&prog,
"conv2d_transpose",
"conv2",
std::vector<std::string>({"b", "weights2"}),
std::vector<std::string>({"e"}),
true);
SetOp(&prog,
"elementwise_add",
"elementwise_add1",
std::vector<std::string>({"f", "g"}),
std::vector<std::string>({"h"}),
true);
SetOp(&prog,
"elementwise_add",
"elementwise_add2",
std::vector<std::string>({"e", "g"}),
std::vector<std::string>({"j"}),
true);
SetOp(&prog,
"batch_norm",
"batch_norm1",
std::vector<std::string>(
{"h", "scale", "bias_bn", "mean", "variance"}),
std::vector<std::string>(
{"i", "mean", "variance", "saved_mean", "saved_variance"}),
true);
SetOp(&prog,
"batch_norm",
"batch_norm2",
std::vector<std::string>(
{"j", "scale2", "bias_bn2", "mean2", "variance2"}),
std::vector<std::string>(
{"k", "mean2", "variance2", "saved_mean2", "saved_variance2"}),
true);
SetOp(&prog,
"elementwise_add",
"elementwise_add3",
std::vector<std::string>({"i", "k"}),
std::vector<std::string>({"l"}),
true);
} else {
SetOp(&prog,
"batch_norm",
"batch_norm1",
std::vector<std::string>(
{"f", "scale", "bias_bn", "mean", "variance"}),
std::vector<std::string>(
{"l", "mean", "variance", "saved_mean", "saved_variance"}),
true);
}
SetOp(&prog,
"gelu",
"gelu1",
std::vector<std::string>({"l"}),
std::vector<std::string>({"m"}),
true);
return prog;
}
void FillTensorWithRandomData(phi::DenseTensor* tnsr,
float lowb,
float upb,
phi::CPUPlace place) {
float* ptr = tnsr->mutable_data<float>(place);
// Initialize input data
std::uniform_real_distribution<float> dist(static_cast<float>(lowb),
static_cast<float>(upb));
std::mt19937 engine;
for (int i = 0; i < tnsr->numel(); ++i) {
ptr[i] = dist(engine);
}
}
void CompareTensors(phi::DenseTensor* tensor1, phi::DenseTensor* tensor2) {
// check dims
for (int i = 0; i < tensor1->numel(); ++i) {
EXPECT_NEAR(tensor1->data<float>()[i], tensor2->data<float>()[i], 1e-3);
}
}
public:
void MainTest(bool is_elementwise_add) {
auto base_prog = BuildProgramDesc(is_elementwise_add);
std::unique_ptr<ir::Graph> graph(new ir::Graph(base_prog));
Scope scope;
auto place = phi::CPUPlace();
NaiveExecutor exe{place};
auto pass = PassRegistry::Instance().Get(
is_elementwise_add ? "conv_transpose_eltwiseadd_bn_fuse_pass"
: "conv_transpose_bn_fuse_pass");
graph->SetNotOwned(kParamScopeAttr, &scope);
auto& prog = graph->OriginProgram();
exe.CreateVariables(prog, 0, true, &scope);
exe.CreateVariables(prog, 0, false, &scope);
exe.Prepare(&scope, prog, 0, false);
std::cout << GenScopeTreeDebugInfo(&scope);
auto* a_tensor = exe.FindTensor("a");
auto* b_tensor = exe.FindTensor("b");
auto* weights_tensor = exe.FindTensor("weights");
auto* weights2_tensor = exe.FindTensor("weights2");
auto* g_tensor = exe.FindTensor("g");
// Batch Norm
auto* bias_bn_tensor = exe.FindTensor("bias_bn"); // shift
auto* scale_tensor = exe.FindTensor("scale");
auto* mean_tensor = exe.FindTensor("mean");
auto* variance_tensor = exe.FindTensor("variance");
auto* bias_bn2_tensor = exe.FindTensor("bias_bn2"); // shift
auto* scale2_tensor = exe.FindTensor("scale2");
auto* mean2_tensor = exe.FindTensor("mean2");
auto* variance2_tensor = exe.FindTensor("variance2");
int ic, oc, iw, ih, n, fw, fh;
n = 1;
fw = fh = 2;
oc = ic = 24;
iw = ih = 160;
// mb1_ic24oc24_ih8oh16kh2sh2dh0ph0_iw80ow160kw2sw2dw0pw0 deconv
a_tensor->Resize({n, ic, ih, iw});
weights_tensor->Resize({oc, ic, fh, fw});
g_tensor->Resize({oc});
bias_bn_tensor->Resize({oc});
scale_tensor->Resize({oc});
mean_tensor->Resize({oc});
variance_tensor->Resize({oc});
if (is_elementwise_add) {
b_tensor->Resize({n, ic, ih, iw});
weights2_tensor->Resize({oc, ic, fh, fw});
bias_bn2_tensor->Resize({oc});
scale2_tensor->Resize({oc});
mean2_tensor->Resize({oc});
variance2_tensor->Resize({oc});
}
// Input and conv transpose
FillTensorWithRandomData(a_tensor, 1.0f, 2.0f, place);
FillTensorWithRandomData(g_tensor, 1.0f, 2.0f, place);
FillTensorWithRandomData(weights_tensor, 1.0f, 2.0f, place);
if (is_elementwise_add) {
FillTensorWithRandomData(b_tensor, 1.0f, 2.0f, place);
FillTensorWithRandomData(weights2_tensor, 1.0f, 2.0f, place);
}
// First Batch_Norm
FillTensorWithRandomData(bias_bn_tensor, 1.0f, 2.0f, place);
FillTensorWithRandomData(scale_tensor, 1.0f, 2.0f, place);
FillTensorWithRandomData(mean_tensor, 1.0f, 2.0f, place);
FillTensorWithRandomData(variance_tensor, 1.0f, 2.0f, place);
// Second Batch Norm (exists only when elementwise_add is present)
if (is_elementwise_add) {
FillTensorWithRandomData(bias_bn2_tensor, 1.0f, 2.0f, place);
FillTensorWithRandomData(scale2_tensor, 1.0f, 2.0f, place);
FillTensorWithRandomData(mean2_tensor, 1.0f, 2.0f, place);
FillTensorWithRandomData(variance2_tensor, 1.0f, 2.0f, place);
}
exe.Run();
// Get result without IR passes applied
// Need to copy result over as the same scope is used in both executors
// so first result will be overwritten by second
auto* m_tensor = exe.FindTensor("m");
phi::DenseTensor no_ir_result;
TensorCopy(*m_tensor, place, &no_ir_result);
graph.reset(pass->Apply(graph.release()));
// Get Program from graph
ProgramDesc optimized_prog;
auto graph2program_pass =
paddle::framework::ir::PassRegistry::Instance().Get(
"graph_to_program_pass");
graph2program_pass->SetNotOwned<paddle::framework::ProgramDesc>(
"program", &optimized_prog);
graph2program_pass->Apply(graph.release());
exe.Prepare(&scope, optimized_prog, 0, false);
exe.Run();
auto* ir_result = exe.FindTensor("m");
// Two graphs. Execute both and compare results
CompareTensors(&no_ir_result, ir_result);
VLOG(3) << DebugString(graph);
}
};
TEST(MKLDNNConvBatchNormPassTest, conv_batch_norm) {
MKLDNNConvBatchNormPassTest().MainTest(false);
}
TEST(MKLDNNConvBatchNormPassTest, conv_elementwise_add_batch_norm) {
MKLDNNConvBatchNormPassTest().MainTest(true);
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(conv_transpose_bn_fuse_pass);
USE_PASS(conv_transpose_eltwiseadd_bn_fuse_pass);
USE_PASS(graph_to_program_pass);
......@@ -252,6 +252,7 @@ if(WITH_GPU AND TENSORRT_FOUND)
endif()
if(WITH_MKLDNN)
set_tests_properties(test_onednn_conv_bn_fuse_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_onednn_conv_elementwise_add_fuse_pass
PROPERTIES TIMEOUT 120)
set_tests_properties(test_mkldnn_depthwise_conv_pass PROPERTIES TIMEOUT 120)
......
......@@ -60,7 +60,6 @@ class TestConvBnFusePass(PassAutoScanTest):
st.integers(min_value=1, max_value=2), min_size=2, max_size=2
)
)
has_bias = draw(st.booleans())
use_mkldnn = draw(st.booleans())
epsilon = draw(st.floats(min_value=0.0, max_value=0.001))
......@@ -110,7 +109,7 @@ class TestConvBnFusePass(PassAutoScanTest):
paddings=paddings,
strides=strides,
use_mkldnn=use_mkldnn,
has_bias=has_bias,
has_bias=False,
is_test=True,
)
bn_op = OpConfig(
......@@ -135,8 +134,6 @@ class TestConvBnFusePass(PassAutoScanTest):
data_layout=data_format,
is_test=True,
)
if has_bias:
conv2d_op.inputs["Bias"] = ["conv2d_bias"]
ops = [conv2d_op, bn_op]
program_config = ProgramConfig(
......@@ -157,10 +154,6 @@ class TestConvBnFusePass(PassAutoScanTest):
},
outputs=["batch_norm_Y"],
)
if has_bias:
program_config.weights["conv2d_bias"] = TensorConfig(
data_gen=partial(generate_conv2d_Bias)
)
return program_config
def sample_predictor_configs(self, program_config):
......@@ -185,10 +178,7 @@ class TestConvBnFusePass(PassAutoScanTest):
use_static=False,
use_calib_mode=False,
)
if program_config.ops[0].attrs['has_bias']:
yield config, ['conv2d', 'elementwise_add'], (1e-5, 1e-5)
else: # it will enter conv_elementwise_add_fuse_pass
yield config, ['conv2d_fusion'], (1e-5, 1e-5)
yield config, ['conv2d_fusion'], (1e-5, 1e-5)
def add_ignore_pass_case(self):
def teller1(program_config, predictor_config):
......@@ -199,25 +189,12 @@ class TestConvBnFusePass(PassAutoScanTest):
return True
return False
# mkldnn Output has diff with bias!
def teller2(program_config, predictor_config):
return (
predictor_config.mkldnn_enabled()
and program_config.ops[0].attrs['has_bias']
)
self.add_ignore_check_case(
teller1,
IgnoreReasons.PASS_ACCURACY_ERROR,
"The output format of conv2d is wrong when data_format attribute is NHWC",
)
self.add_ignore_check_case(
teller2,
IgnoreReasons.PASS_ACCURACY_ERROR,
"Currently mkldnn Output has diff with bias!",
)
def test(self):
self.run_and_statis(
quant=False,
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
import hypothesis.strategies as st
import numpy as np
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
class TestOneDNNConvBnFusePass(PassAutoScanTest):
def sample_program_config(self, draw):
use_mkldnn = True
padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"]))
groups = draw(st.integers(min_value=1, max_value=3))
data_format = draw(st.sampled_from(["NCHW", "NHWC"]))
axis = draw(st.sampled_from([1]))
filter_channel = draw(st.integers(min_value=1, max_value=16)) * 4
filter_size = draw(st.integers(min_value=1, max_value=4))
in_channel = groups * filter_channel
out_channel_factor = draw(st.integers(min_value=1, max_value=16)) * 4
out_channel = groups * out_channel_factor
batch_size = draw(st.integers(min_value=1, max_value=4))
dilations = draw(
st.lists(
st.integers(min_value=1, max_value=2), min_size=2, max_size=2
)
)
paddings = draw(
st.lists(
st.integers(min_value=0, max_value=2), min_size=2, max_size=2
)
)
strides = draw(
st.lists(
st.integers(min_value=1, max_value=2), min_size=2, max_size=2
)
)
epsilon = draw(st.floats(min_value=0.0, max_value=0.001))
x_shape = (
[batch_size, in_channel, 64, 64]
if data_format == "NCHW"
else [batch_size, 64, 64, in_channel]
)
w_shape = [out_channel, filter_channel, filter_size, filter_size]
scale_shape = [out_channel]
bias_shape = [out_channel]
var_shape = [out_channel]
mean_shape = [out_channel]
def generate_data(shape):
return np.random.random(shape).astype(np.float32)
conv2d_op = OpConfig(
"conv2d",
inputs={
"Input": ["conv2d_input"],
"Filter": ["conv2d_weight"],
},
outputs={"Output": ["conv2d_out"]},
data_format=data_format,
dilations=dilations,
padding_algorithm=padding_algorithm,
groups=groups,
paddings=paddings,
strides=strides,
use_mkldnn=use_mkldnn,
has_bias=False,
is_test=True,
)
bn_op = OpConfig(
"batch_norm",
inputs={
"X": ["conv2d_out"],
"Scale": ["batch_norm_Scale"],
"Bias": ["batch_norm_Bias"],
"Mean": ["batch_norm_Mean"],
"Variance": ["batch_norm_Variance"],
},
outputs={
"Y": ["batch_norm_Y"],
"MeanOut": ["batch_norm_Mean"],
"VarianceOut": ["batch_norm_Variance"],
"SavedMean": ["batch_norm_SavedMean"],
"SavedVariance": ["batch_norm_SavedVariance"],
"ReserveSpace": ["batch_norm_ReserveSpace"],
},
epsilon=epsilon,
trainable_statistics=False,
data_layout=data_format,
is_test=True,
)
ops = [conv2d_op, bn_op]
program_config = ProgramConfig(
ops=ops,
inputs={
"conv2d_input": TensorConfig(
data_gen=partial(generate_data, x_shape)
),
},
weights={
"conv2d_weight": TensorConfig(
data_gen=partial(generate_data, w_shape)
),
"batch_norm_Scale": TensorConfig(
data_gen=partial(generate_data, scale_shape)
),
"batch_norm_Bias": TensorConfig(
data_gen=partial(generate_data, bias_shape)
),
"batch_norm_Mean": TensorConfig(
data_gen=partial(generate_data, mean_shape)
),
"batch_norm_Variance": TensorConfig(
data_gen=partial(generate_data, var_shape)
),
},
outputs=["batch_norm_Y"],
)
return program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config()
config.enable_mkldnn()
yield config, ['fused_conv2d'], (1e-5, 1e-5)
def test(self):
self.run_and_statis(
quant=False,
max_examples=100,
passes=["conv_bn_fuse_pass"],
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册