cudnn_placement_pass_tester.cc 4.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/cudnn_placement_pass.h"

#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/operator.h"

namespace paddle {
namespace framework {
namespace ir {

25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
class PlacementPassTest {
 private:
  void RegisterOpKernel() {
    static bool is_registered = false;
    if (!is_registered) {
      auto& all_kernels = OperatorWithKernel::AllOpKernels();

      platform::CUDAPlace place = platform::CUDAPlace(0);
      OpKernelType plain_kernel_type =
          OpKernelType(proto::VarType::FP32, place, DataLayout::kAnyLayout,
                       LibraryType::kPlain);
      OpKernelType cudnn_kernel_type =
          OpKernelType(proto::VarType::FP32, place, DataLayout::kAnyLayout,
                       LibraryType::kCUDNN);

      auto fake_kernel_func = [](const ExecutionContext&) -> void {
        static int num_calls = 0;
        num_calls++;
      };

      all_kernels["conv2d"][cudnn_kernel_type] = fake_kernel_func;
      all_kernels["pool2d"][cudnn_kernel_type] = fake_kernel_func;
      all_kernels["depthwise_conv2d"][plain_kernel_type] = fake_kernel_func;
      all_kernels["relu"][plain_kernel_type] = fake_kernel_func;

      is_registered = true;
    }
52 53
  }

54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
 public:
  void MainTest(std::initializer_list<std::string> cudnn_enabled_op_types,
                unsigned expected_use_cudnn_true_count) {
    // operator                                 use_cudnn
    // --------------------------------------------------
    // (a,b)->concat->c                         -
    // (c,weights,bias)->conv2d->f              false
    // f->relu->g                               -
    // g->pool2d->h                             false
    // (h,weights2,bias2)->depthwise_conv2d->k  false
    // k->relu->l                               -
    Layers layers;
    VarDesc* a = layers.data("a");
    VarDesc* b = layers.data("b");
    VarDesc* c = layers.concat(std::vector<VarDesc*>({a, b}));
    VarDesc* weights_0 = layers.data("weights_0");
    VarDesc* bias_0 = layers.data("bias_0");
    VarDesc* f = layers.conv2d(c, weights_0, bias_0, false);
    VarDesc* g = layers.relu(f);
    VarDesc* h = layers.pool2d(g, false);
    VarDesc* weights_1 = layers.data("weights_1");
    VarDesc* bias_1 = layers.data("bias_1");
    VarDesc* k = layers.depthwise_conv2d(h, weights_1, bias_1, false);
    layers.relu(k);

    RegisterOpKernel();

    std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
    auto pass = PassRegistry::Instance().Get("cudnn_placement_pass");
    pass->Set("cudnn_enabled_op_types",
              new std::unordered_set<std::string>(cudnn_enabled_op_types));

    graph.reset(pass->Apply(graph.release()));

    unsigned use_cudnn_true_count = 0;
    for (auto* node : graph->Nodes()) {
      if (node->IsOp() && node->Op()) {
        auto* op = node->Op();
        if (op->HasAttr("use_cudnn") &&
            boost::get<bool>(op->GetAttr("use_cudnn"))) {
          ++use_cudnn_true_count;
        }
96 97
      }
    }
98 99

    EXPECT_EQ(use_cudnn_true_count, expected_use_cudnn_true_count);
100 101
  }

102 103 104 105 106 107
  void PlacementNameTest() {
    auto pass = PassRegistry::Instance().Get("cudnn_placement_pass");
    EXPECT_EQ(static_cast<PlacementPassBase*>(pass.get())->GetPlacementName(),
              "cuDNN");
  }
};
108 109 110

TEST(CUDNNPlacementPass, enable_conv2d) {
  // 1 conv2d
111
  PlacementPassTest().MainTest({"conv2d"}, 1);
112 113 114 115
}

TEST(CUDNNPlacementPass, enable_relu_pool) {
  // 1 conv2d + 1 pool2d
116
  PlacementPassTest().MainTest({"conv2d", "pool2d"}, 2);
117 118 119 120 121
}

TEST(CUDNNPlacementPass, enable_all) {
  // 1 conv2d + 1 pool2d
  // depthwise_conv2d doesnot have CUDNN kernel.
122 123 124 125 126
  PlacementPassTest().MainTest({}, 2);
}

TEST(CUDNNPlacementPass, placement_name) {
  PlacementPassTest().PlacementNameTest();
127 128 129 130 131 132 133
}

}  // namespace ir
}  // namespace framework
}  // namespace paddle

USE_PASS(cudnn_placement_pass);