未验证 提交 7db747d9 编写于 作者: A Adam Osewski 提交者: GitHub

oneDNN BatchNorm + Act fusion pass. (#27912)

上级 fb7f8529
......@@ -110,6 +110,7 @@ if(WITH_MKLDNN)
pass_library(cpu_quantize_squash_pass inference DIR mkldnn)
pass_library(reshape_transpose_matmul_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(matmul_transpose_reshape_fuse_pass inference DIR mkldnn)
pass_library(batch_norm_act_fuse_pass inference DIR mkldnn)
endif()
cc_library(fuse_bn_act_pass SRCS fuse_bn_act_pass.cc DEPS pass graph_pattern_detector )
......@@ -151,6 +152,7 @@ if (WITH_MKLDNN)
cc_test(test_conv_activation_mkldnn_fuse_pass SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc DEPS conv_activation_mkldnn_fuse_pass)
cc_test(test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass)
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass)
cc_test(test_batch_norm_act_fuse_pass SRCS mkldnn/batch_norm_act_fuse_pass_tester.cc DEPS batch_norm_act_fuse_pass)
set(TEST_CONV_BN_PASS_DEPS conv_bn_fuse_pass graph_to_program_pass conv_op conv_transpose_op math_function im2col vol2col batch_norm_op gelu_op activation_op elementwise_add_op concat_and_split naive_executor device_context)
if (WITH_GPU)
set(TEST_CONV_BN_PASS_DEPS ${TEST_CONV_BN_PASS_DEPS} depthwise_conv)
......
......@@ -1188,6 +1188,26 @@ PDNode *patterns::BatchNormActGrad::operator()(
return bn_grad;
}
PDNode *patterns::BatchNormActOneDNN::operator()(const std::string &act_type) {
auto *bn_x = pattern->NewNode(bn_in_repr())
->AsInput()
->assert_is_op_input("batch_norm", "X");
auto *bn = pattern->NewNode(batch_norm_repr())->assert_is_op("batch_norm");
auto *bn_out = pattern->NewNode(bn_out_repr())
->assert_is_op_output("batch_norm", "Y")
->assert_is_op_input(act_type);
auto *act =
pattern->NewNode(act_repr())->assert_is_op(act_type)->AsIntermediate();
auto *act_out = pattern->NewNode(act_out_repr())
->assert_is_op_output(act_type, "Out")
->AsOutput();
bn->LinksFrom({bn_x}).LinksTo({bn_out});
act->LinksFrom({bn_out}).LinksTo({act_out});
return act_out;
}
PDNode *patterns::ElewiseAddAct::operator()(
paddle::framework::ir::PDNode *ele_x_var,
std::unordered_set<std::string> act_types) {
......
......@@ -664,6 +664,27 @@ struct BatchNormActGrad : public PatternBase {
PATTERN_DECL_NODE(d_bn_bias);
};
//
// \brief Pattern looking for batch_norm and a directly following activation
// operator.
//
// \note Currently only ReLU is supported as an activation function.
// Formula: act(bn(x))
// Op: batch_norm + act
struct BatchNormActOneDNN : public PatternBase {
BatchNormActOneDNN(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "bn_act_onednn") {}
PDNode* operator()(const std::string& act_type);
// declare operator node's name
PATTERN_DECL_NODE(bn_in);
PATTERN_DECL_NODE(batch_norm);
PATTERN_DECL_NODE(act);
PATTERN_DECL_NODE(bn_out);
PATTERN_DECL_NODE(act_out);
};
// The following patterns are used to fuse elewise_add and act
// formula: act(ele_add(x, y))
// op: elementwise_add + act
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
using string::PrettyLogDetail;
void FuseBatchNormActOneDNNPass::ApplyImpl(Graph *graph) const {
std::string act_type("relu");
FuseBatchNormAct(graph, act_type);
}
void FuseBatchNormActOneDNNPass::FuseBatchNormAct(
Graph *graph, const std::string &act_type) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument(
"The input graph of "
"FuseBatchNormActOneDNNPass should not be nullptr."));
FusePassBase::Init("bn_act", graph);
GraphPatternDetector gpd;
patterns::BatchNormActOneDNN bn_act_pattern(gpd.mutable_pattern(), "bn_act");
bn_act_pattern(act_type);
int found_bn_act_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
VLOG(4) << "Fuse BatchNorm with ReLU activation op.";
// BN output
GET_IR_NODE_FROM_SUBGRAPH(bn_out, bn_out, bn_act_pattern);
// ACT output
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, bn_act_pattern);
// ops
GET_IR_NODE_FROM_SUBGRAPH(batch_norm, batch_norm, bn_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act, act, bn_act_pattern);
auto *bn_op = batch_norm->Op();
if (bn_op->HasAttr("use_mkldnn")) {
PADDLE_ENFORCE(
BOOST_GET_CONST(bool, bn_op->GetAttr("use_mkldnn")),
platform::errors::PreconditionNotMet(
"The BatchNorm+Act fusion may happen only when oneDNN library "
"is used."));
}
if (bn_op->HasAttr("trainable_statistics")) {
PADDLE_ENFORCE(
!BOOST_GET_CONST(bool, bn_op->GetAttr("trainable_statistics")),
platform::errors::PreconditionNotMet(
"The BatchNorm+Act fusion may happen only when mean and variance "
"are not calculated by current batch statistics."));
}
if (bn_op->HasAttr("is_test")) {
PADDLE_ENFORCE(
BOOST_GET_CONST(bool, bn_op->GetAttr("is_test")),
platform::errors::PreconditionNotMet(
"The BatchNorm+Act fusion may happen only during inference."));
}
bn_op->SetAttr("use_mkldnn", true);
bn_op->SetAttr("is_test", true);
bn_op->SetAttr("fuse_with_relu", true);
bn_op->SetAttr("trainable_statistics", false);
bn_op->SetOutput("Y", {act_out->Name()});
IR_OP_VAR_LINK(batch_norm, act_out);
GraphSafeRemoveNodes(g, {act, bn_out});
found_bn_act_count++;
};
gpd(graph, handler);
AddStatis(found_bn_act_count);
PrettyLogDetail("--- fused %d batch norm with relu activation",
found_bn_act_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(batch_norm_act_fuse_pass,
paddle::framework::ir::FuseBatchNormActOneDNNPass);
REGISTER_PASS_CAPABILITY(batch_norm_act_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("batch_norm", 0)
.EQ("relu", 0));
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* \brief Fuse the BatchNorm and activation operators into single OneDNN's
* BatchNorm with post-op.
*
* \note Currently only ReLU is supported as an activation function.
*/
class FuseBatchNormActOneDNNPass : public FusePassBase {
public:
virtual ~FuseBatchNormActOneDNNPass() {}
protected:
void ApplyImpl(ir::Graph *graph) const override;
void FuseBatchNormAct(ir::Graph *graph, const std::string &act_types) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <algorithm>
#include <exception>
#include <functional>
#include <iterator>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/errors.h"
namespace paddle {
namespace framework {
namespace ir {
// -------------------------- helper functions --------------------------------
namespace {
using InOutVarNamePair = std::pair<std::string, std::string>;
using OpTypeCountPair = std::pair<std::string, int>;
///
/// @brief Creates the specified operator and sets up its inputs/outputs.
///
/// @param prog The program descriptor to which we add new op.
/// @param[in] op_type_name The operator type name.
/// @param[in] inputs The vector of input pairs: {input_name, variable
/// name}
/// @param[in] outputs The vector of output pairs {output_name, variable}
/// @param[in] use_mkldnn The flag deciding whether or not to set
/// 'use_mkldnn' attribute.
///
/// @return Returns pointer to the created operator descriptor.
///
OpDesc* CreateOp(ProgramDesc* prog, const std::string& op_type_name,
const std::vector<InOutVarNamePair>& inputs,
const std::vector<InOutVarNamePair>& outputs,
bool use_mkldnn = true) {
auto op = prog->MutableBlock(0)->AppendOp();
op->SetType(op_type_name);
op->SetAttr("use_mkldnn", use_mkldnn);
for (const auto& input : inputs) {
op->SetInput(input.first, {input.second});
}
for (const auto& output : outputs) {
op->SetOutput(output.first, {output.second});
}
return op;
}
///
/// @brief Check whether node 'to' is reachable from node 'from' in graph.
///
/// @param[in] graph The graph we're checking for reachability.
/// @param[in] from The 'from' node name.
/// @param[in] to The 'to' node name.
///
/// @return True if there is connection between nodes 'from' and 'to'.
///
bool TestIsReachable(const Graph& graph, std::string from, std::string to) {
auto hash = [](const Node* node) -> std::string {
return node->Name() + std::to_string(node->id());
};
auto find_node = [&](const Graph& graph, const std::string& name) -> Node* {
for (auto& node : GraphTraits::DFS(graph)) {
if (name == hash(&node)) {
return &node;
}
}
return nullptr;
};
if (from == to) return true;
std::map<std::string, bool> visited;
// update the from and to strings to hashed equivs in loop from graph traits
for (auto& node : GraphTraits::DFS(graph)) {
auto hashed = hash(&node);
if (node.Name() == from) {
from = hashed;
}
if (node.Name() == to) {
to = hashed;
}
visited[hashed] = false;
}
visited[from] = true;
std::list<std::string> queue;
queue.push_back(from);
while (!queue.empty()) {
auto cur = find_node(graph, queue.front());
queue.pop_front();
if (cur == nullptr) {
return false;
}
for (auto n : cur->outputs) {
auto hashed_name = hash(n);
if (hashed_name == to) {
return true;
}
if (!visited[hashed_name]) {
visited[hashed_name] = true;
queue.push_back(hashed_name);
}
}
}
return false;
}
///
/// @brief Search through graph and counts provided operator occurences.
///
/// @param[in] graph The graph we search through.
/// @param[in] op_type_count The vector of pairs {op_type_name, op count}
///
/// @note After going through all graph nodes this function asserts
/// whether counted number for each requested op is as expected.
///
void AssertOpsCount(const Graph& graph,
std::vector<OpTypeCountPair> op_type_count) {
for (auto* node : graph.Nodes()) {
if (!node->IsOp()) {
continue;
}
const std::string op_type_name = node->Op()->Type();
auto op_it =
std::find_if(std::begin(op_type_count), std::end(op_type_count),
[op_type_name](const OpTypeCountPair& p) {
return op_type_name == p.first;
});
if (op_it != std::end(op_type_count)) {
op_it->second--;
}
}
for (const OpTypeCountPair& p : op_type_count) {
EXPECT_EQ(p.second, 0);
}
}
///
/// @brief Builds a program descriptor.
///
/// @param[in] transient_vars The vector of transient variables names.
/// @param[in] persistent_vars The vector of persistent variables names. Those
/// will have persistable attribute set to true.
///
/// @return The program descriptor object.
///
ProgramDesc BuildProgramDesc(const std::vector<std::string>& transient_vars,
const std::vector<std::string>& persistent_vars) {
ProgramDesc prog;
auto add_var_to_prog = [&prog](const std::string& var_name) -> VarDesc* {
auto var = prog.MutableBlock(0)->Var(var_name);
var->SetType(proto::VarType::LOD_TENSOR);
return var;
};
for (const auto& v : transient_vars) {
add_var_to_prog(v);
}
for (const auto& v : persistent_vars) {
auto* var = add_var_to_prog(v);
var->SetPersistable(true);
}
return prog;
}
///
/// @brief Execute pass on provided graph and perform checks.
///
/// @param graph The graph we run pass on.
/// @param[in] from The name of a 'starting' node sequence in a
/// graph. This would be used to test for
/// correct node connections.
/// @param[in] to The name of a 'ending' node sequence in a
/// graph. This would be used to test for
/// correct node connections.
/// @param[in] removed_nodes_count The number of nodes we expect will be
/// removed/fused after pass execution.
/// @param[in] added_nodes_count The number of nodes we expect will be
/// added after pass execution.
///
void RunPassAndAssert(Graph* graph, const std::string& from,
const std::string& to, int removed_nodes_count,
int added_nodes_count = 0) {
EXPECT_TRUE(TestIsReachable(*graph, from, to));
int original_nodes_num = graph->Nodes().size();
auto pass = PassRegistry::Instance().Get("batch_norm_act_fuse_pass");
pass->Apply(graph);
int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(TestIsReachable(*graph, from, to));
EXPECT_EQ(original_nodes_num - removed_nodes_count + added_nodes_count,
current_nodes_num);
}
void SetBatchNormAttrs(OpDesc* bn_op, bool is_test = true,
bool trainable_stats = true) {
bn_op->SetAttr("is_test", is_test);
bn_op->SetAttr("trainable_statistics", trainable_stats);
bn_op->SetAttr("fuse_with_relu", false);
}
} // namespace
// ------------------------------ Test cases -----------------------------------
// The below test cases are distinguished by whether following attributes have
// true or false value:
// - is_test
// - trainable_statistics
// The test case name would have only attributes with true value in its name.
TEST(FuseBatchNormActOneDNNPass, ThrowIsTestTrainableStats) {
auto prog = BuildProgramDesc(
{"x", "m", "v", "bn_y", "act_y", "m_out", "var_out", "sm", "sv"},
{"scale", "bias"});
auto* bn_op = CreateOp(&prog, "batch_norm", {{"X", "x"},
{"Scale", "scale"},
{"Bias", "bias"},
{"Mean", "m"},
{"Variance", "v"}},
{{"Y", "bn_y"},
{"MeanOut", "m_out"},
{"VarianceOut", "var_out"},
{"SavedMean", "sm"},
{"SavedVariance", "sv"}});
SetBatchNormAttrs(bn_op, true, true);
CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
// No fusion in this attribute configuration
constexpr int removed_nodes_count = 0;
EXPECT_THROW(RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count),
paddle::platform::EnforceNotMet);
}
TEST(FuseBatchNormActOneDNNPass, FuseIsTest) {
auto prog =
BuildProgramDesc({"x", "m", "v", "bn_y", "act_y"}, {"scale", "bias"});
auto* bn_op = CreateOp(&prog, "batch_norm", {{"X", "x"},
{"Scale", "scale"},
{"Bias", "bias"},
{"Mean", "m"},
{"Variance", "v"}},
{{"Y", "bn_y"}});
SetBatchNormAttrs(bn_op, true, false);
CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
constexpr int removed_nodes_count = 2;
RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count);
AssertOpsCount(graph, {{"batch_norm", 1}, {"relu", 0}});
for (const auto* node : graph.Nodes()) {
if (node->IsOp() && node->Op()->Type() == "batch_norm") {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("fuse_with_relu"));
EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("fuse_with_relu")));
ASSERT_TRUE(op->HasAttr("trainable_statistics"));
EXPECT_FALSE(BOOST_GET_CONST(bool, op->GetAttr("trainable_statistics")));
}
}
}
TEST(FuseBatchNormActOneDNNPass, ThrowTrainableStats) {
auto prog = BuildProgramDesc(
{"x", "m", "v", "bn_y", "act_y", "m_out", "var_out", "sm", "sv"},
{"scale", "bias"});
auto* bn_op = CreateOp(&prog, "batch_norm", {{"X", "x"},
{"Scale", "scale"},
{"Bias", "bias"},
{"Mean", "m"},
{"Variance", "v"}},
{{"Y", "bn_y"},
{"MeanOut", "m_out"},
{"VarianceOut", "var_out"},
{"SavedMean", "sm"},
{"SavedVariance", "sv"}});
SetBatchNormAttrs(bn_op, false, true);
CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
// No fusion in this attribute configuration
constexpr int removed_nodes_count = 0;
EXPECT_THROW(RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count),
paddle::platform::EnforceNotMet);
}
TEST(FuseBatchNormActOneDNNPass, AllAttrsFalse) {
auto prog = BuildProgramDesc(
{"x", "m", "v", "bn_y", "act_y", "m_out", "var_out", "sm", "sv"},
{"scale", "bias"});
auto* bn_op = CreateOp(&prog, "batch_norm", {{"X", "x"},
{"Scale", "scale"},
{"Bias", "bias"},
{"Mean", "m"},
{"Variance", "v"}},
{{"Y", "bn_y"},
{"MeanOut", "m_out"},
{"VarianceOut", "var_out"},
{"SavedMean", "sm"},
{"SavedVariance", "sv"}});
SetBatchNormAttrs(bn_op, false, false);
CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
// No fusion in this attribute configuration
constexpr int removed_nodes_count = 0;
EXPECT_THROW(RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count),
paddle::platform::EnforceNotMet);
}
TEST(FuseBatchNormActOneDNNPass, ThrowUseMkldnn) {
auto prog = BuildProgramDesc(
{"x", "m", "v", "bn_y", "act_y", "m_out", "var_out", "sm", "sv"},
{"scale", "bias"});
auto* bn_op = CreateOp(&prog, "batch_norm", {{"X", "x"},
{"Scale", "scale"},
{"Bias", "bias"},
{"Mean", "m"},
{"Variance", "v"}},
{{"Y", "bn_y"},
{"MeanOut", "m_out"},
{"VarianceOut", "var_out"},
{"SavedMean", "sm"},
{"SavedVariance", "sv"}},
false);
SetBatchNormAttrs(bn_op, false, false);
CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
// No fusion in this attribute configuration
constexpr int removed_nodes_count = 0;
EXPECT_THROW(RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count),
paddle::platform::EnforceNotMet);
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(batch_norm_act_fuse_pass);
......@@ -207,6 +207,7 @@ void CpuPassStrategy::EnableMKLDNN() {
"matmul_transpose_reshape_fuse_pass", //
// Disabled due to topology-dependent speed-up
// "fc_mkldnn_pass",
"batch_norm_act_fuse_pass",
"mkldnn_inplace_pass", // This pass should be activated after
// fuses
})) {
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test for fusion of batch norm and activation."""
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
from inference_pass_test import InferencePassTest
from paddle import enable_static
from paddle.fluid.core import PassVersionChecker
enable_static()
class BnReluOneDnnFusePassTest(InferencePassTest):
def setUp(self):
self.set_params()
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 3, 100, 100], dtype="float32")
bn_out = fluid.layers.batch_norm(
input=data, is_test=True, use_global_stats=self.global_stats)
relu_out = fluid.layers.relu(bn_out)
self.feeds = {
"data": np.random.random((1, 3, 100, 100)).astype("float32")
}
self.fetch_list = [relu_out]
self.enable_mkldnn = True
def set_params(self):
self.global_stats = False
self.pass_name = "batch_norm_act_fuse_pass"
def test_check_output(self):
self.check_output()
self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name))
class BnReluGlobalStatsOneDnnFusePassTest(InferencePassTest):
def setUp(self):
self.set_params()
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 3, 100, 100], dtype="float32")
bn_out = fluid.layers.batch_norm(
input=data, is_test=True, use_global_stats=self.global_stats)
relu_out = fluid.layers.relu(bn_out)
self.feeds = {
"data": np.random.random((1, 3, 100, 100)).astype("float32")
}
self.fetch_list = [relu_out]
self.enable_mkldnn = True
def set_params(self):
self.global_stats = True
self.pass_name = "batch_norm_act_fuse_pass"
def test_check_output(self):
self.check_output()
self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册