elementwise_add_op.h 17.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
G
gongweibao 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
G
gongweibao 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
F
fengjiayi 已提交
14 15
#pragma once

16 17
#include <algorithm>
#include <utility>
W
Wu Yi 已提交
18
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
19
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
W
Wu Yi 已提交
20
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
21
#include "paddle/fluid/operators/math/blas.h"
22 23 24 25 26
#ifdef PADDLE_WITH_CUDA
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#endif
W
wanghuancoder 已提交
27

G
gongweibao 已提交
28 29 30
namespace paddle {
namespace operators {

31
template <typename DeviceContext, typename T>
32 33 34
void default_elementwise_add(const framework::ExecutionContext &ctx,
                             const framework::Tensor *x,
                             const framework::Tensor *y, framework::Tensor *z) {
35
  int axis = ctx.Attr<int>("axis");
36 37 38
  auto x_dims = x->dims();
  auto y_dims = y->dims();
  if (x_dims.size() >= y_dims.size()) {
39 40 41 42 43 44
    ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
                                                          AddFunctor<T>(), z);
  } else {
    ElementwiseComputeEx<InverseAddFunctor<T>, DeviceContext, T>(
        ctx, x, y, axis, InverseAddFunctor<T>(), z);
  }
45 46
}

47 48 49 50 51 52
template <typename DeviceContext, typename T, class Enable = void>
struct SameDimsElemwiseAdd {
  void operator()(const framework::ExecutionContext &ctx,
                  const framework::Tensor *x, const framework::Tensor *y,
                  framework::Tensor *z);
};
53

Q
QI JUN 已提交
54
template <typename DeviceContext, typename T>
Y
Yu Yang 已提交
55
class ElementwiseAddKernel : public framework::OpKernel<T> {
G
gongweibao 已提交
56
 public:
C
chengduo 已提交
57 58 59 60
  void Compute(const framework::ExecutionContext &ctx) const override {
    auto *x = ctx.Input<framework::LoDTensor>("X");
    auto *y = ctx.Input<framework::LoDTensor>("Y");
    auto *z = ctx.Output<framework::LoDTensor>("Out");
C
chengduoZH 已提交
61
    z->mutable_data<T>(ctx.GetPlace());
62
    auto dims_equal = x->dims() == y->dims();
63
    if (dims_equal) {
64 65
      SameDimsElemwiseAdd<DeviceContext, T> same_dims_add;
      same_dims_add(ctx, x, y, z);
66
    } else {
67
      default_elementwise_add<DeviceContext, T>(ctx, x, y, z);
68
    }
G
gongweibao 已提交
69 70 71 72
  }
};

template <typename T>
Y
Yu Yang 已提交
73 74
struct IdentityGrad {
  HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; }
G
gongweibao 已提交
75 76
};

77
template <typename DeviceContext, typename T>
78 79 80 81 82 83 84
void default_elementwise_add_grad(const framework::ExecutionContext &ctx,
                                  const framework::Tensor *x,
                                  const framework::Tensor *y,
                                  const framework::Tensor *out,
                                  const framework::Tensor *dout,
                                  framework::Tensor *dx,
                                  framework::Tensor *dy) {
85 86
  int axis = ctx.Attr<int>("axis");

87 88 89 90
  ElemwiseExplicitGradCompute<DeviceContext, T, IdentityGrad<T>,
                              IdentityGrad<T>>(ctx, *x, *y, *out, *dout, axis,
                                               dx, dy, IdentityGrad<T>(),
                                               IdentityGrad<T>());
91 92
}

93
template <typename DeviceContext, typename T>
94 95 96
typename std::enable_if<
    std::is_floating_point<T>::value &&
    std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
