conv_miopen_helper.h 5.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include "paddle/phi/kernels/gpudnn/conv_gpudnn_base.h"
18

19
namespace phi {
20

21
using ConvArgs = ConvArgsBase<miopenHandle_t, miopenDataType_t>;
22

Y
Yiqun Liu 已提交
23 24 25
template <typename PerfT>
struct SearchAlgorithm {};

26 27 28 29 30 31
template <>
struct SearchAlgorithm<miopenConvFwdAlgorithm_t> {
  using perf_t = miopenConvAlgoPerf_t;
  using algo_t = miopenConvFwdAlgorithm_t;

  template <typename T>
32 33 34 35
  static algo_t Find(const ConvArgs& args,
                     bool exhaustive_search,
                     bool deterministic,
                     size_t workspace_size,
H
hong 已提交
36
                     const phi::GPUContext& ctx) {
37 38
    algo_t algo;

H
hong 已提交
39
    auto workspace_handle = ctx.cudnn_workspace_handle();
40

41 42 43
    int find_count;
    miopenConvAlgoPerf_t find_result;
    auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
44
      PADDLE_ENFORCE_GPU_SUCCESS(
45
          phi::dynload::miopenFindConvolutionForwardAlgorithm(
46 47 48 49 50 51 52 53 54 55 56 57 58 59
              args.handle,
              args.idesc.desc(),
              args.x->data<T>(),
              args.wdesc.desc(),
              args.w->data<T>(),
              args.cdesc.desc(),
              args.odesc.desc(),
              const_cast<T*>(args.o->data<T>()),
              kNUM_CUDNN_FWD_ALGS,
              &find_count,
              &find_result,
              cudnn_workspace_ptr,
              workspace_size,
              false));
60 61
    };

R
ronnywang 已提交
62 63
    workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
    algo = find_result.fwd_algo;
64 65 66 67
    VLOG(3) << "choose algo " << algo;
    return algo;
  }

68
  static size_t GetWorkspaceSize(const ConvArgs& args) {
69
    size_t workspace_size = 0;
70
    PADDLE_ENFORCE_GPU_SUCCESS(
71
        phi::dynload::miopenConvolutionForwardGetWorkSpaceSize(
72 73 74 75 76 77
            args.handle,
            args.wdesc.desc(),
            args.idesc.desc(),
            args.cdesc.desc(),
            args.odesc.desc(),
            &workspace_size));
78 79 80 81 82 83 84 85 86 87
    return workspace_size;
  }
};

template <>
struct SearchAlgorithm<miopenConvBwdDataAlgorithm_t> {
  using perf_t = miopenConvAlgoPerf_t;
  using algo_t = miopenConvBwdDataAlgorithm_t;

  template <typename T>
88 89 90 91
  static algo_t Find(const ConvArgs& args,
                     bool exhaustive_search,
                     bool deterministic,
                     size_t workspace_size,
H
hong 已提交
92
                     const phi::GPUContext& ctx) {
93 94
    algo_t algo;

H
hong 已提交
95
    auto workspace_handle = ctx.cudnn_workspace_handle();
96

97 98 99
    int find_count;
    miopenConvAlgoPerf_t find_result;
    auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
100
      PADDLE_ENFORCE_GPU_SUCCESS(
101
          phi::dynload::miopenFindConvolutionBackwardDataAlgorithm(
102 103 104 105 106 107 108 109 110 111 112 113 114 115
              args.handle,
              args.odesc.desc(),
              args.o->data<T>(),
              args.wdesc.desc(),
              args.w->data<T>(),
              args.cdesc.desc(),
              args.idesc.desc(),
              const_cast<T*>(args.x->data<T>()),
              kNUM_CUDNN_BWD_DATA_ALGS,
              &find_count,
              &find_result,
              cudnn_workspace_ptr,
              workspace_size,
              false));
116 117
    };

R
ronnywang 已提交
118 119
    workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
    algo = find_result.bwd_data_algo;
120 121 122 123
    VLOG(3) << "choose algo " << algo;
    return algo;
  }

124
  static size_t GetWorkspaceSize(const ConvArgs& args) {
125
    size_t workspace_size = 0;
126
    PADDLE_ENFORCE_GPU_SUCCESS(
127
        phi::dynload::miopenConvolutionBackwardDataGetWorkSpaceSize(
128 129 130 131 132 133
            args.handle,
            args.odesc.desc(),
            args.wdesc.desc(),
            args.cdesc.desc(),
            args.idesc.desc(),
            &workspace_size));
134 135 136 137 138 139 140 141 142 143
    return workspace_size;
  }
};

template <>
struct SearchAlgorithm<miopenConvBwdWeightsAlgorithm_t> {
  using perf_t = miopenConvAlgoPerf_t;
  using algo_t = miopenConvBwdWeightsAlgorithm_t;

  template <typename T>
144 145 146 147
  static algo_t Find(const ConvArgs& args,
                     bool exhaustive_search,
                     bool deterministic,
                     size_t workspace_size,
H
hong 已提交
148
                     const phi::GPUContext& ctx) {
149 150
    algo_t algo;

H
hong 已提交
151
    auto workspace_handle = ctx.cudnn_workspace_handle();
152 153 154 155

    int find_count;
    miopenConvAlgoPerf_t find_result;
    auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
156
      PADDLE_ENFORCE_GPU_SUCCESS(
157
          phi::dynload::miopenFindConvolutionBackwardWeightsAlgorithm(
158 159 160 161 162 163 164 165 166 167 168 169 170 171
              args.handle,
              args.odesc.desc(),
              args.o->data<T>(),
              args.idesc.desc(),
              args.x->data<T>(),
              args.cdesc.desc(),
              args.wdesc.desc(),
              const_cast<T*>(args.w->data<T>()),
              kNUM_CUDNN_BWD_FILTER_ALGS,
              &find_count,
              &find_result,
              cudnn_workspace_ptr,
              workspace_size,
              false));
172 173
    };

R
ronnywang 已提交
174 175
    workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
    algo = find_result.bwd_weights_algo;
176 177 178 179
    VLOG(3) << "choose algo " << algo;
    return algo;
  }

180
  static size_t GetWorkspaceSize(const ConvArgs& args) {
181
    size_t workspace_size = 0;
182
    PADDLE_ENFORCE_GPU_SUCCESS(
183
        phi::dynload::miopenConvolutionBackwardWeightsGetWorkSpaceSize(
184 185 186 187 188 189
            args.handle,
            args.odesc.desc(),
            args.idesc.desc(),
            args.cdesc.desc(),
            args.wdesc.desc(),
            &workspace_size));
190 191 192 193
    return workspace_size;
  }
};

194
}  // namespace phi