fusion_conv_inception_op.cu 13.8 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
17
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
Q
qingqing01 已提交
18 19 20 21

namespace paddle {
namespace operators {

22
#if CUDNN_VERSION >= 7100
Q
qingqing01 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
using Tensor = framework::Tensor;
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
using ScopedActivationDescriptor = platform::ScopedActivationDescriptor;
using DataLayout = platform::DataLayout;

using ScopedPoolingDescriptor = platform::ScopedPoolingDescriptor;
using PoolingMode = platform::PoolingMode;
template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;

template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;

template <typename T>
class CUDNNConvInceptionFusionOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
    auto* input = ctx.Input<Tensor>("Input");
    auto filters = ctx.MultiInput<framework::Tensor>("Filter");
    auto bias = ctx.MultiInput<framework::Tensor>("Bias");

    auto* output = ctx.Output<Tensor>("Output");
    auto temp_outs = ctx.MultiOutput<framework::Tensor>("TempOutput");

    const std::string pool_type = ctx.Attr<std::string>("pooling_type");
    const std::string activation = ctx.Attr<std::string>("activation");
    const bool exclusive = ctx.Attr<bool>("exclusive");

    int64_t user_workspace_size =
        static_cast<size_t>(ctx.Attr<int>("workspace_size_MB"));

    const T* input_data = input->data<T>();
    T* output_data = output->mutable_data<T>(ctx.GetPlace());
    T* temp_data = temp_outs[0]->mutable_data<T>(input->dims(), ctx.GetPlace());

    DataLayout layout = DataLayout::kNCHW;
62
    std::vector<int> in_dim = phi::vectorize<int>(input->dims());
Q
qingqing01 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82

    // ------------------- cudnn descriptors ---------------------
    PoolingMode pooling_mode;
    if (pool_type == "max") {
      pooling_mode = PoolingMode::kMaximum;
    } else {
      pooling_mode = exclusive ? PoolingMode::kAverageExclusive
                               : (PoolingMode::kAverageInclusive);
    }
    std::vector<int> k0x0 = {0, 0};
    std::vector<int> k1x1 = {1, 1};
    std::vector<int> k1x1_2 = {1, 1};
    std::vector<int> k3x3 = {3, 3};
    ScopedPoolingDescriptor pool_desc;
    ScopedActivationDescriptor act_desc;
    ScopedTensorDescriptor out_pool_desc;
    ScopedTensorDescriptor input_desc;
    cudnnPoolingDescriptor_t cudnn_pool_desc =
        pool_desc.descriptor(pooling_mode, k3x3, k1x1, k1x1);

83
    cudnnTensorDescriptor_t cudnn_input_desc =
84 85 86
        input_desc.descriptor<T>(layout, phi::vectorize<int>(input->dims()));
    cudnnTensorDescriptor_t pool_out_desc =
        out_pool_desc.descriptor<T>(layout, phi::vectorize<int>(input->dims()));
Q
qingqing01 已提交
87 88 89 90 91 92 93 94 95

    cudnnDataType_t cudnn_dtype = CudnnDataType<T>::type;
    cudnnTensorDescriptor_t* out_desc = new cudnnTensorDescriptor_t[4];
    cudnnFilterDescriptor_t* filter_desc = new cudnnFilterDescriptor_t[4];
    cudnnTensorDescriptor_t* bias_desc = new cudnnTensorDescriptor_t[4];
    cudnnTensorDescriptor_t* in_desc = new cudnnTensorDescriptor_t[4];
    cudnnConvolutionDescriptor_t* conv_desc =
        new cudnnConvolutionDescriptor_t[4];
    for (int i = 0; i < 4; ++i) {
96
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
qingqing01 已提交
97
          platform::dynload::cudnnCreateFilterDescriptor(&filter_desc[i]));
98
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
qingqing01 已提交
99
          platform::dynload::cudnnCreateTensorDescriptor(&bias_desc[i]));
100
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
qingqing01 已提交
101
          platform::dynload::cudnnCreateTensorDescriptor(&in_desc[i]));
