depthwise_conv.h 72.0 KB
Newer Older
H
hong 已提交
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Z
zlx 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

H
hong 已提交
15
#pragma once
A
Abhinav Arora 已提交
16
#include <vector>
17

H
hong 已提交
18 19 20 21
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/core/hostdevice.h"

22 23 24 25 26 27 28
#ifdef __NVCC__
#include <cub/cub.cuh>
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
H
hong 已提交
29

30
#include "paddle/phi/backends/gpu/gpu_device_function.h"
W
Wang Xin 已提交
31
#include "paddle/phi/backends/gpu/gpu_primitives.h"
32
#include "paddle/phi/kernels/funcs/math_function.h"
Z
zlx 已提交
33 34 35 36 37

namespace paddle {
namespace operators {
namespace math {

H
hong 已提交
38 39 40 41 42 43 44 45 46 47
/*
 * \brief Compute the depthwise convolution which include
 * forward process and backpropagation process
 */
template <typename DeviceContext,
          typename T,
          bool fuse_relu_before_conv = false>
class DepthwiseConvFunctor {
 public:
  void operator()(const DeviceContext& context,
48 49
                  const phi::DenseTensor& input,
                  const phi::DenseTensor& filter,
H
hong 已提交
50 51 52
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
                  const std::vector<int>& dilations,
53
                  phi::DenseTensor* output,
H
hong 已提交
54 55 56 57 58 59 60 61 62
                  const DataLayout data_layout = DataLayout::kNCHW);
};

template <typename DeviceContext,
          typename T,
          bool fuse_relu_before_conv = false>
class DepthwiseConvInputGradFunctor {
 public:
  void operator()(const DeviceContext& context,
63 64 65
                  const phi::DenseTensor& input,
                  const phi::DenseTensor& filter,
                  const phi::DenseTensor& output_grad,
H
hong 已提交
66 67 68
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
                  const std::vector<int>& dilations,
69
                  phi::DenseTensor* input_grad,
H
hong 已提交
70 71 72 73 74 75 76 77 78
                  const DataLayout data_layout = DataLayout::kNCHW);
};

template <typename DeviceContext,
          typename T,
          bool fuse_relu_before_conv = false>
class DepthwiseConvFilterGradFunctor {
 public:
  void operator()(const DeviceContext& context,
79 80
                  const phi::DenseTensor& input,
                  const phi::DenseTensor& output_grad,
H
hong 已提交
81 82 83
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
                  const std::vector<int>& dilations,
84
                  phi::DenseTensor* filter_grad,
H
hong 已提交
85 86 87
                  const DataLayout data_layout = DataLayout::kNCHW);
};

88 89 90 91
#define FINAL_MASK 0xffffffff
#define HALF_WARP 16
#define WARP_SIZE 32

92
template <typename T>
93 94
__forceinline__ __device__ T WarpReduceSum(T val, unsigned lane_mask) {
  for (int mask = HALF_WARP; mask > 0; mask >>= 1)
95
    val += phi::backends::gpu::CudaShuffleDownSync(lane_mask, val, mask);
W
wangguanzhong 已提交
96 97
  return val;
}
98

W
wangguanzhong 已提交
99
template <typename T>
100 101 102 103 104 105 106
__forceinline__ __device__ T BlockReduceSum(T val, unsigned mask = FINAL_MASK) {
  static __shared__ T shared[WARP_SIZE];
  int tid = threadIdx.y * blockDim.x + threadIdx.x;
  int lane = tid & 0x1f;
  int wid = tid >> 5;

  val = WarpReduceSum<T>(val, mask);
107

W
wangguanzhong 已提交
108
  __syncthreads();
109 110 111 112 113 114 115 116 117
  if (lane == 0) shared[wid] = val;

  __syncthreads();

  // align block_span to WARP_SIZE
  int block_span = (blockDim.x * blockDim.y + WARP_SIZE - 1) >> 5;
  val = (lane < block_span) ? shared[lane] : static_cast<T>(0.0f);
  val = WarpReduceSum<T>(val, mask);

W
wangguanzhong 已提交
118
  return val;
119 120
}

121 122 123 124 125 126 127 128
#define ARG_DEFINE_KernelDepthwiseConv                                         \
  const T *const input_data, const T *const filter_data, const int batch_size, \
      const int output_channels, const int output_height,                      \
      const int output_width, const int input_channels,                        \
      const int input_height, const int input_width,                           \
      const int filter_multiplier, const int filter_height,                    \
      const int filter_width, const int stride_height, const int stride_width, \
      const int padding_height, const int padding_width,                       \
129
      const int dilate_height, const int dilate_width, T *const output_data
130

131 132
// A Cuda kernel to compute the depthwise convolution forward pass
// in NCHW format.
133
template <typename T, int c_filter, bool fuse_relu_before_conv>
134 135
__device__ __inline__ void KernelDepthwiseConvNCHW(
    ARG_DEFINE_KernelDepthwiseConv) {
136 137
  const int fw_size = c_filter != -1 ? c_filter : filter_width;
  const int fh_size = c_filter != -1 ? c_filter : filter_height;
138 139 140 141
  int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx >= (output_channels * batch_size * output_height * output_width))
    return;

142 143 144 145 146 147 148 149
  int tmp_1 = idx / output_width;
  const int w_out = idx - tmp_1 * output_width;
  int tmp_2 = tmp_1 / output_height;
  const int h_out = tmp_1 - tmp_2 * output_height;
  tmp_1 = tmp_2;
  tmp_2 = tmp_1 / output_channels;
  const int c_out = tmp_1 - tmp_2 * output_channels;
  const int batch = tmp_2;
150 151

  const int c_in = c_out / filter_multiplier;
152
  T value(0);
153 154 155

  int in_offset =
      ((batch * input_channels + c_in) * input_height) * input_width;
156 157 158
  int weight_offset = c_out * filter_height * filter_width;
  int h_in_start = -padding_height + h_out * stride_height;
  int w_in_start = -padding_width + w_out * stride_width;
159 160

#pragma unroll
161 162
  for (int fh = 0, h_in = h_in_start; fh < fh_size;
       fh++, h_in += dilate_height) {
163
#pragma unroll
164 165 166
    for (int fw = 0, w_in = w_in_start; fw < fw_size;
         fw++, w_in += dilate_width) {
      if (h_in >= 0 && h_in < input_height && w_in >= 0 && w_in < input_width) {
167 168 169
        int offset = in_offset + h_in * input_width + w_in;
        T in_data = input_data[offset];
        if (fuse_relu_before_conv) {
170 171
          value += filter_data[weight_offset] *
                   static_cast<T>(max(0.0f, static_cast<double>(in_data)));
172
        } else {
173
          value += filter_data[weight_offset] * in_data;
174
        }
175
      }
176 177 178
      weight_offset++;
    }
  }
179
  output_data[idx] = value;
180
}
181

182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
// A Cuda kernel to compute the depthwise convolution forward pass
// in NHWC format.
template <typename T, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvNHWC(
    ARG_DEFINE_KernelDepthwiseConv) {
  int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx >= (output_channels * batch_size * output_height * output_width))
    return;

  const int c_out = idx % output_channels;
  const int w_out = (idx / output_channels) % output_width;
  const int h_out = (idx / output_channels / output_width) % output_height;
  const int batch = idx / output_width / output_height / output_channels;

  const int c_in = c_out / filter_multiplier;
197
  T value(0);
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
  const int h_in_start = -padding_height + h_out * stride_height;
  const int w_in_start = -padding_width + w_out * stride_width;
  const int h_in_end = h_in_start + filter_height * dilate_height;
  const int w_in_end = w_in_start + filter_width * dilate_width;

