未验证 提交 2e6548a9 编写于 作者: S sneaxiy 提交者: GitHub

vec scale kernel (#40011)

上级 5898e9ab
...@@ -304,14 +304,30 @@ struct AndFunctor { ...@@ -304,14 +304,30 @@ struct AndFunctor {
HOSTDEVICE bool operator()(bool x, bool y) const { return x && y; } HOSTDEVICE bool operator()(bool x, bool y) const { return x && y; }
}; };
template <typename T1, typename T2> template <typename T1, typename T2, int VecSize>
static __global__ void ScaleCUDAKernel(const T1 *__restrict__ x, static __global__ void ScaleCUDAKernel(const T1 *__restrict__ x,
const T2 *__restrict__ scale, const T2 *__restrict__ scale,
T1 *__restrict__ y, int num) { T1 *__restrict__ y, int num) {
static_assert(sizeof(T1) <= sizeof(T2), static_assert(sizeof(T1) <= sizeof(T2),
"sizeof(T1) must be not greater than sizeof(T2)."); "sizeof(T1) must be not greater than sizeof(T2).");
T2 s = scale[0]; T2 s = scale[0];
CUDA_KERNEL_LOOP(i, num) {
int i = (threadIdx.x + blockIdx.x * blockDim.x) * VecSize;
int stride = blockDim.x * gridDim.x * VecSize;
for (; i + VecSize <= num; i += stride) {
platform::AlignedVector<T1, VecSize> x_vec;
platform::AlignedVector<T1, VecSize> y_vec;
platform::Load(x + i, &x_vec);
#pragma unroll
for (int j = 0; j < VecSize; ++j) {
y_vec[j] = static_cast<T1>(static_cast<T2>(x_vec[j]) * s);
}
platform::Store(y_vec, y + i);
}
for (; i < num; ++i) {
y[i] = static_cast<T1>(static_cast<T2>(x[i]) * s); y[i] = static_cast<T1>(static_cast<T2>(x[i]) * s);
} }
} }
...@@ -396,7 +412,6 @@ static __global__ void UpdateLambMomentAndTrustRatioDivCUDAKernel( ...@@ -396,7 +412,6 @@ static __global__ void UpdateLambMomentAndTrustRatioDivCUDAKernel(
for (; i + VecSize <= num; i += stride) { for (; i + VecSize <= num; i += stride) {
platform::AlignedVector<T, VecSize> param_vec; platform::AlignedVector<T, VecSize> param_vec;
platform::AlignedVector<GradT, VecSize> grad_vec; platform::AlignedVector<GradT, VecSize> grad_vec;
platform::AlignedVector<T, VecSize> weight_decay_vec;
platform::AlignedVector<T, VecSize> mom1_vec; platform::AlignedVector<T, VecSize> mom1_vec;
platform::AlignedVector<T, VecSize> mom2_vec; platform::AlignedVector<T, VecSize> mom2_vec;
platform::AlignedVector<T, VecSize> trust_ratio_div_vec; platform::AlignedVector<T, VecSize> trust_ratio_div_vec;
...@@ -760,6 +775,24 @@ static bool CreatePreMulScaleOpIfSupported(ncclDataType_t dtype, ...@@ -760,6 +775,24 @@ static bool CreatePreMulScaleOpIfSupported(ncclDataType_t dtype,
return false; return false;
} }
template <typename T1, typename T2>
static void LaunchScaleKernel(const platform::CUDADeviceContext &dev_ctx,
const T1 *x, const T2 *scale, T1 *y, int n,
gpuStream_t stream) {
int vec_size = std::min(GetChunkedVecSize(x, 0), GetChunkedVecSize(y, 0));
auto config = platform::GetGpuLaunchConfig1D(dev_ctx, n, vec_size);
#define PD_LAMB_VEC_SCALE_KERNEL_CASE \
do { \
ScaleCUDAKernel<T1, T2, kVecSize><<<config.block_per_grid, \
config.thread_per_block, 0, stream>>>( \
x, scale, y, n); \
} while (0)
PD_VEC_LAUNCH_KERNEL(vec_size, PD_LAMB_VEC_SCALE_KERNEL_CASE);
#undef PD_LAMB_VEC_SCALE_KERNEL_CASE
}
template <typename T> template <typename T>
static void NCCLReduceScatterWithScale( static void NCCLReduceScatterWithScale(
const T *sendbuff, T *recvbuff, size_t recvcount, size_t nranks, const T *sendbuff, T *recvbuff, size_t recvcount, size_t nranks,
...@@ -775,10 +808,8 @@ static void NCCLReduceScatterWithScale( ...@@ -775,10 +808,8 @@ static void NCCLReduceScatterWithScale(
PADDLE_ENFORCE_EQ(nranks, 1, PADDLE_ENFORCE_EQ(nranks, 1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"nranks must be 1 when scale != nullptr.")); "nranks must be 1 when scale != nullptr."));
auto numel = recvcount * nranks; LaunchScaleKernel(dev_ctx, sendbuff, scale, recvbuff, recvcount * nranks,
auto config = platform::GetGpuLaunchConfig1D(dev_ctx, numel); stream);
ScaleCUDAKernel<<<config.block_per_grid, config.thread_per_block, 0,
stream>>>(sendbuff, scale, recvbuff, numel);
} }
return; return;
} }
...@@ -792,9 +823,7 @@ static void NCCLReduceScatterWithScale( ...@@ -792,9 +823,7 @@ static void NCCLReduceScatterWithScale(
if (scale && !should_destroy_op) { if (scale && !should_destroy_op) {
size_t numel = recvcount * nranks; size_t numel = recvcount * nranks;
T *new_sendbuff = buffer.Alloc<T>(numel); T *new_sendbuff = buffer.Alloc<T>(numel);
auto config = platform::GetGpuLaunchConfig1D(dev_ctx, numel); LaunchScaleKernel(dev_ctx, sendbuff, scale, new_sendbuff, numel, stream);
ScaleCUDAKernel<<<config.block_per_grid, config.thread_per_block, 0,
stream>>>(sendbuff, scale, new_sendbuff, numel);
sendbuff = new_sendbuff; sendbuff = new_sendbuff;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册