97 98 99 100 101
elementwise_add_grad(const framework::ExecutionContext &ctx,
                     const framework::Tensor *x, const framework::Tensor *y,
                     const framework::Tensor *out,
                     const framework::Tensor *dout, framework::Tensor *dx,
                     framework::Tensor *dy) {
102 103 104 105 106 107 108 109 110 111 112 113
  auto blas = math::GetBlas<DeviceContext, T>(ctx);
  if (dx) {
    blas.VCOPY(dout->numel(), dout->data<T>(),
               dx->mutable_data<T>(ctx.GetPlace()));
  }

  if (dy) {
    blas.VCOPY(dout->numel(), dout->data<T>(),
               dy->mutable_data<T>(ctx.GetPlace()));
  }
}

114
template <typename DeviceContext, typename T>
115
typename std::enable_if<
116 117
    !std::is_floating_point<T>::value &&
    std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
118 119 120 121 122 123
elementwise_add_grad(const framework::ExecutionContext &ctx,
                     const framework::Tensor *x, const framework::Tensor *y,
                     const framework::Tensor *out,
                     const framework::Tensor *dout, framework::Tensor *dx,
                     framework::Tensor *dy) {
  default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
124 125
}

126 127 128
#ifdef PADDLE_WITH_CUDA
#ifdef __NVCC__

129 130 131 132 133 134 135 136 137 138 139 140 141 142
template <typename T, int Size>
struct alignas(sizeof(T) * Size) AlignedVector {
  T val[Size];
};

template <typename T>
inline int VectorizedSize(const T *pointer) {
  uint64_t address = reinterpret_cast<uint64_t>(pointer);
  constexpr int vec4 = std::alignment_of<AlignedVector<T, 4>>::value;  // NOLINT
  if (address % vec4 == 0) {
    return 4;
  }
  return 1;
}
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
template <typename T, int BLOCK_W, int BLOCK_H>
__global__ void MatrixColReduce(const T *__restrict__ in, T *__restrict__ out,
                                size_t width, size_t height) {
  __shared__ T sdata[BLOCK_H][BLOCK_W + 1];
  size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
  size_t width_stride = gridDim.x * blockDim.x;
  size_t full_width = (width & (~((uint64_t)(BLOCK_W - 1)))) +
                      ((width & (BLOCK_W - 1)) ? BLOCK_W : 0);

#pragma unroll
  for (size_t w = idx; w < full_width; w += width_stride) {
    sdata[threadIdx.y][threadIdx.x] = 0;
    __syncthreads();
    size_t offset = w + threadIdx.y * width;
#pragma unroll
    for (size_t h = threadIdx.y; h < height;
         h += BLOCK_H) {  // block-stride loop across matrix height
      sdata[threadIdx.y][threadIdx.x] +=
          (w < width) ? in[offset] : (static_cast<T>(0));
      offset += width * BLOCK_H;
    }
    __syncthreads();

    T val = sdata[threadIdx.x][threadIdx.y];
    for (int i = warpSize >> 1; i > 0; i >>= 1)
      val += platform::CudaShuffleXorSync(0xFFFFFFFF, val, i);

    __syncthreads();
    if (threadIdx.x == 0) sdata[0][threadIdx.y] = val;
    __syncthreads();
    if ((threadIdx.y == 0) && ((w) < width)) out[w] = sdata[0][threadIdx.x];
  }
}