  const int h_end = h_in_end < input_height ? h_in_end : input_height;
  const int w_end = w_in_end < input_width ? w_in_end : input_width;
  const int h_start = h_in_start > 0 ? h_in_start : 0;
  const int w_start = w_in_start > 0 ? w_in_start : 0;
  int weight_offset = 0;

#pragma unroll
  for (int h_in = h_in_start; h_in < h_in_end; h_in += dilate_height) {
#pragma unroll
    for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) {
      if (h_in >= h_start && h_in < h_end && w_in >= w_start && w_in < w_end) {
        int offset = ((batch * input_height + h_in) * input_width + w_in) *
215
                         input_channels +
216 217
                     c_in;
        T in_data = input_data[offset];
218
        const T* weight = filter_data + weight_offset * output_channels + c_out;
219
        if (fuse_relu_before_conv) {
220 221
          value += weight[0] *
                   static_cast<T>(max(0.0f, static_cast<double>(in_data)));
222
        } else {
223
          value += weight[0] * in_data;
224
        }
Z
zlx 已提交
225
      }
226
      weight_offset++;
Z
zlx 已提交
227 228
    }
  }
229 230 231 232
  int index = batch * output_channels * output_height * output_width +
              h_out * output_width * output_channels + w_out * output_channels +
              c_out;
  output_data[index] = value;
Z
zlx 已提交
233
}
234

235
template <typename T, int c_filter, bool fuse_relu_before_conv>
236
__device__ __inline__ void KernelDepthwiseConvCFilterNCHW(
237
    ARG_DEFINE_KernelDepthwiseConv) {
238 239
  const int kWeightSize = c_filter * c_filter;
  T r_weight[kWeightSize];
240 241 242 243
  const int batch = blockIdx.y;
  const int c_out = blockIdx.x;
  const T* weight = filter_data + c_out * c_filter * c_filter;
  for (int i = 0; i < c_filter * c_filter; i++) r_weight[i] = weight[i];
244

245 246 247 248 249 250
  for (int w_out = threadIdx.x; w_out < output_width; w_out += blockDim.x) {
    for (int h_out = threadIdx.y; h_out < output_height; h_out += blockDim.y) {
      const int batch = blockIdx.y;
      const int c_out = blockIdx.x;

      const int c_in = c_out / filter_multiplier;
251
      T value(0);
252 253 254 255 256
      const int h_in_start = -padding_height + h_out * stride_height;
      const int w_in_start = -padding_width + w_out * stride_width;
      const int h_in_end = h_in_start + c_filter * dilate_height;
      const int w_in_end = w_in_start + c_filter * dilate_width;

257 258
      int in_offset =
          ((batch * input_channels + c_in) * input_height) * input_width;
259 260 261 262 263 264 265 266 267 268 269 270

      const int h_end = h_in_end < input_height ? h_in_end : input_height;
      const int w_end = w_in_end < input_width ? w_in_end : input_width;
      const int h_start = h_in_start > 0 ? h_in_start : 0;
      const int w_start = w_in_start > 0 ? w_in_start : 0;

      for (int h_in = h_in_start, h_f = 0; h_f < c_filter;
           h_in += dilate_height, h_f++) {
        for (int w_in = w_in_start, w_f = 0; w_f < c_filter;
             w_in += dilate_width, w_f++) {
          if (h_in >= 0 && h_in < input_height && w_in >= 0 &&
              w_in < input_width) {
271 272 273
            int offset = in_offset + h_in * input_width + w_in;
            if (fuse_relu_before_conv) {
              value += r_weight[h_f * c_filter + w_f] *
274 275
                       static_cast<T>(
                           max(0.0f, static_cast<double>(input_data[offset])));
276
            } else {
277
              value += r_weight[h_f * c_filter + w_f] * input_data[offset];
278
            }
279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318
          }
        }
      }
      int index =
          ((batch * gridDim.x + c_out) * output_height + h_out) * output_width +
          w_out;
      output_data[index] = value;
    }
  }
}

template <typename T, int c_filter, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvCFilterNHWC(
    ARG_DEFINE_KernelDepthwiseConv) {
  const int batch = blockIdx.z;
  int h_out = blockIdx.x * dilate_height + blockIdx.y;
  if (h_out >= output_height) {
    return;
  }
  int in_offset = batch * input_height * input_width * input_channels;
  int out_offset =
      (batch * output_height + h_out) * output_width * output_channels;
  const int h_in_start = -padding_height + h_out * stride_height;
  const int wi_size = (output_width + dilate_width - 1) / dilate_width;
  const int kWeightSize = c_filter * c_filter;
  T r_weight[kWeightSize];

  for (int c_out = threadIdx.x; c_out < output_channels; c_out += blockDim.x) {
    for (int i = 0; i < c_filter * c_filter; i++) {
      const T* weight = filter_data + i * output_channels + c_out;
      r_weight[i] = weight[0];
    }
    const int c_in = c_out / filter_multiplier;
    for (int i = threadIdx.y; i < wi_size * dilate_width; i += blockDim.y) {
      int i_dw = i / wi_size;
      int i_wi = i - i_dw * wi_size;
      int w_out = i_wi * dilate_width + i_dw;
      if (w_out >= output_width) {
        continue;
      }
319
      T value(0);
320 321 322 323 324 325 326 327 328
      const int w_in_start = -padding_width + w_out * stride_width;
      for (int h_in = h_in_start, h_f = 0; h_f < c_filter;
           h_in += dilate_height, h_f++) {
        for (int w_in = w_in_start, w_f = 0; w_f < c_filter;
             w_in += dilate_width, w_f++) {
          if (h_in >= 0 && h_in < input_height && w_in >= 0 &&
              w_in < input_width) {
            int offset =
                in_offset + (h_in * input_width + w_in) * input_channels + c_in;
329 330
            if (fuse_relu_before_conv) {
              value += r_weight[h_f * c_filter + w_f] *
331 332
                       static_cast<T>(
                           max(0.0, static_cast<double>(input_data[offset])));
333 334 335
            } else {
              value += r_weight[h_f * c_filter + w_f] * input_data[offset];
            }
336 337 338
          }
        }
      }
339
      int index = out_offset + w_out * output_channels + c_out;
340 341 342 343 344
      output_data[index] = value;
    }
  }
}

H
hong 已提交
345 346 347 348 349 350
template <typename T,
          int c_filter_multiplier,
          int c_stride,
          int c_filter,
          DataLayout data_layout,
          bool fuse_relu_before_conv>
351
__global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) {
352 353 354 355 356 357 358 359 360
  int final_filter_multiplier = filter_multiplier;
  int h_stride = stride_height;
  int w_stride = stride_width;
  if (c_filter_multiplier != 0) {
    final_filter_multiplier = c_filter_multiplier;
    h_stride = c_stride;
    w_stride = c_stride;
  }
  if (c_filter == -1) {
361
    if (data_layout != DataLayout::kNHWC) {
362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381
      KernelDepthwiseConvNCHW<T, c_filter, fuse_relu_before_conv>(
          input_data,
          filter_data,
          batch_size,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
          output_data);
382
    } else {
H
hong 已提交
383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401
      KernelDepthwiseConvNHWC<T, fuse_relu_before_conv>(input_data,
                                                        filter_data,
                                                        batch_size,
                                                        output_channels,
                                                        output_height,
                                                        output_width,
                                                        input_channels,
                                                        input_height,
                                                        input_width,
                                                        final_filter_multiplier,
                                                        filter_height,
                                                        filter_width,
                                                        h_stride,
                                                        w_stride,
                                                        padding_height,
                                                        padding_width,
                                                        dilate_height,
                                                        dilate_width,
                                                        output_data);
402 403
    }
  } else {
404 405
    if (data_layout != DataLayout::kNHWC) {
      KernelDepthwiseConvCFilterNCHW<T, c_filter, fuse_relu_before_conv>(
H
hong 已提交
406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423
          input_data,
          filter_data,
          batch_size,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
424 425 426
          output_data);
    } else {
      KernelDepthwiseConvCFilterNHWC<T, c_filter, fuse_relu_before_conv>(
H
hong 已提交
427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444
          input_data,
          filter_data,
          batch_size,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
445 446
          output_data);
    }
447
  }
448 449
}