102
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
qingqing01 已提交
103
          platform::dynload::cudnnCreateTensorDescriptor(&out_desc[i]));
104
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
qingqing01 已提交
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
          platform::dynload::cudnnCreateConvolutionDescriptor(&conv_desc[i]));
    }

    std::vector<std::vector<int>> filter_dims;
    std::vector<std::vector<int>> bias_dims;
    std::vector<std::vector<int>> in_dims;
    std::vector<std::vector<int>> out_dims;
    std::vector<std::vector<int>> in_strides;
    std::vector<std::vector<int>> out_strides;
    std::vector<std::vector<int>> bias_strides;

    cudnnTensorFormat_t format = CUDNN_TENSOR_NCHW;
    int n = in_dim[0];
    int h = in_dim[2];
    int w = in_dim[3];
    int oc = output->dims()[1];

    cudnnDataType_t compute_type = (cudnn_dtype == CUDNN_DATA_DOUBLE)
                                       ? CUDNN_DATA_DOUBLE
                                       : CUDNN_DATA_FLOAT;

    for (int i = 0; i < 4; ++i) {
127
      filter_dims.push_back(phi::vectorize<int>(filters[i]->dims()));
128
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetFilterNdDescriptor(
Q
qingqing01 已提交
129 130 131
          filter_desc[i], cudnn_dtype, format, 4, filter_dims[i].data()));
      bias_dims.push_back({1, filter_dims[i][0], 1, 1});
      bias_strides.push_back({filter_dims[i][0], 1, 1, 1});
132
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
133 134 135 136
          bias_desc[i],
          cudnn_dtype,
          4,
          bias_dims[i].data(),
Q
qingqing01 已提交
137 138 139 140 141 142 143
          bias_strides[i].data()));
      in_dims.push_back({n, filter_dims[i][1], h, w});
      out_dims.push_back({n, filter_dims[i][0], h, w});
      in_strides.push_back({filter_dims[i][1] * h * w, h * w, w, 1});
      out_strides.push_back({oc * h * w, h * w, w, 1});

      if (i < 2) {
144
        PADDLE_ENFORCE_GPU_SUCCESS(
145
            platform::dynload::cudnnSetConvolutionNdDescriptor(
146 147 148 149 150 151 152
                conv_desc[i],
                2,
                k0x0.data(),
                k1x1.data(),
                k1x1.data(),
                CUDNN_CROSS_CORRELATION,
                compute_type));
Q
qingqing01 已提交
153
      } else {
154
        PADDLE_ENFORCE_GPU_SUCCESS(
155
            platform::dynload::cudnnSetConvolutionNdDescriptor(
156 157 158 159 160 161 162
                conv_desc[i],
                2,
                k1x1.data(),
                k1x1.data(),
                k1x1.data(),
                CUDNN_CROSS_CORRELATION,
                compute_type));
Q
qingqing01 已提交
163
      }
164 165
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
          conv_desc[i], CUDNN_DEFAULT_MATH));
A
AshburnLee 已提交
166
#if CUDA_VERSION >= 11000 && CUDNN_VERSION >= 8000
A
AshburnLee 已提交
167
      if (!platform::allow_tf32_cudnn) {
168
        PADDLE_ENFORCE_GPU_SUCCESS(
A
AshburnLee 已提交
169 170 171
            platform::dynload::cudnnSetConvolutionMathType(conv_desc[i],
                                                           CUDNN_FMA_MATH));
      }
A
AshburnLee 已提交
172
#endif  // CUDA_VERSION >= 11000 && CUDNN_VERSION >= 8000
Q
qingqing01 已提交
173 174 175 176 177
    }
    in_dims[2][1] *= 2;
    in_strides[2][0] = oc * h * w;
    out_strides[2][0] = filter_dims[2][0] * h * w;  // this out is continuous.
    in_strides[3][0] = filter_dims[2][0] * h * w;
