conv_miopen_helper.h 7.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

Z
zyfncg 已提交
17
#include "paddle/fluid/framework/eigen.h"
18
#include "paddle/fluid/operators/conv_base_helper.h"
19 20 21 22

namespace paddle {
namespace operators {

23
using ConvArgs = ConvArgsBase<miopenHandle_t, miopenDataType_t>;
24 25

template <typename DeviceContext, typename T, size_t D>
H
hong 已提交
26
static void RemovePaddingSlice(const phi::GPUContext& context,
27 28
                               const Tensor* input,
                               Tensor* out,
29 30
                               const std::vector<int>& starts,
                               const std::vector<int>& axes) {
H
hong 已提交
31
  auto& place = *context.eigen_device();
32 33 34 35 36 37 38 39 40 41
  auto in_dims = input->dims();
  auto new_out_dims = out->dims();
  auto offsets = Eigen::array<int, D>();
  auto extents = Eigen::array<int, D>();
  for (size_t i = 0; i < D; ++i) {
    offsets[i] = 0;
    extents[i] = new_out_dims[i];
  }

  for (size_t i = 0; i < axes.size(); ++i) {
42
    int start = starts[i];
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
    if (start < 0) {
      start = (start + in_dims[axes[i]]);
    }
    start = std::max(start, 0);
    offsets[axes[i]] = start;
  }
  auto in_t =
      framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
          *input);

  auto out_t =
      framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
          *out, new_out_dims);
  out_t.device(place) = in_t.slice(offsets, extents);
}

Y
Yiqun Liu 已提交
59 60 61
template <typename PerfT>
struct SearchAlgorithm {};

62 63 64 65 66 67
template <>
struct SearchAlgorithm<miopenConvFwdAlgorithm_t> {
  using perf_t = miopenConvAlgoPerf_t;
  using algo_t = miopenConvFwdAlgorithm_t;

  template <typename T>
68 69 70 71
  static algo_t Find(const ConvArgs& args,
                     bool exhaustive_search,
                     bool deterministic,
                     size_t workspace_size,
H
hong 已提交
72
                     const phi::GPUContext& ctx) {
73 74
    algo_t algo;

H
hong 已提交
75
    auto workspace_handle = ctx.cudnn_workspace_handle();
76

77 78 79
    int find_count;
    miopenConvAlgoPerf_t find_result;
    auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
80
      PADDLE_ENFORCE_GPU_SUCCESS(
81
          platform::dynload::miopenFindConvolutionForwardAlgorithm(
82 83 84 85 86 87 88 89 90 91 92 93 94 95
              args.handle,
              args.idesc.desc(),
              args.x->data<T>(),
              args.wdesc.desc(),
              args.w->data<T>(),
              args.cdesc.desc(),
              args.odesc.desc(),
              const_cast<T*>(args.o->data<T>()),
              kNUM_CUDNN_FWD_ALGS,
              &find_count,
              &find_result,
              cudnn_workspace_ptr,
              workspace_size,
              false));
96 97
    };

R
ronnywang 已提交
98 99
    workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
    algo = find_result.fwd_algo;
100 101 102 103
    VLOG(3) << "choose algo " << algo;
    return algo;
  }

104
  static size_t GetWorkspaceSize(const ConvArgs& args) {
105
    size_t workspace_size = 0;
106
    PADDLE_ENFORCE_GPU_SUCCESS(
107
        platform::dynload::miopenConvolutionForwardGetWorkSpaceSize(
108 109 110 111 112 113
            args.handle,
            args.wdesc.desc(),
            args.idesc.desc(),
            args.cdesc.desc(),
            args.odesc.desc(),
            &workspace_size));
114 115 116 117 118 119 120 121 122 123
    return workspace_size;
  }
};

template <>
struct SearchAlgorithm<miopenConvBwdDataAlgorithm_t> {
  using perf_t = miopenConvAlgoPerf_t;
  using algo_t = miopenConvBwdDataAlgorithm_t;

