depthwise_conv.h 71.9 KB
Newer Older
H
hong 已提交
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Z
zlx 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

H
hong 已提交
15
#pragma once
A
Abhinav Arora 已提交
16
#include <vector>
17

H
hong 已提交
18 19 20
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/core/hostdevice.h"

21 22 23 24 25 26 27
#ifdef __NVCC__
#include <cub/cub.cuh>
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
H
hong 已提交
28

29
#include "paddle/phi/backends/gpu/gpu_device_function.h"
W
Wang Xin 已提交
30
#include "paddle/phi/backends/gpu/gpu_primitives.h"
31
#include "paddle/phi/kernels/funcs/math_function.h"
Z
zlx 已提交
32 33 34 35 36

namespace paddle {
namespace operators {
namespace math {

H
hong 已提交
37 38 39 40 41 42 43 44 45 46
/*
 * \brief Compute the depthwise convolution which include
 * forward process and backpropagation process
 */
template <typename DeviceContext,
          typename T,
          bool fuse_relu_before_conv = false>
class DepthwiseConvFunctor {
 public:
  void operator()(const DeviceContext& context,
47 48
                  const phi::DenseTensor& input,
                  const phi::DenseTensor& filter,
H
hong 已提交
49 50 51
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
                  const std::vector<int>& dilations,
52
                  phi::DenseTensor* output,
H
hong 已提交
53 54 55 56 57 58 59 60 61
                  const DataLayout data_layout = DataLayout::kNCHW);
};

template <typename DeviceContext,
          typename T,
          bool fuse_relu_before_conv = false>
class DepthwiseConvInputGradFunctor {
 public:
  void operator()(const DeviceContext& context,
62 63 64
                  const phi::DenseTensor& input,
                  const phi::DenseTensor& filter,
                  const phi::DenseTensor& output_grad,
H
hong 已提交
65 66 67
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
                  const std::vector<int>& dilations,
68
                  phi::DenseTensor* input_grad,
H
hong 已提交
69 70 71 72 73 74 75 76 77
                  const DataLayout data_layout = DataLayout::kNCHW);
};

template <typename DeviceContext,
          typename T,
          bool fuse_relu_before_conv = false>
class DepthwiseConvFilterGradFunctor {
 public:
  void operator()(const DeviceContext& context,
78 79
                  const phi::DenseTensor& input,
                  const phi::DenseTensor& output_grad,
H
hong 已提交
80 81 82
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
                  const std::vector<int>& dilations,
83
                  phi::DenseTensor* filter_grad,
H
hong 已提交
84 85 86
                  const DataLayout data_layout = DataLayout::kNCHW);
};

87 88 89 90
#define FINAL_MASK 0xffffffff
#define HALF_WARP 16
#define WARP_SIZE 32

91
template <typename T>
92 93
__forceinline__ __device__ T WarpReduceSum(T val, unsigned lane_mask) {
  for (int mask = HALF_WARP; mask > 0; mask >>= 1)
94
    val += phi::backends::gpu::CudaShuffleDownSync(lane_mask, val, mask);
W
wangguanzhong 已提交
95 96
  return val;
}
97

W
wangguanzhong 已提交
98
template <typename T>
99 100 101 102 103 104 105
__forceinline__ __device__ T BlockReduceSum(T val, unsigned mask = FINAL_MASK) {
  static __shared__ T shared[WARP_SIZE];
  int tid = threadIdx.y * blockDim.x + threadIdx.x;
  int lane = tid & 0x1f;
  int wid = tid >> 5;

  val = WarpReduceSum<T>(val, mask);
106

W
wangguanzhong 已提交
107
  __syncthreads();
108 109 110 111 112 113 114 115 116
  if (lane == 0) shared[wid] = val;

  __syncthreads();

  // align block_span to WARP_SIZE
  int block_span = (blockDim.x * blockDim.y + WARP_SIZE - 1) >> 5;
  val = (lane < block_span) ? shared[lane] : static_cast<T>(0.0f);
  val = WarpReduceSum<T>(val, mask);

W
wangguanzhong 已提交
117
  return val;
118 119
}

120 121 122 123 124 125 126 127
#define ARG_DEFINE_KernelDepthwiseConv                                         \
  const T *const input_data, const T *const filter_data, const int batch_size, \
      const int output_channels, const int output_height,                      \
      const int output_width, const int input_channels,                        \
      const int input_height, const int input_width,                           \
      const int filter_multiplier, const int filter_height,                    \
      const int filter_width, const int stride_height, const int stride_width, \
      const int padding_height, const int padding_width,                       \
128
      const int dilate_height, const int dilate_width, T *const output_data
129

130 131
// A Cuda kernel to compute the depthwise convolution forward pass
// in NCHW format.
132
template <typename T, int c_filter, bool fuse_relu_before_conv>
133 134
__device__ __inline__ void KernelDepthwiseConvNCHW(
    ARG_DEFINE_KernelDepthwiseConv) {
135 136
  const int fw_size = c_filter != -1 ? c_filter : filter_width;
  const int fh_size = c_filter != -1 ? c_filter : filter_height;
137 138 139 140
  int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx >= (output_channels * batch_size * output_height * output_width))
    return;

141 142 143 144 145 146 147 148
  int tmp_1 = idx / output_width;
  const int w_out = idx - tmp_1 * output_width;
  int tmp_2 = tmp_1 / output_height;
  const int h_out = tmp_1 - tmp_2 * output_height;
  tmp_1 = tmp_2;
  tmp_2 = tmp_1 / output_channels;
  const int c_out = tmp_1 - tmp_2 * output_channels;
  const int batch = tmp_2;
149 150

  const int c_in = c_out / filter_multiplier;
151
  T value(0);
152 153 154

  int in_offset =
      ((batch * input_channels + c_in) * input_height) * input_width;
155 156 157
  int weight_offset = c_out * filter_height * filter_width;
  int h_in_start = -padding_height + h_out * stride_height;
  int w_in_start = -padding_width + w_out * stride_width;
158 159

#pragma unroll
160 161
  for (int fh = 0, h_in = h_in_start; fh < fh_size;
       fh++, h_in += dilate_height) {
162
#pragma unroll
163 164 165
    for (int fw = 0, w_in = w_in_start; fw < fw_size;
         fw++, w_in += dilate_width) {
      if (h_in >= 0 && h_in < input_height && w_in >= 0 && w_in < input_width) {
166 167 168
        int offset = in_offset + h_in * input_width + w_in;
        T in_data = input_data[offset];
        if (fuse_relu_before_conv) {
169 170
          value += filter_data[weight_offset] *
                   static_cast<T>(max(0.0f, static_cast<double>(in_data)));
171
        } else {
172
          value += filter_data[weight_offset] * in_data;
173
        }
174
      }
175 176 177
      weight_offset++;
    }
  }
178
  output_data[idx] = value;
179
}
180

181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
// A Cuda kernel to compute the depthwise convolution forward pass
// in NHWC format.
template <typename T, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvNHWC(
    ARG_DEFINE_KernelDepthwiseConv) {
  int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx >= (output_channels * batch_size * output_height * output_width))
    return;

  const int c_out = idx % output_channels;
  const int w_out = (idx / output_channels) % output_width;
  const int h_out = (idx / output_channels / output_width) % output_height;
  const int batch = idx / output_width / output_height / output_channels;

  const int c_in = c_out / filter_multiplier;
196
  T value(0);
197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
  const int h_in_start = -padding_height + h_out * stride_height;
  const int w_in_start = -padding_width + w_out * stride_width;
  const int h_in_end = h_in_start + filter_height * dilate_height;
  const int w_in_end = w_in_start + filter_width * dilate_width;

