conv_fusion_op.cu 19.5 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <array>
16
#include "paddle/fluid/framework/conv_search_cache.h"
Q
qingqing01 已提交
17 18
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
19 20
#include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/operators/math/padding.h"
R
ronnywang 已提交
21 22 23
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#else
24
#include "paddle/fluid/platform/cudnn_helper.h"
R
ronnywang 已提交
25
#endif
Q
qingqing01 已提交
26

27
DECLARE_int64(cudnn_exhaustive_search_times);
Q
qingqing01 已提交
28 29 30 31

namespace paddle {
namespace operators {

R
ronnywang 已提交
32
#if PADDLE_WITH_HIP || CUDNN_VERSION >= 7100
Q
qingqing01 已提交
33 34 35 36 37 38
using Tensor = framework::Tensor;
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
using ScopedActivationDescriptor = platform::ScopedActivationDescriptor;
using DataLayout = platform::DataLayout;
39
using framework::AlgorithmsCache;
40
using framework::ConvSearchCache;
41

Q
qingqing01 已提交
42 43 44 45 46 47 48 49 50 51 52 53 54
template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;

template <typename T>
class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
    auto* input = ctx.Input<Tensor>("Input");
    auto* filter = ctx.Input<Tensor>("Filter");
    auto* bias = ctx.Input<Tensor>("Bias");
    auto* residual = ctx.Input<Tensor>("ResidualData");
    auto* output = ctx.Output<Tensor>("Output");
55
    output->mutable_data<T>(ctx.GetPlace());
Q
qingqing01 已提交
56 57 58 59 60 61 62 63 64 65 66 67 68

    std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
    std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
    const std::string activation = ctx.Attr<std::string>("activation");
    int groups = ctx.Attr<int>("groups");
    int64_t user_workspace_size =
        static_cast<size_t>(ctx.Attr<int>("workspace_size_MB"));
    bool exhaustive_search =
        FLAGS_cudnn_exhaustive_search || ctx.Attr<bool>("exhaustive_search");

    const T* filter_data = filter->data<T>();
    const T* bias_data = bias->data<T>();
69 70 71 72 73 74 75 76

    const std::string padding_algorithm =
        ctx.Attr<std::string>("padding_algorithm");

    Tensor transformed_input_channel(input->type());
    Tensor transformed_output(output->type());
    transformed_input_channel = *input;
    transformed_output = *output;
77 78
    T* output_data = transformed_output.data<T>();

Q
qingqing01 已提交
79
    const T* residual_data = residual ? residual->data<T>() : output_data;
80

81 82 83
    // update padding and dilation
    auto in_dims = transformed_input_channel.dims();
    auto filter_dims = filter->dims();
84 85
    framework::DDim in_data_dims =
        framework::slice_ddim(in_dims, 2, in_dims.size());
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136

    framework::DDim filter_data_dims =
        framework::slice_ddim(filter_dims, 2, filter_dims.size());
    std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             in_data_dims, strides, ksize);

    int data_dim = strides.size();  // 2d or 3d
    bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim);

    Tensor transformed_input;
    std::vector<int> padding_common(data_dim, 0);
    if (!is_sys_pad) {
      std::vector<int> padding_diff(data_dim);
      std::vector<int> new_input_shape_vec(data_dim + 2);
      new_input_shape_vec[0] = transformed_input_channel.dims()[0];
      new_input_shape_vec[1] = transformed_input_channel.dims()[1];

      std::vector<int> input_pad(transformed_input_channel.dims().size() * 2,
                                 0);
      for (size_t i = 0; i < data_dim; ++i) {
        padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]);
        padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]);
        new_input_shape_vec[i + 2] =
            transformed_input_channel.dims()[i + 2] + padding_diff[i];
        input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i];
        input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i];
      }
      framework::DDim new_input_shape(
          framework::make_ddim(new_input_shape_vec));
      transformed_input.Resize(new_input_shape);
      auto& dev_ctx =
          ctx.template device_context<paddle::platform::CUDADeviceContext>();

      transformed_input =
          ctx.AllocateTmpTensor<T, paddle::platform::CUDADeviceContext>(
              new_input_shape, dev_ctx);
      const int rank = transformed_input_channel.dims().size();
      T pad_value(0.0);
      switch (rank) {
        case 4: {
          math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
              ctx, input_pad, transformed_input_channel, pad_value,
              &transformed_input);
        } break;
        case 5: {
          math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
              ctx, input_pad, transformed_input_channel, pad_value,
              &transformed_input);
        } break;
        default:
