未验证 提交 e4459a40 编写于 作者: S sneaxiy 提交者: GitHub

Add Output(Step) to DistributedFusedLamb optimizer (#41249)

* add Output(Step) to distributed fused lamb op

* add _set_step
上级 f78cc3da
......@@ -94,6 +94,7 @@ class DistributedFusedLambInitOpMaker
AddOutput("GradOut", "The output gradient list.").AsDuplicable();
AddOutput("GlobalScale",
"The global scale. It is usually the scale factor for AMP.");
AddOutput("Step", "The global step which excludes the NaN/Inf step.");
AddAttr<float>("beta1", "The initial value of Beta1Pow.");
AddAttr<float>("beta2", "The initial value of Beta2Pow.");
......
......@@ -698,6 +698,10 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T>
TensorFillConstant<float>(dev_ctx, global_scale, {1}, 1.0f);
}
VLOG(10) << "Init global scale ends";
TensorFillConstant<int64_t>(dev_ctx, ctx.Output<framework::Tensor>("Step"),
{1}, static_cast<int64_t>(0));
dev_ctx.Wait();
VLOG(10) << "Wait for H2D copy";
}
......
......@@ -110,6 +110,7 @@ class DistributedFusedLambOpMaker : public framework::OpProtoAndCheckerMaker {
.AsDuplicable();
AddOutput("FoundInf", "Whether there is NaN/Inf");
AddOutput("Step", "The global step which excludes the NaN/Inf step.");
AddAttr<float>("beta1", "The initial Beta1Pow value.");
AddAttr<float>("beta2", "The initial Beta2Pow value.");
......
......@@ -381,8 +381,9 @@ static __global__ void UpdateLambMomentAndTrustRatioDivCUDAKernel(
const T *__restrict__ square_grad_norm_p,
const T *__restrict__ global_scale, const T *__restrict__ beta1pow_p,
const T *__restrict__ beta2pow_p, T *__restrict__ mom1_p,
T *__restrict__ mom2_p, T *__restrict__ trust_ratio_div_p, bool *found_inf,
T weight_decay, int weight_decay_end_numel, T beta1, T beta2, T epsilon,
T *__restrict__ mom2_p, T *__restrict__ trust_ratio_div_p,
bool *__restrict__ found_inf, int64_t *__restrict__ step, T weight_decay,
int weight_decay_end_numel, T beta1, T beta2, T epsilon,
T max_global_grad_norm, int num, T rescale_grad) {
T square_grad_norm = *square_grad_norm_p;
bool need_update_found_inf =
......@@ -392,6 +393,7 @@ static __global__ void UpdateLambMomentAndTrustRatioDivCUDAKernel(
return;
} else if (need_update_found_inf) {
*found_inf = false;
++(*step);
}
T scale = rescale_grad / global_scale[0];
......@@ -467,8 +469,8 @@ static void MultiTensorUpdateLambMomentAndTrustRatioDiv(
const platform::CUDADeviceContext &dev_ctx, const int *offsets, int n,
const T *param_p, const GradT *grad_p, const T *square_grad_norm_p,
const T *global_scale, const T *beta1pow_p, const T *beta2pow_p, T *mom1_p,
T *mom2_p, T *trust_ratio_div_p, bool *found_inf_p, T weight_decay,
int weight_decay_end_idx, T beta1, T beta2, T epsilon,
T *mom2_p, T *trust_ratio_div_p, bool *found_inf_p, int64_t *step,
T weight_decay, int weight_decay_end_idx, T beta1, T beta2, T epsilon,
T max_global_grad_norm, T rescale_grad) {
if (n <= 0) return;
int numel = offsets[n] - offsets[0];
......@@ -496,15 +498,24 @@ static void MultiTensorUpdateLambMomentAndTrustRatioDiv(
auto stream = dev_ctx.stream();
auto config = platform::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size);
if (found_inf_p == nullptr) {
PADDLE_ENFORCE_EQ(
step, nullptr,
platform::errors::InvalidArgument(
"Output(Step) cannot be updated twice in one mini-batch."));
} else {
PADDLE_ENFORCE_NOT_NULL(step, platform::errors::InvalidArgument(
"Output(Step) cannot be nullptr."));
}
#define PD_LAUNCH_LAMB_MOM_TRUST_RATIO_DIV_KERNEL \
do { \
UpdateLambMomentAndTrustRatioDivCUDAKernel<T, GradT, kVecSize><<< \
config.block_per_grid, config.thread_per_block, 0, stream>>>( \
param_p, grad_p, square_grad_norm_p, global_scale, beta1pow_p, \
beta2pow_p, mom1_p, mom2_p, trust_ratio_div_p, found_inf_p, \
weight_decay, weight_decay_end_numel, beta1, beta2, epsilon, \
max_global_grad_norm, numel, rescale_grad); \
#define PD_LAUNCH_LAMB_MOM_TRUST_RATIO_DIV_KERNEL \
do { \
UpdateLambMomentAndTrustRatioDivCUDAKernel<T, GradT, kVecSize><<< \
config.block_per_grid, config.thread_per_block, 0, stream>>>( \
param_p, grad_p, square_grad_norm_p, global_scale, beta1pow_p, \
beta2pow_p, mom1_p, mom2_p, trust_ratio_div_p, found_inf_p, step, \
weight_decay, weight_decay_end_numel, beta1, beta2, epsilon, \
max_global_grad_norm, numel, rescale_grad); \
} while (0)
PD_VEC_LAUNCH_KERNEL(vec_size, PD_LAUNCH_LAMB_MOM_TRUST_RATIO_DIV_KERNEL);
......@@ -1315,6 +1326,8 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
const auto *fp16_partial_fused_offsets =
fp16_partial_fused_offsets_t->data<int>();
auto *step = ctx.Output<framework::Tensor>("Step")->data<int64_t>();
VLOG(1) << "FusedParamOffsets: "
<< FlattenToString(fused_offsets, fused_offsets_t->numel(),
fused_offsets_t->place());
......@@ -1337,8 +1350,8 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
dev_ctx, fp32_partial_fused_offsets, fp32_local_param_num,
fp32_param + fp32_offset, fp32_sum_grad, fp32_square_grad_norm,
global_scale, beta1pow, beta2pow, moment1, moment2, trust_ratio_div,
found_inf, weight_decay, fp32_weight_decay_end_idx, beta1, beta2,
epsilon, max_global_grad_norm, rescale_grad);
found_inf, step, weight_decay, fp32_weight_decay_end_idx, beta1,
beta2, epsilon, max_global_grad_norm, rescale_grad);
VLOG(10) << "Update FP32 Moment and TrustRatioDiv done";
}
float *master_param = nullptr;
......@@ -1346,13 +1359,14 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
master_param = fp32_param + fp32_numel;
VLOG(10) << "Update FP16 Moment and TrustRatioDiv starts";
auto tmp_found_inf = has_fp32_param ? nullptr : found_inf;
auto tmp_step = has_fp32_param ? nullptr : step;
MultiTensorUpdateLambMomentAndTrustRatioDiv(
dev_ctx, fp16_partial_fused_offsets, fp16_local_param_num,
master_param + fp16_offset, fp16_sum_grad, fp32_square_grad_norm,
global_scale, beta1pow, beta2pow, moment1 + fp32_numel_each_device,
moment2 + fp32_numel_each_device,
trust_ratio_div + fp32_numel_each_device, tmp_found_inf, weight_decay,
fp16_weight_decay_end_idx, beta1, beta2, epsilon,
trust_ratio_div + fp32_numel_each_device, tmp_found_inf, tmp_step,
weight_decay, fp16_weight_decay_end_idx, beta1, beta2, epsilon,
max_global_grad_norm, rescale_grad);
VLOG(10) << "Update FP16 Moment and TrustRatioDiv done";
}
......
......@@ -75,9 +75,18 @@ class DistributedFusedLamb(Optimizer):
name=unique_name.generate('found_inf'),
shape=[1],
dtype=core.VarDesc.VarType.BOOL)
self._step = None
self._param_to_master_param = {}
def _set_step(self, step):
self._step = step
def _get_or_create_step(self):
if self._step is None:
self._step = self._create_persistable_var('step', dtype='int64')
return self._step
def _set_scale(self, scale):
assert scale is not None
if not isinstance(scale, Variable):
......@@ -189,6 +198,8 @@ class DistributedFusedLamb(Optimizer):
param_order = self._create_persistable_var('param_order', dtype='int32')
param_order.is_distributed = True
step = self._get_or_create_step()
rank = get_rank()
nranks = get_world_size()
scale = self._get_or_create_scale()
......@@ -234,6 +245,7 @@ class DistributedFusedLamb(Optimizer):
'FP16ShardFusedParamOffsets': [fp16_partial_fused_offsets],
'FusedParamOffsets': [fused_offsets],
'ParamOrder': [param_order],
'Step': [step],
},
attrs={
'alignment': self._alignment,
......@@ -290,6 +302,7 @@ class DistributedFusedLamb(Optimizer):
'ParamOut': params,
'GradOut': grads,
'FoundInf': [self._found_inf],
'Step': [step],
},
attrs={
'weight_decay': self._weight_decay,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册