未验证 提交 ab385ca4 编写于 作者: L Leo Chen 提交者: GitHub

fix add_n kernel of large shape (#53767)

上级 268156f8
......@@ -21,34 +21,20 @@ namespace phi {
#define CEIL_DIV(x, y) (((x) + (y)-1) / (y))
template <class T>
__global__ void Sum2CUDAKernel(const T *in_0,
const T *in_1,
T *out,
int64_t N) {
int id = blockIdx.x * blockDim.x + threadIdx.x;
while (id < N) {
out[id] = in_0[id] + in_1[id];
id += blockDim.x * gridDim.x;
}
}
template <class T>
__global__ void SumArrayCUDAKernel(
T **in, T *out, int64_t N, size_t in_size, bool read_dst) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
int id = blockIdx.x * blockDim.x + threadIdx.x;
while (id < N) {
MPType total(read_dst ? static_cast<MPType>(out[id])
CUDA_KERNEL_LOOP_TYPE(idx, N, int64_t) {
MPType total(read_dst ? static_cast<MPType>(out[idx])
: static_cast<MPType>(0));
for (int i = 0; i < in_size; ++i) {
const T *tmp = in[i];
if (tmp) {
total += static_cast<MPType>(tmp[id]);
total += static_cast<MPType>(tmp[idx]);
}
}
out[id] = static_cast<T>(total);
id += blockDim.x * gridDim.x;
out[idx] = static_cast<T>(total);
}
}
......@@ -56,16 +42,14 @@ template <class T>
__global__ void SumSelectedRowsCUDAKernel(T **sr_in_out,
int64_t N,
size_t rows) {
int id = blockIdx.x * blockDim.x + threadIdx.x;
while (id < N) {
CUDA_KERNEL_LOOP_TYPE(idx, N, int64_t) {
for (int i = 0; i < 2 * rows; i += 2) {
const T *tmp = sr_in_out[i];
T *tmp_out = sr_in_out[i + 1];
if (tmp && tmp_out) {
tmp_out[id] += tmp[id];
tmp_out[idx] += tmp[idx];
}
}
id += blockDim.x * gridDim.x;
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册