未验证 提交 b720873d 编写于 作者: L Leo Chen 提交者: GitHub

fix add_n kernel of large shape (#53751)

上级 6fee5a3e
...@@ -23,34 +23,20 @@ namespace phi { ...@@ -23,34 +23,20 @@ namespace phi {
#define CEIL_DIV(x, y) (((x) + (y)-1) / (y)) #define CEIL_DIV(x, y) (((x) + (y)-1) / (y))
template <class T>
__global__ void Sum2CUDAKernel(const T *in_0,
const T *in_1,
T *out,
int64_t N) {
int id = blockIdx.x * blockDim.x + threadIdx.x;
while (id < N) {
out[id] = in_0[id] + in_1[id];
id += blockDim.x * gridDim.x;
}
}
template <class T> template <class T>
__global__ void SumArrayCUDAKernel( __global__ void SumArrayCUDAKernel(
T **in, T *out, int64_t N, size_t in_size, bool read_dst) { T **in, T *out, int64_t N, size_t in_size, bool read_dst) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type; using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
int id = blockIdx.x * blockDim.x + threadIdx.x; CUDA_KERNEL_LOOP_TYPE(idx, N, int64_t) {
while (id < N) { MPType total(read_dst ? static_cast<MPType>(out[idx])
MPType total(read_dst ? static_cast<MPType>(out[id])
: static_cast<MPType>(0)); : static_cast<MPType>(0));
for (int i = 0; i < in_size; ++i) { for (int i = 0; i < in_size; ++i) {
const T *tmp = in[i]; const T *tmp = in[i];
if (tmp) { if (tmp) {
total += static_cast<MPType>(tmp[id]); total += static_cast<MPType>(tmp[idx]);
} }
} }
out[id] = static_cast<T>(total); out[idx] = static_cast<T>(total);
id += blockDim.x * gridDim.x;
} }
} }
...@@ -58,16 +44,14 @@ template <class T> ...@@ -58,16 +44,14 @@ template <class T>
__global__ void SumSelectedRowsCUDAKernel(T **sr_in_out, __global__ void SumSelectedRowsCUDAKernel(T **sr_in_out,
int64_t N, int64_t N,
size_t rows) { size_t rows) {
int id = blockIdx.x * blockDim.x + threadIdx.x; CUDA_KERNEL_LOOP_TYPE(idx, N, int64_t) {
while (id < N) {
for (int i = 0; i < 2 * rows; i += 2) { for (int i = 0; i < 2 * rows; i += 2) {
const T *tmp = sr_in_out[i]; const T *tmp = sr_in_out[i];
T *tmp_out = sr_in_out[i + 1]; T *tmp_out = sr_in_out[i + 1];
if (tmp && tmp_out) { if (tmp && tmp_out) {
tmp_out[id] += tmp[id]; tmp_out[idx] += tmp[idx];
} }
} }
id += blockDim.x * gridDim.x;
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册