Z
zlx 已提交
450
// CUDA kernel to compute the depthwise convolution backprop w.r.t input.
451
#define ARG_DEFINE_KernelDepthwiseConvInputGrad                                \
452 453 454 455 456
  const T *const input_data, const T *const output_grad_data,                  \
      const T *const filter_data, const int batch_size,                        \
      const int output_channels, const int output_height,                      \
      const int output_width, const int input_channels,                        \
      const int input_height, const int input_width,                           \
457 458 459 460
      const int filter_multiplier, const int filter_height,                    \
      const int filter_width, const int stride_height, const int stride_width, \
      const int padding_height, const int padding_width,                       \
      const int dilate_height, const int dilate_width,                         \
461
      T *const input_grad_data
462

463
template <typename T, int c_filter, bool fuse_relu_before_conv>
464
__device__ __inline__ void KernelDepthwiseConvInputGradNCHW(
465
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
466 467 468 469 470 471 472 473 474 475 476 477
  const int fw_size = c_filter != -1 ? c_filter : filter_width;
  const int fh_size = c_filter != -1 ? c_filter : filter_height;
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx >= batch_size * input_channels * input_height * input_width) {
    return;
  }
  if (fuse_relu_before_conv) {
    if (input_data[idx] <= static_cast<T>(0.0f)) {
      input_grad_data[idx] = 0;
      return;
    }
  }
478

479 480 481 482 483 484 485 486
  int tmp_1 = idx / input_width;
  const int w_in = idx - tmp_1 * input_width;
  int tmp_2 = tmp_1 / input_height;
  const int h_in = tmp_1 - tmp_2 * input_height;
  tmp_1 = tmp_2;
  tmp_2 = tmp_1 / input_channels;
  const int c_in = tmp_1 - tmp_2 * input_channels;
  const int batch = tmp_2;
487

488 489 490 491
  T value(0);
  for (int c_mul = 0; c_mul < filter_multiplier; ++c_mul) {
    int c_out = c_in * filter_multiplier + c_mul;
    int filter_offset = c_out * filter_height * filter_width;
492

493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511
#pragma unroll
    for (int fh = 0; fh < fh_size; ++fh) {
#pragma unroll
      for (int fw = 0; fw < fw_size; ++fw) {
        int h_out = h_in + padding_height - fh * dilate_height;
        int w_out = w_in + padding_width - fw * dilate_width;
        if ((h_out - h_out / stride_height * stride_height == 0) &&
            (w_out - w_out / stride_width * stride_width == 0)) {
          h_out /= stride_height;
          w_out /= stride_width;

          if (h_out >= 0 && h_out < output_height && w_out >= 0 &&
              w_out < output_width) {
            int output_grad_offset =
                ((batch * output_channels + c_out) * output_height + h_out) *
                    output_width +
                w_out;
            value += output_grad_data[output_grad_offset] *
                     filter_data[filter_offset];
512 513
          }
        }
514
        filter_offset++;
515 516 517
      }
    }
  }
518
  input_grad_data[idx] = value;
519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536
}

template <typename T, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvInputGradNHWC(
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
  const int batch = blockIdx.z;
  int h_in = blockIdx.x * dilate_height + blockIdx.y;
  if (h_in >= input_height) {
    return;
  }

  for (int c_in = threadIdx.x; c_in < input_channels; c_in += blockDim.x) {
    for (int w_in = threadIdx.y; w_in < input_width; w_in += blockDim.y) {
      int h_out_start =
          h_in - (filter_height - 1) * dilate_height + padding_height;
      int w_out_start =
          w_in - (filter_width - 1) * dilate_width + padding_width;

537
      T value(0);
538 539 540 541
      int index = ((batch * input_height + h_in) * input_width + w_in) *
                      input_channels +
                  c_in;
      if (fuse_relu_before_conv) {
542
        if (input_data[index] <= T(0)) {
543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565
          input_grad_data[index] = 0;
          continue;
        }
      }

      for (int c_i = 0; c_i < filter_multiplier; c_i++) {
        int c_out = c_in * filter_multiplier + c_i;
        int weight_offset = filter_height * filter_width;
        for (int h_out = h_out_start, h_f = 0; h_f < filter_height;
             h_out += dilate_height, h_f++) {
          for (int w_out = w_out_start, w_f = 0; w_f < filter_width;
               w_out += dilate_width, w_f++) {
            weight_offset--;
            int s_h_out = h_out / stride_height;
            int s_w_out = w_out / stride_width;
            if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
                s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
                s_w_out < output_width) {
              int output_grad_offset =
                  ((batch * output_height + s_h_out) * output_width + s_w_out) *
                      output_channels +
                  c_out;
              int filter_offset = weight_offset * output_channels + c_out;
566 567 568 569
              value += output_grad_data[output_grad_offset] *
                       filter_data[filter_offset];
            }
          }
Z
zlx 已提交
570 571
        }
      }
572
      input_grad_data[index] = value;
Z
zlx 已提交
573 574 575 576
    }
  }
}

H
hong 已提交
577 578 579
template <typename T,
          int c_filter,
          int c_filter_multiplier,
580
          bool fuse_relu_before_conv>
581
__device__ __inline__ void KernelDepthwiseConvInputGradCFilterNCHW(
582
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
583 584
  const int kWeightSize = c_filter * c_filter * c_filter_multiplier + 1;
  T r_weight[kWeightSize];
585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600
  const int batch = blockIdx.y;
  const int c_in = blockIdx.x;

  for (int c_i = 0; c_i < filter_multiplier; c_i++) {
    int c_out = c_in * filter_multiplier + c_i;
    const T* weight = filter_data + c_out * c_filter * c_filter;
    for (int i = 0; i < c_filter * c_filter; i++)
      r_weight[i + c_i * c_filter * c_filter] =
          weight[c_filter * c_filter - i - 1];
  }

  for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) {
    for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) {
      int h_out_start = h_in - (c_filter - 1) * dilate_height + padding_height;
      int w_out_start = w_in - (c_filter - 1) * dilate_width + padding_width;

601
      T value(0);
602 603 604
      int index =
          ((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
          w_in;
605
      if (fuse_relu_before_conv) {
606
        if (input_data[index] <= T(0)) {
607 608 609 610
          input_grad_data[index] = 0;
          continue;
        }
      }
611 612 613 614 615 616 617 618 619 620 621 622

      for (int c_i = 0; c_i < filter_multiplier; c_i++) {
        int c_out = c_in * filter_multiplier + c_i;
        for (int h_out = h_out_start, h_f = 0; h_f < c_filter;
             h_out += dilate_height, h_f++) {
          for (int w_out = w_out_start, w_f = 0; w_f < c_filter;
               w_out += dilate_width, w_f++) {
            int s_h_out = h_out / stride_height;
            int s_w_out = w_out / stride_width;
            if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
                s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
                s_w_out < output_width) {
623 624 625 626 627
              int output_grad_offset =
                  ((batch * output_channels + c_out) * output_height +
                   s_h_out) *
                      output_width +
                  s_w_out;
628 629 630 631 632 633 634 635 636 637 638 639
              value +=
                  output_grad_data[output_grad_offset] *
                  r_weight[h_f * c_filter + w_f + c_i * c_filter * c_filter];
            }
          }
        }
      }
      input_grad_data[index] = value;
    }
  }
}

H
hong 已提交
640 641 642
template <typename T,
          int c_filter,
          int c_filter_multiplier,
643
          bool fuse_relu_before_conv>