137 138 139 140
          PADDLE_THROW(platform::errors::PermissionDenied(
              "Operator Conv2DFusion expects Input to be a 4-D or 5-D Tensor. "
              "But recieved the actual dimension = %d, shape = [%s].",
              rank, transformed_input_channel.dims()));
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
      }

    } else {
      transformed_input = transformed_input_channel;
      if (paddings.size() == data_dim) {
        for (size_t i = 0; i < data_dim; ++i) {
          padding_common[i] = paddings[i];
        }
      } else {
        for (size_t i = 0; i < data_dim; ++i) {
          padding_common[i] = paddings[2 * i];
        }
      }
    }

    const T* input_data = transformed_input.data<T>();
Q
qingqing01 已提交
157 158 159 160 161 162 163 164 165 166 167 168

    // ------------------- cudnn descriptors ---------------------
    ScopedTensorDescriptor input_desc;
    ScopedTensorDescriptor output_desc;
    ScopedFilterDescriptor filter_desc;
    ScopedTensorDescriptor bias_desc;
    ScopedConvolutionDescriptor conv_desc;
    ScopedActivationDescriptor act_desc;
    DataLayout layout = DataLayout::kNCHW;
    if (input->dims().size() == 5) {
      layout = DataLayout::kNCDHW;
    }
R
ronnywang 已提交
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
#ifdef PADDLE_WITH_HIP
    miopenConvolutionDescriptor_t cudnn_conv_desc =
        conv_desc.descriptor<T>(padding_common, strides, dilations);
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::miopenSetConvolutionGroupCount(cudnn_conv_desc,
                                                          groups));
    // Now only support NCHW
    std::vector<int> bias_dim = {
        1, static_cast<int>(transformed_output.dims()[1]), 1, 1};
    miopenTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
        layout, framework::vectorize<int>(transformed_input.dims()));
    miopenTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
        layout, framework::vectorize<int>(transformed_output.dims()));
    miopenTensorDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
        layout, framework::vectorize<int>(filter->dims()));
    miopenTensorDescriptor_t cudnn_bias_desc =
        bias_desc.descriptor<T>(layout, bias_dim);
    miopenActivationDescriptor_t cudnn_act_desc =
        act_desc.descriptor<T>(activation);
Q
qingqing01 已提交
188

R
ronnywang 已提交
189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
    miopenConvFwdAlgorithm_t algo;
    auto handle = dev_ctx.cudnn_handle();
    auto workspace_handle = dev_ctx.cudnn_workspace_handle();

    auto x_dims = framework::vectorize(transformed_input.dims());
    auto f_dims = framework::vectorize(filter->dims());

    size_t workspace_size = 0;
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::miopenConvolutionForwardGetWorkSpaceSize(
            handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc,
            cudnn_output_desc, &workspace_size));
    int find_count;
    miopenConvAlgoPerf_t find_result;
    auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
      PADDLE_ENFORCE_CUDA_SUCCESS(
          platform::dynload::miopenFindConvolutionForwardAlgorithm(
              handle, cudnn_input_desc, input_data, cudnn_filter_desc,
              filter_data, cudnn_conv_desc, cudnn_output_desc, output_data,
              kNUM_CUDNN_FWD_ALGS, &find_count, &find_result,
              cudnn_workspace_ptr, workspace_size, false));
    };
    workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
    algo = find_result.fwd_algo;
    VLOG(3) << "cuDNN forward algo " << algo;

    {
      ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
      auto cudnn_func = [&](void* cudnn_workspace) {
        PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenConvolutionForward(
            handle, &alpha, cudnn_input_desc, input_data, cudnn_filter_desc,
            filter_data, cudnn_conv_desc, algo, &beta, cudnn_output_desc,
            output_data, cudnn_workspace, workspace_size));
      };
      workspace_handle.RunFunc(cudnn_func, workspace_size);
      PADDLE_ENFORCE_CUDA_SUCCESS(
          platform::dynload::miopenConvolutionForwardBias(
              handle, &alpha, cudnn_bias_desc, bias_data, &beta,
              cudnn_output_desc, output_data));
      if (activation != "identity") {
        PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenActivationForward(
            handle, cudnn_act_desc, &alpha, cudnn_output_desc, output_data,
            &beta, cudnn_output_desc, output_data));
      }
      if (residual) {
        PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenOpTensor(
            handle, miopenTensorOpAdd, &alpha, cudnn_output_desc, output_data,
            &alpha, cudnn_output_desc, residual_data, &beta, cudnn_output_desc,
            output_data));
      }
    }
