TensorEvaluate.h 3.0 KB
Newer Older
H
hedaoyuan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
/**
 * TensorEvaluate.h
 *
 * Author: hedaoyuan (hedaoyuan@baidu.com)
 * Created on: 2016-06-06
 *
 * Copyright (c) Baidu.com, Inc. All Rights Reserved
 *
 */

#pragma once

#include <algorithm>
#include "paddle/utils/Logging.h"
#include "hl_base.h"

namespace paddle {

/**
 * \brief The tensor cpu evaluate api.
 */
template<class T, typename LeftType, typename RightType>
inline void TensorCpuApply(LeftType& lhs, const RightType& rhs) {
  TensorApply<LeftType, T> lhs_(lhs);
  TensorApply<const RightType, T> rhs_(rhs);
  CHECK_EQ(lhs_.getWidth(), rhs_.getWidth());
  CHECK_EQ(lhs_.getHeight(), rhs_.getHeight());
  CHECK_EQ(lhs_.useGpu(), rhs_.useGpu());

H
hedaoyuan 已提交
30 31
  int height = lhs_.getHeight();
  int width = lhs_.getWidth();
H
hedaoyuan 已提交
32
  if (lhs_.isContiguous() && rhs_.isContiguous()) {
H
hedaoyuan 已提交
33
    int size = height * width;
H
hedaoyuan 已提交
34 35 36 37
    for (int index = 0; index < size; index++) {
      lhs_.applyRef(index) = rhs_.apply(index);
    }
  } else {
H
hedaoyuan 已提交
38 39
    for (int i = 0; i < height; i++) {
      for (int j = 0; j < width; j++) {
H
hedaoyuan 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
        lhs_.applyRef(i, j) = rhs_.apply(i, j);
      }
    }
  }
}

#ifdef __NVCC__
template<typename LeftType, typename RightType>
__global__
void TensorElementWiseOp(LeftType lhs, RightType rhs, const int border) {
  const int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < border) {
    lhs.applyRef(idx) = rhs.apply(idx);
  }
}

template<typename LeftType, typename RightType>
__global__ void TensorElementWiseOp(LeftType lhs, RightType rhs) {
  const int colIdx = blockIdx.x * blockDim.x + threadIdx.x;
  const int rowIdx = blockIdx.y * blockDim.y + threadIdx.y;
  for (int i = rowIdx; i < lhs.getHeight(); i += gridDim.y * blockDim.y) {
    for (int j = colIdx; j < lhs.getWidth(); j += gridDim.x * blockDim.x) {
      lhs.applyRef(i, j) = rhs.apply(i, j);
    }
  }
}

/**
 * \brief The tensor gpu evaluate api.
 */
template<class T, typename LeftType, typename RightType>
inline void TensorGpuApply(LeftType& lhs, const RightType& rhs) {
  TensorApply<LeftType, T> lhs_(lhs);
  TensorApply<const RightType, T> rhs_(rhs);
  CHECK_EQ(lhs_.getWidth(), rhs_.getWidth());
  CHECK_EQ(lhs_.getHeight(), rhs_.getHeight());
  CHECK_EQ(lhs_.useGpu(), rhs_.useGpu());

  int dimM = lhs_.getHeight();
  int dimN = lhs_.getWidth();

  if (lhs_.isContiguous() && rhs_.isContiguous()) {
    int size = dimM * dimN;
    int blockSize = size <= 1024 ? size : 1024;
    int gridSize = (size + 1024 - 1) / 1024;
    TensorElementWiseOp
      <<<gridSize, blockSize, 0, STREAM_DEFAULT>>>(lhs_, rhs_, size);
  } else {
    int blockSizeY = std::min(32, dimM);
    int blockSizeX = (32 / blockSizeY) * 32;
    int gridSizeX = std::min(32, (dimN + blockSizeX - 1) / blockSizeX);
    int gridSizeY = std::min(32, (dimM + blockSizeY - 1) / blockSizeY);
    dim3 threads(blockSizeX, blockSizeY);
    dim3 grid(gridSizeX, gridSizeY);
    TensorElementWiseOp
      <<<grid, threads, 0, STREAM_DEFAULT>>>(lhs_, rhs_);
  }

  CHECK_SYNC("TensorGpuApply failed");
}
#else
template<class T, typename LeftType, typename RightType>
inline void TensorGpuApply(LeftType& lhs, RightType& rhs) {
}
#endif

}  // namespace paddle