  template <typename T>
124 125 126 127
  static algo_t Find(const ConvArgs& args,
                     bool exhaustive_search,
                     bool deterministic,
                     size_t workspace_size,
H
hong 已提交
128
                     const phi::GPUContext& ctx) {
129 130
    algo_t algo;

H
hong 已提交
131
    auto workspace_handle = ctx.cudnn_workspace_handle();
132

133 134 135
    int find_count;
    miopenConvAlgoPerf_t find_result;
    auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
136
      PADDLE_ENFORCE_GPU_SUCCESS(
137
          platform::dynload::miopenFindConvolutionBackwardDataAlgorithm(
138 139 140 141 142 143 144 145 146 147 148 149 150 151
              args.handle,
              args.odesc.desc(),
              args.o->data<T>(),
              args.wdesc.desc(),
              args.w->data<T>(),
              args.cdesc.desc(),
              args.idesc.desc(),
              const_cast<T*>(args.x->data<T>()),
              kNUM_CUDNN_BWD_DATA_ALGS,
              &find_count,
              &find_result,
              cudnn_workspace_ptr,
              workspace_size,
              false));
152 153
    };

R
ronnywang 已提交
154 155
    workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
    algo = find_result.bwd_data_algo;
156 157 158 159
    VLOG(3) << "choose algo " << algo;
    return algo;
  }

160
  static size_t GetWorkspaceSize(const ConvArgs& args) {
161
    size_t workspace_size = 0;
162
    PADDLE_ENFORCE_GPU_SUCCESS(
163
        platform::dynload::miopenConvolutionBackwardDataGetWorkSpaceSize(
164 165 166 167 168 169
            args.handle,
            args.odesc.desc(),
            args.wdesc.desc(),
            args.cdesc.desc(),
            args.idesc.desc(),
            &workspace_size));
170 171 172 173 174 175 176 177 178 179
    return workspace_size;
  }
};

template <>
struct SearchAlgorithm<miopenConvBwdWeightsAlgorithm_t> {
  using perf_t = miopenConvAlgoPerf_t;
  using algo_t = miopenConvBwdWeightsAlgorithm_t;

  template <typename T>
180 181 182 183
  static algo_t Find(const ConvArgs& args,
                     bool exhaustive_search,
                     bool deterministic,
                     size_t workspace_size,
H
hong 已提交
184
                     const phi::GPUContext& ctx) {
185 186
    algo_t algo;

H
hong 已提交
187
    auto workspace_handle = ctx.cudnn_workspace_handle();
188 189 190 191

    int find_count;
    miopenConvAlgoPerf_t find_result;
    auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
192
      PADDLE_ENFORCE_GPU_SUCCESS(
193
          platform::dynload::miopenFindConvolutionBackwardWeightsAlgorithm(
194 195 196 197 198 199 200 201 202 203 204 205 206 207
              args.handle,
              args.odesc.desc(),
              args.o->data<T>(),
              args.idesc.desc(),
              args.x->data<T>(),
              args.cdesc.desc(),
              args.wdesc.desc(),
              const_cast<T*>(args.w->data<T>()),
              kNUM_CUDNN_BWD_FILTER_ALGS,
              &find_count,
              &find_result,
              cudnn_workspace_ptr,
              workspace_size,
              false));
208 209
    };

R
ronnywang 已提交
210 211
    workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
    algo = find_result.bwd_weights_algo;
212 213 214 215
    VLOG(3) << "choose algo " << algo;
    return algo;
  }

216
  static size_t GetWorkspaceSize(const ConvArgs& args) {
217
    size_t workspace_size = 0;
218
    PADDLE_ENFORCE_GPU_SUCCESS(
219
        platform::dynload::miopenConvolutionBackwardWeightsGetWorkSpaceSize(
220 221 222 223 224 225
            args.handle,
            args.odesc.desc(),
            args.idesc.desc(),
            args.cdesc.desc(),
            args.wdesc.desc(),
            &workspace_size));
226 227 228 229 230 231
    return workspace_size;
  }
};

}  // namespace operators
}  // namespace paddle