#else  // PADDLE_WITH_HIP
Q
qingqing01 已提交
241
    cudnnConvolutionDescriptor_t cudnn_conv_desc =
242
        conv_desc.descriptor<T>(padding_common, strides, dilations);
243 244
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cudnnSetConvolutionGroupCount(cudnn_conv_desc,
245
                                                         groups));
Q
qingqing01 已提交
246 247

    cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
248
        layout, framework::vectorize<int>(transformed_input.dims()));
Q
qingqing01 已提交
249
    cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
250
        layout, framework::vectorize<int>(transformed_output.dims()));
Q
qingqing01 已提交
251
    cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
252
        layout, framework::vectorize<int>(filter->dims()));
Q
qingqing01 已提交
253
    // Now only support NCHW
254 255
    std::vector<int> bias_dim = {
        1, static_cast<int>(transformed_output.dims()[1]), 1, 1};
Q
qingqing01 已提交
256 257 258 259 260 261 262
    cudnnTensorDescriptor_t cudnn_bias_desc =
        bias_desc.descriptor<T>(layout, bias_dim);
    cudnnActivationDescriptor_t cudnn_act_desc =
        act_desc.descriptor<T>(activation);

    // ------------------- cudnn conv workspace ---------------------
    size_t workspace_size_in_bytes;  // final workspace to allocate.
263
    size_t workspace_size_limit = 0;
Q
qingqing01 已提交
264 265
    if (FLAGS_conv_workspace_size_limit > 0 || user_workspace_size > 0) {
      int64_t max_user_size =
266
          std::min(static_cast<int64_t>(FLAGS_conv_workspace_size_limit),
Q
qingqing01 已提交
267 268 269 270 271 272 273
                   user_workspace_size);
      workspace_size_limit = max_user_size * 1024 * 1024;
    }

    // ------------------- cudnn conv algorithm ---------------------
    cudnnConvolutionFwdAlgo_t algo;
    auto handle = dev_ctx.cudnn_handle();
C
chengduo 已提交
274
    auto workspace_handle = dev_ctx.cudnn_workspace_handle();
Q
qingqing01 已提交
275

276 277
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
        cudnn_conv_desc, CUDNN_DEFAULT_MATH));
A
AshburnLee 已提交
278
#if CUDA_VERSION >= 11000 && CUDNN_VERSION >= 8000
A
AshburnLee 已提交
279 280 281 282 283
    if (!platform::allow_tf32_cudnn) {
      PADDLE_ENFORCE_CUDA_SUCCESS(
          platform::dynload::cudnnSetConvolutionMathType(cudnn_conv_desc,
                                                         CUDNN_FMA_MATH));
    }
A
AshburnLee 已提交
284
#endif  // CUDA_VERSION >= 11000 && CUDNN_VERSION >= 8000
Q
qingqing01 已提交
285

286
    auto x_dims = framework::vectorize(transformed_input.dims());
Q
qingqing01 已提交
287
    auto f_dims = framework::vectorize(filter->dims());
