未验证 提交 adb80494 编写于 作者: Z Zeng Jinle 提交者: GitHub

Remove wrong __restrict__ of CUDA LarsMomentumOpKernel (#36460)

* remove wrong restrict

* remove master_param_out __restrict__

* update
上级 e703a2ed
......@@ -84,22 +84,18 @@ class LarsThreadConfig {
template <typename T, typename MT, int VecSize, bool IsAmp = false>
__device__ inline void VectorizeLarsUpdate(
const T* __restrict__ grad, const MT* __restrict__ param,
const MT* __restrict__ velocity, T* __restrict__ param_out,
MT* __restrict__ velocity_out, const MT mu, MT local_lr,
const T* __restrict__ grad, const MT* param, const MT* velocity,
T* param_out, MT* velocity_out, const MT mu, MT local_lr,
const MT lars_weight_decay, const MT rescale_grad, const int tid,
const int grid_stride, const int numel,
MT* __restrict__ master_param_out = nullptr) {
const int grid_stride, const int numel, MT* master_param_out = nullptr) {
using VecType = paddle::platform::AlignedVector<T, VecSize>;
using VecMType = paddle::platform::AlignedVector<MT, VecSize>;
int main = numel >> (VecSize >> 1);
int tail_offset = main * VecSize;
const VecType* __restrict__ grad_vec = reinterpret_cast<const VecType*>(grad);
const VecMType* __restrict__ param_vec =
reinterpret_cast<const VecMType*>(param);
const VecMType* __restrict__ velocity_vec =
reinterpret_cast<const VecMType*>(velocity);
const VecType* grad_vec = reinterpret_cast<const VecType*>(grad);
const VecMType* param_vec = reinterpret_cast<const VecMType*>(param);
const VecMType* velocity_vec = reinterpret_cast<const VecMType*>(velocity);
VecType* param_out_vec = reinterpret_cast<VecType*>(param_out);
VecMType* velocity_out_vec = reinterpret_cast<VecMType*>(velocity_out);
......@@ -157,66 +153,30 @@ __forceinline__ __device__ void L2NormKernel(
template <typename T, typename MT>
__global__ void L2NormKernel(
#endif
const T* __restrict__ p_data, const T* __restrict__ g_data,
MT* __restrict__ p_buffer, MT* __restrict__ g_buffer, const int64_t numel,
const int repeat_times, const MT rescale_grad, const int thresh = 0,
MT* __restrict__ p_n = nullptr, MT* __restrict__ g_n = nullptr) {
const T* p_data, const T* __restrict__ g_data, MT* __restrict__ p_buffer,
MT* __restrict__ g_buffer, const int64_t numel, const int repeat_times,
const MT rescale_grad, const int thresh = 0, MT* __restrict__ p_n = nullptr,
MT* __restrict__ g_n = nullptr) {
__shared__ MT s_buffer[2];
int tid = threadIdx.x + blockDim.x * blockIdx.x;
int grid_stride = LARS_BLOCK_SIZE * gridDim.x;
const MT rescale_pow = rescale_grad * rescale_grad;
if (threadIdx.x == 0) {
s_buffer[0] = static_cast<MT>(0);
s_buffer[1] = static_cast<MT>(0);
}
MT p_tmp = static_cast<MT>(0);
MT g_tmp = static_cast<MT>(0);
if (repeat_times == 0) {
if (tid < numel) {
p_tmp = static_cast<MT>(p_data[tid]);
g_tmp = static_cast<MT>(g_data[tid]);
}
MT tmp0 = math::blockReduceSum<MT>(p_tmp * p_tmp, FINAL_MASK);
MT tmp1 = math::blockReduceSum<MT>(g_tmp * g_tmp, FINAL_MASK);
if (threadIdx.x == 0) {
s_buffer[0] += tmp0;
s_buffer[1] += tmp1;
}
} else {
/* Avoid occupy too much temp buffer. Slice the whole data into 2 parts,
the front of data whose quantity is excatly multiple of grid-thread
number, and delt in for loop, the rest is delt with another step. */
for (int i = 0; i < repeat_times; ++i) {
p_tmp = static_cast<MT>(p_data[tid]);
g_tmp = static_cast<MT>(g_data[tid]);
tid += grid_stride;
MT tmp0 = math::blockReduceSum<MT>(p_tmp * p_tmp, FINAL_MASK);
MT tmp1 = math::blockReduceSum<MT>(g_tmp * g_tmp, FINAL_MASK);
if (threadIdx.x == 0) {
s_buffer[0] += tmp0;
s_buffer[1] += tmp1;
}
__syncthreads();
}
MT p_val = 0;
MT g_val = 0;
if (tid < numel) {
p_val = static_cast<MT>(p_data[tid]);
g_val = static_cast<MT>(g_data[tid]);
}
MT tmp0 = math::blockReduceSum<MT>(p_val * p_val, FINAL_MASK);
MT tmp1 = math::blockReduceSum<MT>(g_val * g_val, FINAL_MASK);
if (threadIdx.x == 0) {
s_buffer[0] += tmp0;
s_buffer[1] += tmp1;
}
while (tid < numel) {
MT tmp0 = static_cast<MT>(p_data[tid]);
MT tmp1 = static_cast<MT>(g_data[tid]);
p_tmp += (tmp0 * tmp0);
g_tmp += (tmp1 * tmp1);
tid += grid_stride;
}
__syncthreads();
p_tmp = math::blockReduceSum<MT>(p_tmp, FINAL_MASK);
g_tmp = math::blockReduceSum<MT>(g_tmp, FINAL_MASK);
if (threadIdx.x == 0) {
p_buffer[blockIdx.x] = s_buffer[0];
g_buffer[blockIdx.x] = s_buffer[1];
p_buffer[blockIdx.x] = p_tmp;
g_buffer[blockIdx.x] = g_tmp;
}
#if CUDA_VERSION >= 11000
cg->sync(); // Grid sync for writring partial result to gloabl memory
......@@ -236,10 +196,9 @@ __global__ void L2NormKernel(
template <typename T, typename MT>
__forceinline__ __device__ void MomentumUpdate(
const T* __restrict__ param, const T* __restrict__ grad,
const MT* __restrict__ velocity, T* param_out, MT* velocity_out,
const MT* __restrict__ master_param, MT* __restrict__ master_param_out,
const MT* __restrict__ learning_rate, const MT mu,
const T* param, const T* __restrict__ grad, const MT* velocity,
T* param_out, MT* velocity_out, const MT* master_param,
MT* master_param_out, const MT* __restrict__ learning_rate, const MT mu,
const MT lars_weight_decay, const MT lars_coeff, const MT epsilon,
const MT rescale_grad, const MT param_norm, const MT grad_norm,
const int tid, const int grid_stride, const int64_t numel,
......@@ -316,14 +275,13 @@ __global__ void MergedMomentumLarsKernel(LarsParamWarpper<T, MT> lars_warpper,
template <typename T, typename MT>
__global__ void MomentumLarsKernel(
const T* __restrict__ param, const T* __restrict__ grad,
const MT* __restrict__ velocity, T* param_out, MT* velocity_out,
const MT* __restrict__ master_param, MT* __restrict__ master_param_out,
const MT* __restrict__ learning_rate, MT* __restrict__ p_buffer,
MT* __restrict__ g_buffer, const MT mu, const MT lars_coeff,
const MT lars_weight_decay, const MT epsilon, const MT rescale_grad,
const int repeat_times, const int thresh, const int64_t numel,
const bool is_amp) {
const T* param, const T* __restrict__ grad, const MT* velocity,
T* param_out, MT* velocity_out, const MT* master_param,
MT* master_param_out, const MT* __restrict__ learning_rate,
MT* __restrict__ p_buffer, MT* __restrict__ g_buffer, const MT mu,
const MT lars_coeff, const MT lars_weight_decay, const MT epsilon,
const MT rescale_grad, const int repeat_times, const int thresh,
const int64_t numel, const bool is_amp) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
int grid_stride = gridDim.x * LARS_BLOCK_SIZE;
#if CUDA_VERSION >= 11000
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册