644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673
__device__ __inline__ void KernelDepthwiseConvInputGradCFilterNHWC(
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
  int h_in = blockIdx.x * dilate_height + blockIdx.y;
  if (h_in >= input_height) {
    return;
  }
  const int kWeightSize = c_filter * c_filter * c_filter_multiplier + 1;
  T r_weight[kWeightSize];
  const int batch = blockIdx.z;
  const int wi_size = (input_width + dilate_width - 1) / dilate_width;
  const int h_out_start =
      h_in - (c_filter - 1) * dilate_height + padding_height;

  for (int c_in = threadIdx.x; c_in < input_channels; c_in += blockDim.x) {
    for (int c_i = 0; c_i < c_filter_multiplier; c_i++) {
      int c_out = c_in * c_filter_multiplier + c_i;
      for (int i = 0; i < c_filter * c_filter; i++)
        r_weight[i + c_i * c_filter * c_filter] =
            filter_data[(c_filter * c_filter - i - 1) * output_channels +
                        c_out];
    }
    for (int i = threadIdx.y; i < wi_size * dilate_width; i += blockDim.y) {
      int i_dw = i / wi_size;
      int i_wi = i - i_dw * wi_size;
      int w_in = i_wi * dilate_width + i_dw;
      if (w_in >= input_width) {
        continue;
      }
      int w_out_start = w_in - (c_filter - 1) * dilate_width + padding_width;

674
      T value(0);
675 676 677 678
      int index = ((batch * input_height + h_in) * input_width + w_in) *
                      input_channels +
                  c_in;
      if (fuse_relu_before_conv) {
679
        if (input_data[index] <= T(0)) {
680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711
          input_grad_data[index] = 0;
          continue;
        }
      }

      for (int c_i = 0; c_i < c_filter_multiplier; c_i++) {
        int c_out = c_in * c_filter_multiplier + c_i;
        for (int h_out = h_out_start, h_f = 0; h_f < c_filter;
             h_out += dilate_height, h_f++) {
          for (int w_out = w_out_start, w_f = 0; w_f < c_filter;
               w_out += dilate_width, w_f++) {
            int s_h_out = h_out / stride_height;
            int s_w_out = w_out / stride_width;
            if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
                s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
                s_w_out < output_width) {
              int output_grad_offset =
                  ((batch * output_height + s_h_out) * output_width + s_w_out) *
                      output_channels +
                  c_out;
              value +=
                  output_grad_data[output_grad_offset] *
                  r_weight[h_f * c_filter + w_f + c_i * c_filter * c_filter];
            }
          }
        }
      }
      input_grad_data[index] = value;
    }
  }
}

H
hong 已提交
712 713 714 715 716 717
template <typename T,
          int c_filter_multiplier,
          int c_stride,
          int c_filter,
          DataLayout data_layout,
          bool fuse_relu_before_conv>
718
__global__ void KernelDepthwiseConvInputGradSp(
719
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
720 721 722 723 724 725 726 727 728 729 730
  int final_filter_multiplier = filter_multiplier;
  int h_stride = stride_height;
  int w_stride = stride_width;
  if (c_filter_multiplier != 0) {
    final_filter_multiplier = c_filter_multiplier;
    h_stride = c_stride;
    w_stride = c_stride;
  }

  if (c_filter_multiplier == 0 || c_filter == -1) {
    if (data_layout != DataLayout::kNHWC) {
731
      KernelDepthwiseConvInputGradNCHW<T, c_filter, fuse_relu_before_conv>(
H
hong 已提交
732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751
          input_data,
          output_grad_data,
          filter_data,
          batch_size,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
          input_grad_data);
752 753
    } else {
      KernelDepthwiseConvInputGradNHWC<T, fuse_relu_before_conv>(
H
hong 已提交
754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773
          input_data,
          output_grad_data,
          filter_data,
          batch_size,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
          input_grad_data);
774 775 776
    }
  } else {
    if (data_layout != DataLayout::kNHWC) {
H
hong 已提交
777 778 779
      KernelDepthwiseConvInputGradCFilterNCHW<T,
                                              c_filter,
                                              c_filter_multiplier,
780
                                              fuse_relu_before_conv>(
H
hong 已提交
781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800
          input_data,
          output_grad_data,
          filter_data,
          batch_size,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          c_filter_multiplier,
          filter_height,
          filter_width,
          c_stride,
          c_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
          input_grad_data);
801
    } else {
H
hong 已提交
802 803 804
      KernelDepthwiseConvInputGradCFilterNHWC<T,
                                              c_filter,
                                              c_filter_multiplier,
805
                                              fuse_relu_before_conv>(
H
hong 已提交
806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825
          input_data,
          output_grad_data,
          filter_data,
          batch_size,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          c_filter_multiplier,
          filter_height,
          filter_width,
          c_stride,
          c_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
          input_grad_data);
826 827
    }
  }
828 829
}

830
// Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
831
template <typename T, bool fuse_relu_before_conv>
832
__device__ __inline__ void KernelDepthwiseConvFilterGradNCHW(
H
hong 已提交
833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851
    const T* output_grad_data,
    const T* input_data,
    const int num,
    const int output_channels,
    const int output_height,
    const int output_width,
    const int input_channels,
    const int input_height,
    const int input_width,
    const int filter_multiplier,
    const int filter_height,
    const int filter_width,
    const int stride_height,
    const int stride_width,
    const int padding_height,
    const int padding_width,
    const int dilate_height,
    const int dilate_width,
    T* filter_grad_data) {
852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910
  T f_grad(0);
  const bool loop_batch = output_height * output_width >= WARP_SIZE;

  int kw_id = blockIdx.x;
  int kh_id = blockIdx.y;
  int oc_id = blockIdx.z;
  int ic_id = oc_id / filter_multiplier;
  int idx = ((blockIdx.z * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x;

  const int ohw = output_height * output_width;
  const int onhw = num * ohw;
  const int h_offset = kh_id * dilate_height - padding_height;
  const int w_offset = kw_id * dilate_width - padding_width;

  if (loop_batch) {
    for (int og_w = threadIdx.x; og_w < output_width; og_w += blockDim.x) {
      for (int bid = 0; bid < num; ++bid) {
        for (int og_h = threadIdx.y; og_h < output_height; og_h += blockDim.y) {
          int i_h = og_h * stride_height + h_offset;
          int i_w = og_w * stride_width + w_offset;

          if (i_w >= 0 && i_w < input_width && i_h >= 0 && i_h < input_height) {
            int input_offset =
                ((bid * input_channels + ic_id) * input_height + i_h) *
                    input_width +
                i_w;
            int output_grad_offset =
                ((bid * output_channels + oc_id) * output_height + og_h) *
                    output_width +
                og_w;
            if (fuse_relu_before_conv) {
              f_grad +=
                  output_grad_data[output_grad_offset] *
                  static_cast<T>(
                      max(0.0f, static_cast<double>(input_data[input_offset])));
            } else {
              f_grad += output_grad_data[output_grad_offset] *
                        input_data[input_offset];
            }
          }
        }
      }
    }
  } else {
    for (int id = threadIdx.x; id < onhw; id += blockDim.x) {
      int bid = id / ohw;
      int og_hw = id - bid * ohw;
      int og_h = og_hw / output_width;
      int og_w = og_hw - og_h * output_width;

      int i_h = og_h * stride_height + h_offset;
      int i_w = og_w * stride_width + w_offset;

      if (i_w >= 0 && i_w < input_width && i_h >= 0 && i_h < input_height) {
        int input_offset =
            ((bid * input_channels + ic_id) * input_height + i_h) *
                input_width +
            i_w;
        int output_grad_offset = (bid * output_channels + oc_id) * ohw + og_hw;
911
        if (fuse_relu_before_conv) {
912 913 914
          f_grad += output_grad_data[output_grad_offset] *
                    static_cast<T>(max(
                        0.0f, static_cast<double>(input_data[input_offset])));
915
        } else {
916 917
          f_grad +=
              output_grad_data[output_grad_offset] * input_data[input_offset];
918 919 920 921
        }
      }
    }
  }
W
wangguanzhong 已提交
922

923 924 925 926
  T val = BlockReduceSum<T>(f_grad);
  if (threadIdx.x == 0 && threadIdx.y == 0) {
    filter_grad_data[idx] = val;
  }
927 928 929 930
}

template <typename T, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvFilterGradNHWC(
H
hong 已提交
931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949
    const T* output_grad_data,
    const T* input_data,
    const int num,
    const int output_channels,
    const int output_height,
    const int output_width,
    const int input_channels,
    const int input_height,
    const int input_width,
    const int filter_multiplier,
    const int filter_height,
    const int filter_width,
    const int stride_height,
    const int stride_width,
    const int padding_height,
    const int padding_width,
    const int dilate_height,
    const int dilate_width,
    T* filter_grad_data) {
950 951 952 953 954 955
  int bid = blockIdx.z;
  int image_h = blockIdx.y;
  int kernel_iw = blockIdx.x % filter_width;
  int kernel_ih = blockIdx.x / filter_width;
  for (int kernel_id = threadIdx.x; kernel_id < output_channels;
       kernel_id += blockDim.x) {
956
    T s(0);
957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975
    int gbid =
        ((kernel_id * filter_height) + kernel_ih) * filter_width + kernel_iw;
    for (int image_w = threadIdx.y; image_w < output_width;
         image_w += blockDim.y) {
      int kernel_h = kernel_ih * dilate_height - padding_height;
      int kernel_w = kernel_iw * dilate_width - padding_width;

      int image_hk = image_h * stride_height + kernel_h;
      int image_wk = image_w * stride_width + kernel_w;
      if (image_hk < 0 || image_hk >= input_height) continue;
      if (image_wk < 0 || image_wk >= input_width) continue;
#define gaid(N, H, W, C) \
  ((((N)*output_height + (H)) * output_width + (W)) * output_channels + (C))
      int input_id =
          ((bid * input_height + image_hk) * input_width + image_wk) *
              input_channels +
          kernel_id / filter_multiplier;
      if (fuse_relu_before_conv) {
        s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] *
976 977
             static_cast<T>(
                 max(0.0f, static_cast<double>(input_data[input_id])));
978 979 980 981 982 983
      } else {
        s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] *
             input_data[input_id];
      }
#undef gaid
    }
W
Wang Xin 已提交
984
    phi::CudaAtomicAdd(&filter_grad_data[gbid], s);
985 986 987 988 989
  }
}

