cudnn_placement_pass_tester.cc 3.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/cudnn_placement_pass.h"

#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/operator.h"

namespace paddle {
namespace framework {
namespace ir {

void RegisterOpKernel() {
  static bool is_registered = false;
  if (!is_registered) {
    auto& all_kernels = OperatorWithKernel::AllOpKernels();

    platform::CUDAPlace place = platform::CUDAPlace(0);
    OpKernelType plain_kernel_type =
        OpKernelType(proto::VarType::FP32, place, DataLayout::kAnyLayout,
                     LibraryType::kPlain);
    OpKernelType cudnn_kernel_type =
        OpKernelType(proto::VarType::FP32, place, DataLayout::kAnyLayout,
                     LibraryType::kCUDNN);

    auto fake_kernel_func = [](const ExecutionContext&) -> void {
      static int num_calls = 0;
      num_calls++;
    };

    all_kernels["conv2d"][cudnn_kernel_type] = fake_kernel_func;
    all_kernels["pool2d"][cudnn_kernel_type] = fake_kernel_func;
    all_kernels["depthwise_conv2d"][plain_kernel_type] = fake_kernel_func;
    all_kernels["relu"][plain_kernel_type] = fake_kernel_func;

    is_registered = true;
  }
}

void MainTest(std::initializer_list<std::string> cudnn_enabled_op_types,
              unsigned expected_use_cudnn_true_count) {
  // operator                                 use_cudnn
  // --------------------------------------------------
  // (a,b)->concat->c                         -
  // (c,weights,bias)->conv2d->f              false
  // f->relu->g                               -
  // g->pool2d->h                             false
  // (h,weights2,bias2)->depthwise_conv2d->k  false
  // k->relu->l                               -
  Layers layers;
  VarDesc* a = layers.data("a");
  VarDesc* b = layers.data("b");
  VarDesc* c = layers.concat(std::vector<VarDesc*>({a, b}));
  VarDesc* weights_0 = layers.data("weights_0");
  VarDesc* bias_0 = layers.data("bias_0");
  VarDesc* f = layers.conv2d(c, weights_0, bias_0, false);
  VarDesc* g = layers.relu(f);
  VarDesc* h = layers.pool2d(g, false);
  VarDesc* weights_1 = layers.data("weights_1");
  VarDesc* bias_1 = layers.data("bias_1");
  VarDesc* k = layers.depthwise_conv2d(h, weights_1, bias_1, false);
  layers.relu(k);

  RegisterOpKernel();

  std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
  auto pass = PassRegistry::Instance().Get("cudnn_placement_pass");
  pass->Set("cudnn_enabled_op_types",
            new std::unordered_set<std::string>(cudnn_enabled_op_types));

  graph.reset(pass->Apply(graph.release()));

  unsigned use_cudnn_true_count = 0;
  for (auto* node : graph->Nodes()) {
    if (node->IsOp() && node->Op()) {
      auto* op = node->Op();
      if (op->HasAttr("use_cudnn") &&
          boost::get<bool>(op->GetAttr("use_cudnn"))) {
        ++use_cudnn_true_count;
      }
    }
  }

  EXPECT_EQ(use_cudnn_true_count, expected_use_cudnn_true_count);
}

TEST(CUDNNPlacementPass, enable_conv2d) {
  // 1 conv2d
  MainTest({"conv2d"}, 1);
}

TEST(CUDNNPlacementPass, enable_relu_pool) {
  // 1 conv2d + 1 pool2d
  MainTest({"conv2d", "pool2d"}, 2);
}

TEST(CUDNNPlacementPass, enable_all) {
  // 1 conv2d + 1 pool2d
  // depthwise_conv2d doesnot have CUDNN kernel.
  MainTest({}, 2);
}

}  // namespace ir
}  // namespace framework
}  // namespace paddle

USE_PASS(cudnn_placement_pass);