未验证 提交 c5ffad12 编写于 作者: A Adam Osewski 提交者: GitHub

[oneDNN] Refactor fuse pass helper functions to one place. (#30460)

* Move pass tester helper functions to single common place.

* Use helper functions in two more fuse pass tests.
上级 1d7bf1de
......@@ -158,13 +158,14 @@ if(NOT WIN32)
cc_test(test_sync_batch_norm_pass SRCS sync_batch_norm_pass_tester.cc DEPS sync_batch_norm_pass)
endif()
if (WITH_MKLDNN)
cc_library(pass_test_util SRCS mkldnn/pass_test_util.cc DEPS graph pass)
cc_test(test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass)
cc_test(test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor)
cc_test(test_conv_activation_mkldnn_fuse_pass SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc DEPS conv_activation_mkldnn_fuse_pass)
cc_test(test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass)
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass)
cc_test(test_fc_act_mkldnn_fuse_pass SRCS mkldnn/fc_act_mkldnn_fuse_pass_tester.cc DEPS fc_act_mkldnn_fuse_pass)
cc_test(test_batch_norm_act_fuse_pass SRCS mkldnn/batch_norm_act_fuse_pass_tester.cc DEPS batch_norm_act_fuse_pass)
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass pass_test_util)
cc_test(test_fc_act_mkldnn_fuse_pass SRCS mkldnn/fc_act_mkldnn_fuse_pass_tester.cc DEPS fc_act_mkldnn_fuse_pass pass_test_util)
cc_test(test_batch_norm_act_fuse_pass SRCS mkldnn/batch_norm_act_fuse_pass_tester.cc DEPS batch_norm_act_fuse_pass pass_test_util)
set(TEST_CONV_BN_PASS_DEPS conv_bn_fuse_pass graph_to_program_pass conv_op conv_transpose_op math_function im2col vol2col batch_norm_op gelu_op activation_op elementwise_add_op concat_and_split naive_executor device_context)
if (WITH_GPU)
set(TEST_CONV_BN_PASS_DEPS ${TEST_CONV_BN_PASS_DEPS} depthwise_conv)
......
......@@ -13,17 +13,11 @@
// limitations under the License.
#include <gtest/gtest.h>
#include <algorithm>
#include <exception>
#include <functional>
#include <iterator>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/pass_test_util.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/errors.h"
......@@ -31,209 +25,15 @@ namespace paddle {
namespace framework {
namespace ir {
// -------------------------- helper functions --------------------------------
namespace {
using InOutVarNamePair = std::pair<std::string, std::string>;
using OpTypeCountPair = std::pair<std::string, int>;
///
/// @brief Creates the specified operator and sets up its inputs/outputs.
///
/// @param prog The program descriptor to which we add new op.
/// @param[in] op_type_name The operator type name.
/// @param[in] inputs The vector of input pairs: {input_name, variable
/// name}
/// @param[in] outputs The vector of output pairs {output_name, variable}
/// @param[in] use_mkldnn The flag deciding whether or not to set
/// 'use_mkldnn' attribute.
///
/// @return Returns pointer to the created operator descriptor.
///
OpDesc* CreateOp(ProgramDesc* prog, const std::string& op_type_name,
const std::vector<InOutVarNamePair>& inputs,
const std::vector<InOutVarNamePair>& outputs,
bool use_mkldnn = true) {
auto op = prog->MutableBlock(0)->AppendOp();
op->SetType(op_type_name);
op->SetAttr("use_mkldnn", use_mkldnn);
for (const auto& input : inputs) {
op->SetInput(input.first, {input.second});
}
for (const auto& output : outputs) {
op->SetOutput(output.first, {output.second});
}
return op;
}
///
/// @brief Check whether node 'to' is reachable from node 'from' in graph.
///
/// @param[in] graph The graph we're checking for reachability.
/// @param[in] from The 'from' node name.
/// @param[in] to The 'to' node name.
///
/// @return True if there is connection between nodes 'from' and 'to'.
///
bool TestIsReachable(const Graph& graph, std::string from, std::string to) {
auto hash = [](const Node* node) -> std::string {
return node->Name() + std::to_string(node->id());
};
auto find_node = [&](const Graph& graph, const std::string& name) -> Node* {
for (auto& node : GraphTraits::DFS(graph)) {
if (name == hash(&node)) {
return &node;
}
}
return nullptr;
};
if (from == to) return true;
std::map<std::string, bool> visited;
// update the from and to strings to hashed equivs in loop from graph traits
for (auto& node : GraphTraits::DFS(graph)) {
auto hashed = hash(&node);
if (node.Name() == from) {
from = hashed;
}
if (node.Name() == to) {
to = hashed;
}
visited[hashed] = false;
}
visited[from] = true;
std::list<std::string> queue;
queue.push_back(from);
while (!queue.empty()) {
auto cur = find_node(graph, queue.front());
queue.pop_front();
if (cur == nullptr) {
return false;
}
for (auto n : cur->outputs) {
auto hashed_name = hash(n);
if (hashed_name == to) {
return true;
}
if (!visited[hashed_name]) {
visited[hashed_name] = true;
queue.push_back(hashed_name);
}
}
}
return false;
}
///
/// @brief Search through graph and counts provided operator occurences.
///
/// @param[in] graph The graph we search through.
/// @param[in] op_type_count The vector of pairs {op_type_name, op count}
///
/// @note After going through all graph nodes this function asserts
/// whether counted number for each requested op is as expected.
///
void AssertOpsCount(const Graph& graph,
std::vector<OpTypeCountPair> op_type_count) {
for (auto* node : graph.Nodes()) {
if (!node->IsOp()) {
continue;
}
const std::string op_type_name = node->Op()->Type();
auto op_it =
std::find_if(std::begin(op_type_count), std::end(op_type_count),
[op_type_name](const OpTypeCountPair& p) {
return op_type_name == p.first;
});
if (op_it != std::end(op_type_count)) {
op_it->second--;
}
}
for (const OpTypeCountPair& p : op_type_count) {
EXPECT_EQ(p.second, 0);
}
}
///
/// @brief Builds a program descriptor.
///
/// @param[in] transient_vars The vector of transient variables names.
/// @param[in] persistent_vars The vector of persistent variables names. Those
/// will have persistable attribute set to true.
///
/// @return The program descriptor object.
///
ProgramDesc BuildProgramDesc(const std::vector<std::string>& transient_vars,
const std::vector<std::string>& persistent_vars) {
ProgramDesc prog;
auto add_var_to_prog = [&prog](const std::string& var_name) -> VarDesc* {
auto var = prog.MutableBlock(0)->Var(var_name);
var->SetType(proto::VarType::LOD_TENSOR);
return var;
};
for (const auto& v : transient_vars) {
add_var_to_prog(v);
}
for (const auto& v : persistent_vars) {
auto* var = add_var_to_prog(v);
var->SetPersistable(true);
}
return prog;
}
///
/// @brief Execute pass on provided graph and perform checks.
///
/// @param graph The graph we run pass on.
/// @param[in] from The name of a 'starting' node sequence in a
/// graph. This would be used to test for
/// correct node connections.
/// @param[in] to The name of a 'ending' node sequence in a
/// graph. This would be used to test for
/// correct node connections.
/// @param[in] removed_nodes_count The number of nodes we expect will be
/// removed/fused after pass execution.
/// @param[in] added_nodes_count The number of nodes we expect will be
/// added after pass execution.
///
void RunPassAndAssert(Graph* graph, const std::string& from,
const std::string& to, int removed_nodes_count,
int added_nodes_count = 0) {
EXPECT_TRUE(TestIsReachable(*graph, from, to));
int original_nodes_num = graph->Nodes().size();
auto pass = PassRegistry::Instance().Get("batch_norm_act_fuse_pass");
pass->Apply(graph);
int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(TestIsReachable(*graph, from, to));
EXPECT_EQ(original_nodes_num - removed_nodes_count + added_nodes_count,
current_nodes_num);
}
void SetBatchNormAttrs(OpDesc* bn_op, bool is_test = true,
bool trainable_stats = true) {
bn_op->SetAttr("is_test", is_test);
bn_op->SetAttr("trainable_statistics", trainable_stats);
bn_op->SetAttr("fuse_with_relu", false);
}
} // namespace
}
// ------------------------------ Test cases -----------------------------------
......@@ -244,10 +44,10 @@ void SetBatchNormAttrs(OpDesc* bn_op, bool is_test = true,
// The test case name would have only attributes with true value in its name.
TEST(FuseBatchNormActOneDNNPass, ThrowIsTestTrainableStats) {
auto prog = BuildProgramDesc(
auto prog = test::BuildProgramDesc(
{"x", "m", "v", "bn_y", "act_y", "m_out", "var_out", "sm", "sv"},
{"scale", "bias"});
auto* bn_op = CreateOp(&prog, "batch_norm", {{"X", "x"},
auto* bn_op = test::CreateOp(&prog, "batch_norm", {{"X", "x"},
{"Scale", "scale"},
{"Bias", "bias"},
{"Mean", "m"},
......@@ -258,33 +58,35 @@ TEST(FuseBatchNormActOneDNNPass, ThrowIsTestTrainableStats) {
{"SavedMean", "sm"},
{"SavedVariance", "sv"}});
SetBatchNormAttrs(bn_op, true, true);
CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false);
test::CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
// No fusion in this attribute configuration
constexpr int removed_nodes_count = 0;
EXPECT_THROW(RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count),
EXPECT_THROW(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x",
"act_y", removed_nodes_count),
paddle::platform::EnforceNotMet);
}
TEST(FuseBatchNormActOneDNNPass, FuseIsTest) {
auto prog =
BuildProgramDesc({"x", "m", "v", "bn_y", "act_y"}, {"scale", "bias"});
auto* bn_op = CreateOp(&prog, "batch_norm", {{"X", "x"},
auto prog = test::BuildProgramDesc({"x", "m", "v", "bn_y", "act_y"},
{"scale", "bias"});
auto* bn_op = test::CreateOp(&prog, "batch_norm", {{"X", "x"},
{"Scale", "scale"},
{"Bias", "bias"},
{"Mean", "m"},
{"Variance", "v"}},
{{"Y", "bn_y"}});
SetBatchNormAttrs(bn_op, true, false);
CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false);
test::CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
constexpr int removed_nodes_count = 2;
RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count);
AssertOpsCount(graph, {{"batch_norm", 1}, {"relu", 0}});
EXPECT_TRUE(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x",
"act_y", removed_nodes_count));
EXPECT_TRUE(test::AssertOpsCount(graph, {{"batch_norm", 1}, {"relu", 0}}));
for (const auto* node : graph.Nodes()) {
if (node->IsOp() && node->Op()->Type() == "batch_norm") {
......@@ -300,10 +102,10 @@ TEST(FuseBatchNormActOneDNNPass, FuseIsTest) {
}
TEST(FuseBatchNormActOneDNNPass, ThrowTrainableStats) {
auto prog = BuildProgramDesc(
auto prog = test::BuildProgramDesc(
{"x", "m", "v", "bn_y", "act_y", "m_out", "var_out", "sm", "sv"},
{"scale", "bias"});
auto* bn_op = CreateOp(&prog, "batch_norm", {{"X", "x"},
auto* bn_op = test::CreateOp(&prog, "batch_norm", {{"X", "x"},
{"Scale", "scale"},
{"Bias", "bias"},
{"Mean", "m"},
......@@ -314,21 +116,22 @@ TEST(FuseBatchNormActOneDNNPass, ThrowTrainableStats) {
{"SavedMean", "sm"},
{"SavedVariance", "sv"}});
SetBatchNormAttrs(bn_op, false, true);
CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false);
test::CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
// No fusion in this attribute configuration
constexpr int removed_nodes_count = 0;
EXPECT_THROW(RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count),
EXPECT_THROW(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x",
"act_y", removed_nodes_count),
paddle::platform::EnforceNotMet);
}
TEST(FuseBatchNormActOneDNNPass, AllAttrsFalse) {
auto prog = BuildProgramDesc(
auto prog = test::BuildProgramDesc(
{"x", "m", "v", "bn_y", "act_y", "m_out", "var_out", "sm", "sv"},
{"scale", "bias"});
auto* bn_op = CreateOp(&prog, "batch_norm", {{"X", "x"},
auto* bn_op = test::CreateOp(&prog, "batch_norm", {{"X", "x"},
{"Scale", "scale"},
{"Bias", "bias"},
{"Mean", "m"},
......@@ -339,21 +142,22 @@ TEST(FuseBatchNormActOneDNNPass, AllAttrsFalse) {
{"SavedMean", "sm"},
{"SavedVariance", "sv"}});
SetBatchNormAttrs(bn_op, false, false);
CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false);
test::CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
// No fusion in this attribute configuration
constexpr int removed_nodes_count = 0;
EXPECT_THROW(RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count),
EXPECT_THROW(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x",
"act_y", removed_nodes_count),
paddle::platform::EnforceNotMet);
}
TEST(FuseBatchNormActOneDNNPass, ThrowUseMkldnn) {
auto prog = BuildProgramDesc(
auto prog = test::BuildProgramDesc(
{"x", "m", "v", "bn_y", "act_y", "m_out", "var_out", "sm", "sv"},
{"scale", "bias"});
auto* bn_op = CreateOp(&prog, "batch_norm", {{"X", "x"},
auto* bn_op = test::CreateOp(&prog, "batch_norm", {{"X", "x"},
{"Scale", "scale"},
{"Bias", "bias"},
{"Mean", "m"},
......@@ -365,16 +169,23 @@ TEST(FuseBatchNormActOneDNNPass, ThrowUseMkldnn) {
{"SavedVariance", "sv"}},
false);
SetBatchNormAttrs(bn_op, false, false);
CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false);
test::CreateOp(&prog, "relu", {{"X", "bn_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
// No fusion in this attribute configuration
constexpr int removed_nodes_count = 0;
EXPECT_THROW(RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count),
EXPECT_THROW(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x",
"act_y", removed_nodes_count),
paddle::platform::EnforceNotMet);
}
TEST(FuseBatchNormActOneDNNPass, pass_op_version_check) {
ASSERT_TRUE(
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
.IsPassCompatible("batch_norm_act_fuse_pass"));
}
} // namespace ir
} // namespace framework
} // namespace paddle
......
......@@ -13,259 +13,151 @@
// limitations under the License.
#include <gtest/gtest.h>
#include <string>
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/pass_test_util.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
namespace {
constexpr int nodes_removed = 3;
constexpr int nodes_added = 1;
void SetOp(ProgramDesc* prog, const std::string& type,
const std::vector<std::pair<std::string, std::string>>& inputs,
const std::pair<std::string, std::string>& output) {
auto op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetAttr("use_mkldnn", true);
for (const auto& input : inputs) {
op->SetInput(input.first, {input.second});
}
op->SetOutput(output.first, {output.second});
}
struct TestIsReachable {
using func = std::function<bool(const std::string&, const std::string&)>;
auto operator()(const std::unique_ptr<ir::Graph>& graph) -> func {
auto hash = [](const Node* node) -> std::string {
return node->Name() + std::to_string(node->id());
};
auto find_node = [&](const std::unique_ptr<ir::Graph>& graph,
const std::string& name) -> Node* {
for (auto& node : GraphTraits::DFS(*graph)) {
if (name == hash(&node)) {
return &node;
}
}
return nullptr;
};
// update the from and to strings to hashed equivs in loop from graph traits
return [&](std::string from, std::string to) -> bool {
if (from == to) return true;
std::map<std::string, bool> visited;
for (auto& node : GraphTraits::DFS(*graph)) {
auto hashed = hash(&node);
if (node.Name() == from) from = hashed;
if (node.Name() == to) to = hashed;
visited[hashed] = false;
}
visited[from] = true;
std::list<std::string> queue;
queue.push_back(from);
while (!queue.empty()) {
auto cur = find_node(graph, queue.front());
queue.pop_front();
if (cur == nullptr) return false;
for (auto n : cur->outputs) {
auto hashed_name = hash(n);
if (hashed_name == to) return true;
if (!visited[hashed_name]) {
visited[hashed_name] = true;
queue.push_back(hashed_name);
}
}
}
return false;
};
}
};
void AssertOpsCount(const std::unique_ptr<ir::Graph>& graph,
int expected_conv_count,
int expected_elementwise_add_count = 0) {
int conv_count = 0;
int elementwise_add_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "conv2d") {
++conv_count;
}
if (node->IsOp() && node->Op()->Type() == "elementwise_add") {
++elementwise_add_count;
}
}
EXPECT_EQ(conv_count, expected_conv_count);
EXPECT_EQ(elementwise_add_count, expected_elementwise_add_count);
}
ProgramDesc BuildProgramDesc(const std::vector<std::string>& transient_vars,
const std::vector<std::string>& persistent_vars) {
ProgramDesc prog;
auto add_var_to_prog = [&prog](const std::string& var_name) -> VarDesc* {
auto var = prog.MutableBlock(0)->Var(var_name);
var->SetType(proto::VarType::LOD_TENSOR);
return var;
};
for (const auto& v : transient_vars) {
add_var_to_prog(v);
}
for (const auto& v : persistent_vars) {
auto var = add_var_to_prog(v);
var->SetPersistable(true);
}
return prog;
}
void RunPassAndAssert(ProgramDesc* prog, const std::string& from,
const std::string& to, int expected_conv_num) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(*prog));
TestIsReachable is_reachable;
EXPECT_TRUE(is_reachable(graph)(from, to));
auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(is_reachable(graph)(from, to));
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
current_nodes_num);
AssertOpsCount(graph, expected_conv_num);
}
} // namespace
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsYWithElementwiseAddRelu) {
auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"bias", "weights"});
auto prog =
test::BuildProgramDesc({"a", "b", "c", "d", "e"}, {"bias", "weights"});
SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"});
SetOp(&prog, "conv2d",
test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}});
test::CreateOp(&prog, "conv2d",
{{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}},
{"Output", "c"});
SetOp(&prog, "elementwise_add", {{"X", "a"}, {"Y", "c"}}, {"Out", "d"});
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
RunPassAndAssert(&prog, "a", "relu", 1);
{{"Output", "c"}});
test::CreateOp(&prog, "elementwise_add", {{"X", "a"}, {"Y", "c"}},
{{"Out", "d"}});
test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}});
Graph graph(prog);
EXPECT_TRUE(test::RunPassAndAssert(&graph,
"conv_elementwise_add_mkldnn_fuse_pass",
"a", "relu", nodes_removed, nodes_added));
EXPECT_TRUE(
test::AssertOpsCount(graph, {{"conv2d", 1}, {"elementwise_add", 0}}));
}
TEST(ConvElementwiseAddMKLDNNFusePass,
ConvolutionProjectionAsYWithElementwiseAddRelu) {
auto prog = BuildProgramDesc({"a", "b", "c", "d", "e", "f"},
auto prog = test::BuildProgramDesc({"a", "b", "c", "d", "e", "f"},
{"bias", "weights", "bias2", "weights2"});
SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"});
test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}});
// right branch
SetOp(&prog, "conv2d",
test::CreateOp(&prog, "conv2d",
{{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}},
{"Output", "c"});
{{"Output", "c"}});
// left branch
SetOp(&prog, "conv2d",
test::CreateOp(&prog, "conv2d",
{{"Input", "a"}, {"Bias", "bias2"}, {"Filter", "weights2"}},
{"Output", "f"});
{{"Output", "f"}});
SetOp(&prog, "elementwise_add", {{"X", "f"}, {"Y", "c"}}, {"Out", "d"});
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
test::CreateOp(&prog, "elementwise_add", {{"X", "f"}, {"Y", "c"}},
{{"Out", "d"}});
test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}});
RunPassAndAssert(&prog, "a", "relu", 2);
Graph graph(prog);
EXPECT_TRUE(test::RunPassAndAssert(&graph,
"conv_elementwise_add_mkldnn_fuse_pass",
"a", "relu", nodes_removed, nodes_added));
EXPECT_TRUE(
test::AssertOpsCount(graph, {{"conv2d", 2}, {"elementwise_add", 0}}));
}
TEST(ConvElementwiseAddMKLDNNFusePass,
ConvolutionAsYWithElementwiseAddReluNoBias) {
auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"});
SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"});
SetOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}},
{"Output", "c"});
SetOp(&prog, "elementwise_add", {{"X", "a"}, {"Y", "c"}}, {"Out", "d"});
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
RunPassAndAssert(&prog, "a", "relu", 1);
auto prog = test::BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"});
test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}});
test::CreateOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}},
{{"Output", "c"}});
test::CreateOp(&prog, "elementwise_add", {{"X", "a"}, {"Y", "c"}},
{{"Out", "d"}});
test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}});
Graph graph(prog);
EXPECT_TRUE(test::RunPassAndAssert(&graph,
"conv_elementwise_add_mkldnn_fuse_pass",
"a", "relu", nodes_removed, nodes_added));
EXPECT_TRUE(
test::AssertOpsCount(graph, {{"conv2d", 1}, {"elementwise_add", 0}}));
}
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsXWithElementwiseAddRelu) {
auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"bias", "weights"});
auto prog =
test::BuildProgramDesc({"a", "b", "c", "d", "e"}, {"bias", "weights"});
SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"});
SetOp(&prog, "conv2d",
test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}});
test::CreateOp(&prog, "conv2d",
{{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}},
{"Output", "c"});
{{"Output", "c"}});
SetOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "a"}}, {"Out", "d"});
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
test::CreateOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "a"}},
{{"Out", "d"}});
test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}});
RunPassAndAssert(&prog, "a", "relu", 1);
Graph graph(prog);
EXPECT_TRUE(test::RunPassAndAssert(&graph,
"conv_elementwise_add_mkldnn_fuse_pass",
"a", "relu", nodes_removed, nodes_added));
EXPECT_TRUE(
test::AssertOpsCount(graph, {{"conv2d", 1}, {"elementwise_add", 0}}));
}
TEST(ConvElementwiseAddMKLDNNFusePass,
ConvolutionAsXWithElementwiseAddReluNoBias) {
auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"});
SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"});
SetOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}},
{"Output", "c"});
SetOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "a"}}, {"Out", "d"});
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
RunPassAndAssert(&prog, "a", "relu", 1);
auto prog = test::BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"});
test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}});
test::CreateOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}},
{{"Output", "c"}});
test::CreateOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "a"}},
{{"Out", "d"}});
test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}});
Graph graph(prog);
EXPECT_TRUE(test::RunPassAndAssert(&graph,
"conv_elementwise_add_mkldnn_fuse_pass",
"a", "relu", nodes_removed, nodes_added));
EXPECT_TRUE(
test::AssertOpsCount(graph, {{"conv2d", 1}, {"elementwise_add", 0}}));
}
TEST(ConvElementwiseAddMKLDNNFusePass, NoFusion) {
auto prog =
BuildProgramDesc({"a", "b", "c", "d", "e", "f", "g"}, {"weights"});
SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"});
SetOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}},
{"Output", "c"});
SetOp(&prog, "conv2d", {{"Input", "d"}, {"Filter", "weights"}},
{"Output", "e"});
SetOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "e"}}, {"Out", "f"});
SetOp(&prog, "relu", {{"X", "f"}}, {"Out", "g"});
test::BuildProgramDesc({"a", "b", "c", "d", "e", "f", "g"}, {"weights"});
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}});
test::CreateOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}},
{{"Output", "c"}});
TestIsReachable is_reachable;
EXPECT_TRUE(is_reachable(graph)("a", "g"));
test::CreateOp(&prog, "conv2d", {{"Input", "d"}, {"Filter", "weights"}},
{{"Output", "e"}});
auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
test::CreateOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "e"}},
{{"Out", "f"}});
test::CreateOp(&prog, "relu", {{"X", "f"}}, {{"Out", "g"}});
EXPECT_TRUE(is_reachable(graph)("a", "g"));
EXPECT_EQ(original_nodes_num, current_nodes_num);
Graph graph(prog);
AssertOpsCount(graph, 2, 1);
EXPECT_TRUE(test::RunPassAndAssert(
&graph, "conv_elementwise_add_mkldnn_fuse_pass", "a", "g", 0, 0));
EXPECT_TRUE(
test::AssertOpsCount(graph, {{"conv2d", 2}, {"elementwise_add", 1}}));
}
TEST(ConvElementwiseAddMKLDNNFusePass, pass_op_version_check) {
......
......@@ -13,17 +13,11 @@
// limitations under the License.
#include <gtest/gtest.h>
#include <algorithm>
#include <exception>
#include <functional>
#include <iterator>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/pass_test_util.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/errors.h"
......@@ -31,238 +25,45 @@ namespace paddle {
namespace framework {
namespace ir {
// -------------------------- helper functions --------------------------------
namespace {
using InOutVarNamePair = std::pair<std::string, std::string>;
using OpTypeCountPair = std::pair<std::string, int>;
///
/// @brief Creates the specified operator and sets up its inputs/outputs.
///
/// @param prog The program descriptor to which we add new op.
/// @param[in] op_type_name The operator type name.
/// @param[in] inputs The vector of input pairs: {input_name, variable
/// name}
/// @param[in] outputs The vector of output pairs {output_name, variable}
/// @param[in] use_mkldnn The flag deciding whether or not to set
/// 'use_mkldnn' attribute.
///
/// @return Returns pointer to the created operator descriptor.
///
OpDesc* CreateOp(ProgramDesc* prog, const std::string& op_type_name,
const std::vector<InOutVarNamePair>& inputs,
const std::vector<InOutVarNamePair>& outputs,
bool use_mkldnn = true) {
auto op = prog->MutableBlock(0)->AppendOp();
op->SetType(op_type_name);
op->SetAttr("use_mkldnn", use_mkldnn);
for (const auto& input : inputs) {
op->SetInput(input.first, {input.second});
}
for (const auto& output : outputs) {
op->SetOutput(output.first, {output.second});
}
return op;
}
///
/// @brief Check whether node 'to' is reachable from node 'from' in graph.
///
/// @param[in] graph The graph we're checking for reachability.
/// @param[in] from The 'from' node name.
/// @param[in] to The 'to' node name.
///
/// @return True if there is connection between nodes 'from' and 'to'.
///
bool TestIsReachable(const Graph& graph, std::string from, std::string to) {
auto hash = [](const Node* node) -> std::string {
return node->Name() + std::to_string(node->id());
};
auto find_node = [&](const Graph& graph, const std::string& name) -> Node* {
for (auto& node : GraphTraits::DFS(graph)) {
if (name == hash(&node)) {
return &node;
}
}
return nullptr;
};
if (from == to) return true;
std::map<std::string, bool> visited;
// update the from and to strings to hashed equivs in loop from graph traits
for (auto& node : GraphTraits::DFS(graph)) {
auto hashed = hash(&node);
if (node.Name() == from) {
from = hashed;
}
if (node.Name() == to) {
to = hashed;
}
visited[hashed] = false;
}
visited[from] = true;
std::list<std::string> queue;
queue.push_back(from);
while (!queue.empty()) {
auto cur = find_node(graph, queue.front());
queue.pop_front();
if (cur == nullptr) {
return false;
}
for (auto n : cur->outputs) {
auto hashed_name = hash(n);
if (hashed_name == to) {
return true;
}
if (!visited[hashed_name]) {
visited[hashed_name] = true;
queue.push_back(hashed_name);
}
}
}
return false;
}
///
/// @brief Search through graph and counts provided operator occurences.
///
/// @param[in] graph The graph we search through.
/// @param[in] op_type_count The vector of pairs {op_type_name, op count}
///
/// @note After going through all graph nodes this function asserts
/// whether counted number for each requested op is as expected.
///
void AssertOpsCount(const Graph& graph,
std::vector<OpTypeCountPair> op_type_count) {
for (auto* node : graph.Nodes()) {
if (!node->IsOp()) {
continue;
}
const std::string op_type_name = node->Op()->Type();
auto op_it =
std::find_if(std::begin(op_type_count), std::end(op_type_count),
[op_type_name](const OpTypeCountPair& p) {
return op_type_name == p.first;
});
if (op_it != std::end(op_type_count)) {
op_it->second--;
}
}
for (const OpTypeCountPair& p : op_type_count) {
EXPECT_EQ(p.second, 0);
}
}
///
/// @brief Builds a program descriptor.
///
/// @param[in] transient_vars The vector of transient variables names.
/// @param[in] persistent_vars The vector of persistent variables names. Those
/// will have persistable attribute set to true.
///
/// @return The program descriptor object.
///
ProgramDesc BuildProgramDesc(const std::vector<std::string>& transient_vars,
const std::vector<std::string>& persistent_vars) {
ProgramDesc prog;
auto add_var_to_prog = [&prog](const std::string& var_name) -> VarDesc* {
auto var = prog.MutableBlock(0)->Var(var_name);
var->SetType(proto::VarType::LOD_TENSOR);
return var;
};
for (const auto& v : transient_vars) {
add_var_to_prog(v);
}
for (const auto& v : persistent_vars) {
auto* var = add_var_to_prog(v);
var->SetPersistable(true);
}
return prog;
}
///
/// @brief Execute pass on provided graph and perform checks.
///
/// @param graph The graph we run pass on.
/// @param[in] from The name of a 'starting' node sequence in a
/// graph. This would be used to test for
/// correct node connections.
/// @param[in] to The name of a 'ending' node sequence in a
/// graph. This would be used to test for
/// correct node connections.
/// @param[in] removed_nodes_count The number of nodes we expect will be
/// removed/fused after pass execution.
/// @param[in] added_nodes_count The number of nodes we expect will be
/// added after pass execution.
///
void RunPassAndAssert(Graph* graph, const std::string& from,
const std::string& to, int removed_nodes_count,
int added_nodes_count = 0) {
EXPECT_TRUE(TestIsReachable(*graph, from, to));
int original_nodes_num = graph->Nodes().size();
auto pass = PassRegistry::Instance().Get("fc_act_mkldnn_fuse_pass");
pass->Apply(graph);
int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(TestIsReachable(*graph, from, to));
EXPECT_EQ(original_nodes_num - removed_nodes_count + added_nodes_count,
current_nodes_num);
}
} // namespace
// ------------------------------ Test cases -----------------------------------
TEST(FuseFCActOneDNNPass, ThrowUseMkldnn) {
auto prog = BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
CreateOp(&prog, "fc",
auto prog =
test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
test::CreateOp(&prog, "fc",
{
{"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"},
},
{{"Out", "fc_y"}}, false);
CreateOp(&prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false);
test::CreateOp(&prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
// No fusion in this attribute configuration
constexpr int removed_nodes_count = 0;
EXPECT_THROW(RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count),
EXPECT_THROW(test::RunPassAndAssert(&graph, "fc_act_mkldnn_fuse_pass", "x",
"act_y", removed_nodes_count),
paddle::platform::EnforceNotMet);
}
TEST(FuseFCActOneDNNPass, FuseWithGeluTanh) {
auto prog = BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
CreateOp(&prog, "fc",
auto prog =
test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
test::CreateOp(&prog, "fc",
{
{"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"},
},
{{"Out", "fc_y"}});
auto* act_op =
CreateOp(&prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false);
auto* act_op = test::CreateOp(&prog, "gelu", {{"Input", "fc_y"}},
{{"Out", "act_y"}}, false);
act_op->SetAttr("approximate", true);
Graph graph(prog);
constexpr int removed_nodes_count = 2;
RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count);
AssertOpsCount(graph, {{"fc", 1}, {"gelu", 0}});
EXPECT_TRUE(test::RunPassAndAssert(&graph, "fc_act_mkldnn_fuse_pass", "x",
"act_y", removed_nodes_count));
EXPECT_TRUE(test::AssertOpsCount(graph, {{"fc", 1}, {"gelu", 0}}));
for (const auto* node : graph.Nodes()) {
if (node->IsOp() && node->Op()->Type() == "fc") {
......@@ -272,27 +73,29 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluTanh) {
ASSERT_TRUE(op->HasAttr("activation_type"));
auto act_type =
BOOST_GET_CONST(std::string, op->GetAttr("activation_type"));
EXPECT_TRUE(act_type.compare("gelu_tanh") == 0);
EXPECT_EQ(act_type.compare("gelu_tanh"), 0);
}
}
}
TEST(FuseFCActOneDNNPass, FuseWithGeluErf) {
auto prog = BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
CreateOp(&prog, "fc",
auto prog =
test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
test::CreateOp(&prog, "fc",
{
{"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"},
},
{{"Out", "fc_y"}});
auto* act_op =
CreateOp(&prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false);
auto* act_op = test::CreateOp(&prog, "gelu", {{"Input", "fc_y"}},
{{"Out", "act_y"}}, false);
act_op->SetAttr("approximate", false);
Graph graph(prog);
constexpr int removed_nodes_count = 2;
RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count);
AssertOpsCount(graph, {{"fc", 1}, {"gelu", 0}});
EXPECT_TRUE(test::RunPassAndAssert(&graph, "fc_act_mkldnn_fuse_pass", "x",
"act_y", removed_nodes_count));
EXPECT_TRUE(test::AssertOpsCount(graph, {{"fc", 1}, {"gelu", 0}}));
for (const auto* node : graph.Nodes()) {
if (node->IsOp() && node->Op()->Type() == "fc") {
......@@ -302,25 +105,27 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluErf) {
ASSERT_TRUE(op->HasAttr("activation_type"));
auto act_type =
BOOST_GET_CONST(std::string, op->GetAttr("activation_type"));
EXPECT_TRUE(act_type.compare("gelu_erf") == 0);
EXPECT_EQ(act_type.compare("gelu_erf"), 0);
}
}
}
TEST(FuseFCActOneDNNPass, FuseWithGeluAuto) {
auto prog = BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
CreateOp(&prog, "fc",
auto prog =
test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
test::CreateOp(&prog, "fc",
{
{"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"},
},
{{"Out", "fc_y"}});
CreateOp(&prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false);
test::CreateOp(&prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
constexpr int removed_nodes_count = 2;
RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count);
AssertOpsCount(graph, {{"fc", 1}, {"gelu", 0}});
EXPECT_TRUE(test::RunPassAndAssert(&graph, "fc_act_mkldnn_fuse_pass", "x",
"act_y", removed_nodes_count));
EXPECT_TRUE(test::AssertOpsCount(graph, {{"fc", 1}, {"gelu", 0}}));
for (const auto* node : graph.Nodes()) {
if (node->IsOp() && node->Op()->Type() == "fc") {
......@@ -330,25 +135,27 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluAuto) {
ASSERT_TRUE(op->HasAttr("activation_type"));
auto act_type =
BOOST_GET_CONST(std::string, op->GetAttr("activation_type"));
EXPECT_TRUE(act_type.compare("gelu") == 0);
EXPECT_EQ(act_type.compare("gelu"), 0);
}
}
}
TEST(FuseFCActOneDNNPass, FuseWithTanh) {
auto prog = BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
CreateOp(&prog, "fc",
auto prog =
test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
test::CreateOp(&prog, "fc",
{
{"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"},
},
{{"Out", "fc_y"}});
CreateOp(&prog, "tanh", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false);
test::CreateOp(&prog, "tanh", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
constexpr int removed_nodes_count = 2;
RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count);
AssertOpsCount(graph, {{"fc", 1}, {"tanh", 0}});
EXPECT_TRUE(test::RunPassAndAssert(&graph, "fc_act_mkldnn_fuse_pass", "x",
"act_y", removed_nodes_count));
EXPECT_TRUE(test::AssertOpsCount(graph, {{"fc", 1}, {"tanh", 0}}));
for (const auto* node : graph.Nodes()) {
if (node->IsOp() && node->Op()->Type() == "fc") {
......@@ -358,25 +165,28 @@ TEST(FuseFCActOneDNNPass, FuseWithTanh) {
ASSERT_TRUE(op->HasAttr("activation_type"));
auto act_type =
BOOST_GET_CONST(std::string, op->GetAttr("activation_type"));
EXPECT_TRUE(act_type.compare("tanh") == 0);
EXPECT_EQ(act_type.compare("tanh"), 0);
}
}
}
TEST(FuseFCActOneDNNPass, FuseWithSigmoid) {
auto prog = BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
CreateOp(&prog, "fc",
auto prog =
test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
test::CreateOp(&prog, "fc",
{
{"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"},
},
{{"Out", "fc_y"}});
CreateOp(&prog, "sigmoid", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false);
test::CreateOp(&prog, "sigmoid", {{"Input", "fc_y"}}, {{"Out", "act_y"}},
false);
Graph graph(prog);
constexpr int removed_nodes_count = 2;
RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count);
AssertOpsCount(graph, {{"fc", 1}, {"sigmoid", 0}});
EXPECT_TRUE(test::RunPassAndAssert(&graph, "fc_act_mkldnn_fuse_pass", "x",
"act_y", removed_nodes_count));
EXPECT_TRUE(test::AssertOpsCount(graph, {{"fc", 1}, {"sigmoid", 0}}));
for (const auto* node : graph.Nodes()) {
if (node->IsOp() && node->Op()->Type() == "fc") {
......@@ -386,11 +196,17 @@ TEST(FuseFCActOneDNNPass, FuseWithSigmoid) {
ASSERT_TRUE(op->HasAttr("activation_type"));
auto act_type =
BOOST_GET_CONST(std::string, op->GetAttr("activation_type"));
EXPECT_TRUE(act_type.compare("sigmoid") == 0);
EXPECT_EQ(act_type.compare("sigmoid"), 0);
}
}
}
TEST(FuseFCActOneDNNPass, pass_op_version_check) {
ASSERT_TRUE(
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
.IsPassCompatible("fc_act_mkldnn_fuse_pass"));
}
} // namespace ir
} // namespace framework
} // namespace paddle
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <exception>
#include <functional>
#include <iterator>
#include <list>
#include <map>
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/mkldnn/pass_test_util.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
namespace test {
OpDesc* CreateOp(ProgramDesc* prog, const std::string& op_type_name,
const std::vector<InOutVarNamePair>& inputs,
const std::vector<InOutVarNamePair>& outputs,
bool use_mkldnn) {
auto op = prog->MutableBlock(0)->AppendOp();
op->SetType(op_type_name);
op->SetAttr("use_mkldnn", use_mkldnn);
for (const auto& input : inputs) {
op->SetInput(input.first, {input.second});
}
for (const auto& output : outputs) {
op->SetOutput(output.first, {output.second});
}
return op;
}
bool TestIsReachable(const Graph& graph, std::string from, std::string to) {
auto hash = [](const Node* node) -> std::string {
return node->Name() + std::to_string(node->id());
};
auto find_node = [&](const Graph& graph, const std::string& name) -> Node* {
for (auto& node : GraphTraits::DFS(graph)) {
if (name == hash(&node)) {
return &node;
}
}
return nullptr;
};
if (from == to) return true;
std::map<std::string, bool> visited;
// update the from and to strings to hashed equivs in loop from graph traits
for (auto& node : GraphTraits::DFS(graph)) {
auto hashed = hash(&node);
if (node.Name() == from) {
from = hashed;
}
if (node.Name() == to) {
to = hashed;
}
visited[hashed] = false;
}
visited[from] = true;
std::list<std::string> queue;
queue.push_back(from);
while (!queue.empty()) {
auto cur = find_node(graph, queue.front());
queue.pop_front();
if (cur == nullptr) {
return false;
}
for (auto n : cur->outputs) {
auto hashed_name = hash(n);
if (hashed_name == to) {
return true;
}
if (!visited[hashed_name]) {
visited[hashed_name] = true;
queue.push_back(hashed_name);
}
}
}
return false;
}
bool AssertOpsCount(const Graph& graph,
std::vector<OpTypeCountPair> op_type_count) {
for (auto* node : graph.Nodes()) {
if (!node->IsOp()) {
continue;
}
const std::string op_type_name = node->Op()->Type();
auto op_it =
std::find_if(std::begin(op_type_count), std::end(op_type_count),
[op_type_name](const OpTypeCountPair& p) {
return op_type_name == p.first;
});
if (op_it != std::end(op_type_count)) {
op_it->second--;
}
}
bool result{true};
for (const OpTypeCountPair& p : op_type_count) {
result = result && (p.second == 0);
}
return result;
}
ProgramDesc BuildProgramDesc(const std::vector<std::string>& transient_vars,
const std::vector<std::string>& persistent_vars) {
ProgramDesc prog;
auto add_var_to_prog = [&prog](const std::string& var_name) -> VarDesc* {
auto var = prog.MutableBlock(0)->Var(var_name);
var->SetType(proto::VarType::LOD_TENSOR);
return var;
};
for (const auto& v : transient_vars) {
add_var_to_prog(v);
}
for (const auto& v : persistent_vars) {
auto* var = add_var_to_prog(v);
var->SetPersistable(true);
}
return prog;
}
bool RunPassAndAssert(Graph* graph, const std::string& pass_name,
const std::string& from, const std::string& to,
int removed_nodes_count, int added_nodes_count) {
if (!TestIsReachable(*graph, from, to)) return false;
int original_nodes_num = graph->Nodes().size();
auto pass = PassRegistry::Instance().Get(pass_name);
pass->Apply(graph);
int current_nodes_num = graph->Nodes().size();
if (!TestIsReachable(*graph, from, to)) return false;
int expected_nodes_num =
original_nodes_num - removed_nodes_count + added_nodes_count;
return expected_nodes_num == current_nodes_num;
}
} // namespace test
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
namespace ir {
// -------------------------- helper functions --------------------------------
namespace test {
/// The pair describing correlation between {input/output name, variable name}.
using InOutVarNamePair = std::pair<std::string, std::string>;
/// The pair describing number of occurrences of given op type.
using OpTypeCountPair = std::pair<std::string, int>;
///
/// @brief Creates the specified operator and sets up its inputs/outputs.
///
/// @param prog The program descriptor to which we add new op.
/// @param[in] op_type_name The operator type name.
/// @param[in] inputs The vector of input pairs: {input_name, variable
/// name}
/// @param[in] outputs The vector of output pairs {output_name, variable}
/// @param[in] use_mkldnn The flag deciding whether or not to set
/// 'use_mkldnn' attribute.
///
/// @return Returns pointer to the created operator descriptor.
///
OpDesc* CreateOp(ProgramDesc* prog, const std::string& op_type_name,
const std::vector<InOutVarNamePair>& inputs,
const std::vector<InOutVarNamePair>& outputs,
bool use_mkldnn = true);
///
/// @brief Check whether node 'to' is reachable from node 'from' in graph.
///
/// @param[in] graph The graph we're checking for reachability.
/// @param[in] from The 'from' node name.
/// @param[in] to The 'to' node name.
///
/// @return True if there is connection between nodes 'from' and 'to'.
///
bool TestIsReachable(const Graph& graph, std::string from, std::string to);
///
/// @brief Search through graph and counts provided operator occurrences.
///
/// @param[in] graph The graph we search through.
/// @param[in] op_type_count The vector of pairs {op_type_name, op count}
///
/// @note After going through all graph nodes this function asserts
/// whether counted number for each requested op is as expected.
///
/// @return Returns true if occurrences of all ops is as expected.
///
bool AssertOpsCount(const Graph& graph,
std::vector<OpTypeCountPair> op_type_count);
///
/// @brief Builds a program descriptor.
///
/// @param[in] transient_vars The vector of transient variables names.
/// @param[in] persistent_vars The vector of persistent variables names. Those
/// will have persistable attribute set to true.
///
/// @return The program descriptor object.
///
ProgramDesc BuildProgramDesc(const std::vector<std::string>& transient_vars,
const std::vector<std::string>& persistent_vars);
///
/// @brief Execute pass on provided graph and perform checks.
///
/// @note Check whether the balance of removed and added nodes after pass
/// is as expected.
///
/// @param graph The graph we run pass on.
/// @param[in] from The name of a 'starting' node sequence in a
/// graph. This would be used to test for
/// correct node connections.
/// @param[in] to The name of a 'ending' node sequence in a
/// graph. This would be used to test for
/// correct node connections.
/// @param[in] removed_nodes_count The number of nodes we expect will be
/// removed/fused after pass execution.
/// @param[in] added_nodes_count The number of nodes we expect will be added
/// after pass execution.
///
/// @return Return true if all checks passed, otherwise false.
///
bool RunPassAndAssert(Graph* graph, const std::string& pass_name,
const std::string& from, const std::string& to,
int removed_nodes_count, int added_nodes_count = 0);
} // namespace test
} // namespace ir
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册