fc_act_mkldnn_fuse_pass_tester.cc 7.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gtest/gtest.h>

#include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h"
18
#include "paddle/fluid/framework/ir/pass_test_util.h"
19
#include "paddle/fluid/framework/op_desc.h"
20
#include "paddle/fluid/framework/op_version_registry.h"
21 22 23 24 25 26 27 28 29 30
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/errors.h"

namespace paddle {
namespace framework {
namespace ir {

// ------------------------------ Test cases -----------------------------------

TEST(FuseFCActOneDNNPass, ThrowUseMkldnn) {
31 32 33 34 35 36 37 38
  auto prog =
      test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
  test::CreateOp(&prog, "fc",
                 {
                     {"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"},
                 },
                 {{"Out", "fc_y"}}, false);
  test::CreateOp(&prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false);
39 40 41 42 43

  Graph graph(prog);
  // No fusion in this attribute configuration
  constexpr int removed_nodes_count = 0;

44 45
  EXPECT_THROW(test::RunPassAndAssert(&graph, "fc_act_mkldnn_fuse_pass", "x",
                                      "act_y", removed_nodes_count),
46 47 48 49
               paddle::platform::EnforceNotMet);
}

TEST(FuseFCActOneDNNPass, FuseWithGeluTanh) {
50 51 52 53 54 55 56 57 58
  auto prog =
      test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
  test::CreateOp(&prog, "fc",
                 {
                     {"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"},
                 },
                 {{"Out", "fc_y"}});
  auto* act_op = test::CreateOp(&prog, "gelu", {{"Input", "fc_y"}},
                                {{"Out", "act_y"}}, false);
59 60 61 62 63
  act_op->SetAttr("approximate", true);

  Graph graph(prog);
  constexpr int removed_nodes_count = 2;

64 65 66
  EXPECT_TRUE(test::RunPassAndAssert(&graph, "fc_act_mkldnn_fuse_pass", "x",
                                     "act_y", removed_nodes_count));
  EXPECT_TRUE(test::AssertOpsCount(graph, {{"fc", 1}, {"gelu", 0}}));
67 68 69 70 71 72 73 74 75

  for (const auto* node : graph.Nodes()) {
    if (node->IsOp() && node->Op()->Type() == "fc") {
      const auto* op = node->Op();
      ASSERT_TRUE(op->HasAttr("use_mkldnn"));
      EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("use_mkldnn")));
      ASSERT_TRUE(op->HasAttr("activation_type"));
      auto act_type =
          BOOST_GET_CONST(std::string, op->GetAttr("activation_type"));
76
      EXPECT_EQ(act_type.compare("gelu_tanh"), 0);
77 78 79 80 81
    }
  }
}

TEST(FuseFCActOneDNNPass, FuseWithGeluErf) {
82 83 84 85 86 87 88 89 90
  auto prog =
      test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
  test::CreateOp(&prog, "fc",
                 {
                     {"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"},
                 },
                 {{"Out", "fc_y"}});
  auto* act_op = test::CreateOp(&prog, "gelu", {{"Input", "fc_y"}},
                                {{"Out", "act_y"}}, false);
91 92 93 94 95
  act_op->SetAttr("approximate", false);

  Graph graph(prog);
  constexpr int removed_nodes_count = 2;

96 97 98
  EXPECT_TRUE(test::RunPassAndAssert(&graph, "fc_act_mkldnn_fuse_pass", "x",
                                     "act_y", removed_nodes_count));
  EXPECT_TRUE(test::AssertOpsCount(graph, {{"fc", 1}, {"gelu", 0}}));
99 100 101 102 103 104 105 106 107

  for (const auto* node : graph.Nodes()) {
    if (node->IsOp() && node->Op()->Type() == "fc") {
      const auto* op = node->Op();
      ASSERT_TRUE(op->HasAttr("use_mkldnn"));
      EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("use_mkldnn")));
      ASSERT_TRUE(op->HasAttr("activation_type"));
      auto act_type =
          BOOST_GET_CONST(std::string, op->GetAttr("activation_type"));
108
      EXPECT_EQ(act_type.compare("gelu_erf"), 0);
109 110 111 112 113
    }
  }
}

