未验证 提交 edc06c6a 编写于 作者: J jakpiase 提交者: GitHub

Added fc + activation fuse pass (currently only gelu, sigmoid and tanh are supported) (#29772)

上级 0e0bb1b9
......@@ -105,6 +105,7 @@ if(WITH_MKLDNN)
pass_library(cpu_bfloat16_placement_pass inference DIR mkldnn)
pass_library(cpu_bfloat16_pass inference DIR mkldnn)
pass_library(fc_mkldnn_pass inference DIR mkldnn)
pass_library(fc_act_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(cpu_quantize_placement_pass base DIR mkldnn)
pass_library(cpu_quantize_pass inference DIR mkldnn)
pass_library(cpu_quantize_squash_pass inference DIR mkldnn)
......@@ -155,6 +156,7 @@ if (WITH_MKLDNN)
cc_test(test_conv_activation_mkldnn_fuse_pass SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc DEPS conv_activation_mkldnn_fuse_pass)
cc_test(test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass)
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass)
cc_test(test_fc_act_mkldnn_fuse_pass SRCS mkldnn/fc_act_mkldnn_fuse_pass_tester.cc DEPS fc_act_mkldnn_fuse_pass)
cc_test(test_batch_norm_act_fuse_pass SRCS mkldnn/batch_norm_act_fuse_pass_tester.cc DEPS batch_norm_act_fuse_pass)
set(TEST_CONV_BN_PASS_DEPS conv_bn_fuse_pass graph_to_program_pass conv_op conv_transpose_op math_function im2col vol2col batch_norm_op gelu_op activation_op elementwise_add_op concat_and_split naive_executor device_context)
if (WITH_GPU)
......
......@@ -1017,6 +1017,23 @@ PDNode *patterns::FCMKLDNN::operator()(paddle::framework::ir::PDNode *x,
return fc_out_var;
}
PDNode *patterns::FCActOneDNN::operator()(const std::string &act_type) {
auto *fc = pattern->NewNode(fc_repr())->assert_is_op("fc");
auto *fc_out = pattern->NewNode(fc_out_repr())
->assert_is_op_output("fc", "Out")
->assert_is_op_input(act_type);
auto *act =
pattern->NewNode(act_repr())->assert_is_op(act_type)->AsIntermediate();
auto *act_out = pattern->NewNode(act_out_repr())
->assert_is_op_output(act_type, "Out")
->AsOutput();
fc->LinksTo({fc_out});
act->LinksFrom({fc_out}).LinksTo({act_out});
return act_out;
}
PDNode *patterns::Embedding::operator()(PDNode *x) {
x->assert_is_op_input("lookup_table", "Ids");
auto *lookup_table_op =
......
......@@ -552,6 +552,27 @@ struct FCMKLDNN : public PatternBase {
PATTERN_DECL_NODE(output);
};
//
// \brief Pattern looking for fc and a directly following activation
// operator.
//
// \note Currently only gelu and tanh are supported as an activation
// function.
// Formula: act(fc(x))
// Op: fc + act
struct FCActOneDNN : public PatternBase {
FCActOneDNN(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "fc_act_onednn") {}
PDNode* operator()(const std::string& act_type);
// declare operator node's name
PATTERN_DECL_NODE(fc);
PATTERN_DECL_NODE(act);
PATTERN_DECL_NODE(fc_out);
PATTERN_DECL_NODE(act_out);
};
// Embedding
struct Embedding : public PatternBase {
Embedding(PDPattern* pattern, const std::string& name_scope)
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
using string::PrettyLogDetail;
void FuseFCActOneDNNPass::ApplyImpl(Graph *graph) const {
std::vector<std::string> act_types = {"gelu", "tanh", "sigmoid"};
for (std::string act_type : act_types) FuseFCAct(graph, act_type);
}
void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
const std::string &act_type) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("fc_act", graph);
GraphPatternDetector gpd;
patterns::FCActOneDNN fc_act_pattern(gpd.mutable_pattern(), "fc_act");
fc_act_pattern(act_type);
int found_fc_act_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
VLOG(4) << "Fuse fc with activation op.";
// FC output
GET_IR_NODE_FROM_SUBGRAPH(fc_out, fc_out, fc_act_pattern);
// ACT output
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, fc_act_pattern);
// ops
GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fc_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act, act, fc_act_pattern);
auto *fc_op = fc->Op();
auto *act_op = act->Op();
if (fc_op->HasAttr("use_mkldnn")) {
PADDLE_ENFORCE(
BOOST_GET_CONST(bool, fc_op->GetAttr("use_mkldnn")),
platform::errors::PreconditionNotMet(
"The FC+Act fusion may happen only when oneDNN library "
"is used."));
}
if (act_type == "gelu" && act_op->HasAttr("approximate")) {
bool approximate = BOOST_GET_CONST(bool, act_op->GetAttr("approximate"));
std::string type = approximate ? "_tanh" : "_erf";
fc_op->SetAttr("activation_type", act_type + type);
} else
fc_op->SetAttr("activation_type", act_type);
fc_op->SetAttr("use_mkldnn", true);
fc_op->SetOutput("Out", {act_out->Name()});
IR_OP_VAR_LINK(fc, act_out);
GraphSafeRemoveNodes(g, {act, fc_out});
found_fc_act_count++;
};
gpd(graph, handler);
AddStatis(found_fc_act_count);
PrettyLogDetail("--- fused %d fc with %s activation", found_fc_act_count,
act_type);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fc_act_mkldnn_fuse_pass,
paddle::framework::ir::FuseFCActOneDNNPass);
REGISTER_PASS_CAPABILITY(fc_act_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("fc", 0)
.LE("gelu", 0)
.LE("sigmoid", 0)
.LE("tanh", 0));
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* \brief Fuse the FC and activation operators into single OneDNN's
* FC with post-op.
*
* \note Currently only GeLU, sigmoid and tanh are supported as an activation
* function.
*/
class FuseFCActOneDNNPass : public FusePassBase {
public:
virtual ~FuseFCActOneDNNPass() {}
protected:
void ApplyImpl(ir::Graph *graph) const override;
void FuseFCAct(ir::Graph *graph, const std::string &act_types) const;
};
} // namespace ir
} // namespace framework
} // namespace paddlea
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <algorithm>
#include <exception>
#include <functional>
#include <iterator>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/errors.h"
namespace paddle {
namespace framework {
namespace ir {
// -------------------------- helper functions --------------------------------
namespace {
using InOutVarNamePair = std::pair<std::string, std::string>;
using OpTypeCountPair = std::pair<std::string, int>;
///
/// @brief Creates the specified operator and sets up its inputs/outputs.
///
/// @param prog The program descriptor to which we add new op.
/// @param[in] op_type_name The operator type name.
/// @param[in] inputs The vector of input pairs: {input_name, variable
/// name}
/// @param[in] outputs The vector of output pairs {output_name, variable}
/// @param[in] use_mkldnn The flag deciding whether or not to set
/// 'use_mkldnn' attribute.
///
/// @return Returns pointer to the created operator descriptor.
///
OpDesc* CreateOp(ProgramDesc* prog, const std::string& op_type_name,
const std::vector<InOutVarNamePair>& inputs,
const std::vector<InOutVarNamePair>& outputs,
bool use_mkldnn = true) {
auto op = prog->MutableBlock(0)->AppendOp();
op->SetType(op_type_name);
op->SetAttr("use_mkldnn", use_mkldnn);
for (const auto& input : inputs) {
op->SetInput(input.first, {input.second});
}
for (const auto& output : outputs) {
op->SetOutput(output.first, {output.second});
}
return op;
}
///
/// @brief Check whether node 'to' is reachable from node 'from' in graph.
///
/// @param[in] graph The graph we're checking for reachability.
/// @param[in] from The 'from' node name.
/// @param[in] to The 'to' node name.
///
/// @return True if there is connection between nodes 'from' and 'to'.
///
bool TestIsReachable(const Graph& graph, std::string from, std::string to) {
auto hash = [](const Node* node) -> std::string {
return node->Name() + std::to_string(node->id());
};
auto find_node = [&](const Graph& graph, const std::string& name) -> Node* {
for (auto& node : GraphTraits::DFS(graph)) {
if (name == hash(&node)) {
return &node;
}
}
return nullptr;
};
if (from == to) return true;
std::map<std::string, bool> visited;
// update the from and to strings to hashed equivs in loop from graph traits
for (auto& node : GraphTraits::DFS(graph)) {
auto hashed = hash(&node);
if (node.Name() == from) {
from = hashed;
}
if (node.Name() == to) {
to = hashed;
}
visited[hashed] = false;
}
visited[from] = true;
std::list<std::string> queue;
queue.push_back(from);
while (!queue.empty()) {
auto cur = find_node(graph, queue.front());
queue.pop_front();
if (cur == nullptr) {
return false;
}
for (auto n : cur->outputs) {
auto hashed_name = hash(n);
if (hashed_name == to) {
return true;
}
if (!visited[hashed_name]) {
visited[hashed_name] = true;
queue.push_back(hashed_name);
}
}
}
return false;
}
///
/// @brief Search through graph and counts provided operator occurences.
///
/// @param[in] graph The graph we search through.
/// @param[in] op_type_count The vector of pairs {op_type_name, op count}
///
/// @note After going through all graph nodes this function asserts
/// whether counted number for each requested op is as expected.
///
void AssertOpsCount(const Graph& graph,
std::vector<OpTypeCountPair> op_type_count) {
for (auto* node : graph.Nodes()) {
if (!node->IsOp()) {
continue;
}
const std::string op_type_name = node->Op()->Type();
auto op_it =
std::find_if(std::begin(op_type_count), std::end(op_type_count),
[op_type_name](const OpTypeCountPair& p) {
return op_type_name == p.first;
});
if (op_it != std::end(op_type_count)) {
op_it->second--;
}
}
for (const OpTypeCountPair& p : op_type_count) {
EXPECT_EQ(p.second, 0);
}
}
///
/// @brief Builds a program descriptor.
///
/// @param[in] transient_vars The vector of transient variables names.
/// @param[in] persistent_vars The vector of persistent variables names. Those
/// will have persistable attribute set to true.
///
/// @return The program descriptor object.
///
ProgramDesc BuildProgramDesc(const std::vector<std::string>& transient_vars,
const std::vector<std::string>& persistent_vars) {
ProgramDesc prog;
auto add_var_to_prog = [&prog](const std::string& var_name) -> VarDesc* {
auto var = prog.MutableBlock(0)->Var(var_name);
var->SetType(proto::VarType::LOD_TENSOR);
return var;
};
for (const auto& v : transient_vars) {
add_var_to_prog(v);
}
for (const auto& v : persistent_vars) {
auto* var = add_var_to_prog(v);
var->SetPersistable(true);
}
return prog;
}
///
/// @brief Execute pass on provided graph and perform checks.
///
/// @param graph The graph we run pass on.
/// @param[in] from The name of a 'starting' node sequence in a
/// graph. This would be used to test for
/// correct node connections.
/// @param[in] to The name of a 'ending' node sequence in a
/// graph. This would be used to test for
/// correct node connections.
/// @param[in] removed_nodes_count The number of nodes we expect will be
/// removed/fused after pass execution.
/// @param[in] added_nodes_count The number of nodes we expect will be
/// added after pass execution.
///
void RunPassAndAssert(Graph* graph, const std::string& from,
const std::string& to, int removed_nodes_count,
int added_nodes_count = 0) {
EXPECT_TRUE(TestIsReachable(*graph, from, to));
int original_nodes_num = graph->Nodes().size();
auto pass = PassRegistry::Instance().Get("fc_act_mkldnn_fuse_pass");
pass->Apply(graph);
int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(TestIsReachable(*graph, from, to));
EXPECT_EQ(original_nodes_num - removed_nodes_count + added_nodes_count,
current_nodes_num);
}
} // namespace
// ------------------------------ Test cases -----------------------------------
TEST(FuseFCActOneDNNPass, ThrowUseMkldnn) {
auto prog = BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
CreateOp(&prog, "fc",
{
{"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"},
},
{{"Out", "fc_y"}}, false);
CreateOp(&prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
// No fusion in this attribute configuration
constexpr int removed_nodes_count = 0;
EXPECT_THROW(RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count),
paddle::platform::EnforceNotMet);
}
TEST(FuseFCActOneDNNPass, FuseWithGeluTanh) {
auto prog = BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
CreateOp(&prog, "fc",
{
{"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"},
},
{{"Out", "fc_y"}});
auto* act_op =
CreateOp(&prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false);
act_op->SetAttr("approximate", true);
Graph graph(prog);
constexpr int removed_nodes_count = 2;
RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count);
AssertOpsCount(graph, {{"fc", 1}, {"gelu", 0}});
for (const auto* node : graph.Nodes()) {
if (node->IsOp() && node->Op()->Type() == "fc") {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type"));
auto act_type =
BOOST_GET_CONST(std::string, op->GetAttr("activation_type"));
EXPECT_TRUE(act_type.compare("gelu_tanh") == 0);
}
}
}
TEST(FuseFCActOneDNNPass, FuseWithGeluErf) {
auto prog = BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
CreateOp(&prog, "fc",
{
{"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"},
},
{{"Out", "fc_y"}});
auto* act_op =
CreateOp(&prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false);
act_op->SetAttr("approximate", false);
Graph graph(prog);
constexpr int removed_nodes_count = 2;
RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count);
AssertOpsCount(graph, {{"fc", 1}, {"gelu", 0}});
for (const auto* node : graph.Nodes()) {
if (node->IsOp() && node->Op()->Type() == "fc") {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type"));
auto act_type =
BOOST_GET_CONST(std::string, op->GetAttr("activation_type"));
EXPECT_TRUE(act_type.compare("gelu_erf") == 0);
}
}
}
TEST(FuseFCActOneDNNPass, FuseWithGeluAuto) {
auto prog = BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
CreateOp(&prog, "fc",
{
{"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"},
},
{{"Out", "fc_y"}});
CreateOp(&prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
constexpr int removed_nodes_count = 2;
RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count);
AssertOpsCount(graph, {{"fc", 1}, {"gelu", 0}});
for (const auto* node : graph.Nodes()) {
if (node->IsOp() && node->Op()->Type() == "fc") {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type"));
auto act_type =
BOOST_GET_CONST(std::string, op->GetAttr("activation_type"));
EXPECT_TRUE(act_type.compare("gelu") == 0);
}
}
}
TEST(FuseFCActOneDNNPass, FuseWithTanh) {
auto prog = BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
CreateOp(&prog, "fc",
{
{"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"},
},
{{"Out", "fc_y"}});
CreateOp(&prog, "tanh", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
constexpr int removed_nodes_count = 2;
RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count);
AssertOpsCount(graph, {{"fc", 1}, {"tanh", 0}});
for (const auto* node : graph.Nodes()) {
if (node->IsOp() && node->Op()->Type() == "fc") {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type"));
auto act_type =
BOOST_GET_CONST(std::string, op->GetAttr("activation_type"));
EXPECT_TRUE(act_type.compare("tanh") == 0);
}
}
}
TEST(FuseFCActOneDNNPass, FuseWithSigmoid) {
auto prog = BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
CreateOp(&prog, "fc",
{
{"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"},
},
{{"Out", "fc_y"}});
CreateOp(&prog, "sigmoid", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
constexpr int removed_nodes_count = 2;
RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count);
AssertOpsCount(graph, {{"fc", 1}, {"sigmoid", 0}});
for (const auto* node : graph.Nodes()) {
if (node->IsOp() && node->Op()->Type() == "fc") {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type"));
auto act_type =
BOOST_GET_CONST(std::string, op->GetAttr("activation_type"));
EXPECT_TRUE(act_type.compare("sigmoid") == 0);
}
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(fc_act_mkldnn_fuse_pass);
......@@ -206,7 +206,8 @@ void CpuPassStrategy::EnableMKLDNN() {
"reshape_transpose_matmul_mkldnn_fuse_pass", //
"matmul_transpose_reshape_fuse_pass", //
// Disabled due to topology-dependent speed-up
// "fc_mkldnn_pass",
//"fc_mkldnn_pass",
//"fc_act_mkldnn_fuse_pass",
"batch_norm_act_fuse_pass",
"mkldnn_inplace_pass", // This pass should be activated after
// fuses
......
......@@ -206,6 +206,7 @@ void profile(bool use_mkldnn = false) {
"relu", "fc"};
cfg.SetMKLDNNOp(op_list);
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
cfg.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass");
}
std::vector<std::vector<PaddleTensor>> outputs;
......@@ -262,6 +263,7 @@ void compare(bool use_mkldnn = false) {
"relu"};
cfg.SetMKLDNNOp(op_list);
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
cfg.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass");
}
std::vector<std::vector<PaddleTensor>> input_slots_all;
......
......@@ -50,8 +50,10 @@ void profile(bool use_mkldnn = false) {
if (use_mkldnn) {
cfg.EnableMKLDNN();
if (!FLAGS_disable_mkldnn_fc)
if (!FLAGS_disable_mkldnn_fc) {
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
cfg.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass");
}
}
std::vector<std::vector<PaddleTensor>> outputs;
......@@ -83,8 +85,10 @@ void compare(bool use_mkldnn = false) {
SetConfig(&cfg);
if (use_mkldnn) {
cfg.EnableMKLDNN();
if (!FLAGS_disable_mkldnn_fc)
if (!FLAGS_disable_mkldnn_fc) {
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
cfg.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass");
}
}
std::vector<std::vector<PaddleTensor>> input_slots_all;
......
......@@ -163,6 +163,7 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) {
if (use_mkldnn) {
cfg->EnableMKLDNN();
cfg->pass_builder()->AppendPass("fc_mkldnn_pass");
cfg->pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass");
}
// Enable seqpool_concat_fuse_pass, disabled by default since it takes much
// time
......
......@@ -25,6 +25,7 @@ void compare(bool use_mkldnn = false) {
if (use_mkldnn) {
cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
cfg.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass");
}
std::vector<std::vector<PaddleTensor>> input_slots_all;
......
......@@ -26,6 +26,7 @@ void profile(bool use_mkldnn = false) {
if (use_mkldnn) {
cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
cfg.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass");
}
std::vector<std::vector<PaddleTensor>> input_slots_all;
......
......@@ -86,6 +86,7 @@ void profile(bool use_mkldnn = false) {
if (use_mkldnn) {
cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
cfg.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass");
}
// cfg.pass_builder()->TurnOnDebug();
std::vector<std::vector<PaddleTensor>> outputs;
......@@ -136,6 +137,7 @@ void compare(bool use_mkldnn = false) {
if (use_mkldnn) {
cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
cfg.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass");
}
std::vector<std::vector<PaddleTensor>> input_slots_all;
......
......@@ -459,6 +459,36 @@ class FCPrimitiveFactory {
constexpr float placeholder = 1.0f; // beta
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
negative_slope, placeholder);
} else if (ctx.Attr<std::string>("activation_type") == "gelu") {
constexpr float scale = 1.0f;
constexpr float alpha = 0.0f;
constexpr float beta = 0.0f;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_gelu,
alpha, beta);
} else if (ctx.Attr<std::string>("activation_type") == "gelu_tanh") {
constexpr float scale = 1.0f;
constexpr float alpha = 0.0f;
constexpr float beta = 0.0f;
post_operations.append_eltwise(
scale, mkldnn::algorithm::eltwise_gelu_tanh, alpha, beta);
} else if (ctx.Attr<std::string>("activation_type") == "gelu_erf") {
constexpr float scale = 1.0f;
constexpr float alpha = 0.0f;
constexpr float beta = 0.0f;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_gelu_erf,
alpha, beta);
} else if (ctx.Attr<std::string>("activation_type") == "tanh") {
constexpr float scale = 1.0f;
constexpr float alpha = 0.0f;
constexpr float beta = 0.0f;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_tanh,
alpha, beta);
} else if (ctx.Attr<std::string>("activation_type") == "sigmoid") {
constexpr float scale = 1.0f;
constexpr float alpha = 0.0f;
constexpr float beta = 0.0f;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_logistic,
alpha, beta);
}
attributes.set_post_ops(post_operations);
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test for fusion of fc and activation."""
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
from inference_pass_test import InferencePassTest
from paddle import enable_static
from paddle.fluid.core import PassVersionChecker
enable_static()
class FCGeluTanhOneDnnFusePassTest(InferencePassTest):
def setUp(self):
self.set_params()
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 128, 768], dtype="float32")
fc_out = fluid.layers.fc(input=data, size=3072, num_flatten_dims=2)
gelu_out = fluid.layers.gelu(fc_out, approximate=False)
self.feeds = {"data": np.random.random((1, 128, 768)).astype("float32")}
self.fetch_list = [gelu_out]
self.enable_mkldnn = True
def set_params(self):
self.pass_name = "fc_act_mkldnn_fuse_pass"
def test_check_output(self):
self.check_output()
class FCGeluErfOneDnnFusePassTest(InferencePassTest):
def setUp(self):
self.set_params()
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 128, 768], dtype="float32")
fc_out = fluid.layers.fc(input=data, size=3072, num_flatten_dims=2)
gelu_out = fluid.layers.gelu(fc_out, approximate=True)
self.feeds = {"data": np.random.random((1, 128, 768)).astype("float32")}
self.fetch_list = [gelu_out]
self.enable_mkldnn = True
def set_params(self):
self.pass_name = "fc_act_mkldnn_fuse_pass"
def test_check_output(self):
self.check_output()
self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name))
class FCTanhOneDnnFusePassTest(InferencePassTest):
def setUp(self):
self.set_params()
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 128, 768], dtype="float32")
fc_out = fluid.layers.fc(input=data, size=3072, num_flatten_dims=2)
tanh_out = fluid.layers.tanh(fc_out)
self.feeds = {"data": np.random.random((1, 128, 768)).astype("float32")}
self.fetch_list = [tanh_out]
self.enable_mkldnn = True
def set_params(self):
self.pass_name = "fc_act_mkldnn_fuse_pass"
def test_check_output(self):
self.check_output()
self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name))
class FCSigmoidOneDnnFusePassTest(InferencePassTest):
def setUp(self):
self.set_params()
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 128, 768], dtype="float32")
fc_out = fluid.layers.fc(input=data, size=3072, num_flatten_dims=2)
sigmoid_out = fluid.layers.sigmoid(fc_out)
self.feeds = {"data": np.random.random((1, 128, 768)).astype("float32")}
self.fetch_list = [sigmoid_out]
self.enable_mkldnn = True
def set_params(self):
self.pass_name = "fc_act_mkldnn_fuse_pass"
def test_check_output(self):
self.check_output()
self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册