  const int h_end = h_in_end < input_height ? h_in_end : input_height;
  const int w_end = w_in_end < input_width ? w_in_end : input_width;
  const int h_start = h_in_start > 0 ? h_in_start : 0;
  const int w_start = w_in_start > 0 ? w_in_start : 0;
  int weight_offset = 0;

#pragma unroll
  for (int h_in = h_in_start; h_in < h_in_end; h_in += dilate_height) {
#pragma unroll
    for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) {
      if (h_in >= h_start && h_in < h_end && w_in >= w_start && w_in < w_end) {
        int offset = ((batch * input_height + h_in) * input_width + w_in) *
214
                         input_channels +
215 216
                     c_in;
        T in_data = input_data[offset];
217
        const T* weight = filter_data + weight_offset * output_channels + c_out;
218
        if (fuse_relu_before_conv) {
219 220
          value += weight[0] *
                   static_cast<T>(max(0.0f, static_cast<double>(in_data)));
221
        } else {
222
          value += weight[0] * in_data;
223
        }
Z
zlx 已提交
224
      }
225
      weight_offset++;
Z
zlx 已提交
226 227
    }
  }
228 229 230 231
  int index = batch * output_channels * output_height * output_width +
              h_out * output_width * output_channels + w_out * output_channels +
              c_out;
  output_data[index] = value;
Z
zlx 已提交
232
}
233

234
template <typename T, int c_filter, bool fuse_relu_before_conv>
235
__device__ __inline__ void KernelDepthwiseConvCFilterNCHW(
236
    ARG_DEFINE_KernelDepthwiseConv) {
237 238
  const int kWeightSize = c_filter * c_filter;
  T r_weight[kWeightSize];
239 240 241 242
  const int batch = blockIdx.y;
  const int c_out = blockIdx.x;
  const T* weight = filter_data + c_out * c_filter * c_filter;
  for (int i = 0; i < c_filter * c_filter; i++) r_weight[i] = weight[i];
243

244 245 246 247 248 249
  for (int w_out = threadIdx.x; w_out < output_width; w_out += blockDim.x) {
    for (int h_out = threadIdx.y; h_out < output_height; h_out += blockDim.y) {
      const int batch = blockIdx.y;
      const int c_out = blockIdx.x;

      const int c_in = c_out / filter_multiplier;
250
      T value(0);
251 252 253 254 255
      const int h_in_start = -padding_height + h_out * stride_height;
      const int w_in_start = -padding_width + w_out * stride_width;
      const int h_in_end = h_in_start + c_filter * dilate_height;
      const int w_in_end = w_in_start + c_filter * dilate_width;

256 257
      int in_offset =
          ((batch * input_channels + c_in) * input_height) * input_width;
258 259 260 261 262 263 264 265 266 267 268 269

      const int h_end = h_in_end < input_height ? h_in_end : input_height;
      const int w_end = w_in_end < input_width ? w_in_end : input_width;
      const int h_start = h_in_start > 0 ? h_in_start : 0;
      const int w_start = w_in_start > 0 ? w_in_start : 0;

      for (int h_in = h_in_start, h_f = 0; h_f < c_filter;
           h_in += dilate_height, h_f++) {
        for (int w_in = w_in_start, w_f = 0; w_f < c_filter;
             w_in += dilate_width, w_f++) {
          if (h_in >= 0 && h_in < input_height && w_in >= 0 &&
              w_in < input_width) {
270 271 272
            int offset = in_offset + h_in * input_width + w_in;
            if (fuse_relu_before_conv) {
              value += r_weight[h_f * c_filter + w_f] *
273 274
                       static_cast<T>(
                           max(0.0f, static_cast<double>(input_data[offset])));
275
            } else {
276
              value += r_weight[h_f * c_filter + w_f] * input_data[offset];
277
            }
278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317
          }
        }
      }
      int index =
          ((batch * gridDim.x + c_out) * output_height + h_out) * output_width +
          w_out;
      output_data[index] = value;
    }
  }
}

template <typename T, int c_filter, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvCFilterNHWC(
    ARG_DEFINE_KernelDepthwiseConv) {
  const int batch = blockIdx.z;
  int h_out = blockIdx.x * dilate_height + blockIdx.y;
  if (h_out >= output_height) {
    return;
  }
  int in_offset = batch * input_height * input_width * input_channels;
  int out_offset =
      (batch * output_height + h_out) * output_width * output_channels;
  const int h_in_start = -padding_height + h_out * stride_height;
  const int wi_size = (output_width + dilate_width - 1) / dilate_width;
  const int kWeightSize = c_filter * c_filter;
  T r_weight[kWeightSize];

  for (int c_out = threadIdx.x; c_out < output_channels; c_out += blockDim.x) {
    for (int i = 0; i < c_filter * c_filter; i++) {
      const T* weight = filter_data + i * output_channels + c_out;
      r_weight[i] = weight[0];
    }
    const int c_in = c_out / filter_multiplier;
    for (int i = threadIdx.y; i < wi_size * dilate_width; i += blockDim.y) {
      int i_dw = i / wi_size;
      int i_wi = i - i_dw * wi_size;
      int w_out = i_wi * dilate_width + i_dw;
      if (w_out >= output_width) {
        continue;
      }
318
      T value(0);
319 320 321 322 323 324 325 326 327
      const int w_in_start = -padding_width + w_out * stride_width;
      for (int h_in = h_in_start, h_f = 0; h_f < c_filter;
           h_in += dilate_height, h_f++) {
        for (int w_in = w_in_start, w_f = 0; w_f < c_filter;
             w_in += dilate_width, w_f++) {
          if (h_in >= 0 && h_in < input_height && w_in >= 0 &&
              w_in < input_width) {
            int offset =
                in_offset + (h_in * input_width + w_in) * input_channels + c_in;
328 329
            if (fuse_relu_before_conv) {
              value += r_weight[h_f * c_filter + w_f] *
330 331
                       static_cast<T>(
                           max(0.0, static_cast<double>(input_data[offset])));
332 333 334
            } else {
              value += r_weight[h_f * c_filter + w_f] * input_data[offset];
            }
335 336 337
          }
        }
      }
338
      int index = out_offset + w_out * output_channels + c_out;
339 340 341 342 343
      output_data[index] = value;
    }
  }
}

H
hong 已提交
344 345 346 347 348 349
template <typename T,
          int c_filter_multiplier,
          int c_stride,
          int c_filter,
          DataLayout data_layout,
          bool fuse_relu_before_conv>
350
__global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) {
351 352 353 354 355 356 357 358 359
  int final_filter_multiplier = filter_multiplier;
  int h_stride = stride_height;
  int w_stride = stride_width;
  if (c_filter_multiplier != 0) {
    final_filter_multiplier = c_filter_multiplier;
    h_stride = c_stride;
    w_stride = c_stride;
  }
  if (c_filter == -1) {
360
    if (data_layout != DataLayout::kNHWC) {
361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380
      KernelDepthwiseConvNCHW<T, c_filter, fuse_relu_before_conv>(
          input_data,
          filter_data,
          batch_size,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
          output_data);
381
    } else {
H
hong 已提交
382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400
      KernelDepthwiseConvNHWC<T, fuse_relu_before_conv>(input_data,
                                                        filter_data,
                                                        batch_size,
                                                        output_channels,
                                                        output_height,
                                                        output_width,
                                                        input_channels,
                                                        input_height,
                                                        input_width,
                                                        final_filter_multiplier,
                                                        filter_height,
                                                        filter_width,
                                                        h_stride,
                                                        w_stride,
                                                        padding_height,
                                                        padding_width,
                                                        dilate_height,
                                                        dilate_width,
                                                        output_data);
401 402
    }
  } else {
403 404
    if (data_layout != DataLayout::kNHWC) {
      KernelDepthwiseConvCFilterNCHW<T, c_filter, fuse_relu_before_conv>(
H
hong 已提交
405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422
          input_data,
          filter_data,
          batch_size,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
423 424 425
          output_data);
    } else {
      KernelDepthwiseConvCFilterNHWC<T, c_filter, fuse_relu_before_conv>(
H
hong 已提交
426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443
          input_data,
          filter_data,
          batch_size,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
444 445
          output_data);
    }
446
  }
447 448
}