template <int BLOCK_W, int BLOCK_H>
__global__ void FP16MatrixColReduce(
    const paddle::platform::float16 *__restrict__ in,
    paddle::platform::float16 *__restrict__ out, size_t width, size_t height) {
  constexpr int repeats = BLOCK_H / BLOCK_W;
  __shared__ paddle::platform::float16 sdata[BLOCK_H][BLOCK_W + 1];
  size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
  size_t width_stride = gridDim.x * blockDim.x;
  size_t full_width = (width & (~((uint64_t)(BLOCK_W - 1)))) +
                      ((width & (BLOCK_W - 1)) ? BLOCK_W : 0);

#pragma unroll
  for (size_t w = idx; w < full_width; w += width_stride) {
    for (int r = 0; r < repeats; r++) {
      sdata[threadIdx.y + r * BLOCK_W][threadIdx.x] = 0;
    }
    __syncthreads();
    for (int r = 0; r < repeats; r++) {
      size_t offset = w + (r * BLOCK_W + threadIdx.y) * width;
#pragma unroll
      for (size_t h = r * BLOCK_H + threadIdx.y; h < height;
           h += BLOCK_H) {  // block-stride loop across matrix height
        sdata[r * BLOCK_W + threadIdx.y][threadIdx.x] +=
            (w < width) ? in[offset + r * BLOCK_W * width]
                        : (static_cast<paddle::platform::float16>(0));
        offset += width * BLOCK_H;
      }
    }
    __syncthreads();

    paddle::platform::float16 result =
        static_cast<paddle::platform::float16>(0);
    for (int r = 0; r < repeats; r++) {
      paddle::platform::float16 val =
          sdata[threadIdx.x + r * BLOCK_W][threadIdx.y];
      for (int i = warpSize >> 1; i > 0; i >>= 1)
        val += platform::CudaShuffleXorSync(0xFFFFFFFF, val, i);
      __syncthreads();
      result += val;
    }
    if (threadIdx.x == 0) sdata[0][threadIdx.y] = result;
    __syncthreads();
    if ((threadIdx.y == 0) && ((w) < width)) out[w] = sdata[0][threadIdx.x];
  }
}
222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260

template <typename T>
__global__ void MatrixReduceLongWidth(const T *__restrict__ in, T *out,
                                      size_t width, size_t height) {
  int idx = threadIdx.x + blockIdx.x * blockDim.x;

  for (; idx < width; idx += blockDim.x * gridDim.x) {
    T sum = static_cast<T>(0);
    for (int row = 0; row < height; row++) {
      sum += in[idx + row * width];
    }

    out[idx] = sum;
  }
}

template <typename T, int VEC_SIZE>
__global__ void VecMatrixReduceLongWidth(const T *__restrict__ in, T *out,
                                         size_t width, size_t height) {
  using LoadT = AlignedVector<T, VEC_SIZE>;
  int idx = threadIdx.x + blockIdx.x * blockDim.x;
  int w = idx * VEC_SIZE;
  int width_stride = blockDim.x * gridDim.x * VEC_SIZE;
  for (; w < width; w += width) {
    T zero = static_cast<T>(0);
    T sum[VEC_SIZE] = {zero};
    T tmp_vec[VEC_SIZE] = {zero};
    LoadT *tmp_ptr = reinterpret_cast<LoadT *>(&tmp_vec);
    for (int row = 0; row < height; row++) {
      int offset = width * row + w;
      *tmp_ptr = *reinterpret_cast<const LoadT *>(&in[offset]);
      for (int v = 0; v < VEC_SIZE; v++) {
        sum[v] += tmp_vec[v];
      }
    }

    for (int v = 0; v < VEC_SIZE; v++) out[w + v] = sum[v];
  }
}
261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307
#endif
#endif
bool static RunSpecialDims(const framework::DDim &dx_dims,
                           const framework::DDim &dy_dims,
                           const framework::DDim &dout_dims, int axis) {
  auto smaller_dims = dx_dims;
  auto bigger_dims = dy_dims;
  auto smaller_dims_size = smaller_dims.size();
  auto bigger_dims_size = bigger_dims.size();
  int smaller_ignore_size = 0;
  int bigger_ignore_size = 0;
  for (int i = 0; i < smaller_dims_size; i++) {
    if (smaller_dims[i] == 1)
      smaller_ignore_size++;
    else
      break;
  }
  for (int i = 0; i < bigger_dims_size; i++) {
    if (bigger_dims[i] == 1)
      bigger_ignore_size++;
    else
      break;
  }

  int smaller_real_size = smaller_dims.size() - smaller_ignore_size;
  int bigger_real_size = bigger_dims.size() - bigger_ignore_size;

  if (smaller_real_size == bigger_real_size) return false;

  if (bigger_real_size < smaller_real_size) {
    smaller_dims = dy_dims;
    bigger_dims = dx_dims;
    std::swap(smaller_real_size, bigger_real_size);
  }
  int big_size = bigger_dims.size();
  int small_size = smaller_dims.size();
  for (int i = 1; i <= smaller_real_size; i++) {
    if (bigger_dims[big_size - i] != smaller_dims[small_size - i]) return false;
  }

  if (axis != -1 && (axis != (bigger_real_size - smaller_real_size))) {
    return false;
  }

  return true;
}