template <typename T, int c_filter, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvFilterGradCFilterNHWC(
H
hong 已提交
990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008
    const T* output_grad_data,
    const T* input_data,
    const int num,
    const int output_channels,
    const int output_height,
    const int output_width,
    const int input_channels,
    const int input_height,
    const int input_width,
    const int filter_multiplier,
    const int filter_height,
    const int filter_width,
    const int stride_height,
    const int stride_width,
    const int padding_height,
    const int padding_width,
    const int dilate_height,
    const int dilate_width,
    T* filter_grad_data) {
1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038
  const int bid = blockIdx.z;
  int image_h = blockIdx.x * dilate_height + blockIdx.y;
  if (image_h >= output_height) {
    return;
  }
  const int kWeightSize = c_filter * c_filter;
  T r_weight[kWeightSize];
  const int wi_size = (output_width + dilate_width - 1) / dilate_width;

  for (int kernel_id = threadIdx.x; kernel_id < output_channels;
       kernel_id += blockDim.x) {
    for (int i = 0; i < c_filter * c_filter; ++i) {
      r_weight[i] = 0;
    }
    for (int i = threadIdx.y; i < wi_size * dilate_width; i += blockDim.y) {
      int i_dw = i / wi_size;
      int i_wi = i - i_dw * wi_size;
      int image_w = i_wi * dilate_width + i_dw;
      if (image_w >= output_width) {
        continue;
      }
      for (int kernel_ih = 0; kernel_ih < c_filter; ++kernel_ih) {
        for (int kernel_iw = 0; kernel_iw < c_filter; ++kernel_iw) {
          int kernel_h = kernel_ih * dilate_height - padding_height;
          int kernel_w = kernel_iw * dilate_width - padding_width;
          int image_hk = image_h * stride_height + kernel_h;
          int image_wk = image_w * stride_width + kernel_w;
          if (image_hk < 0 || image_hk >= input_height) continue;
          if (image_wk < 0 || image_wk >= input_width) continue;
          int input_id =
1039
              ((bid * input_height + image_hk) * input_width + image_wk) *
1040
                  input_channels +
1041
              kernel_id / filter_multiplier;
1042 1043 1044 1045
          int output_id =
              ((bid * output_height + image_h) * output_width + image_w) *
                  output_channels +
              kernel_id;
1046
          T s(0);
1047
          if (fuse_relu_before_conv) {
1048
            s = output_grad_data[output_id] *
1049 1050
                static_cast<T>(
                    max(0.0f, static_cast<double>(input_data[input_id])));
1051
          } else {
1052
            s = output_grad_data[output_id] * input_data[input_id];
1053
          }
1054
          r_weight[kernel_ih * c_filter + kernel_iw] += s;
1055
        }
1056
      }
Z
zlx 已提交
1057
    }
1058 1059
    for (int i = 0; i < c_filter * c_filter; ++i) {
      T* weight = filter_grad_data + i * output_channels + kernel_id;
W
Wang Xin 已提交
1060
      phi::CudaAtomicAdd(&weight[0], r_weight[i]);
1061
    }
Z
zlx 已提交
1062
  }
1063 1064
}

H
hong 已提交
1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089
template <typename T,
          int c_filter_multiplier,
          int c_stride,
          int c_filter,
          DataLayout data_layout,
          bool fuse_relu_before_conv>
__global__ void KernelDepthwiseConvFilterGradSp(const T* output_grad_data,
                                                const T* input_data,
                                                const int num,
                                                const int output_channels,
                                                const int output_height,
                                                const int output_width,
                                                const int input_channels,
                                                const int input_height,
                                                const int input_width,
                                                const int filter_multiplier,
                                                const int filter_height,
                                                const int filter_width,
                                                const int stride_height,
                                                const int stride_width,
                                                const int padding_height,
                                                const int padding_width,
                                                const int dilate_height,
                                                const int dilate_width,
                                                T* filter_grad_data) {
1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100
  int final_filter_multiplier = filter_multiplier;
  int h_stride = stride_height;
  int w_stride = stride_width;
  if (c_filter_multiplier != 0) {
    final_filter_multiplier = c_filter_multiplier;
    h_stride = c_stride;
    w_stride = c_stride;
  }
  if (c_filter_multiplier == 0 || c_filter == -1) {
    if (data_layout != DataLayout::kNHWC) {
      KernelDepthwiseConvFilterGradNCHW<T, fuse_relu_before_conv>(
H
hong 已提交
1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118
          output_grad_data,
          input_data,
          num,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
1119 1120 1121
          filter_grad_data);
    } else {
      KernelDepthwiseConvFilterGradNHWC<T, fuse_relu_before_conv>(
H
hong 已提交
1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139
          output_grad_data,
          input_data,
          num,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
1140 1141 1142 1143 1144
          filter_grad_data);
    }
  } else {
    if (data_layout != DataLayout::kNHWC) {
      KernelDepthwiseConvFilterGradNCHW<T, fuse_relu_before_conv>(
H
hong 已提交
1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162
          output_grad_data,
          input_data,
          num,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
1163 1164
          filter_grad_data);
    } else {
H
hong 已提交
1165 1166
      KernelDepthwiseConvFilterGradCFilterNHWC<T,
                                               c_filter,
1167
                                               fuse_relu_before_conv>(
H
hong 已提交
1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185
          output_grad_data,
          input_data,
          num,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
1186 1187 1188
          filter_grad_data);
    }
  }
Z
zlx 已提交
1189 1190 1191 1192 1193 1194 1195
}