Z
zlx 已提交
449
// CUDA kernel to compute the depthwise convolution backprop w.r.t input.
450
#define ARG_DEFINE_KernelDepthwiseConvInputGrad                                \
451 452 453 454 455
  const T *const input_data, const T *const output_grad_data,                  \
      const T *const filter_data, const int batch_size,                        \
      const int output_channels, const int output_height,                      \
      const int output_width, const int input_channels,                        \
      const int input_height, const int input_width,                           \
456 457 458 459
      const int filter_multiplier, const int filter_height,                    \
      const int filter_width, const int stride_height, const int stride_width, \
      const int padding_height, const int padding_width,                       \
      const int dilate_height, const int dilate_width,                         \
460
      T *const input_grad_data
461

462
template <typename T, int c_filter, bool fuse_relu_before_conv>
463
__device__ __inline__ void KernelDepthwiseConvInputGradNCHW(
464
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
465 466 467 468 469 470 471 472 473 474 475 476
  const int fw_size = c_filter != -1 ? c_filter : filter_width;
  const int fh_size = c_filter != -1 ? c_filter : filter_height;
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx >= batch_size * input_channels * input_height * input_width) {
    return;
  }
  if (fuse_relu_before_conv) {
    if (input_data[idx] <= static_cast<T>(0.0f)) {
      input_grad_data[idx] = 0;
      return;
    }
  }
477

478 479 480 481 482 483 484 485
  int tmp_1 = idx / input_width;
  const int w_in = idx - tmp_1 * input_width;
  int tmp_2 = tmp_1 / input_height;
  const int h_in = tmp_1 - tmp_2 * input_height;
  tmp_1 = tmp_2;
  tmp_2 = tmp_1 / input_channels;
  const int c_in = tmp_1 - tmp_2 * input_channels;
  const int batch = tmp_2;
486

487 488 489 490
  T value(0);
  for (int c_mul = 0; c_mul < filter_multiplier; ++c_mul) {
    int c_out = c_in * filter_multiplier + c_mul;
    int filter_offset = c_out * filter_height * filter_width;
491

492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510
#pragma unroll
    for (int fh = 0; fh < fh_size; ++fh) {
#pragma unroll
      for (int fw = 0; fw < fw_size; ++fw) {
        int h_out = h_in + padding_height - fh * dilate_height;
        int w_out = w_in + padding_width - fw * dilate_width;
        if ((h_out - h_out / stride_height * stride_height == 0) &&
            (w_out - w_out / stride_width * stride_width == 0)) {
          h_out /= stride_height;
          w_out /= stride_width;

          if (h_out >= 0 && h_out < output_height && w_out >= 0 &&
              w_out < output_width) {
            int output_grad_offset =
                ((batch * output_channels + c_out) * output_height + h_out) *
                    output_width +
                w_out;
            value += output_grad_data[output_grad_offset] *
                     filter_data[filter_offset];
511 512
          }
        }
513
        filter_offset++;
514 515 516
      }
    }
  }
517
  input_grad_data[idx] = value;
518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535
}

template <typename T, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvInputGradNHWC(
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
  const int batch = blockIdx.z;
  int h_in = blockIdx.x * dilate_height + blockIdx.y;
  if (h_in >= input_height) {
    return;
  }

  for (int c_in = threadIdx.x; c_in < input_channels; c_in += blockDim.x) {
    for (int w_in = threadIdx.y; w_in < input_width; w_in += blockDim.y) {
      int h_out_start =
          h_in - (filter_height - 1) * dilate_height + padding_height;
      int w_out_start =
          w_in - (filter_width - 1) * dilate_width + padding_width;

536
      T value(0);
537 538 539 540
      int index = ((batch * input_height + h_in) * input_width + w_in) *
                      input_channels +
                  c_in;
      if (fuse_relu_before_conv) {
541
        if (input_data[index] <= T(0)) {
542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564
          input_grad_data[index] = 0;
          continue;
        }
      }

      for (int c_i = 0; c_i < filter_multiplier; c_i++) {
        int c_out = c_in * filter_multiplier + c_i;
        int weight_offset = filter_height * filter_width;
        for (int h_out = h_out_start, h_f = 0; h_f < filter_height;
             h_out += dilate_height, h_f++) {
          for (int w_out = w_out_start, w_f = 0; w_f < filter_width;
               w_out += dilate_width, w_f++) {
            weight_offset--;
            int s_h_out = h_out / stride_height;
            int s_w_out = w_out / stride_width;
            if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
                s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
                s_w_out < output_width) {
              int output_grad_offset =
                  ((batch * output_height + s_h_out) * output_width + s_w_out) *
                      output_channels +
                  c_out;
              int filter_offset = weight_offset * output_channels + c_out;
565 566 567 568
              value += output_grad_data[output_grad_offset] *
                       filter_data[filter_offset];
            }
          }
Z
zlx 已提交
569 570
        }
      }
571
      input_grad_data[index] = value;
Z
zlx 已提交
572 573 574 575
    }
  }
}

H
hong 已提交
576 577 578
template <typename T,
          int c_filter,
          int c_filter_multiplier,
579
          bool fuse_relu_before_conv>
580
__device__ __inline__ void KernelDepthwiseConvInputGradCFilterNCHW(
581
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
582 583
  const int kWeightSize = c_filter * c_filter * c_filter_multiplier + 1;
  T r_weight[kWeightSize];
584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599
  const int batch = blockIdx.y;
  const int c_in = blockIdx.x;

  for (int c_i = 0; c_i < filter_multiplier; c_i++) {
    int c_out = c_in * filter_multiplier + c_i;
    const T* weight = filter_data + c_out * c_filter * c_filter;
    for (int i = 0; i < c_filter * c_filter; i++)
      r_weight[i + c_i * c_filter * c_filter] =
          weight[c_filter * c_filter - i - 1];
  }

  for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) {
    for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) {
      int h_out_start = h_in - (c_filter - 1) * dilate_height + padding_height;
      int w_out_start = w_in - (c_filter - 1) * dilate_width + padding_width;

600
      T value(0);
601 602 603
      int index =
          ((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
          w_in;
604
      if (fuse_relu_before_conv) {
605
        if (input_data[index] <= T(0)) {
606 607 608 609
          input_grad_data[index] = 0;
          continue;
        }
      }
610 611 612 613 614 615 616 617 618 619 620 621

      for (int c_i = 0; c_i < filter_multiplier; c_i++) {
        int c_out = c_in * filter_multiplier + c_i;
        for (int h_out = h_out_start, h_f = 0; h_f < c_filter;
             h_out += dilate_height, h_f++) {
          for (int w_out = w_out_start, w_f = 0; w_f < c_filter;
               w_out += dilate_width, w_f++) {
            int s_h_out = h_out / stride_height;
            int s_w_out = w_out / stride_width;
            if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
                s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
                s_w_out < output_width) {
622 623 624 625 626
              int output_grad_offset =
                  ((batch * output_channels + c_out) * output_height +
                   s_h_out) *
                      output_width +
                  s_w_out;
627 628 629 630 631 632 633 634 635 636 637 638
              value +=
                  output_grad_data[output_grad_offset] *
                  r_weight[h_f * c_filter + w_f + c_i * c_filter * c_filter];
            }
          }
        }
      }
      input_grad_data[index] = value;
    }
  }
}

H
hong 已提交
639 640 641
template <typename T,
          int c_filter,
          int c_filter_multiplier,
642
          bool fuse_relu_before_conv>
