fusion_conv_inception_op.cu 13.9 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
17
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
Q
qingqing01 已提交
18 19 20 21

namespace paddle {
namespace operators {

22
#if CUDNN_VERSION >= 7100
Q
qingqing01 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
using Tensor = framework::Tensor;
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
using ScopedActivationDescriptor = platform::ScopedActivationDescriptor;
using DataLayout = platform::DataLayout;

using ScopedPoolingDescriptor = platform::ScopedPoolingDescriptor;
using PoolingMode = platform::PoolingMode;
template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;

template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;

template <typename T>
class CUDNNConvInceptionFusionOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
L
Leo Chen 已提交
42
    auto& dev_ctx = ctx.template device_context<phi::GPUContext>();
Q
qingqing01 已提交
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
    auto* input = ctx.Input<Tensor>("Input");
    auto filters = ctx.MultiInput<framework::Tensor>("Filter");
    auto bias = ctx.MultiInput<framework::Tensor>("Bias");

    auto* output = ctx.Output<Tensor>("Output");
    auto temp_outs = ctx.MultiOutput<framework::Tensor>("TempOutput");

    const std::string pool_type = ctx.Attr<std::string>("pooling_type");
    const std::string activation = ctx.Attr<std::string>("activation");
    const bool exclusive = ctx.Attr<bool>("exclusive");

    int64_t user_workspace_size =
        static_cast<size_t>(ctx.Attr<int>("workspace_size_MB"));

    const T* input_data = input->data<T>();
58 59 60 61
    T* output_data = dev_ctx.Alloc<T>(output, output->numel() * sizeof(T));
    temp_outs[0]->Resize(input->dims());
    T* temp_data =
        dev_ctx.Alloc<T>(temp_outs[0], temp_outs[0]->numel() * sizeof(T));
Q
qingqing01 已提交
62 63

    DataLayout layout = DataLayout::kNCHW;
64
    std::vector<int> in_dim = phi::vectorize<int>(input->dims());
Q
qingqing01 已提交
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84

    // ------------------- cudnn descriptors ---------------------
    PoolingMode pooling_mode;
    if (pool_type == "max") {
      pooling_mode = PoolingMode::kMaximum;
    } else {
      pooling_mode = exclusive ? PoolingMode::kAverageExclusive
                               : (PoolingMode::kAverageInclusive);
    }
    std::vector<int> k0x0 = {0, 0};
    std::vector<int> k1x1 = {1, 1};
    std::vector<int> k1x1_2 = {1, 1};
    std::vector<int> k3x3 = {3, 3};
    ScopedPoolingDescriptor pool_desc;
    ScopedActivationDescriptor act_desc;
    ScopedTensorDescriptor out_pool_desc;
    ScopedTensorDescriptor input_desc;
    cudnnPoolingDescriptor_t cudnn_pool_desc =
        pool_desc.descriptor(pooling_mode, k3x3, k1x1, k1x1);

85
    cudnnTensorDescriptor_t cudnn_input_desc =
86 87 88
        input_desc.descriptor<T>(layout, phi::vectorize<int>(input->dims()));
    cudnnTensorDescriptor_t pool_out_desc =
        out_pool_desc.descriptor<T>(layout, phi::vectorize<int>(input->dims()));
Q
qingqing01 已提交
89 90 91 92 93 94 95 96 97

    cudnnDataType_t cudnn_dtype = CudnnDataType<T>::type;
    cudnnTensorDescriptor_t* out_desc = new cudnnTensorDescriptor_t[4];
    cudnnFilterDescriptor_t* filter_desc = new cudnnFilterDescriptor_t[4];
    cudnnTensorDescriptor_t* bias_desc = new cudnnTensorDescriptor_t[4];
    cudnnTensorDescriptor_t* in_desc = new cudnnTensorDescriptor_t[4];
    cudnnConvolutionDescriptor_t* conv_desc =
        new cudnnConvolutionDescriptor_t[4];
    for (int i = 0; i < 4; ++i) {
98
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
qingqing01 已提交
99
          platform::dynload::cudnnCreateFilterDescriptor(&filter_desc[i]));
100
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
qingqing01 已提交
101
          platform::dynload::cudnnCreateTensorDescriptor(&bias_desc[i]));
102
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
qingqing01 已提交
103
          platform::dynload::cudnnCreateTensorDescriptor(&in_desc[i]));