178
    PADDLE_ENFORCE_GPU_SUCCESS(
Q
qingqing01 已提交
179 180 181 182 183 184
        platform::dynload::cudnnSetConvolutionGroupCount(conv_desc[2], 2));

    cudnnConvolutionFwdAlgo_t algo[4];
    auto handle = dev_ctx.cudnn_handle();
    size_t workspace_size_in_bytes = 0;  // final workspace to allocate.

185
    size_t workspace_size_limit = 0;
Q
qingqing01 已提交
186 187
    if (FLAGS_conv_workspace_size_limit > 0 || user_workspace_size > 0) {
      int64_t max_user_size =
188
          std::min(static_cast<int64_t>(FLAGS_conv_workspace_size_limit),
Q
qingqing01 已提交
189 190 191 192 193
                   user_workspace_size);
      workspace_size_limit = max_user_size * 1024 * 1024;
    }

    for (int i = 0; i < 4; ++i) {
194
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
Q
qingqing01 已提交
195
          in_desc[i], cudnn_dtype, 4, in_dims[i].data(), in_strides[i].data()));
196 197 198 199 200 201
      PADDLE_ENFORCE_GPU_SUCCESS(
          platform::dynload::cudnnSetTensorNdDescriptor(out_desc[i],
                                                        cudnn_dtype,
                                                        4,
                                                        out_dims[i].data(),
                                                        out_strides[i].data()));
202 203 204 205 206 207

      int perf_count;
      int best_algo_idx = 0;
      size_t tmp_size = 0;
      std::unique_ptr<cudnnConvolutionFwdAlgoPerf_t[]> perf_results(
          new cudnnConvolutionFwdAlgoPerf_t[kNUM_CUDNN_FWD_ALGS]);
208
      PADDLE_ENFORCE_GPU_SUCCESS(
209
          platform::dynload::cudnnGetConvolutionForwardAlgorithm_v7(
210 211 212 213 214 215 216 217
              handle,
              in_desc[i],
              filter_desc[i],
              conv_desc[i],
              out_desc[i],
              kNUM_CUDNN_FWD_ALGS,
              &perf_count,
              perf_results.get()));
218 219
      algo[i] = (perf_results.get())[best_algo_idx].algo;

220
      PADDLE_ENFORCE_GPU_SUCCESS(
221
          platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
222 223 224 225 226 227 228
              handle,
              in_desc[i],
              filter_desc[i],
              conv_desc[i],
              out_desc[i],
              algo[i],
              &tmp_size));
229

Q
qingqing01 已提交
230 231 232 233 234 235 236 237 238 239 240 241
      workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size);
    }
    cudnnActivationDescriptor_t cudnn_act_desc =
        act_desc.descriptor<T>(activation);

    int oc0 = filter_dims[0][0];
    int oc1 = filter_dims[1][0] - filter_dims[2][1] * 2;
    int oc3 = filter_dims[3][0];
    int oc2 = oc - oc0 - oc1 - oc3;

    // branch1: pool + 1x1 conv
    ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
242 243 244 245 246 247 248 249 250
    PADDLE_ENFORCE_GPU_SUCCESS(
        platform::dynload::cudnnPoolingForward(handle,
                                               cudnn_pool_desc,
                                               &alpha,
                                               cudnn_input_desc,
                                               input_data,
                                               &beta,
                                               pool_out_desc,
                                               temp_data));
Q
qingqing01 已提交
251 252 253 254 255 256

    std::vector<const void*> in_datas;
    in_datas.push_back(static_cast<const void*>(temp_data));
    in_datas.push_back(static_cast<const void*>(input_data));
    in_datas.push_back(
        static_cast<const void*>(output_data + (oc0 + oc1) * h * w));
257
    T* temp2_data = temp_outs[1]->mutable_data<T>(phi::make_ddim(out_dims[2]),
258
                                                  ctx.GetPlace());