TEST(FuseFCActOneDNNPass, FuseWithGeluAuto) {
114 115 116 117 118 119 120 121
  auto prog =
      test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
  test::CreateOp(&prog, "fc",
                 {
                     {"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"},
                 },
                 {{"Out", "fc_y"}});
  test::CreateOp(&prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false);
122 123 124 125

  Graph graph(prog);
  constexpr int removed_nodes_count = 2;

126 127 128
  EXPECT_TRUE(test::RunPassAndAssert(&graph, "fc_act_mkldnn_fuse_pass", "x",
                                     "act_y", removed_nodes_count));
  EXPECT_TRUE(test::AssertOpsCount(graph, {{"fc", 1}, {"gelu", 0}}));
129 130 131 132 133 134 135 136 137

  for (const auto* node : graph.Nodes()) {
    if (node->IsOp() && node->Op()->Type() == "fc") {
      const auto* op = node->Op();
      ASSERT_TRUE(op->HasAttr("use_mkldnn"));
      EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("use_mkldnn")));
      ASSERT_TRUE(op->HasAttr("activation_type"));
      auto act_type =
          BOOST_GET_CONST(std::string, op->GetAttr("activation_type"));
138
      EXPECT_EQ(act_type.compare("gelu"), 0);
139 140 141 142 143
    }
  }
}

TEST(FuseFCActOneDNNPass, FuseWithTanh) {
144 145 146 147 148 149 150 151
  auto prog =
      test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
  test::CreateOp(&prog, "fc",
                 {
                     {"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"},
                 },
                 {{"Out", "fc_y"}});
  test::CreateOp(&prog, "tanh", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false);
152 153 154 155

  Graph graph(prog);
  constexpr int removed_nodes_count = 2;

156 157 158
  EXPECT_TRUE(test::RunPassAndAssert(&graph, "fc_act_mkldnn_fuse_pass", "x",
                                     "act_y", removed_nodes_count));
  EXPECT_TRUE(test::AssertOpsCount(graph, {{"fc", 1}, {"tanh", 0}}));
159 160 161 162 163 164 165 166 167

  for (const auto* node : graph.Nodes()) {
    if (node->IsOp() && node->Op()->Type() == "fc") {
      const auto* op = node->Op();
      ASSERT_TRUE(op->HasAttr("use_mkldnn"));
      EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("use_mkldnn")));
      ASSERT_TRUE(op->HasAttr("activation_type"));
      auto act_type =
          BOOST_GET_CONST(std::string, op->GetAttr("activation_type"));
168
      EXPECT_EQ(act_type.compare("tanh"), 0);
169 170 171 172 173
    }
  }
}

TEST(FuseFCActOneDNNPass, FuseWithSigmoid) {
174 175 176 177 178 179 180 181 182
  auto prog =
      test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
  test::CreateOp(&prog, "fc",
                 {
                     {"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"},
                 },
                 {{"Out", "fc_y"}});
  test::CreateOp(&prog, "sigmoid", {{"Input", "fc_y"}}, {{"Out", "act_y"}},
                 false);
183 184 185 186

  Graph graph(prog);
  constexpr int removed_nodes_count = 2;

187 188 189
  EXPECT_TRUE(test::RunPassAndAssert(&graph, "fc_act_mkldnn_fuse_pass", "x",
                                     "act_y", removed_nodes_count));
  EXPECT_TRUE(test::AssertOpsCount(graph, {{"fc", 1}, {"sigmoid", 0}}));
190 191 192 193 194 195 196 197 198

  for (const auto* node : graph.Nodes()) {
    if (node->IsOp() && node->Op()->Type() == "fc") {
      const auto* op = node->Op();
      ASSERT_TRUE(op->HasAttr("use_mkldnn"));
      EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("use_mkldnn")));
      ASSERT_TRUE(op->HasAttr("activation_type"));
      auto act_type =
          BOOST_GET_CONST(std::string, op->GetAttr("activation_type"));
199
      EXPECT_EQ(act_type.compare("sigmoid"), 0);
200 201 202 203
    }
  }
}

204 205 206 207 208 209
TEST(FuseFCActOneDNNPass, pass_op_version_check) {
  ASSERT_TRUE(
      paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
          .IsPassCompatible("fc_act_mkldnn_fuse_pass"));
}

210 211 212 213 214
}  // namespace ir
}  // namespace framework
}  // namespace paddle

USE_PASS(fc_act_mkldnn_fuse_pass);