288
    if (!exhaustive_search) {
289
#if CUDNN_VERSION >= 8000
290 291 292 293 294
      int perf_count;
      int best_algo_idx = 0;
      size_t tmp_size = 0;
      std::unique_ptr<cudnnConvolutionFwdAlgoPerf_t[]> perf_results(
          new cudnnConvolutionFwdAlgoPerf_t[kNUM_CUDNN_FWD_ALGS]);
295
      PADDLE_ENFORCE_CUDA_SUCCESS(
296
          platform::dynload::cudnnGetConvolutionForwardAlgorithm_v7(
297
              handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
298 299 300
              cudnn_output_desc, kNUM_CUDNN_FWD_ALGS, &perf_count,
              perf_results.get()));
      algo = (perf_results.get())[best_algo_idx].algo;
301 302 303 304 305 306
      PADDLE_ENFORCE_CUDA_SUCCESS(
          platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
              handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
              cudnn_output_desc, algo, &workspace_size_in_bytes));
      if (workspace_size_in_bytes > workspace_size_limit)
        workspace_size_limit = workspace_size_in_bytes;
307 308 309 310 311 312 313 314
#else
      PADDLE_ENFORCE_CUDA_SUCCESS(
          platform::dynload::cudnnGetConvolutionForwardAlgorithm(
              handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
              cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
              workspace_size_limit, &algo));
      VLOG(3) << "cuDNN forward algo " << algo;
#endif
Q
qingqing01 已提交
315
    } else {
316 317
      std::function<cudnnConvolutionFwdAlgo_t()> search_func =
          [&]() -> cudnnConvolutionFwdAlgo_t {
Q
qingqing01 已提交
318 319 320
        int returned_algo_count;
        std::array<cudnnConvolutionFwdAlgoPerf_t, kNUM_CUDNN_FWD_ALGS>
            fwd_perf_stat;
C
chengduo 已提交
321
        auto cudnn_find_func = [&](void* cudnn_workspace) {
322
          PADDLE_ENFORCE_CUDA_SUCCESS(
C
chengduo 已提交
323 324 325 326
              platform::dynload::cudnnFindConvolutionForwardAlgorithmEx(
                  handle, cudnn_input_desc, input_data, cudnn_filter_desc,
                  filter_data, cudnn_conv_desc, cudnn_output_desc, output_data,
                  kNUM_CUDNN_FWD_ALGS, &returned_algo_count,
327
                  fwd_perf_stat.data(), cudnn_workspace, workspace_size_limit));
C
chengduo 已提交
328
        };
329
        workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit);
Q
qingqing01 已提交
330 331 332 333 334 335 336 337
        VLOG(3) << "Perf result: (algo: stat, time, memory)";
        for (int i = 0; i < returned_algo_count; ++i) {
          const auto& stat = fwd_perf_stat[i];
          VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time << " "
                  << stat.memory;
        }
        return fwd_perf_stat[0].algo;
      };
338
      AlgorithmsCache<cudnnConvolutionFwdAlgo_t>& algo_cache =
339
          *(framework::ConvSearchCache::Instance().GetConvFusion());
Q
qingqing01 已提交
340 341 342
      int search_times = ctx.Attr<int>("search_times");
      search_times = std::max(
          static_cast<int>(FLAGS_cudnn_exhaustive_search_times), search_times);
343
      // TODO(dangqingqing): Unify this if-else.
Q
qingqing01 已提交
344 345 346 347
      if (search_times > 0) {
        // The searched algo will be cached by `search_times` times for
        // different input dimension. For other dimensions, select the algo
        // of closest area.
348 349
        algo = algo_cache.GetAlgorithm(x_dims[2] * x_dims[3], search_times, 0,
                                       search_func);
Q
qingqing01 已提交
350
      } else {
351
        auto dtype = platform::CudnnDataType<T>::type;
352
        algo = algo_cache.GetAlgorithm(x_dims, f_dims, strides, paddings,
353
                                       dilations, 0, dtype, search_func);
Q
qingqing01 已提交
354 355 356 357
      }
      VLOG(3) << "choose algo " << algo;
    }

358 359 360
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
            handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
361
            cudnn_output_desc, algo, &workspace_size_in_bytes));
362 363 364 365 366 367 368
    PADDLE_ENFORCE_LE(
        workspace_size_in_bytes, workspace_size_limit,
        platform::errors::InvalidArgument(
            "The actual workspace size to be allocated for cuDNN is expected "
            "to be less than the limit. But recieved: the actual workspace "
            "size = %d, limit = %d.",
            workspace_size_in_bytes, workspace_size_limit));