643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672
__device__ __inline__ void KernelDepthwiseConvInputGradCFilterNHWC(
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
  int h_in = blockIdx.x * dilate_height + blockIdx.y;
  if (h_in >= input_height) {
    return;
  }
  const int kWeightSize = c_filter * c_filter * c_filter_multiplier + 1;
  T r_weight[kWeightSize];
  const int batch = blockIdx.z;
  const int wi_size = (input_width + dilate_width - 1) / dilate_width;
  const int h_out_start =
      h_in - (c_filter - 1) * dilate_height + padding_height;

  for (int c_in = threadIdx.x; c_in < input_channels; c_in += blockDim.x) {
    for (int c_i = 0; c_i < c_filter_multiplier; c_i++) {
      int c_out = c_in * c_filter_multiplier + c_i;
      for (int i = 0; i < c_filter * c_filter; i++)
        r_weight[i + c_i * c_filter * c_filter] =
            filter_data[(c_filter * c_filter - i - 1) * output_channels +
                        c_out];
    }
    for (int i = threadIdx.y; i < wi_size * dilate_width; i += blockDim.y) {
      int i_dw = i / wi_size;
      int i_wi = i - i_dw * wi_size;
      int w_in = i_wi * dilate_width + i_dw;
      if (w_in >= input_width) {
        continue;
      }
      int w_out_start = w_in - (c_filter - 1) * dilate_width + padding_width;

673
      T value(0);
674 675 676 677
      int index = ((batch * input_height + h_in) * input_width + w_in) *
                      input_channels +
                  c_in;
      if (fuse_relu_before_conv) {
678
        if (input_data[index] <= T(0)) {
679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710
          input_grad_data[index] = 0;
          continue;
        }
      }

      for (int c_i = 0; c_i < c_filter_multiplier; c_i++) {
        int c_out = c_in * c_filter_multiplier + c_i;
        for (int h_out = h_out_start, h_f = 0; h_f < c_filter;
             h_out += dilate_height, h_f++) {
          for (int w_out = w_out_start, w_f = 0; w_f < c_filter;
               w_out += dilate_width, w_f++) {
            int s_h_out = h_out / stride_height;
            int s_w_out = w_out / stride_width;
            if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
                s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
                s_w_out < output_width) {
              int output_grad_offset =
                  ((batch * output_height + s_h_out) * output_width + s_w_out) *
                      output_channels +
                  c_out;
              value +=
                  output_grad_data[output_grad_offset] *
                  r_weight[h_f * c_filter + w_f + c_i * c_filter * c_filter];
            }
          }
        }
      }
      input_grad_data[index] = value;
    }
  }
}

H
hong 已提交
711 712 713 714 715 716
template <typename T,
          int c_filter_multiplier,
          int c_stride,
          int c_filter,
          DataLayout data_layout,
          bool fuse_relu_before_conv>
717
__global__ void KernelDepthwiseConvInputGradSp(
718
    ARG_DEFINE_KernelDepthwiseConvInputGrad) {
719 720 721 722 723 724 725 726 727 728 729
  int final_filter_multiplier = filter_multiplier;
  int h_stride = stride_height;
  int w_stride = stride_width;
  if (c_filter_multiplier != 0) {
    final_filter_multiplier = c_filter_multiplier;
    h_stride = c_stride;
    w_stride = c_stride;
  }

  if (c_filter_multiplier == 0 || c_filter == -1) {
    if (data_layout != DataLayout::kNHWC) {
730
      KernelDepthwiseConvInputGradNCHW<T, c_filter, fuse_relu_before_conv>(
H
hong 已提交
731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750
          input_data,
          output_grad_data,
          filter_data,
          batch_size,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
          input_grad_data);
751 752
    } else {
      KernelDepthwiseConvInputGradNHWC<T, fuse_relu_before_conv>(
H
hong 已提交
753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772
          input_data,
          output_grad_data,
          filter_data,
          batch_size,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
          input_grad_data);
773 774 775
    }
  } else {
    if (data_layout != DataLayout::kNHWC) {
H
hong 已提交
776 777 778
      KernelDepthwiseConvInputGradCFilterNCHW<T,
                                              c_filter,
                                              c_filter_multiplier,
779
                                              fuse_relu_before_conv>(
H
hong 已提交
780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799
          input_data,
          output_grad_data,
          filter_data,
          batch_size,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          c_filter_multiplier,
          filter_height,
          filter_width,
          c_stride,
          c_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
          input_grad_data);
800
    } else {
H
hong 已提交
801 802 803
      KernelDepthwiseConvInputGradCFilterNHWC<T,
                                              c_filter,
                                              c_filter_multiplier,
804
                                              fuse_relu_before_conv>(
H
hong 已提交
805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824
          input_data,
          output_grad_data,
          filter_data,
          batch_size,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          c_filter_multiplier,
          filter_height,
          filter_width,
          c_stride,
          c_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
          input_grad_data);
825 826
    }
  }
827 828
}

829
// Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
830
template <typename T, bool fuse_relu_before_conv>
831
__device__ __inline__ void KernelDepthwiseConvFilterGradNCHW(
H
hong 已提交
832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850
    const T* output_grad_data,
    const T* input_data,
    const int num,
    const int output_channels,
    const int output_height,
    const int output_width,
    const int input_channels,
    const int input_height,
    const int input_width,
    const int filter_multiplier,
    const int filter_height,
    const int filter_width,
    const int stride_height,
    const int stride_width,
    const int padding_height,
    const int padding_width,
    const int dilate_height,
    const int dilate_width,
    T* filter_grad_data) {
851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909
  T f_grad(0);
  const bool loop_batch = output_height * output_width >= WARP_SIZE;

  int kw_id = blockIdx.x;
  int kh_id = blockIdx.y;
  int oc_id = blockIdx.z;
  int ic_id = oc_id / filter_multiplier;
  int idx = ((blockIdx.z * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x;

  const int ohw = output_height * output_width;
  const int onhw = num * ohw;
  const int h_offset = kh_id * dilate_height - padding_height;
  const int w_offset = kw_id * dilate_width - padding_width;

  if (loop_batch) {
    for (int og_w = threadIdx.x; og_w < output_width; og_w += blockDim.x) {
      for (int bid = 0; bid < num; ++bid) {
        for (int og_h = threadIdx.y; og_h < output_height; og_h += blockDim.y) {
          int i_h = og_h * stride_height + h_offset;
          int i_w = og_w * stride_width + w_offset;

          if (i_w >= 0 && i_w < input_width && i_h >= 0 && i_h < input_height) {
            int input_offset =
                ((bid * input_channels + ic_id) * input_height + i_h) *
                    input_width +
                i_w;
            int output_grad_offset =
                ((bid * output_channels + oc_id) * output_height + og_h) *
                    output_width +
                og_w;
            if (fuse_relu_before_conv) {
              f_grad +=
                  output_grad_data[output_grad_offset] *
                  static_cast<T>(
                      max(0.0f, static_cast<double>(input_data[input_offset])));
            } else {
              f_grad += output_grad_data[output_grad_offset] *
                        input_data[input_offset];
            }
          }
        }
      }
    }
  } else {
    for (int id = threadIdx.x; id < onhw; id += blockDim.x) {
      int bid = id / ohw;
      int og_hw = id - bid * ohw;
      int og_h = og_hw / output_width;
      int og_w = og_hw - og_h * output_width;

      int i_h = og_h * stride_height + h_offset;
      int i_w = og_w * stride_width + w_offset;

      if (i_w >= 0 && i_w < input_width && i_h >= 0 && i_h < input_height) {
        int input_offset =
            ((bid * input_channels + ic_id) * input_height + i_h) *
                input_width +
            i_w;
        int output_grad_offset = (bid * output_channels + oc_id) * ohw + og_hw;
910
        if (fuse_relu_before_conv) {
911 912 913
          f_grad += output_grad_data[output_grad_offset] *
                    static_cast<T>(max(
                        0.0f, static_cast<double>(input_data[input_offset])));
914
        } else {
915 916
          f_grad +=
              output_grad_data[output_grad_offset] * input_data[input_offset];
917 918 919 920
        }
      }
    }
  }
W
wangguanzhong 已提交
921

922 923 924 925
  T val = BlockReduceSum<T>(f_grad);
  if (threadIdx.x == 0 && threadIdx.y == 0) {
    filter_grad_data[idx] = val;
  }
926 927 928 929
}

template <typename T, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvFilterGradNHWC(
H
hong 已提交
930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948
    const T* output_grad_data,
    const T* input_data,
    const int num,
    const int output_channels,
    const int output_height,
    const int output_width,
    const int input_channels,
    const int input_height,
    const int input_width,
    const int filter_multiplier,
    const int filter_height,
    const int filter_width,
    const int stride_height,
    const int stride_width,
    const int padding_height,
    const int padding_width,
    const int dilate_height,
    const int dilate_width,
    T* filter_grad_data) {
949 950 951 952 953 954
  int bid = blockIdx.z;
  int image_h = blockIdx.y;
  int kernel_iw = blockIdx.x % filter_width;
  int kernel_ih = blockIdx.x / filter_width;
  for (int kernel_id = threadIdx.x; kernel_id < output_channels;
       kernel_id += blockDim.x) {
955
    T s(0);
956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974
    int gbid =
        ((kernel_id * filter_height) + kernel_ih) * filter_width + kernel_iw;
    for (int image_w = threadIdx.y; image_w < output_width;
         image_w += blockDim.y) {
      int kernel_h = kernel_ih * dilate_height - padding_height;
      int kernel_w = kernel_iw * dilate_width - padding_width;

      int image_hk = image_h * stride_height + kernel_h;
      int image_wk = image_w * stride_width + kernel_w;
      if (image_hk < 0 || image_hk >= input_height) continue;
      if (image_wk < 0 || image_wk >= input_width) continue;
#define gaid(N, H, W, C) \
  ((((N)*output_height + (H)) * output_width + (W)) * output_channels + (C))
      int input_id =
          ((bid * input_height + image_hk) * input_width + image_wk) *
              input_channels +
          kernel_id / filter_multiplier;
      if (fuse_relu_before_conv) {
        s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] *
975 976
             static_cast<T>(
                 max(0.0f, static_cast<double>(input_data[input_id])));
977 978 979 980 981 982
      } else {
        s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] *
             input_data[input_id];
      }
#undef gaid
    }
W
Wang Xin 已提交
983
    phi::CudaAtomicAdd(&filter_grad_data[gbid], s);
984 985 986 987 988
  }
}