Q
qingqing01 已提交
259 260 261 262 263 264 265 266 267 268
    in_datas.push_back(static_cast<const void*>(temp2_data + oc2 * h * w));

    std::vector<void*> out_datas;
    out_datas.push_back(static_cast<void*>(output_data));
    out_datas.push_back(static_cast<void*>(output_data + oc0 * h * w));
    out_datas.push_back(static_cast<void*>(temp2_data));
    out_datas.push_back(
        static_cast<void*>(output_data + (oc0 + oc1 + oc2) * h * w));

    for (int i = 0; i < 4; ++i) {
C
chengduo 已提交
269
      auto func = [&](void* cudnn_workspace) {
270
        PADDLE_ENFORCE_GPU_SUCCESS(
271
            platform::dynload::cudnnConvolutionBiasActivationForward(
272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
                handle,
                &alpha,
                in_desc[i],
                in_datas[i],
                filter_desc[i],
                static_cast<const void*>(filters[i]->data<T>()),
                conv_desc[i],
                algo[i],
                cudnn_workspace,
                workspace_size_in_bytes,
                &beta,
                out_desc[i],
                out_datas[i],
                bias_desc[i],
                static_cast<const void*>(bias[i]->data<T>()),
                cudnn_act_desc,
                out_desc[i],
                out_datas[i]));
C
chengduo 已提交
290 291 292
      };
      auto workspace_handle = dev_ctx.cudnn_workspace_handle();
      workspace_handle.RunFunc(func, workspace_size_in_bytes);
Q
qingqing01 已提交
293 294 295 296
    }

    cudnnTensorDescriptor_t x_desc;
    cudnnTensorDescriptor_t y_desc;
297
    PADDLE_ENFORCE_GPU_SUCCESS(
298
        platform::dynload::cudnnCreateTensorDescriptor(&x_desc));
299
    PADDLE_ENFORCE_GPU_SUCCESS(
300
        platform::dynload::cudnnCreateTensorDescriptor(&y_desc));
301
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
Q
qingqing01 已提交
302
        x_desc, cudnn_dtype, 4, out_dims[3].data(), out_strides[2].data()));
303
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
Q
qingqing01 已提交
304
        y_desc, cudnn_dtype, 4, out_dims[3].data(), out_strides[3].data()));
305
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnTransformTensor(
306 307 308 309 310 311 312
        handle,
        CudnnDataType<T>::kOne(),
        x_desc,
        static_cast<const void*>(out_datas[2]),
        CudnnDataType<T>::kZero(),
        y_desc,
        static_cast<void*>(output_data + (oc0 + oc1) * h * w)));
Q
qingqing01 已提交
313 314

    for (int i = 0; i < 4; ++i) {
315
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
qingqing01 已提交
316
          platform::dynload::cudnnDestroyTensorDescriptor(in_desc[i]));
317
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
qingqing01 已提交
318
          platform::dynload::cudnnDestroyTensorDescriptor(out_desc[i]));
319
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
qingqing01 已提交
320
          platform::dynload::cudnnDestroyFilterDescriptor(filter_desc[i]));
321
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
qingqing01 已提交
322
          platform::dynload::cudnnDestroyTensorDescriptor(bias_desc[i]));
323
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
qingqing01 已提交
324 325
          platform::dynload::cudnnDestroyConvolutionDescriptor(conv_desc[i]));
    }
326
    PADDLE_ENFORCE_GPU_SUCCESS(
327
        platform::dynload::cudnnDestroyTensorDescriptor(x_desc));
328
    PADDLE_ENFORCE_GPU_SUCCESS(
329
        platform::dynload::cudnnDestroyTensorDescriptor(y_desc));
Q
qingqing01 已提交
330 331 332 333 334 335 336
  }
};
#endif

}  // namespace operators
}  // namespace paddle

337
#if CUDNN_VERSION >= 7100
Q
qingqing01 已提交
338 339 340 341 342
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(conv2d_inception_fusion,
                        ops::CUDNNConvInceptionFusionOpKernel<float>,
                        ops::CUDNNConvInceptionFusionOpKernel<double>);
#endif