104
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
qingqing01 已提交
105
          platform::dynload::cudnnCreateTensorDescriptor(&out_desc[i]));
106
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
qingqing01 已提交
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
          platform::dynload::cudnnCreateConvolutionDescriptor(&conv_desc[i]));
    }

    std::vector<std::vector<int>> filter_dims;
    std::vector<std::vector<int>> bias_dims;
    std::vector<std::vector<int>> in_dims;
    std::vector<std::vector<int>> out_dims;
    std::vector<std::vector<int>> in_strides;
    std::vector<std::vector<int>> out_strides;
    std::vector<std::vector<int>> bias_strides;

    cudnnTensorFormat_t format = CUDNN_TENSOR_NCHW;
    int n = in_dim[0];
    int h = in_dim[2];
    int w = in_dim[3];
    int oc = output->dims()[1];

    cudnnDataType_t compute_type = (cudnn_dtype == CUDNN_DATA_DOUBLE)
                                       ? CUDNN_DATA_DOUBLE
                                       : CUDNN_DATA_FLOAT;

    for (int i = 0; i < 4; ++i) {
129
      filter_dims.push_back(phi::vectorize<int>(filters[i]->dims()));
130
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetFilterNdDescriptor(
Q
qingqing01 已提交
131 132 133
          filter_desc[i], cudnn_dtype, format, 4, filter_dims[i].data()));
      bias_dims.push_back({1, filter_dims[i][0], 1, 1});
      bias_strides.push_back({filter_dims[i][0], 1, 1, 1});
134
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
135 136 137 138
          bias_desc[i],
          cudnn_dtype,
          4,
          bias_dims[i].data(),
Q
qingqing01 已提交
139 140 141 142 143 144 145
          bias_strides[i].data()));
      in_dims.push_back({n, filter_dims[i][1], h, w});
      out_dims.push_back({n, filter_dims[i][0], h, w});
      in_strides.push_back({filter_dims[i][1] * h * w, h * w, w, 1});
      out_strides.push_back({oc * h * w, h * w, w, 1});

      if (i < 2) {
146
        PADDLE_ENFORCE_GPU_SUCCESS(
147
            platform::dynload::cudnnSetConvolutionNdDescriptor(
148 149 150 151 152 153 154
                conv_desc[i],
                2,
                k0x0.data(),
                k1x1.data(),
                k1x1.data(),
                CUDNN_CROSS_CORRELATION,
                compute_type));
Q
qingqing01 已提交
155
      } else {
156
        PADDLE_ENFORCE_GPU_SUCCESS(
157
            platform::dynload::cudnnSetConvolutionNdDescriptor(
158 159 160 161 162 163 164
                conv_desc[i],
                2,
                k1x1.data(),
                k1x1.data(),
                k1x1.data(),
                CUDNN_CROSS_CORRELATION,
                compute_type));
Q
qingqing01 已提交
165
      }
166 167
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
          conv_desc[i], CUDNN_DEFAULT_MATH));
A
AshburnLee 已提交
168
#if CUDA_VERSION >= 11000 && CUDNN_VERSION >= 8000
A
AshburnLee 已提交
169
      if (!platform::allow_tf32_cudnn) {
170
        PADDLE_ENFORCE_GPU_SUCCESS(
A
AshburnLee 已提交
171 172 173
            platform::dynload::cudnnSetConvolutionMathType(conv_desc[i],
                                                           CUDNN_FMA_MATH));
      }
A
AshburnLee 已提交
174
#endif  // CUDA_VERSION >= 11000 && CUDNN_VERSION >= 8000
Q
qingqing01 已提交
175 176 177 178 179
    }
    in_dims[2][1] *= 2;
    in_strides[2][0] = oc * h * w;
    out_strides[2][0] = filter_dims[2][0] * h * w;  // this out is continuous.
    in_strides[3][0] = filter_dims[2][0] * h * w;