Q
qingqing01 已提交
369

N
nhzlx 已提交
370
    if ((activation == "identity") && (!residual)) {
371 372 373 374 375 376
      // Only the CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM algo is
      // enabled with CUDNN_ACTIVATION_IDENTITY in cuDNN lib.
      // But test in some case, the speed is slower, change to use
      // cudnnConvolutionForward and cudnnAddTensor
      // ------------- cudnn conv forward and bias add ---------------------
      ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
C
chengduo 已提交
377
      auto cudnn_func = [&](void* cudnn_workspace) {
378 379 380 381
        PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnConvolutionForward(
            handle, &alpha, cudnn_input_desc, input_data, cudnn_filter_desc,
            filter_data, cudnn_conv_desc, algo, cudnn_workspace,
            workspace_size_in_bytes, &beta, cudnn_output_desc, output_data));
C
chengduo 已提交
382 383
      };
      workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
384 385 386
      PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnAddTensor(
          handle, &alpha, cudnn_bias_desc, bias_data, &alpha, cudnn_output_desc,
          output_data));
387 388 389 390 391 392 393
    } else {
      if (activation == "identity") {
        algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
      }
      // ------------------- cudnn conv+bias+act forward --------------------
      ScalingParamType<T> alpha1 = 1.0f;
      ScalingParamType<T> alpha2 = residual ? 1.0f : 0.0f;
C
chengduo 已提交
394
      auto cudnn_func = [&](void* cudnn_workspace) {
395 396 397 398 399 400
        PADDLE_ENFORCE_CUDA_SUCCESS(
            platform::dynload::cudnnConvolutionBiasActivationForward(
                handle, &alpha1, cudnn_input_desc, input_data,
                cudnn_filter_desc, filter_data, cudnn_conv_desc, algo,
                cudnn_workspace, workspace_size_in_bytes, &alpha2,
                cudnn_output_desc, residual_data, cudnn_bias_desc, bias_data,
401
                cudnn_act_desc, cudnn_output_desc, output_data));
C
chengduo 已提交
402 403
      };
      workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
404
    }
R
ronnywang 已提交
405
#endif
Q
qingqing01 已提交
406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423
    std::vector<int> channels = ctx.Attr<std::vector<int>>("split_channels");
    if (channels.size()) {
      auto outs = ctx.MultiOutput<framework::Tensor>("Outputs");
      if (x_dims[0] == 1) {
        // share data with Output
        framework::Tensor t;
        t.ShareDataWith(*output);
        auto y_dims = output->dims();
        t.Resize({y_dims[1], y_dims[2], y_dims[3]});
        int s = 0;
        for (size_t i = 0; i < channels.size(); ++i) {
          int e = s + channels[i];
          outs[i]->ShareDataWith(t.Slice(s, e));
          outs[i]->Resize({x_dims[0], channels[i], y_dims[2], y_dims[3]});
          s = e;
        }
      } else {
        // TODO(qingiqng): do copy when batch size large than 1
424 425 426 427
        PADDLE_THROW(platform::errors::Unimplemented(
            "Input with batch size greater than 1 is unsupported. The recieved "
            "batch size is %d, Input's shape is [%s].",
            x_dims[0], framework::make_ddim(x_dims)));
Q
qingqing01 已提交
428 429
      }
    }
Q
qingqing01 已提交
430 431
  }
};
D
Dang Qingqing 已提交
432
#endif
Q
qingqing01 已提交
433 434 435 436 437

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
R
ronnywang 已提交
438
#if CUDNN_VERSION >= 7100
Q
qingqing01 已提交
439 440
REGISTER_OP_CUDA_KERNEL(conv2d_fusion, ops::CUDNNConvFusionOpKernel<float>,
                        ops::CUDNNConvFusionOpKernel<double>);
D
Dang Qingqing 已提交
441
#endif
R
ronnywang 已提交
442 443 444
#ifdef PADDLE_WITH_HIP
REGISTER_OP_CUDA_KERNEL(conv2d_fusion, ops::CUDNNConvFusionOpKernel<float>);
#endif