/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
1196
template <class T, bool fuse_relu_before_conv>
H
hong 已提交
1197
class DepthwiseConvFunctor<phi::GPUContext, T, fuse_relu_before_conv> {
Z
zlx 已提交
1198
 public:
H
hong 已提交
1199
  void operator()(const phi::GPUContext& context,
1200 1201
                  const phi::DenseTensor& input,
                  const phi::DenseTensor& filter,
X
xzl 已提交
1202
                  const std::vector<int>& strides,
1203
                  const std::vector<int>& paddings,
H
hong 已提交
1204
                  const std::vector<int>& dilations,
1205
                  phi::DenseTensor* output,
1206
                  const DataLayout data_layout = DataLayout::kNCHW) {
Z
zlx 已提交
1207
    const int batch_size = input.dims()[0];
1208
    const int input_channels =
1209
        (data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
1210
    const int input_height =
1211
        (data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
1212
    const int input_width =
1213
        (data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
1214
    const int output_channels =
1215
        (data_layout != DataLayout::kNHWC ? output->dims()[1]
1216 1217
                                          : output->dims()[3]);
    const int output_height =
1218
        (data_layout != DataLayout::kNHWC ? output->dims()[2]
1219 1220
                                          : output->dims()[1]);
    const int output_width =
1221
        (data_layout != DataLayout::kNHWC ? output->dims()[3]
1222
                                          : output->dims()[2]);
1223 1224
    const int ksize_height = filter.dims()[2];
    const int ksize_width = filter.dims()[3];
Z
zlx 已提交
1225 1226 1227 1228
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
1229 1230
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
1231 1232 1233 1234 1235

    const T* input_data = input.data<T>();
    const T* filter_data = filter.data<T>();
    T* output_data = output->mutable_data<T>(context.GetPlace());

1236
    phi::DenseTensor filter_hwc;
1237
    if (data_layout == DataLayout::kNHWC) {
H
hong 已提交
1238 1239 1240 1241
      framework::DDim filter_hwc_dims({filter.dims()[2],
                                       filter.dims()[3],
                                       filter.dims()[0],
                                       filter.dims()[1]});
1242 1243 1244
      filter_hwc.Resize(filter_hwc_dims);
      filter_hwc.mutable_data<T>(context.GetPlace());
      std::vector<int> perm_axis({2, 3, 0, 1});
H
hong 已提交
1245
      phi::funcs::TransposeNormal<phi::GPUContext, T> trans;
1246 1247 1248 1249
      trans(context, filter, &filter_hwc, perm_axis);
      filter_data = filter_hwc.data<T>();
    }

1250
    int thread = 512;
1251 1252 1253
    int blocks;
    dim3 threads;
    dim3 grid;
W
wangguanzhong 已提交
1254

1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266
    if (data_layout != DataLayout::kNHWC) {
      if (output_width > 1024 && output_width <= 2048)
        thread = (output_width - 1) / 2 + 1;
      else if (output_width > 512 && output_width <= 1024)
        thread = output_width;
#ifdef __HIPCC__
      thread = std::min(thread, 256);
#endif
      blocks = std::min(std::max(thread / output_width, 1), output_height);
      threads = dim3(std::min(output_width, thread), blocks, 1);
      grid = dim3(output_channels, batch_size, 1);
    } else {
1267
#ifdef __HIPCC__
1268
      thread = std::min(thread, 256);
1269
#endif
1270 1271 1272 1273 1274
      blocks = std::min(
          std::max(thread / output_channels, 1),
          ((output_width + dilate_width - 1) / dilate_width) * dilate_width);
      threads = dim3(std::min(output_channels, thread), blocks, 1);
      grid = dim3((output_height + dilate_height - 1) / dilate_height,
H
hong 已提交
1275 1276
                  dilate_height,
                  batch_size);
1277
    }
1278
    int filter_multiplier = output_channels / input_channels;
1279
    int nums_output = output->numel();
1280 1281 1282
#ifdef __HIPCC__
    int block_size = 256;
#else
1283
    int block_size = 512;
1284
#endif
1285
    int grid_size = (nums_output + block_size - 1) / block_size;
1286

1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351
#define check_case(c_filter_multiplier, c_stride, c_filter)             \
  if (c_filter_multiplier == 0 ||                                       \
      filter_multiplier == c_filter_multiplier &&                       \
          stride_height == stride_width && stride_height == c_stride && \
          (ksize_height == ksize_width && ksize_height == c_filter ||   \
           c_filter == -1)) {                                           \
    if (c_filter == -1) {                                               \
      threads.x = block_size;                                           \
      grid.x = grid_size;                                               \
      threads.y = threads.z = grid.y = grid.z = 1;                      \
    }                                                                   \
    if (data_layout != DataLayout::kNHWC) {                             \
      KernelDepthwiseConvSp<T,                                          \
                            c_filter_multiplier,                        \
                            c_stride,                                   \
                            c_filter,                                   \
                            DataLayout::kNCHW,                          \
                            fuse_relu_before_conv>                      \
          <<<grid, threads, 0, context.stream()>>>(input_data,          \
                                                   filter_data,         \
                                                   batch_size,          \
                                                   output_channels,     \
                                                   output_height,       \
                                                   output_width,        \
                                                   input_channels,      \
                                                   input_height,        \
                                                   input_width,         \
                                                   filter_multiplier,   \
                                                   ksize_height,        \
                                                   ksize_width,         \
                                                   stride_height,       \
                                                   stride_width,        \
                                                   padding_height,      \
                                                   padding_width,       \
                                                   dilate_height,       \
                                                   dilate_width,        \
                                                   output_data);        \
    } else {                                                            \
      KernelDepthwiseConvSp<T,                                          \
                            c_filter_multiplier,                        \
                            c_stride,                                   \
                            c_filter,                                   \
                            DataLayout::kNHWC,                          \
                            fuse_relu_before_conv>                      \
          <<<grid, threads, 0, context.stream()>>>(input_data,          \
                                                   filter_data,         \
                                                   batch_size,          \
                                                   output_channels,     \
                                                   output_height,       \
                                                   output_width,        \
                                                   input_channels,      \
                                                   input_height,        \
                                                   input_width,         \
                                                   filter_multiplier,   \
                                                   ksize_height,        \
                                                   ksize_width,         \
                                                   stride_height,       \
                                                   stride_width,        \
                                                   padding_height,      \
                                                   padding_width,       \
                                                   dilate_height,       \
                                                   dilate_width,        \
                                                   output_data);        \
    }                                                                   \
    return;                                                             \
1352
  }
1353 1354 1355 1356 1357 1358
    check_case(1, 1, 3);
    check_case(1, 1, 5);
    check_case(1, 1, -1);
    check_case(1, 2, 3);
    check_case(1, 2, 5);
    check_case(1, 2, -1);
1359 1360 1361 1362 1363 1364
    check_case(2, 1, 3);
    check_case(2, 1, 5);
    check_case(2, 1, -1);
    check_case(2, 2, 3);
    check_case(2, 2, 5);
    check_case(2, 2, -1);
1365 1366 1367
    check_case(0, 0, -1);
// NOTE(liangdun): 0,0 for other case
// add other case if needed, e.g. check_case(2^n,1)
1368
#undef check_case
Z
zlx 已提交
1369 1370 1371
  }
};

1372
template <typename T, bool fuse_relu_before_conv>
H
hong 已提交
1373
class DepthwiseConvInputGradFunctor<phi::GPUContext, T, fuse_relu_before_conv> {
Z
zlx 已提交
1374
 public:
H
hong 已提交
1375
  void operator()(const phi::GPUContext& context,
1376 1377 1378
                  const phi::DenseTensor& input,
                  const phi::DenseTensor& filter,
                  const phi::DenseTensor& output_grad,
X
xzl 已提交
1379 1380
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
1381
                  const std::vector<int>& dilations,
1382
                  phi::DenseTensor* input_grad,
1383
                  const DataLayout data_layout = DataLayout::kNCHW) {
Z
zlx 已提交
1384
    const int batch_size = input.dims()[0];
1385
    const int input_channels =
1386
        (data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
1387
    const int input_height =
1388
        (data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
1389
    const int input_width =
1390
        (data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
1391
    const int output_channels =
1392
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[1]
1393 1394
                                          : output_grad.dims()[3]);
    const int output_height =
1395
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[2]
1396 1397
                                          : output_grad.dims()[1]);
    const int output_width =
1398
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[3]
1399
                                          : output_grad.dims()[2]);
1400 1401 1402
    const int ksize_height = filter.dims()[2];
    const int ksize_width = filter.dims()[3];
    const int stride_height = strides[0];
Z
zlx 已提交
1403 1404 1405
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
1406 1407
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
1408

1409
    const T* input_data = input.data<T>();
1410
    const T* filter_data = filter.data<T>();
Z
zlx 已提交
1411 1412 1413
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

1414
    phi::DenseTensor filter_hwc;
1415
    if (data_layout == DataLayout::kNHWC) {
H
hong 已提交
1416 1417 1418 1419
      framework::DDim filter_hwc_dims({filter.dims()[2],
                                       filter.dims()[3],
                                       filter.dims()[0],
                                       filter.dims()[1]});
1420 1421 1422
      filter_hwc.Resize(filter_hwc_dims);
      filter_hwc.mutable_data<T>(context.GetPlace());
      std::vector<int> perm_axis({2, 3, 0, 1});
H
hong 已提交
1423
      phi::funcs::TransposeNormal<phi::GPUContext, T> trans;
1424 1425 1426 1427
      trans(context, filter, &filter_hwc, perm_axis);
      filter_data = filter_hwc.data<T>();
    }

1428
    int thread = 512;
1429 1430 1431
    int blocks;
    dim3 threads;
    dim3 grid;
W
wangguanzhong 已提交
1432

1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447
    if (data_layout != DataLayout::kNHWC) {
      if (input_width > 1024 && input_width <= 2048) {
        thread = (input_width - 1) / 2 + 1;
      } else if (input_width > 512 && input_width <= 1024) {
        thread = input_width;
      }
      blocks = std::min(std::max(thread / input_width, 1), input_height);
      threads = dim3(std::min(input_width, thread), blocks, 1);
      grid = dim3(input_channels, batch_size, 1);
    } else {
      blocks = std::min(
          std::max(thread / input_channels, 1),
          ((input_width + dilate_width - 1) / dilate_width) * dilate_width);
      threads = dim3(std::min(input_channels, thread), blocks, 1);
      grid = dim3((input_height + dilate_height - 1) / dilate_height,
H
hong 已提交
1448 1449
                  dilate_height,
                  batch_size);
1450
    }
1451
    int filter_multiplier = output_channels / input_channels;
1452 1453 1454 1455 1456 1457 1458
    int nums_input = input_grad->numel();
#ifdef __HIPCC__
    int block_size = 256;
#else
    int block_size = 512;
#endif
    int grid_size = (nums_input + block_size - 1) / block_size;
1459

1460 1461 1462 1463 1464 1465 1466
#define check_case(c_filter_multiplier, c_stride, c_filter)             \
  if (c_filter_multiplier == 0 ||                                       \
      filter_multiplier == c_filter_multiplier &&                       \
          stride_height == stride_width && stride_height == c_stride && \
          (ksize_height == ksize_width && ksize_height == c_filter ||   \
           c_filter == -1)) {                                           \
    if (data_layout != DataLayout::kNHWC) {                             \
1467 1468 1469 1470 1471
      if (c_filter == -1) {                                             \
        threads.x = block_size;                                         \
        grid.x = grid_size;                                             \
        threads.y = threads.z = grid.y = grid.z = 1;                    \
      }                                                                 \
1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526
      KernelDepthwiseConvInputGradSp<T,                                 \
                                     c_filter_multiplier,               \
                                     c_stride,                          \
                                     c_filter,                          \
                                     DataLayout::kNCHW,                 \
                                     fuse_relu_before_conv>             \
          <<<grid, threads, 0, context.stream()>>>(input_data,          \
                                                   output_grad_data,    \
                                                   filter_data,         \
                                                   batch_size,          \
                                                   output_channels,     \
                                                   output_height,       \
                                                   output_width,        \
                                                   input_channels,      \
                                                   input_height,        \
                                                   input_width,         \
                                                   filter_multiplier,   \
                                                   ksize_height,        \
                                                   ksize_width,         \
                                                   stride_height,       \
                                                   stride_width,        \
                                                   padding_height,      \
                                                   padding_width,       \
                                                   dilate_height,       \
                                                   dilate_width,        \
                                                   input_grad_data);    \
    } else {                                                            \
      KernelDepthwiseConvInputGradSp<T,                                 \
                                     c_filter_multiplier,               \
                                     c_stride,                          \
                                     c_filter,                          \
                                     DataLayout::kNHWC,                 \
                                     fuse_relu_before_conv>             \
          <<<grid, threads, 0, context.stream()>>>(input_data,          \
                                                   output_grad_data,    \
                                                   filter_data,         \
                                                   batch_size,          \
                                                   output_channels,     \
                                                   output_height,       \
                                                   output_width,        \
                                                   input_channels,      \
                                                   input_height,        \
                                                   input_width,         \
                                                   filter_multiplier,   \
                                                   ksize_height,        \
                                                   ksize_width,         \
                                                   stride_height,       \
                                                   stride_width,        \
                                                   padding_height,      \
                                                   padding_width,       \
                                                   dilate_height,       \
                                                   dilate_width,        \
                                                   input_grad_data);    \
    }                                                                   \
    return;                                                             \
1527
  }
1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542
    check_case(1, 1, 3);
    check_case(1, 1, 5);
    check_case(1, 1, -1);
    check_case(1, 2, 3);
    check_case(1, 2, 5);
    check_case(1, 2, -1);
    check_case(2, 1, 3);
    check_case(2, 1, 5);
    check_case(2, 1, -1);
    check_case(2, 2, 3);
    check_case(2, 2, 5);
    check_case(2, 2, -1);
    check_case(0, 0, -1);
// NOTE(liangdun): 0,0 for other case
// add other case if needed, e.g. check_case(2^n,1)
1543
#undef check_case
Z
zlx 已提交
1544 1545 1546
  }
};

1547
template <typename T, bool fuse_relu_before_conv>
H
hong 已提交
1548 1549
class DepthwiseConvFilterGradFunctor<phi::GPUContext,
                                     T,
1550
                                     fuse_relu_before_conv> {
Z
zlx 已提交
1551
 public:
H
hong 已提交
1552
  void operator()(const phi::GPUContext& context,
1553 1554
                  const phi::DenseTensor& input,
                  const phi::DenseTensor& output_grad,
X
xzl 已提交
1555 1556
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
1557
                  const std::vector<int>& dilations,
1558
                  phi::DenseTensor* filter_grad,
1559
                  const DataLayout data_layout = DataLayout::kNCHW) {
Z
zlx 已提交
1560
    const int batch_size = input.dims()[0];
1561
    const int input_channels =
1562
        (data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
1563
    const int input_height =
1564
        (data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
1565
    const int input_width =
1566
        (data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
1567
    const int output_channels =
1568
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[1]
1569 1570
                                          : output_grad.dims()[3]);
    const int output_height =
1571
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[2]
1572 1573
                                          : output_grad.dims()[1]);
    const int output_width =
1574
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[3]
1575
                                          : output_grad.dims()[2]);
1576 1577
    const int ksize_height = filter_grad->dims()[2];
    const int ksize_width = filter_grad->dims()[3];
Z
zlx 已提交
1578 1579 1580 1581
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
1582 1583
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
1584 1585 1586

    const T* input_data = input.data<T>();
    const T* output_grad_data = output_grad.data<T>();
1587
    T* filter_grad_data = filter_grad->mutable_data<T>(context.GetPlace());
Z
zlx 已提交
1588

1589
    int block_size = 512;
1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601
    int blocks;
    dim3 threads;
    dim3 grid;
    if (data_layout != DataLayout::kNHWC) {
      if (output_width > 1024 && output_width <= 2048) {
        block_size = (output_width - 1) / 2 + 1;
      } else if (output_width > 512 && output_width <= 1024) {
        block_size = output_width;
      }
      blocks = std::min(std::max(block_size / output_width, 1), output_height);
      grid = dim3(ksize_width, ksize_height, output_channels);
      threads = dim3(std::min(output_width, block_size), blocks, 1);
1602 1603 1604 1605
      if (output_height * output_width < WARP_SIZE) {
        threads = dim3(
            std::min(block_size, batch_size * output_height * output_width));
      }
1606 1607 1608 1609 1610
    } else {
      blocks = std::min(
          std::max(block_size / output_channels, 1),
          ((output_width + dilate_width - 1) / dilate_width) * dilate_width);
      grid = dim3((output_height + dilate_height - 1) / dilate_height,
H
hong 已提交
1611 1612
                  dilate_height,
                  batch_size);
1613 1614
      threads = dim3(std::min(output_channels, block_size), blocks, 1);
    }
1615 1616
    int filter_multiplier = output_channels / input_channels;

1617 1618 1619 1620 1621 1622 1623
#define check_case(c_filter_multiplier, c_stride, c_filter)                    \
  if (c_filter_multiplier == 0 ||                                              \
      filter_multiplier == c_filter_multiplier &&                              \
          stride_height == stride_width && stride_height == c_stride &&        \
          (ksize_height == ksize_width && ksize_height == c_filter ||          \
           c_filter == -1)) {                                                  \
    if (data_layout != DataLayout::kNHWC) {                                    \
1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648
      KernelDepthwiseConvFilterGradSp<T,                                       \
                                      c_filter_multiplier,                     \
                                      c_stride,                                \
                                      c_filter,                                \
                                      DataLayout::kNCHW,                       \
                                      fuse_relu_before_conv>                   \
          <<<grid, threads, 0, context.stream()>>>(output_grad_data,           \
                                                   input_data,                 \
                                                   batch_size,                 \
                                                   output_channels,            \
                                                   output_height,              \
                                                   output_width,               \
                                                   input_channels,             \
                                                   input_height,               \
                                                   input_width,                \
                                                   filter_multiplier,          \
                                                   ksize_height,               \
                                                   ksize_width,                \
                                                   stride_height,              \
                                                   stride_width,               \
                                                   padding_height,             \
                                                   padding_width,              \
                                                   dilate_height,              \
                                                   dilate_width,               \
                                                   filter_grad_data);          \
1649
    } else {                                                                   \
1650
      phi::DenseTensor filter_grad_hwc;                                        \
1651
      if (c_filter != -1) {                                                    \
H
hong 已提交
1652 1653 1654 1655
        framework::DDim filter_grad_hwc_dims({filter_grad->dims()[2],          \
                                              filter_grad->dims()[3],          \
                                              filter_grad->dims()[0],          \
                                              filter_grad->dims()[1]});        \
1656 1657
        filter_grad_hwc.Resize(filter_grad_hwc_dims);                          \
        filter_grad_hwc.mutable_data<T>(context.GetPlace());                   \
H
hong 已提交
1658
        phi::funcs::SetConstant<phi::GPUContext, T> set_zero;                  \
1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672
        set_zero(context, &filter_grad_hwc, static_cast<T>(0));                \
        filter_grad_data = filter_grad_hwc.data<T>();                          \
      } else {                                                                 \
        block_size = 512;                                                      \
        if (output_channels > 1024 && output_channels <= 2048) {               \
          block_size = (output_channels - 1) / 2 + 1;                          \
        } else if (output_channels > 512 && output_channels <= 1024) {         \
          block_size = output_channels;                                        \
        }                                                                      \
        blocks =                                                               \
            std::min(std::max(block_size / output_channels, 1), output_width); \
        grid = dim3(ksize_width * ksize_height, output_height, batch_size);    \
        threads = dim3(std::min(output_channels, block_size), blocks, 1);      \
      }                                                                        \
1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697
      KernelDepthwiseConvFilterGradSp<T,                                       \
                                      c_filter_multiplier,                     \
                                      c_stride,                                \
                                      c_filter,                                \
                                      DataLayout::kNHWC,                       \
                                      fuse_relu_before_conv>                   \
          <<<grid, threads, 0, context.stream()>>>(output_grad_data,           \
                                                   input_data,                 \
                                                   batch_size,                 \
                                                   output_channels,            \
                                                   output_height,              \
                                                   output_width,               \
                                                   input_channels,             \
                                                   input_height,               \
                                                   input_width,                \
                                                   filter_multiplier,          \
                                                   ksize_height,               \
                                                   ksize_width,                \
                                                   stride_height,              \
                                                   stride_width,               \
                                                   padding_height,             \
                                                   padding_width,              \
                                                   dilate_height,              \
                                                   dilate_width,               \
                                                   filter_grad_data);          \
1698 1699
      if (c_filter != -1) {                                                    \
        std::vector<int> perm_axis({2, 3, 0, 1});                              \
H
hong 已提交
1700
        phi::funcs::TransposeNormal<phi::GPUContext, T> trans;                 \
1701 1702 1703 1704
        trans(context, filter_grad_hwc, filter_grad, perm_axis);               \
      }                                                                        \
    }                                                                          \
    return;                                                                    \
1705
  }
1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718
    check_case(1, 1, 3);
    check_case(1, 1, 5);
    check_case(1, 1, -1);
    check_case(1, 2, 3);
    check_case(1, 2, 5);
    check_case(1, 2, -1);
    check_case(2, 1, 3);
    check_case(2, 1, 5);
    check_case(2, 1, -1);
    check_case(2, 2, 3);
    check_case(2, 2, 5);
    check_case(2, 2, -1);
    check_case(0, 0, -1);
1719
#undef check_case
Z
zlx 已提交
1720 1721 1722
  }
};

H
hong 已提交
1723 1724
template class DepthwiseConvFunctor<phi::GPUContext, float, false>;
template class DepthwiseConvFunctor<phi::GPUContext, double, false>;
1725
template class DepthwiseConvFunctor<phi::GPUContext, platform::float16, false>;
Z
zlx 已提交
1726

H
hong 已提交
1727 1728
template class DepthwiseConvInputGradFunctor<phi::GPUContext, float, false>;
template class DepthwiseConvInputGradFunctor<phi::GPUContext, double, false>;
1729 1730 1731
template class DepthwiseConvInputGradFunctor<phi::GPUContext,
                                             platform::float16,
                                             false>;
1732

H
hong 已提交
1733 1734
template class DepthwiseConvFilterGradFunctor<phi::GPUContext, float, false>;
template class DepthwiseConvFilterGradFunctor<phi::GPUContext, double, false>;
1735 1736 1737
template class DepthwiseConvFilterGradFunctor<phi::GPUContext,
                                              platform::float16,
                                              false>;
1738

H
hong 已提交
1739 1740
template class DepthwiseConvFunctor<phi::GPUContext, float, true>;
template class DepthwiseConvFunctor<phi::GPUContext, double, true>;
1741
template class DepthwiseConvFunctor<phi::GPUContext, platform::float16, true>;
1742

H
hong 已提交
1743 1744
template class DepthwiseConvInputGradFunctor<phi::GPUContext, float, true>;
template class DepthwiseConvInputGradFunctor<phi::GPUContext, double, true>;
1745 1746 1747
template class DepthwiseConvInputGradFunctor<phi::GPUContext,
                                             platform::float16,
                                             true>;
1748

H
hong 已提交
1749 1750
template class DepthwiseConvFilterGradFunctor<phi::GPUContext, float, true>;
template class DepthwiseConvFilterGradFunctor<phi::GPUContext, double, true>;
1751 1752 1753
template class DepthwiseConvFilterGradFunctor<phi::GPUContext,
                                              platform::float16,
                                              true>;
Z
zlx 已提交
1754 1755 1756 1757

}  // namespace math
}  // namespace operators
}  // namespace paddle