diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cu b/paddle/fluid/operators/optimizers/lars_momentum_op.cu index caefd496978af2b6c34b61572fbb33725b33fab4..e90f1136fd30daa86b6a9afcd75c877a4de97a1a 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cu +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cu @@ -28,7 +28,7 @@ limitations under the License. */ #define LARS_BLOCK_SIZE 512 #endif -#define LARS_MAX_MERGED_OPS 150 +#define LARS_MAX_MERGED_OPS 60 namespace paddle { namespace operators { @@ -256,11 +256,8 @@ template struct LarsParamWarpper { int64_t numel_arr[LARS_MAX_MERGED_OPS]; int repeat_arr[LARS_MAX_MERGED_OPS]; - const T* __restrict__ p_arr[LARS_MAX_MERGED_OPS]; const T* __restrict__ g_arr[LARS_MAX_MERGED_OPS]; - const MT* __restrict__ v_arr[LARS_MAX_MERGED_OPS]; const MT* __restrict__ lr_arr[LARS_MAX_MERGED_OPS]; - const MT* __restrict__ master_p_arr[LARS_MAX_MERGED_OPS]; T* __restrict__ p_out_arr[LARS_MAX_MERGED_OPS]; MT* __restrict__ v_out_arr[LARS_MAX_MERGED_OPS]; MT* __restrict__ master_p_out_arr[LARS_MAX_MERGED_OPS]; @@ -268,7 +265,7 @@ struct LarsParamWarpper { }; template -__global__ void MergedMomentumLarsKernel(LarsParamWarpper* lars_warpper, +__global__ void MergedMomentumLarsKernel(LarsParamWarpper lars_warpper, MT* __restrict__ p_buffer, MT* __restrict__ g_buffer, const int op_num, const MT mu, @@ -279,18 +276,18 @@ __global__ void MergedMomentumLarsKernel(LarsParamWarpper* lars_warpper, int tid = threadIdx.x + blockIdx.x * blockDim.x; const cooperative_groups::grid_group cg = cooperative_groups::this_grid(); for (int i = 0; i < op_num; ++i) { - int numel = lars_warpper->numel_arr[i]; + int numel = lars_warpper.numel_arr[i]; MT param_norm = static_cast(0); MT grad_norm = static_cast(0); - L2NormKernel(&cg, lars_warpper->p_arr[i], lars_warpper->g_arr[i], - p_buffer, g_buffer, numel, lars_warpper->repeat_arr[i], + L2NormKernel(&cg, lars_warpper.p_out_arr[i], lars_warpper.g_arr[i], + p_buffer, g_buffer, numel, lars_warpper.repeat_arr[i], rescale_grad, 0, ¶m_norm, &grad_norm); MomentumUpdate( - lars_warpper->p_arr[i], lars_warpper->g_arr[i], - lars_warpper->v_out_arr[i], lars_warpper->p_out_arr[i], - lars_warpper->v_out_arr[i], lars_warpper->master_p_arr[i], - lars_warpper->master_p_out_arr[i], lars_warpper->lr_arr[i], mu, - lars_warpper->weight_decay_arr[i], lars_coeff, epsilon, rescale_grad, + lars_warpper.p_out_arr[i], lars_warpper.g_arr[i], + lars_warpper.v_out_arr[i], lars_warpper.p_out_arr[i], + lars_warpper.v_out_arr[i], lars_warpper.master_p_out_arr[i], + lars_warpper.master_p_out_arr[i], lars_warpper.lr_arr[i], mu, + lars_warpper.weight_decay_arr[i], lars_coeff, epsilon, rescale_grad, param_norm, grad_norm, tid, grid_stride, numel, is_amp); } } @@ -410,15 +407,21 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { size_t temp_numel = param[i]->numel(); total_numel += temp_numel; lars_warpper.numel_arr[i] = temp_numel; - lars_warpper.p_arr[i] = param[i]->data(); lars_warpper.g_arr[i] = grad[i]->data(); - lars_warpper.v_arr[i] = velocity[i]->data(); lars_warpper.lr_arr[i] = learning_rate[i]->data(); lars_warpper.p_out_arr[i] = param_out[i]->mutable_data(ctx.GetPlace()); lars_warpper.v_out_arr[i] = velocity_out[i]->mutable_data(ctx.GetPlace()); lars_warpper.weight_decay_arr[i] = static_cast(weight_decay_arr[i]); + PADDLE_ENFORCE_EQ( + param[i]->data(), lars_warpper.p_out_arr[i], + platform::errors::InvalidArgument( + "Input(Param) and Output(ParamOut) must be the same Tensors.")); + PADDLE_ENFORCE_EQ(velocity[i]->data(), lars_warpper.v_out_arr[i], + platform::errors::InvalidArgument( + "Input(Velocity) and Output(VelocityOut) must be " + "the same Tensors.")); } int64_t avg_numel = total_numel / op_num; LarsThreadConfig lars_thread_config(avg_numel, sm_num, @@ -429,19 +432,16 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { } if (multi_precision) { for (int i = 0; i < op_num; ++i) { - lars_warpper.master_p_arr[i] = master_param[i]->data(); lars_warpper.master_p_out_arr[i] = master_param_out[i]->mutable_data(ctx.GetPlace()); + PADDLE_ENFORCE_EQ(master_param[i]->data(), + lars_warpper.master_p_out_arr[i], + platform::errors::InvalidArgument( + "Input(MasterParam) and Output(MasterParamOut) " + "must be the same Tensors.")); } } - auto merged_buf = memory::Alloc(cuda_ctx, sizeof(lars_warpper)); - auto* merged_ptr = - reinterpret_cast*>(merged_buf->ptr()); - memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, cuda_ctx.GetPlace()), - reinterpret_cast(merged_ptr), platform::CPUPlace(), - reinterpret_cast(&lars_warpper), sizeof(lars_warpper), - cuda_ctx.stream()); - void* cuda_param[] = {reinterpret_cast(&merged_ptr), + void* cuda_param[] = {reinterpret_cast(&lars_warpper), reinterpret_cast(&p_buffer), reinterpret_cast(&g_buffer), reinterpret_cast(&op_num),