未验证 提交 2e6548a9 编写于 作者: S sneaxiy 提交者: GitHub

vec scale kernel (#40011)

上级 5898e9ab
......@@ -304,14 +304,30 @@ struct AndFunctor {
HOSTDEVICE bool operator()(bool x, bool y) const { return x && y; }
};
template <typename T1, typename T2>
template <typename T1, typename T2, int VecSize>
static __global__ void ScaleCUDAKernel(const T1 *__restrict__ x,
const T2 *__restrict__ scale,
T1 *__restrict__ y, int num) {
static_assert(sizeof(T1) <= sizeof(T2),
"sizeof(T1) must be not greater than sizeof(T2).");
T2 s = scale[0];
CUDA_KERNEL_LOOP(i, num) {
int i = (threadIdx.x + blockIdx.x * blockDim.x) * VecSize;
int stride = blockDim.x * gridDim.x * VecSize;
for (; i + VecSize <= num; i += stride) {
platform::AlignedVector<T1, VecSize> x_vec;
platform::AlignedVector<T1, VecSize> y_vec;
platform::Load(x + i, &x_vec);
#pragma unroll
for (int j = 0; j < VecSize; ++j) {
y_vec[j] = static_cast<T1>(static_cast<T2>(x_vec[j]) * s);
}
platform::Store(y_vec, y + i);
}
for (; i < num; ++i) {
y[i] = static_cast<T1>(static_cast<T2>(x[i]) * s);
}
}
......@@ -396,7 +412,6 @@ static __global__ void UpdateLambMomentAndTrustRatioDivCUDAKernel(
for (; i + VecSize <= num; i += stride) {
platform::AlignedVector<T, VecSize> param_vec;
platform::AlignedVector<GradT, VecSize> grad_vec;
platform::AlignedVector<T, VecSize> weight_decay_vec;
platform::AlignedVector<T, VecSize> mom1_vec;
platform::AlignedVector<T, VecSize> mom2_vec;
platform::AlignedVector<T, VecSize> trust_ratio_div_vec;
......@@ -760,6 +775,24 @@ static bool CreatePreMulScaleOpIfSupported(ncclDataType_t dtype,
return false;
}
template <typename T1, typename T2>
static void LaunchScaleKernel(const platform::CUDADeviceContext &dev_ctx,
const T1 *x, const T2 *scale, T1 *y, int n,
gpuStream_t stream) {
int vec_size = std::min(GetChunkedVecSize(x, 0), GetChunkedVecSize(y, 0));
auto config = platform::GetGpuLaunchConfig1D(dev_ctx, n, vec_size);
#define PD_LAMB_VEC_SCALE_KERNEL_CASE \
do { \
ScaleCUDAKernel<T1, T2, kVecSize><<<config.block_per_grid, \
config.thread_per_block, 0, stream>>>( \
x, scale, y, n); \
} while (0)
PD_VEC_LAUNCH_KERNEL(vec_size, PD_LAMB_VEC_SCALE_KERNEL_CASE);
#undef PD_LAMB_VEC_SCALE_KERNEL_CASE
}
template <typename T>
static void NCCLReduceScatterWithScale(
const T *sendbuff, T *recvbuff, size_t recvcount, size_t nranks,
......@@ -775,10 +808,8 @@ static void NCCLReduceScatterWithScale(
PADDLE_ENFORCE_EQ(nranks, 1,
platform::errors::InvalidArgument(
"nranks must be 1 when scale != nullptr."));
auto numel = recvcount * nranks;
auto config = platform::GetGpuLaunchConfig1D(dev_ctx, numel);
ScaleCUDAKernel<<<config.block_per_grid, config.thread_per_block, 0,
stream>>>(sendbuff, scale, recvbuff, numel);
LaunchScaleKernel(dev_ctx, sendbuff, scale, recvbuff, recvcount * nranks,
stream);
}
return;
}
......@@ -792,9 +823,7 @@ static void NCCLReduceScatterWithScale(
if (scale && !should_destroy_op) {
size_t numel = recvcount * nranks;
T *new_sendbuff = buffer.Alloc<T>(numel);
auto config = platform::GetGpuLaunchConfig1D(dev_ctx, numel);
ScaleCUDAKernel<<<config.block_per_grid, config.thread_per_block, 0,
stream>>>(sendbuff, scale, new_sendbuff, numel);
LaunchScaleKernel(dev_ctx, sendbuff, scale, new_sendbuff, numel, stream);
sendbuff = new_sendbuff;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册