未验证 提交 63fd7d66 编写于 作者: Z Zeng Jinle 提交者: GitHub

refine merge lars (#36428)

上级 3e6d9dbb
......@@ -28,7 +28,7 @@ limitations under the License. */
#define LARS_BLOCK_SIZE 512
#endif
#define LARS_MAX_MERGED_OPS 150
#define LARS_MAX_MERGED_OPS 60
namespace paddle {
namespace operators {
......@@ -256,11 +256,8 @@ template <typename T, typename MT>
struct LarsParamWarpper {
int64_t numel_arr[LARS_MAX_MERGED_OPS];
int repeat_arr[LARS_MAX_MERGED_OPS];
const T* __restrict__ p_arr[LARS_MAX_MERGED_OPS];
const T* __restrict__ g_arr[LARS_MAX_MERGED_OPS];
const MT* __restrict__ v_arr[LARS_MAX_MERGED_OPS];
const MT* __restrict__ lr_arr[LARS_MAX_MERGED_OPS];
const MT* __restrict__ master_p_arr[LARS_MAX_MERGED_OPS];
T* __restrict__ p_out_arr[LARS_MAX_MERGED_OPS];
MT* __restrict__ v_out_arr[LARS_MAX_MERGED_OPS];
MT* __restrict__ master_p_out_arr[LARS_MAX_MERGED_OPS];
......@@ -268,7 +265,7 @@ struct LarsParamWarpper {
};
template <typename T, typename MT>
__global__ void MergedMomentumLarsKernel(LarsParamWarpper<T, MT>* lars_warpper,
__global__ void MergedMomentumLarsKernel(LarsParamWarpper<T, MT> lars_warpper,
MT* __restrict__ p_buffer,
MT* __restrict__ g_buffer,
const int op_num, const MT mu,
......@@ -279,18 +276,18 @@ __global__ void MergedMomentumLarsKernel(LarsParamWarpper<T, MT>* lars_warpper,
int tid = threadIdx.x + blockIdx.x * blockDim.x;
const cooperative_groups::grid_group cg = cooperative_groups::this_grid();
for (int i = 0; i < op_num; ++i) {
int numel = lars_warpper->numel_arr[i];
int numel = lars_warpper.numel_arr[i];
MT param_norm = static_cast<MT>(0);
MT grad_norm = static_cast<MT>(0);
L2NormKernel<T, MT>(&cg, lars_warpper->p_arr[i], lars_warpper->g_arr[i],
p_buffer, g_buffer, numel, lars_warpper->repeat_arr[i],
L2NormKernel<T, MT>(&cg, lars_warpper.p_out_arr[i], lars_warpper.g_arr[i],
p_buffer, g_buffer, numel, lars_warpper.repeat_arr[i],
rescale_grad, 0, &param_norm, &grad_norm);
MomentumUpdate<T, MT>(
lars_warpper->p_arr[i], lars_warpper->g_arr[i],
lars_warpper->v_out_arr[i], lars_warpper->p_out_arr[i],
lars_warpper->v_out_arr[i], lars_warpper->master_p_arr[i],
lars_warpper->master_p_out_arr[i], lars_warpper->lr_arr[i], mu,
lars_warpper->weight_decay_arr[i], lars_coeff, epsilon, rescale_grad,
lars_warpper.p_out_arr[i], lars_warpper.g_arr[i],
lars_warpper.v_out_arr[i], lars_warpper.p_out_arr[i],
lars_warpper.v_out_arr[i], lars_warpper.master_p_out_arr[i],
lars_warpper.master_p_out_arr[i], lars_warpper.lr_arr[i], mu,
lars_warpper.weight_decay_arr[i], lars_coeff, epsilon, rescale_grad,
param_norm, grad_norm, tid, grid_stride, numel, is_amp);
}
}
......@@ -410,15 +407,21 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
size_t temp_numel = param[i]->numel();
total_numel += temp_numel;
lars_warpper.numel_arr[i] = temp_numel;
lars_warpper.p_arr[i] = param[i]->data<T>();
lars_warpper.g_arr[i] = grad[i]->data<T>();
lars_warpper.v_arr[i] = velocity[i]->data<MT>();
lars_warpper.lr_arr[i] = learning_rate[i]->data<MT>();
lars_warpper.p_out_arr[i] =
param_out[i]->mutable_data<T>(ctx.GetPlace());
lars_warpper.v_out_arr[i] =
velocity_out[i]->mutable_data<MT>(ctx.GetPlace());
lars_warpper.weight_decay_arr[i] = static_cast<MT>(weight_decay_arr[i]);
PADDLE_ENFORCE_EQ(
param[i]->data<T>(), lars_warpper.p_out_arr[i],
platform::errors::InvalidArgument(
"Input(Param) and Output(ParamOut) must be the same Tensors."));
PADDLE_ENFORCE_EQ(velocity[i]->data<MT>(), lars_warpper.v_out_arr[i],
platform::errors::InvalidArgument(
"Input(Velocity) and Output(VelocityOut) must be "
"the same Tensors."));
}
int64_t avg_numel = total_numel / op_num;
LarsThreadConfig<float> lars_thread_config(avg_numel, sm_num,
......@@ -429,19 +432,16 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
}
if (multi_precision) {
for (int i = 0; i < op_num; ++i) {
lars_warpper.master_p_arr[i] = master_param[i]->data<MT>();
lars_warpper.master_p_out_arr[i] =
master_param_out[i]->mutable_data<MT>(ctx.GetPlace());
PADDLE_ENFORCE_EQ(master_param[i]->data<MT>(),
lars_warpper.master_p_out_arr[i],
platform::errors::InvalidArgument(
"Input(MasterParam) and Output(MasterParamOut) "
"must be the same Tensors."));
}
}
auto merged_buf = memory::Alloc(cuda_ctx, sizeof(lars_warpper));
auto* merged_ptr =
reinterpret_cast<LarsParamWarpper<T, MT>*>(merged_buf->ptr());
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, cuda_ctx.GetPlace()),
reinterpret_cast<void*>(merged_ptr), platform::CPUPlace(),
reinterpret_cast<void*>(&lars_warpper), sizeof(lars_warpper),
cuda_ctx.stream());
void* cuda_param[] = {reinterpret_cast<void*>(&merged_ptr),
void* cuda_param[] = {reinterpret_cast<void*>(&lars_warpper),
reinterpret_cast<void*>(&p_buffer),
reinterpret_cast<void*>(&g_buffer),
reinterpret_cast<void*>(&op_num),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册