308 309 310 311 312
#ifdef PADDLE_WITH_CUDA
// cuda definition
template <typename DeviceContext, typename T>
typename std::enable_if<
    std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
313 314 315 316 317
elementwise_add_grad(const framework::ExecutionContext &ctx,
                     const framework::Tensor *x, const framework::Tensor *y,
                     const framework::Tensor *out,
                     const framework::Tensor *dout, framework::Tensor *dx,
                     framework::Tensor *dy);
318 319
#endif

Q
QI JUN 已提交
320
template <typename DeviceContext, typename T>
321
class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
G
gongweibao 已提交
322
 public:
C
chengduo 已提交
323
  void Compute(const framework::ExecutionContext &ctx) const override {
324 325
    ElemwiseGradKernel<T>::Compute(ctx);

C
chengduoZH 已提交
326 327
    using Tensor = framework::Tensor;

328 329
    auto *x = ctx.Input<Tensor>("X");
    auto *y = ctx.Input<Tensor>("Y");
C
chengduo 已提交
330 331 332
    auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
    auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto *dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
333
    // skip out
C
chengduo 已提交
334
    auto *out = dout;
335

336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361
#ifdef PADDLE_WITH_CUDA
#ifdef __NVCC__

    int axis = ctx.Attr<int>("axis");
    if (ctx.GetPlace() == platform::CUDAPlace() && dx != nullptr &&
        dy != nullptr && dout != nullptr && dx->numel() != dy->numel() &&
        RunSpecialDims(dx->dims(), dy->dims(), dout->dims(), axis)) {
      auto *dx_data = dx->mutable_data<T>(ctx.GetPlace());
      auto *dy_data = dy->mutable_data<T>(ctx.GetPlace());
      auto *dout_data = dout->data<T>();
      auto stream = ctx.cuda_device_context().stream();
      auto *out_data = dx_data;
      int width = dx->numel();
      int height = dout->numel() / width;
      if (dx->dims() == dout->dims()) {
        width = dy->numel();
        height = dout->numel() / width;
        out_data = dy_data;
        framework::TensorCopy(
            *dout, ctx.GetPlace(),
            ctx.template device_context<platform::DeviceContext>(), dx);
      } else {
        framework::TensorCopy(
            *dout, ctx.GetPlace(),
            ctx.template device_context<platform::DeviceContext>(), dy);
      }
362 363 364 365 366 367 368 369 370 371 372 373 374 375 376
      // special optimization using cub
      if (width == 1) {
        int nums = height;
        size_t temp_storage_bytes = 0;
        auto err = cub::DeviceReduce::Sum(nullptr, temp_storage_bytes,
                                          dout_data, out_data, nums, stream);
        PADDLE_ENFORCE_CUDA_SUCCESS(err);
        framework::Tensor tmp;
        auto *temp_storage = tmp.mutable_data<uint8_t>(
            framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}),
            ctx.GetPlace());
        err = cub::DeviceReduce::Sum(temp_storage, temp_storage_bytes,
                                     dout_data, out_data, nums, stream);
        PADDLE_ENFORCE_CUDA_SUCCESS(err);
      }
