TensorAssign.h 4.9 KB
Newer Older
H
hedaoyuan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
H
hedaoyuan 已提交
14 15 16 17 18 19 20 21

#pragma once

#include <algorithm>
#include "paddle/utils/Logging.h"

namespace paddle {

H
hedaoyuan 已提交
22 23 24 25
/**
 * \brief Tensor Assign Expression(return by lazyAssign,
 * and evaluated by AssignEvaluate) 
 */
H
hedaoyuan 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
template<typename LhsType, typename RhsType, class T>
class TensorAssignOp {
public:
  explicit TensorAssignOp(const LhsType& lhs, const RhsType& rhs)
    : lhs_(lhs), rhs_(rhs) {
    #ifndef __CUDA_ARCH__
      CHECK_EQ(lhs_.getWidth(), rhs_.getWidth());
      CHECK_EQ(lhs_.getHeight(), rhs_.getHeight());
      CHECK_EQ(lhs_.useGpu(), rhs_.useGpu());
    #endif
  }

  INLINE void apply(const int i, const int j) {
    lhs_.applyRef(i, j) = rhs_.apply(i, j);
  }
  INLINE void apply(const int index) {
    lhs_.applyRef(index) = rhs_.apply(index);
  }

  INLINE size_t getWidth() const { return lhs_.getWidth(); }
  INLINE size_t getHeight() const { return rhs_.getHeight(); }
  INLINE bool isContiguous() const {
    return lhs_.isContiguous() && rhs_.isContiguous();
  }
  INLINE bool useGpu() const { return lhs_.useGpu(); }

private:
  TensorApply<LhsType, T> lhs_;
  TensorApply<const RhsType, T> rhs_;
};

template <typename Assign, typename... AssignOp>
void AssignCpuEvaluate(int height, int width, bool isContiguous,
                       Assign&& assign, AssignOp&& ... args) {
  if (isContiguous) {
    int size = height * width;
    for (int index = 0; index < size; index++) {
      assign.apply(index);
      __attribute__((unused)) int dummy[] = { (((args)).apply(index), 0)... };
    }
  } else {
    for (int i = 0; i < height; i++) {
      for (int j = 0; j < width; j++) {
        assign.apply(i, j);
        __attribute__((unused)) int dummy[] = { (((args)).apply(i, j), 0)... };
      }
    }
  }
}

#ifdef __NVCC__
template <typename Assign, typename... AssignOp>
__global__
void AssignGpuEvaluate1(const int border, Assign assign, AssignOp ... args) {
  const int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < border) {
    assign.apply(idx);
    __attribute__((unused)) int dummy[] = { (((args)).apply(idx), 0)... };
  }
}

template <typename Assign, typename... AssignOp>
__global__
void AssignGpuEvaluate2(const int height, const int width,
                        Assign assign, AssignOp ... args) {
  const int colIdx = blockIdx.x * blockDim.x + threadIdx.x;
  const int rowIdx = blockIdx.y * blockDim.y + threadIdx.y;
  for (int i = rowIdx; i < height; i += gridDim.y * blockDim.y) {
    for (int j = colIdx; j < width; j += gridDim.x * blockDim.x) {
      assign.apply(i, j);
      __attribute__((unused)) int dummy[] = { (((args)).apply(i, j), 0)... };
    }
  }
}
#endif

H
hedaoyuan 已提交
102 103 104 105 106
/**
 * \brief Evaluate one or more TensorAssignOp objects.
 *
 * \note At least one assignment expression is required
 */
H
hedaoyuan 已提交
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
template <typename Assign, typename... AssignOp>
void AssignEvaluate(Assign&& assign, AssignOp&& ... args) {
  const bool useGpu_ = assign.useGpu();
  bool isContiguous_ = assign.isContiguous();
  const size_t height = assign.getHeight();
  const size_t width = assign.getWidth();

  const int packSize = sizeof...(args);
  const bool packUseGpu[] = { ((args)).useGpu()... };
  const bool packIsContiguous[] = { ((args)).isContiguous()... };
  const size_t packHeight[] = { ((args)).getHeight()... };
  const size_t packWidth[] = { ((args)).getWidth()... };

  for (int i = 0; i < packSize; i++) {
    CHECK_EQ(useGpu_, packUseGpu[i]);
    CHECK_EQ(height, packHeight[i]);
    CHECK_EQ(width, packWidth[i]);
    isContiguous_  = isContiguous_ && packIsContiguous[i];
  }

  if (useGpu_) {
#ifdef __NVCC__
    if (isContiguous_) {
      int size = height * width;
      int blockSize = size <= 1024 ? size : 1024;
      int gridSize = (size + 1024 - 1) / 1024;
      AssignGpuEvaluate1
        <<<gridSize, blockSize, 0, STREAM_DEFAULT>>>(size, assign, args...);
    } else {
      int blockSizeY = std::min(32, (int)height);
      int blockSizeX = (32 / blockSizeY) * 32;
      int gridSizeX = std::min(32, (int)(width + blockSizeX - 1) / blockSizeX);
      int gridSizeY = std::min(32, (int)(height + blockSizeY - 1) / blockSizeY);
      dim3 threads(blockSizeX, blockSizeY);
      dim3 grid(gridSizeX, gridSizeY);
      AssignGpuEvaluate2
        <<<grid, threads, 0, STREAM_DEFAULT>>>(height, width, assign, args...);
    }

    CHECK_SYNC("AssignEvaluate failed");
#endif
  } else {
    AssignCpuEvaluate(height, width, isContiguous_, assign, args...);
  }
}

}  // namespace paddle