template <typename T, int c_filter, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvFilterGradCFilterNHWC(
H
hong 已提交
989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007
    const T* output_grad_data,
    const T* input_data,
    const int num,
    const int output_channels,
    const int output_height,
    const int output_width,
    const int input_channels,
    const int input_height,
    const int input_width,
    const int filter_multiplier,
    const int filter_height,
    const int filter_width,
    const int stride_height,
    const int stride_width,
    const int padding_height,
    const int padding_width,
    const int dilate_height,
    const int dilate_width,
    T* filter_grad_data) {
1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037
  const int bid = blockIdx.z;
  int image_h = blockIdx.x * dilate_height + blockIdx.y;
  if (image_h >= output_height) {
    return;
  }
  const int kWeightSize = c_filter * c_filter;
  T r_weight[kWeightSize];
  const int wi_size = (output_width + dilate_width - 1) / dilate_width;

  for (int kernel_id = threadIdx.x; kernel_id < output_channels;
       kernel_id += blockDim.x) {
    for (int i = 0; i < c_filter * c_filter; ++i) {
      r_weight[i] = 0;
    }
    for (int i = threadIdx.y; i < wi_size * dilate_width; i += blockDim.y) {
      int i_dw = i / wi_size;
      int i_wi = i - i_dw * wi_size;
      int image_w = i_wi * dilate_width + i_dw;
      if (image_w >= output_width) {
        continue;
      }
      for (int kernel_ih = 0; kernel_ih < c_filter; ++kernel_ih) {
        for (int kernel_iw = 0; kernel_iw < c_filter; ++kernel_iw) {
          int kernel_h = kernel_ih * dilate_height - padding_height;
          int kernel_w = kernel_iw * dilate_width - padding_width;
          int image_hk = image_h * stride_height + kernel_h;
          int image_wk = image_w * stride_width + kernel_w;
          if (image_hk < 0 || image_hk >= input_height) continue;
          if (image_wk < 0 || image_wk >= input_width) continue;
          int input_id =
1038
              ((bid * input_height + image_hk) * input_width + image_wk) *
1039
                  input_channels +
1040
              kernel_id / filter_multiplier;
1041 1042 1043 1044
          int output_id =
              ((bid * output_height + image_h) * output_width + image_w) *
                  output_channels +
              kernel_id;
1045
          T s(0);
1046
          if (fuse_relu_before_conv) {
1047
            s = output_grad_data[output_id] *
1048 1049
                static_cast<T>(
                    max(0.0f, static_cast<double>(input_data[input_id])));
1050
          } else {
1051
            s = output_grad_data[output_id] * input_data[input_id];
1052
          }
1053
          r_weight[kernel_ih * c_filter + kernel_iw] += s;
1054
        }
1055
      }
Z
zlx 已提交
1056
    }
1057 1058
    for (int i = 0; i < c_filter * c_filter; ++i) {
      T* weight = filter_grad_data + i * output_channels + kernel_id;
W
Wang Xin 已提交
1059
      phi::CudaAtomicAdd(&weight[0], r_weight[i]);
1060
    }
Z
zlx 已提交
1061
  }
1062 1063
}

H
hong 已提交
1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088
template <typename T,
          int c_filter_multiplier,
          int c_stride,
          int c_filter,
          DataLayout data_layout,
          bool fuse_relu_before_conv>
__global__ void KernelDepthwiseConvFilterGradSp(const T* output_grad_data,
                                                const T* input_data,
                                                const int num,
                                                const int output_channels,
                                                const int output_height,
                                                const int output_width,
                                                const int input_channels,
                                                const int input_height,
                                                const int input_width,
                                                const int filter_multiplier,
                                                const int filter_height,
                                                const int filter_width,
                                                const int stride_height,
                                                const int stride_width,
                                                const int padding_height,
                                                const int padding_width,
                                                const int dilate_height,
                                                const int dilate_width,
                                                T* filter_grad_data) {
1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099
  int final_filter_multiplier = filter_multiplier;
  int h_stride = stride_height;
  int w_stride = stride_width;
  if (c_filter_multiplier != 0) {
    final_filter_multiplier = c_filter_multiplier;
    h_stride = c_stride;
    w_stride = c_stride;
  }
  if (c_filter_multiplier == 0 || c_filter == -1) {
    if (data_layout != DataLayout::kNHWC) {
      KernelDepthwiseConvFilterGradNCHW<T, fuse_relu_before_conv>(
H
hong 已提交
1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117
          output_grad_data,
          input_data,
          num,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
1118 1119 1120
          filter_grad_data);
    } else {
      KernelDepthwiseConvFilterGradNHWC<T, fuse_relu_before_conv>(
H
hong 已提交
1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138
          output_grad_data,
          input_data,
          num,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
1139 1140 1141 1142 1143
          filter_grad_data);
    }
  } else {
    if (data_layout != DataLayout::kNHWC) {
      KernelDepthwiseConvFilterGradNCHW<T, fuse_relu_before_conv>(
H
hong 已提交
1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161
          output_grad_data,
          input_data,
          num,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
1162 1163
          filter_grad_data);
    } else {
H
hong 已提交
1164 1165
      KernelDepthwiseConvFilterGradCFilterNHWC<T,
                                               c_filter,
1166
                                               fuse_relu_before_conv>(
H
hong 已提交
1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184
          output_grad_data,
          input_data,
          num,
          output_channels,
          output_height,
          output_width,
          input_channels,
          input_height,
          input_width,
          final_filter_multiplier,
          filter_height,
          filter_width,
          h_stride,
          w_stride,
          padding_height,
          padding_width,
          dilate_height,
          dilate_width,
1185 1186 1187
          filter_grad_data);
    }
  }
Z
zlx 已提交
1188 1189 1190 1191 1192 1193 1194
}

