fc_act_mkldnn_fuse_pass_tester.cc 10.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gtest/gtest.h>

#include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h"
18
#include "paddle/fluid/framework/ir/pass_test_util.h"
19
#include "paddle/fluid/framework/op_desc.h"
20
#include "paddle/fluid/framework/op_version_registry.h"
21 22 23 24 25 26 27 28 29 30
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/errors.h"

namespace paddle {
namespace framework {
namespace ir {

// ------------------------------ Test cases -----------------------------------

TEST(FuseFCActOneDNNPass, ThrowUseMkldnn) {
31 32
  auto prog =
      test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
33 34
  test::CreateOp(&prog,
                 "fc",
35
                 {
36
                     {"Input", "x"},
37
                     {"W", "weights"},
38
                     {"Bias", "bias"},
39
                 },
40 41
                 {{"Out", "fc_y"}},
                 false);
42
  test::CreateOp(&prog, "gelu", {{"X", "fc_y"}}, {{"Out", "act_y"}}, false);
43 44 45 46 47

  Graph graph(prog);
  // No fusion in this attribute configuration
  constexpr int removed_nodes_count = 0;

48 49 50 51
  EXPECT_THROW(
      test::RunPassAndAssert(
          &graph, "fc_act_mkldnn_fuse_pass", "x", "act_y", removed_nodes_count),
      paddle::platform::EnforceNotMet);
52 53 54
}

TEST(FuseFCActOneDNNPass, FuseWithGeluTanh) {
55 56
  auto prog =
      test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
57 58
  test::CreateOp(&prog,
                 "fc",
59
                 {
60
                     {"Input", "x"},
61
                     {"W", "weights"},
62
                     {"Bias", "bias"},
63 64
                 },
                 {{"Out", "fc_y"}});
65 66
  auto* act_op =
      test::CreateOp(&prog, "gelu", {{"X", "fc_y"}}, {{"Out", "act_y"}}, false);
67 68 69 70 71
  act_op->SetAttr("approximate", true);

  Graph graph(prog);
  constexpr int removed_nodes_count = 2;

72 73
  EXPECT_TRUE(test::RunPassAndAssert(
      &graph, "fc_act_mkldnn_fuse_pass", "x", "act_y", removed_nodes_count));
74
  EXPECT_TRUE(test::AssertOpsCount(graph, {{"fc", 1}, {"gelu", 0}}));
75 76 77 78 79

  for (const auto* node : graph.Nodes()) {
    if (node->IsOp() && node->Op()->Type() == "fc") {
      const auto* op = node->Op();
      ASSERT_TRUE(op->HasAttr("use_mkldnn"));
R
Ruibiao Chen 已提交
80
      EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
81
      ASSERT_TRUE(op->HasAttr("fuse_activation"));
82
      auto act_type =
83
          PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
84
      EXPECT_EQ(act_type.compare("gelu_tanh"), 0);
85 86 87 88 89
    }
  }
}

TEST(FuseFCActOneDNNPass, FuseWithGeluErf) {
90 91
  auto prog =
      test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
92 93
  test::CreateOp(&prog,
                 "fc",
94
                 {
95
                     {"Input", "x"},
96
                     {"W", "weights"},
97
                     {"Bias", "bias"},
98 99
                 },
                 {{"Out", "fc_y"}});
100 101
  auto* act_op =
      test::CreateOp(&prog, "gelu", {{"X", "fc_y"}}, {{"Out", "act_y"}}, false);
102 103 104 105 106
  act_op->SetAttr("approximate", false);

  Graph graph(prog);
  constexpr int removed_nodes_count = 2;

107 108
  EXPECT_TRUE(test::RunPassAndAssert(
      &graph, "fc_act_mkldnn_fuse_pass", "x", "act_y", removed_nodes_count));
109
  EXPECT_TRUE(test::AssertOpsCount(graph, {{"fc", 1}, {"gelu", 0}}));
110 111 112 113 114

  for (const auto* node : graph.Nodes()) {
    if (node->IsOp() && node->Op()->Type() == "fc") {
      const auto* op = node->Op();
      ASSERT_TRUE(op->HasAttr("use_mkldnn"));
R
Ruibiao Chen 已提交
115
      EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
116
      ASSERT_TRUE(op->HasAttr("fuse_activation"));
117
      auto act_type =
118
          PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
119
      EXPECT_EQ(act_type.compare("gelu_erf"), 0);
120 121 122 123 124
    }
  }
}

TEST(FuseFCActOneDNNPass, FuseWithGeluAuto) {
125 126
  auto prog =
      test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
127 128
  test::CreateOp(&prog,
                 "fc",
129
                 {
130
                     {"Input", "x"},
131
                     {"W", "weights"},
132
                     {"Bias", "bias"},
133 134
                 },
                 {{"Out", "fc_y"}});
135
  test::CreateOp(&prog, "gelu", {{"X", "fc_y"}}, {{"Out", "act_y"}}, false);
136 137 138 139

  Graph graph(prog);
  constexpr int removed_nodes_count = 2;

140 141
  EXPECT_TRUE(test::RunPassAndAssert(
      &graph, "fc_act_mkldnn_fuse_pass", "x", "act_y", removed_nodes_count));
142
  EXPECT_TRUE(test::AssertOpsCount(graph, {{"fc", 1}, {"gelu", 0}}));
143 144 145 146 147

  for (const auto* node : graph.Nodes()) {
    if (node->IsOp() && node->Op()->Type() == "fc") {
      const auto* op = node->Op();
      ASSERT_TRUE(op->HasAttr("use_mkldnn"));
R
Ruibiao Chen 已提交
148
      EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
149
      ASSERT_TRUE(op->HasAttr("fuse_activation"));
150
      auto act_type =
151
          PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
152
      EXPECT_EQ(act_type.compare("gelu"), 0);
153 154 155 156 157
    }
  }
}

TEST(FuseFCActOneDNNPass, FuseWithTanh) {
158 159
  auto prog =
      test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
160 161
  test::CreateOp(&prog,
                 "fc",
162
                 {
163
                     {"Input", "x"},
164
                     {"W", "weights"},
165
                     {"Bias", "bias"},
166 167
                 },
                 {{"Out", "fc_y"}});
168
  test::CreateOp(&prog, "tanh", {{"X", "fc_y"}}, {{"Out", "act_y"}}, false);
169 170 171 172

  Graph graph(prog);
  constexpr int removed_nodes_count = 2;

173 174
  EXPECT_TRUE(test::RunPassAndAssert(
      &graph, "fc_act_mkldnn_fuse_pass", "x", "act_y", removed_nodes_count));
175
  EXPECT_TRUE(test::AssertOpsCount(graph, {{"fc", 1}, {"tanh", 0}}));
176 177 178 179 180

  for (const auto* node : graph.Nodes()) {
    if (node->IsOp() && node->Op()->Type() == "fc") {
      const auto* op = node->Op();
      ASSERT_TRUE(op->HasAttr("use_mkldnn"));
R
Ruibiao Chen 已提交
181
      EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
182
      ASSERT_TRUE(op->HasAttr("fuse_activation"));
183
      auto act_type =
184
          PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
185
      EXPECT_EQ(act_type.compare("tanh"), 0);
186 187 188 189 190
    }
  }
}

TEST(FuseFCActOneDNNPass, FuseWithSigmoid) {
191 192
  auto prog =
      test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
193 194
  test::CreateOp(&prog,
                 "fc",
195
                 {
196
                     {"Input", "x"},
197
                     {"W", "weights"},
198
                     {"Bias", "bias"},
199 200
                 },
                 {{"Out", "fc_y"}});
201
  test::CreateOp(&prog, "sigmoid", {{"X", "fc_y"}}, {{"Out", "act_y"}}, false);
202 203 204 205

  Graph graph(prog);
  constexpr int removed_nodes_count = 2;

206 207
  EXPECT_TRUE(test::RunPassAndAssert(
      &graph, "fc_act_mkldnn_fuse_pass", "x", "act_y", removed_nodes_count));
208
  EXPECT_TRUE(test::AssertOpsCount(graph, {{"fc", 1}, {"sigmoid", 0}}));
209 210 211 212 213

  for (const auto* node : graph.Nodes()) {
    if (node->IsOp() && node->Op()->Type() == "fc") {
      const auto* op = node->Op();
      ASSERT_TRUE(op->HasAttr("use_mkldnn"));
R
Ruibiao Chen 已提交
214
      EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
215
      ASSERT_TRUE(op->HasAttr("fuse_activation"));
216
      auto act_type =
217
          PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
218
      EXPECT_EQ(act_type.compare("sigmoid"), 0);
219 220 221 222
    }
  }
}

223 224 225
TEST(FuseFCActOneDNNPass, FuseWithMish) {
  auto prog =
      test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
226 227
  test::CreateOp(&prog,
                 "fc",
228
                 {
229
                     {"Input", "x"},
230
                     {"W", "weights"},
231
                     {"Bias", "bias"},
232 233
                 },
                 {{"Out", "fc_y"}});
234
  test::CreateOp(&prog, "mish", {{"X", "fc_y"}}, {{"Out", "act_y"}}, false);
235 236 237 238

  Graph graph(prog);
  constexpr int removed_nodes_count = 2;

239 240
  EXPECT_TRUE(test::RunPassAndAssert(
      &graph, "fc_act_mkldnn_fuse_pass", "x", "act_y", removed_nodes_count));
241 242 243 244 245 246
  EXPECT_TRUE(test::AssertOpsCount(graph, {{"fc", 1}, {"mish", 0}}));

  for (const auto* node : graph.Nodes()) {
    if (node->IsOp() && node->Op()->Type() == "fc") {
      const auto* op = node->Op();
      ASSERT_TRUE(op->HasAttr("use_mkldnn"));
R
Ruibiao Chen 已提交
247
      EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
248
      ASSERT_TRUE(op->HasAttr("fuse_activation"));
249
      auto act_type =
250
          PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
251 252 253 254 255
      EXPECT_EQ(act_type.compare("mish"), 0);
    }
  }
}

J
jakpiase 已提交
256 257 258
TEST(FuseFCActOneDNNPass, FuseWithHardSwish) {
  auto prog =
      test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
259 260
  test::CreateOp(&prog,
                 "fc",
J
jakpiase 已提交
261
                 {
262
                     {"Input", "x"},
263
                     {"W", "weights"},
264
                     {"Bias", "bias"},
J
jakpiase 已提交
265 266
                 },
                 {{"Out", "fc_y"}});
267
  test::CreateOp(
268
      &prog, "hard_swish", {{"X", "fc_y"}}, {{"Out", "act_y"}}, false);
J
jakpiase 已提交
269 270 271 272

  Graph graph(prog);
  constexpr int removed_nodes_count = 2;

273 274
  EXPECT_TRUE(test::RunPassAndAssert(
      &graph, "fc_act_mkldnn_fuse_pass", "x", "act_y", removed_nodes_count));
J
jakpiase 已提交
275 276 277 278 279 280
  EXPECT_TRUE(test::AssertOpsCount(graph, {{"fc", 1}, {"hard_swish", 0}}));

  for (const auto* node : graph.Nodes()) {
    if (node->IsOp() && node->Op()->Type() == "fc") {
      const auto* op = node->Op();
      ASSERT_TRUE(op->HasAttr("use_mkldnn"));
R
Ruibiao Chen 已提交
281
      EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
282
      ASSERT_TRUE(op->HasAttr("fuse_activation"));
J
jakpiase 已提交
283
      auto act_type =
284
          PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
J
jakpiase 已提交
285 286 287 288 289
      EXPECT_EQ(act_type.compare("hard_swish"), 0);
    }
  }
}

290 291 292 293 294 295
TEST(FuseFCActOneDNNPass, pass_op_version_check) {
  ASSERT_TRUE(
      paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
          .IsPassCompatible("fc_act_mkldnn_fuse_pass"));
}

296 297 298 299 300
}  // namespace ir
}  // namespace framework
}  // namespace paddle

USE_PASS(fc_act_mkldnn_fuse_pass);