未验证 提交 45c7d905 编写于 作者: J JamesLim 提交者: GitHub

Optimization of elementwise CUDA kernel (#30801)

上级 0b3c2296
......@@ -99,6 +99,7 @@ inline void get_mid_dims(const framework::DDim &x_dims,
(*post) *= x_dims[i];
}
}
inline int GetElementwiseIndex(const int *x_dims_array, const int max_dim,
const int *index_array) {
int index_ = 0;
......@@ -202,12 +203,16 @@ void CommonForwardBroadcastCPU(const framework::Tensor *x,
#if defined(__NVCC__) || defined(__HIPCC__)
template <typename Functor, typename T, typename OutType>
__global__ void ElementwiseKernel(const T *x, const T *y, OutType *out, int pre,
int n, int post, int total, Functor func) {
__global__ void ElementwiseKernel(const T *__restrict__ x_data,
const T *__restrict__ y_data,
OutType *__restrict__ out_data, int n,
int post, const size_t total, Functor func) {
int tid = threadIdx.x + blockDim.x * blockIdx.x;
int idx = tid / post % n;
if (tid < total) {
out[tid] = func(x[tid], y[idx]);
int stride = blockDim.x * gridDim.x;
for (int i = tid; i < total; i += stride) {
int idx = i / post % n;
out_data[i] = func(x_data[i], y_data[idx]);
}
}
......@@ -224,14 +229,16 @@ void ComputeElementwiseCUDA(const framework::Tensor *x,
int numel = pre * n * post;
int threads = 256;
int blocks = (numel + threads - 1) / threads;
if (is_xsize_larger) {
ElementwiseKernel<Functor, T,
OutType><<<blocks, threads, 0, ctx.stream()>>>(
x_data, y_data, out_data, pre, n, post, numel, func);
x_data, y_data, out_data, n, post, numel, func);
} else {
ElementwiseKernel<Functor, T,
OutType><<<blocks, threads, 0, ctx.stream()>>>(
y_data, x_data, out_data, pre, n, post, numel, func);
y_data, x_data, out_data, n, post, numel, func);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册