377 378 379 380 381 382 383 384 385 386

      constexpr int block_x = 32;
      constexpr int block_y = 32;
      dim3 blocks(block_x, block_y);

      int max_physical_threads =
          ctx.cuda_device_context().GetMaxPhysicalThreadCount();
      int max_blocks = std::max(max_physical_threads / (block_x * block_y), 1);
      int theory_block = (width + blocks.x - 1) / blocks.x;
      dim3 grids(std::min(theory_block, max_blocks));
387 388
      if (std::is_same<T, paddle::platform::float16>::value &&
          (width / height) < 32) {
389 390 391 392 393 394 395 396 397 398 399 400 401
        const paddle::platform::float16 *ptr1 =
            reinterpret_cast<const paddle::platform::float16 *>(dout_data);
        paddle::platform::float16 *ptr2 =
            reinterpret_cast<paddle::platform::float16 *>(out_data);
        if (height <= 32) {
          FP16MatrixColReduce<32, 32><<<grids, blocks, 0, stream>>>(
              ptr1, ptr2, width, height);
        } else {
          FP16MatrixColReduce<32, 64><<<grids, blocks, 0, stream>>>(
              ptr1, ptr2, width, height);
        }
        return;
      }
402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419

      if (width / height < 32) {
        MatrixColReduce<T, block_x, block_y><<<grids, blocks, 0, stream>>>(
            dout_data, out_data, width, height);
      } else {
        size_t thread_nums = 1024;
        size_t block_nums = (width + thread_nums - 1) / thread_nums;
        int vec_size = VectorizedSize<T>(dx_data);
        if (vec_size == 4 && width % 4 == 0) {
          block_nums = (width / vec_size + thread_nums - 1) / thread_nums;
          VecMatrixReduceLongWidth<T,
                                   4><<<block_nums, thread_nums, 0, stream>>>(
              dout_data, out_data, width, height);
        } else {
          MatrixReduceLongWidth<T><<<block_nums, thread_nums, 0, stream>>>(
              dout_data, out_data, width, height);
        }
      }
420 421 422 423 424
      return;
    }

#endif
#endif
425 426 427 428 429 430 431 432 433 434 435 436 437 438
    // Special case when dy is not needed and dx doesn't reduce
    if (dx != nullptr && dy == nullptr && dx->dims() == dout->dims()) {
      VLOG(4) << "Special case when dy is not needed and dx doesn't "
                 "reduce";
      framework::TensorCopy(
          *dout, ctx.GetPlace(),
          ctx.template device_context<platform::DeviceContext>(), dx);
    } else if (dx == nullptr && dy != nullptr && dy->dims() == dout->dims()) {
      VLOG(4) << "Special case when dx is not needed and dy doesn't "
                 "reduce";
      framework::TensorCopy(
          *dout, ctx.GetPlace(),
          ctx.template device_context<platform::DeviceContext>(), dy);
    } else if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
439
      elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
440
    } else {
441 442
      default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx,
                                                     dy);
443
    }
G
gongweibao 已提交
444 445 446
  }
};

447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466
template <typename DeviceContext, typename T>
class ElementwiseAddDoubleGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    using Tensor = framework::Tensor;

    auto *y = ctx.Input<Tensor>("Y");
    auto *dout = ctx.Input<Tensor>("DOut");
    auto *ddx = ctx.Input<Tensor>("DDX");
    auto *ddy = ctx.Input<Tensor>("DDY");

    auto *ddout = ctx.Output<Tensor>("DDOut");

    // ddOut = ddx + ddy
    if (ddout) {
      Tensor ddx_safe, ddy_safe;
      GetDoubleGradSafeTensor<DeviceContext, T>(ctx, dout, ddx, &ddx_safe);
      GetDoubleGradSafeTensor<DeviceContext, T>(ctx, y, ddy, &ddy_safe);

      ddout->mutable_data<T>(ctx.GetPlace());
467 468
      default_elementwise_add<DeviceContext, T>(ctx, &ddx_safe, &ddy_safe,
                                                ddout);
469 470 471 472
    }
  }
};

G
gongweibao 已提交
473 474
}  // namespace operators
}  // namespace paddle