/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
1195
template <class T, bool fuse_relu_before_conv>
H
hong 已提交
1196
class DepthwiseConvFunctor<phi::GPUContext, T, fuse_relu_before_conv> {
Z
zlx 已提交
1197
 public:
H
hong 已提交
1198
  void operator()(const phi::GPUContext& context,
1199 1200
                  const phi::DenseTensor& input,
                  const phi::DenseTensor& filter,
X
xzl 已提交
1201
                  const std::vector<int>& strides,
1202
                  const std::vector<int>& paddings,
H
hong 已提交
1203
                  const std::vector<int>& dilations,
1204
                  phi::DenseTensor* output,
1205
                  const DataLayout data_layout = DataLayout::kNCHW) {
Z
zlx 已提交
1206
    const int batch_size = input.dims()[0];
1207
    const int input_channels =
1208
        (data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
1209
    const int input_height =
1210
        (data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
1211
    const int input_width =
1212
        (data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
1213
    const int output_channels =
1214
        (data_layout != DataLayout::kNHWC ? output->dims()[1]
1215 1216
                                          : output->dims()[3]);
    const int output_height =
1217
        (data_layout != DataLayout::kNHWC ? output->dims()[2]
1218 1219
                                          : output->dims()[1]);
    const int output_width =
1220
        (data_layout != DataLayout::kNHWC ? output->dims()[3]
1221
                                          : output->dims()[2]);
1222 1223
    const int ksize_height = filter.dims()[2];
    const int ksize_width = filter.dims()[3];
Z
zlx 已提交
1224 1225 1226 1227
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
1228 1229
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
1230 1231 1232

    const T* input_data = input.data<T>();
    const T* filter_data = filter.data<T>();
1233
    T* output_data = context.template Alloc<T>(output);
Z
zlx 已提交
1234

1235
    phi::DenseTensor filter_hwc;
1236
    if (data_layout == DataLayout::kNHWC) {
H
hong 已提交
1237 1238 1239 1240
      framework::DDim filter_hwc_dims({filter.dims()[2],
                                       filter.dims()[3],
                                       filter.dims()[0],
                                       filter.dims()[1]});
1241
      filter_hwc.Resize(filter_hwc_dims);
1242
      context.template Alloc<T>(&filter_hwc);
1243
      std::vector<int> perm_axis({2, 3, 0, 1});
H
hong 已提交
1244
      phi::funcs::TransposeNormal<phi::GPUContext, T> trans;
1245 1246 1247 1248
      trans(context, filter, &filter_hwc, perm_axis);
      filter_data = filter_hwc.data<T>();
    }

1249
    int thread = 512;
1250 1251 1252
    int blocks;
    dim3 threads;
    dim3 grid;
W
wangguanzhong 已提交
1253

1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265
    if (data_layout != DataLayout::kNHWC) {
      if (output_width > 1024 && output_width <= 2048)
        thread = (output_width - 1) / 2 + 1;
      else if (output_width > 512 && output_width <= 1024)
        thread = output_width;
#ifdef __HIPCC__
      thread = std::min(thread, 256);
#endif
      blocks = std::min(std::max(thread / output_width, 1), output_height);
      threads = dim3(std::min(output_width, thread), blocks, 1);
      grid = dim3(output_channels, batch_size, 1);
    } else {
1266
#ifdef __HIPCC__
1267
      thread = std::min(thread, 256);
1268
#endif
1269 1270 1271 1272 1273
      blocks = std::min(
          std::max(thread / output_channels, 1),
          ((output_width + dilate_width - 1) / dilate_width) * dilate_width);
      threads = dim3(std::min(output_channels, thread), blocks, 1);
      grid = dim3((output_height + dilate_height - 1) / dilate_height,
H
hong 已提交
1274 1275
                  dilate_height,
                  batch_size);
1276
    }
1277
    int filter_multiplier = output_channels / input_channels;
1278
    int nums_output = output->numel();
1279 1280 1281
#ifdef __HIPCC__
    int block_size = 256;
#else
1282
    int block_size = 512;
1283
#endif
1284
    int grid_size = (nums_output + block_size - 1) / block_size;
1285

1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350
#define check_case(c_filter_multiplier, c_stride, c_filter)             \
  if (c_filter_multiplier == 0 ||                                       \
      filter_multiplier == c_filter_multiplier &&                       \
          stride_height == stride_width && stride_height == c_stride && \
          (ksize_height == ksize_width && ksize_height == c_filter ||   \
           c_filter == -1)) {                                           \
    if (c_filter == -1) {                                               \
      threads.x = block_size;                                           \
      grid.x = grid_size;                                               \
      threads.y = threads.z = grid.y = grid.z = 1;                      \
    }                                                                   \
    if (data_layout != DataLayout::kNHWC) {                             \
      KernelDepthwiseConvSp<T,                                          \
                            c_filter_multiplier,                        \
                            c_stride,                                   \
                            c_filter,                                   \
                            DataLayout::kNCHW,                          \
                            fuse_relu_before_conv>                      \
          <<<grid, threads, 0, context.stream()>>>(input_data,          \
                                                   filter_data,         \
                                                   batch_size,          \
                                                   output_channels,     \
                                                   output_height,       \
                                                   output_width,        \
                                                   input_channels,      \
                                                   input_height,        \
                                                   input_width,         \
                                                   filter_multiplier,   \
                                                   ksize_height,        \
                                                   ksize_width,         \
                                                   stride_height,       \
                                                   stride_width,        \
                                                   padding_height,      \
                                                   padding_width,       \
                                                   dilate_height,       \
                                                   dilate_width,        \
                                                   output_data);        \
    } else {                                                            \
      KernelDepthwiseConvSp<T,                                          \
                            c_filter_multiplier,                        \
                            c_stride,                                   \
                            c_filter,                                   \
                            DataLayout::kNHWC,                          \
                            fuse_relu_before_conv>                      \
          <<<grid, threads, 0, context.stream()>>>(input_data,          \
                                                   filter_data,         \
                                                   batch_size,          \
                                                   output_channels,     \
                                                   output_height,       \
                                                   output_width,        \
                                                   input_channels,      \
                                                   input_height,        \
                                                   input_width,         \
                                                   filter_multiplier,   \
                                                   ksize_height,        \
                                                   ksize_width,         \
                                                   stride_height,       \
                                                   stride_width,        \
                                                   padding_height,      \
                                                   padding_width,       \
                                                   dilate_height,       \
                                                   dilate_width,        \
                                                   output_data);        \
    }                                                                   \
    return;                                                             \
1351
  }
1352 1353 1354 1355 1356 1357
    check_case(1, 1, 3);
    check_case(1, 1, 5);
    check_case(1, 1, -1);
    check_case(1, 2, 3);
    check_case(1, 2, 5);
    check_case(1, 2, -1);
1358 1359 1360 1361 1362 1363
    check_case(2, 1, 3);
    check_case(2, 1, 5);
    check_case(2, 1, -1);
    check_case(2, 2, 3);
    check_case(2, 2, 5);
    check_case(2, 2, -1);
1364 1365 1366
    check_case(0, 0, -1);
// NOTE(liangdun): 0,0 for other case
// add other case if needed, e.g. check_case(2^n,1)
1367
#undef check_case
Z
zlx 已提交
1368 1369 1370
  }
};

1371
template <typename T, bool fuse_relu_before_conv>
H
hong 已提交
1372
class DepthwiseConvInputGradFunctor<phi::GPUContext, T, fuse_relu_before_conv> {
Z
zlx 已提交
1373
 public:
H
hong 已提交
1374
  void operator()(const phi::GPUContext& context,
1375 1376 1377
                  const phi::DenseTensor& input,
                  const phi::DenseTensor& filter,
                  const phi::DenseTensor& output_grad,
X
xzl 已提交
1378 1379
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
1380
                  const std::vector<int>& dilations,
1381
                  phi::DenseTensor* input_grad,
1382
                  const DataLayout data_layout = DataLayout::kNCHW) {
Z
zlx 已提交
1383
    const int batch_size = input.dims()[0];
1384
    const int input_channels =
1385
        (data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
1386
    const int input_height =
1387
        (data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
1388
    const int input_width =
1389
        (data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
1390
    const int output_channels =
1391
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[1]
1392 1393
                                          : output_grad.dims()[3]);
    const int output_height =
1394
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[2]
1395 1396
                                          : output_grad.dims()[1]);
    const int output_width =
1397
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[3]
1398
                                          : output_grad.dims()[2]);
1399 1400 1401
    const int ksize_height = filter.dims()[2];
    const int ksize_width = filter.dims()[3];
    const int stride_height = strides[0];
Z
zlx 已提交
1402 1403 1404
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
1405 1406
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
1407

1408
    const T* input_data = input.data<T>();
1409
    const T* filter_data = filter.data<T>();
Z
zlx 已提交
1410
    const T* output_grad_data = output_grad.data<T>();
1411
    T* input_grad_data = context.template Alloc<T>(input_grad);
Z
zlx 已提交
1412

1413
    phi::DenseTensor filter_hwc;
1414
    if (data_layout == DataLayout::kNHWC) {
H
hong 已提交
1415 1416 1417 1418
      framework::DDim filter_hwc_dims({filter.dims()[2],
                                       filter.dims()[3],
                                       filter.dims()[0],
                                       filter.dims()[1]});
1419
      filter_hwc.Resize(filter_hwc_dims);
1420
      context.template Alloc<T>(&filter_hwc);
1421
      std::vector<int> perm_axis({2, 3, 0, 1});
H
hong 已提交
1422
      phi::funcs::TransposeNormal<phi::GPUContext, T> trans;
1423 1424 1425 1426
      trans(context, filter, &filter_hwc, perm_axis);
      filter_data = filter_hwc.data<T>();
    }

1427
    int thread = 512;
1428 1429 1430
    int blocks;
    dim3 threads;
    dim3 grid;
W
wangguanzhong 已提交
1431

1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446
    if (data_layout != DataLayout::kNHWC) {
      if (input_width > 1024 && input_width <= 2048) {
        thread = (input_width - 1) / 2 + 1;
      } else if (input_width > 512 && input_width <= 1024) {
        thread = input_width;
      }
      blocks = std::min(std::max(thread / input_width, 1), input_height);
      threads = dim3(std::min(input_width, thread), blocks, 1);
      grid = dim3(input_channels, batch_size, 1);
    } else {
      blocks = std::min(
          std::max(thread / input_channels, 1),
          ((input_width + dilate_width - 1) / dilate_width) * dilate_width);
      threads = dim3(std::min(input_channels, thread), blocks, 1);
      grid = dim3((input_height + dilate_height - 1) / dilate_height,
H
hong 已提交
1447 1448
                  dilate_height,
                  batch_size);
1449
    }
1450
    int filter_multiplier = output_channels / input_channels;
1451 1452 1453 1454 1455 1456 1457
    int nums_input = input_grad->numel();
#ifdef __HIPCC__
    int block_size = 256;
#else
    int block_size = 512;
#endif
    int grid_size = (nums_input + block_size - 1) / block_size;
1458

1459 1460 1461 1462 1463 1464 1465
#define check_case(c_filter_multiplier, c_stride, c_filter)             \
  if (c_filter_multiplier == 0 ||                                       \
      filter_multiplier == c_filter_multiplier &&                       \
          stride_height == stride_width && stride_height == c_stride && \
          (ksize_height == ksize_width && ksize_height == c_filter ||   \
           c_filter == -1)) {                                           \
    if (data_layout != DataLayout::kNHWC) {                             \
1466 1467 1468 1469 1470
      if (c_filter == -1) {                                             \
        threads.x = block_size;                                         \
        grid.x = grid_size;                                             \
        threads.y = threads.z = grid.y = grid.z = 1;                    \
      }                                                                 \
1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525
      KernelDepthwiseConvInputGradSp<T,                                 \
                                     c_filter_multiplier,               \
                                     c_stride,                          \
                                     c_filter,                          \
                                     DataLayout::kNCHW,                 \
                                     fuse_relu_before_conv>             \
          <<<grid, threads, 0, context.stream()>>>(input_data,          \
                                                   output_grad_data,    \
                                                   filter_data,         \
                                                   batch_size,          \
                                                   output_channels,     \
                                                   output_height,       \
                                                   output_width,        \
                                                   input_channels,      \
                                                   input_height,        \
                                                   input_width,         \
                                                   filter_multiplier,   \
                                                   ksize_height,        \
                                                   ksize_width,         \
                                                   stride_height,       \
                                                   stride_width,        \
                                                   padding_height,      \
                                                   padding_width,       \
                                                   dilate_height,       \
                                                   dilate_width,        \
                                                   input_grad_data);    \
    } else {                                                            \
      KernelDepthwiseConvInputGradSp<T,                                 \
                                     c_filter_multiplier,               \
                                     c_stride,                          \
                                     c_filter,                          \
                                     DataLayout::kNHWC,                 \
                                     fuse_relu_before_conv>             \
          <<<grid, threads, 0, context.stream()>>>(input_data,          \
                                                   output_grad_data,    \
                                                   filter_data,         \
                                                   batch_size,          \
                                                   output_channels,     \
                                                   output_height,       \
                                                   output_width,        \
                                                   input_channels,      \
                                                   input_height,        \
                                                   input_width,         \
                                                   filter_multiplier,   \
                                                   ksize_height,        \
                                                   ksize_width,         \
                                                   stride_height,       \
                                                   stride_width,        \
                                                   padding_height,      \
                                                   padding_width,       \
                                                   dilate_height,       \
                                                   dilate_width,        \
                                                   input_grad_data);    \
    }                                                                   \
    return;                                                             \
1526
  }
1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541
    check_case(1, 1, 3);
    check_case(1, 1, 5);
    check_case(1, 1, -1);
    check_case(1, 2, 3);
    check_case(1, 2, 5);
    check_case(1, 2, -1);
    check_case(2, 1, 3);
    check_case(2, 1, 5);
    check_case(2, 1, -1);
    check_case(2, 2, 3);
    check_case(2, 2, 5);
    check_case(2, 2, -1);
    check_case(0, 0, -1);
// NOTE(liangdun): 0,0 for other case
// add other case if needed, e.g. check_case(2^n,1)
1542
#undef check_case
Z
zlx 已提交
1543 1544 1545
  }
};

1546
template <typename T, bool fuse_relu_before_conv>
H
hong 已提交
1547 1548
class DepthwiseConvFilterGradFunctor<phi::GPUContext,
                                     T,
1549
                                     fuse_relu_before_conv> {
Z
zlx 已提交
1550
 public:
H
hong 已提交
1551
  void operator()(const phi::GPUContext& context,
1552 1553
                  const phi::DenseTensor& input,
                  const phi::DenseTensor& output_grad,
X
xzl 已提交
1554 1555
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
1556
                  const std::vector<int>& dilations,
1557
                  phi::DenseTensor* filter_grad,
1558
                  const DataLayout data_layout = DataLayout::kNCHW) {
Z
zlx 已提交
1559
    const int batch_size = input.dims()[0];
1560
    const int input_channels =
1561
        (data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
1562
    const int input_height =
1563
        (data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
1564
    const int input_width =
1565
        (data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
1566
    const int output_channels =
1567
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[1]
1568 1569
                                          : output_grad.dims()[3]);
    const int output_height =
1570
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[2]
1571 1572
                                          : output_grad.dims()[1]);
    const int output_width =
1573
        (data_layout != DataLayout::kNHWC ? output_grad.dims()[3]
1574
                                          : output_grad.dims()[2]);
1575 1576
    const int ksize_height = filter_grad->dims()[2];
    const int ksize_width = filter_grad->dims()[3];
Z
zlx 已提交
1577 1578 1579 1580
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
1581 1582
    const int dilate_height = dilations[0];
    const int dilate_width = dilations[1];
Z
zlx 已提交
1583 1584 1585

    const T* input_data = input.data<T>();
    const T* output_grad_data = output_grad.data<T>();
1586
    T* filter_grad_data = context.template Alloc<T>(filter_grad);
Z
zlx 已提交
1587

1588
    int block_size = 512;
1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600
    int blocks;
    dim3 threads;
    dim3 grid;
    if (data_layout != DataLayout::kNHWC) {
      if (output_width > 1024 && output_width <= 2048) {
        block_size = (output_width - 1) / 2 + 1;
      } else if (output_width > 512 && output_width <= 1024) {
        block_size = output_width;
      }
      blocks = std::min(std::max(block_size / output_width, 1), output_height);
      grid = dim3(ksize_width, ksize_height, output_channels);
      threads = dim3(std::min(output_width, block_size), blocks, 1);
1601 1602 1603 1604
      if (output_height * output_width < WARP_SIZE) {
        threads = dim3(
            std::min(block_size, batch_size * output_height * output_width));
      }
1605 1606 1607 1608 1609
    } else {
      blocks = std::min(
          std::max(block_size / output_channels, 1),
          ((output_width + dilate_width - 1) / dilate_width) * dilate_width);
      grid = dim3((output_height + dilate_height - 1) / dilate_height,
H
hong 已提交
1610 1611
                  dilate_height,
                  batch_size);
1612 1613
      threads = dim3(std::min(output_channels, block_size), blocks, 1);
    }
1614 1615
    int filter_multiplier = output_channels / input_channels;

1616 1617 1618 1619 1620 1621 1622
#define check_case(c_filter_multiplier, c_stride, c_filter)                    \
  if (c_filter_multiplier == 0 ||                                              \
      filter_multiplier == c_filter_multiplier &&                              \
          stride_height == stride_width && stride_height == c_stride &&        \
          (ksize_height == ksize_width && ksize_height == c_filter ||          \
           c_filter == -1)) {                                                  \
    if (data_layout != DataLayout::kNHWC) {                                    \
1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647
      KernelDepthwiseConvFilterGradSp<T,                                       \
                                      c_filter_multiplier,                     \
                                      c_stride,                                \
                                      c_filter,                                \
                                      DataLayout::kNCHW,                       \
                                      fuse_relu_before_conv>                   \
          <<<grid, threads, 0, context.stream()>>>(output_grad_data,           \
                                                   input_data,                 \
                                                   batch_size,                 \
                                                   output_channels,            \
                                                   output_height,              \
                                                   output_width,               \
                                                   input_channels,             \
                                                   input_height,               \
                                                   input_width,                \
                                                   filter_multiplier,          \
                                                   ksize_height,               \
                                                   ksize_width,                \
                                                   stride_height,              \
                                                   stride_width,               \
                                                   padding_height,             \
                                                   padding_width,              \
                                                   dilate_height,              \
                                                   dilate_width,               \
                                                   filter_grad_data);          \
1648
    } else {                                                                   \
1649
      phi::DenseTensor filter_grad_hwc;                                        \
1650
      if (c_filter != -1) {                                                    \
H
hong 已提交
1651 1652 1653 1654
        framework::DDim filter_grad_hwc_dims({filter_grad->dims()[2],          \
                                              filter_grad->dims()[3],          \
                                              filter_grad->dims()[0],          \
                                              filter_grad->dims()[1]});        \
1655
        filter_grad_hwc.Resize(filter_grad_hwc_dims);                          \
1656
        context.template Alloc<T>(&filter_grad_hwc);                           \
H
hong 已提交
1657
        phi::funcs::SetConstant<phi::GPUContext, T> set_zero;                  \
1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671
        set_zero(context, &filter_grad_hwc, static_cast<T>(0));                \
        filter_grad_data = filter_grad_hwc.data<T>();                          \
      } else {                                                                 \
        block_size = 512;                                                      \
        if (output_channels > 1024 && output_channels <= 2048) {               \
          block_size = (output_channels - 1) / 2 + 1;                          \
        } else if (output_channels > 512 && output_channels <= 1024) {         \
          block_size = output_channels;                                        \
        }                                                                      \
        blocks =                                                               \
            std::min(std::max(block_size / output_channels, 1), output_width); \
        grid = dim3(ksize_width * ksize_height, output_height, batch_size);    \
        threads = dim3(std::min(output_channels, block_size), blocks, 1);      \
      }                                                                        \
1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696
      KernelDepthwiseConvFilterGradSp<T,                                       \
                                      c_filter_multiplier,                     \
                                      c_stride,                                \
                                      c_filter,                                \
                                      DataLayout::kNHWC,                       \
                                      fuse_relu_before_conv>                   \
          <<<grid, threads, 0, context.stream()>>>(output_grad_data,           \
                                                   input_data,                 \
                                                   batch_size,                 \
                                                   output_channels,            \
                                                   output_height,              \
                                                   output_width,               \
                                                   input_channels,             \
                                                   input_height,               \
                                                   input_width,                \
                                                   filter_multiplier,          \
                                                   ksize_height,               \
                                                   ksize_width,                \
                                                   stride_height,              \
                                                   stride_width,               \
                                                   padding_height,             \
                                                   padding_width,              \
                                                   dilate_height,              \
                                                   dilate_width,               \
                                                   filter_grad_data);          \
1697 1698
      if (c_filter != -1) {                                                    \
        std::vector<int> perm_axis({2, 3, 0, 1});                              \
H
hong 已提交
1699
        phi::funcs::TransposeNormal<phi::GPUContext, T> trans;                 \
1700 1701 1702 1703
        trans(context, filter_grad_hwc, filter_grad, perm_axis);               \
      }                                                                        \
    }                                                                          \
    return;                                                                    \
1704
  }
1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717
    check_case(1, 1, 3);
    check_case(1, 1, 5);
    check_case(1, 1, -1);
    check_case(1, 2, 3);
    check_case(1, 2, 5);
    check_case(1, 2, -1);
    check_case(2, 1, 3);
    check_case(2, 1, 5);
    check_case(2, 1, -1);
    check_case(2, 2, 3);
    check_case(2, 2, 5);
    check_case(2, 2, -1);
    check_case(0, 0, -1);
1718
#undef check_case
Z
zlx 已提交
1719 1720 1721
  }
};

H
hong 已提交
1722 1723
template class DepthwiseConvFunctor<phi::GPUContext, float, false>;
template class DepthwiseConvFunctor<phi::GPUContext, double, false>;
1724
template class DepthwiseConvFunctor<phi::GPUContext, platform::float16, false>;
Z
zlx 已提交
1725

H
hong 已提交
1726 1727
template class DepthwiseConvInputGradFunctor<phi::GPUContext, float, false>;
template class DepthwiseConvInputGradFunctor<phi::GPUContext, double, false>;
1728 1729 1730
template class DepthwiseConvInputGradFunctor<phi::GPUContext,
                                             platform::float16,
                                             false>;
1731

H
hong 已提交
1732 1733
template class DepthwiseConvFilterGradFunctor<phi::GPUContext, float, false>;
template class DepthwiseConvFilterGradFunctor<phi::GPUContext, double, false>;
1734 1735 1736
template class DepthwiseConvFilterGradFunctor<phi::GPUContext,
                                              platform::float16,
                                              false>;
1737

H
hong 已提交
1738 1739
template class DepthwiseConvFunctor<phi::GPUContext, float, true>;
template class DepthwiseConvFunctor<phi::GPUContext, double, true>;
1740
template class DepthwiseConvFunctor<phi::GPUContext, platform::float16, true>;
1741

H
hong 已提交
1742 1743
template class DepthwiseConvInputGradFunctor<phi::GPUContext, float, true>;
template class DepthwiseConvInputGradFunctor<phi::GPUContext, double, true>;
1744 1745 1746
template class DepthwiseConvInputGradFunctor<phi::GPUContext,
                                             platform::float16,
                                             true>;
1747

H
hong 已提交
1748 1749
template class DepthwiseConvFilterGradFunctor<phi::GPUContext, float, true>;
template class DepthwiseConvFilterGradFunctor<phi::GPUContext, double, true>;
1750 1751 1752
template class DepthwiseConvFilterGradFunctor<phi::GPUContext,
                                              platform::float16,
                                              true>;
Z
zlx 已提交
1753 1754 1755 1756

}  // namespace math
}  // namespace operators
}  // namespace paddle