180
    PADDLE_ENFORCE_GPU_SUCCESS(
Q
qingqing01 已提交
181 182 183 184 185 186
        platform::dynload::cudnnSetConvolutionGroupCount(conv_desc[2], 2));

    cudnnConvolutionFwdAlgo_t algo[4];
    auto handle = dev_ctx.cudnn_handle();
    size_t workspace_size_in_bytes = 0;  // final workspace to allocate.

187
    size_t workspace_size_limit = 0;
Q
qingqing01 已提交
188 189
    if (FLAGS_conv_workspace_size_limit > 0 || user_workspace_size > 0) {
      int64_t max_user_size =
190
          std::min(static_cast<int64_t>(FLAGS_conv_workspace_size_limit),
Q
qingqing01 已提交
191 192 193 194 195
                   user_workspace_size);
      workspace_size_limit = max_user_size * 1024 * 1024;
    }

    for (int i = 0; i < 4; ++i) {
196
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
Q
qingqing01 已提交
197
          in_desc[i], cudnn_dtype, 4, in_dims[i].data(), in_strides[i].data()));
198 199 200 201 202 203
      PADDLE_ENFORCE_GPU_SUCCESS(
          platform::dynload::cudnnSetTensorNdDescriptor(out_desc[i],
                                                        cudnn_dtype,
                                                        4,
                                                        out_dims[i].data(),
                                                        out_strides[i].data()));
204 205 206 207 208 209

      int perf_count;
      int best_algo_idx = 0;
      size_t tmp_size = 0;
      std::unique_ptr<cudnnConvolutionFwdAlgoPerf_t[]> perf_results(
          new cudnnConvolutionFwdAlgoPerf_t[kNUM_CUDNN_FWD_ALGS]);
210
      PADDLE_ENFORCE_GPU_SUCCESS(
211
          platform::dynload::cudnnGetConvolutionForwardAlgorithm_v7(
212 213 214 215 216 217 218 219
              handle,
              in_desc[i],
              filter_desc[i],
              conv_desc[i],
              out_desc[i],
              kNUM_CUDNN_FWD_ALGS,
              &perf_count,
              perf_results.get()));
220 221
      algo[i] = (perf_results.get())[best_algo_idx].algo;

222
      PADDLE_ENFORCE_GPU_SUCCESS(
223
          platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
224 225 226 227 228 229 230
              handle,
              in_desc[i],
              filter_desc[i],
              conv_desc[i],
              out_desc[i],
              algo[i],
              &tmp_size));
231

Q
qingqing01 已提交
232 233 234 235 236 237 238 239 240 241 242 243
      workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size);
    }
    cudnnActivationDescriptor_t cudnn_act_desc =
        act_desc.descriptor<T>(activation);

    int oc0 = filter_dims[0][0];
    int oc1 = filter_dims[1][0] - filter_dims[2][1] * 2;
    int oc3 = filter_dims[3][0];
    int oc2 = oc - oc0 - oc1 - oc3;

    // branch1: pool + 1x1 conv
    ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
244 245 246 247 248 249 250 251 252
    PADDLE_ENFORCE_GPU_SUCCESS(
        platform::dynload::cudnnPoolingForward(handle,
                                               cudnn_pool_desc,
                                               &alpha,
                                               cudnn_input_desc,
                                               input_data,
                                               &beta,
                                               pool_out_desc,
                                               temp_data));
Q
qingqing01 已提交
253 254 255 256 257 258

    std::vector<const void*> in_datas;
    in_datas.push_back(static_cast<const void*>(temp_data));
    in_datas.push_back(static_cast<const void*>(input_data));
    in_datas.push_back(
        static_cast<const void*>(output_data + (oc0 + oc1) * h * w));
259 260 261
    temp_outs[1]->Resize(phi::make_ddim(out_dims[2]));
    T* temp2_data =
        dev_ctx.Alloc<T>(temp_outs[1], temp_outs[1]->numel() * sizeof(T));
Q
qingqing01 已提交
262 263 264 265 266 267 268 269 270 271
    in_datas.push_back(static_cast<const void*>(temp2_data + oc2 * h * w));

    std::vector<void*> out_datas;
    out_datas.push_back(static_cast<void*>(output_data));
    out_datas.push_back(static_cast<void*>(output_data + oc0 * h * w));
    out_datas.push_back(static_cast<void*>(temp2_data));
    out_datas.push_back(
        static_cast<void*>(output_data + (oc0 + oc1 + oc2) * h * w));

    for (int i = 0; i < 4; ++i) {
C
chengduo 已提交
272
      auto func = [&](void* cudnn_workspace) {
273
        PADDLE_ENFORCE_GPU_SUCCESS(
274
            platform::dynload::cudnnConvolutionBiasActivationForward(
275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292
                handle,
                &alpha,
                in_desc[i],
                in_datas[i],
                filter_desc[i],
                static_cast<const void*>(filters[i]->data<T>()),
                conv_desc[i],
                algo[i],
                cudnn_workspace,
                workspace_size_in_bytes,
                &beta,
                out_desc[i],
                out_datas[i],
                bias_desc[i],
                static_cast<const void*>(bias[i]->data<T>()),
                cudnn_act_desc,
                out_desc[i],
                out_datas[i]));
C
chengduo 已提交
293 294 295
      };
      auto workspace_handle = dev_ctx.cudnn_workspace_handle();
      workspace_handle.RunFunc(func, workspace_size_in_bytes);
Q
qingqing01 已提交
296 297 298 299
    }

    cudnnTensorDescriptor_t x_desc;
    cudnnTensorDescriptor_t y_desc;
300
    PADDLE_ENFORCE_GPU_SUCCESS(
301
        platform::dynload::cudnnCreateTensorDescriptor(&x_desc));
302
    PADDLE_ENFORCE_GPU_SUCCESS(
303
        platform::dynload::cudnnCreateTensorDescriptor(&y_desc));
304
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
Q
qingqing01 已提交
305
        x_desc, cudnn_dtype, 4, out_dims[3].data(), out_strides[2].data()));
306
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
Q
qingqing01 已提交
307
        y_desc, cudnn_dtype, 4, out_dims[3].data(), out_strides[3].data()));
308
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnTransformTensor(
309 310 311 312 313 314 315
        handle,
        CudnnDataType<T>::kOne(),
        x_desc,
        static_cast<const void*>(out_datas[2]),
        CudnnDataType<T>::kZero(),
        y_desc,
        static_cast<void*>(output_data + (oc0 + oc1) * h * w)));
Q
qingqing01 已提交
316 317

    for (int i = 0; i < 4; ++i) {
318
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
qingqing01 已提交
319
          platform::dynload::cudnnDestroyTensorDescriptor(in_desc[i]));
320
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
qingqing01 已提交
321
          platform::dynload::cudnnDestroyTensorDescriptor(out_desc[i]));
322
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
qingqing01 已提交
323
          platform::dynload::cudnnDestroyFilterDescriptor(filter_desc[i]));
324
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
qingqing01 已提交
325
          platform::dynload::cudnnDestroyTensorDescriptor(bias_desc[i]));
326
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
qingqing01 已提交
327 328
          platform::dynload::cudnnDestroyConvolutionDescriptor(conv_desc[i]));
    }
329
    PADDLE_ENFORCE_GPU_SUCCESS(
330
        platform::dynload::cudnnDestroyTensorDescriptor(x_desc));
331
    PADDLE_ENFORCE_GPU_SUCCESS(
332
        platform::dynload::cudnnDestroyTensorDescriptor(y_desc));
Q
qingqing01 已提交
333 334 335 336 337 338 339
  }
};
#endif

}  // namespace operators
}  // namespace paddle

340
#if CUDNN_VERSION >= 7100
Q
qingqing01 已提交
341 342 343 344 345
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(conv2d_inception_fusion,
                        ops::CUDNNConvInceptionFusionOpKernel<float>,
                        ops::CUDNNConvInceptionFusionOpKernel<double>);
#endif