未验证 提交 e4459a40 编写于 作者: S sneaxiy 提交者: GitHub

Add Output(Step) to DistributedFusedLamb optimizer (#41249)

* add Output(Step) to distributed fused lamb op

* add _set_step
上级 f78cc3da
...@@ -94,6 +94,7 @@ class DistributedFusedLambInitOpMaker ...@@ -94,6 +94,7 @@ class DistributedFusedLambInitOpMaker
AddOutput("GradOut", "The output gradient list.").AsDuplicable(); AddOutput("GradOut", "The output gradient list.").AsDuplicable();
AddOutput("GlobalScale", AddOutput("GlobalScale",
"The global scale. It is usually the scale factor for AMP."); "The global scale. It is usually the scale factor for AMP.");
AddOutput("Step", "The global step which excludes the NaN/Inf step.");
AddAttr<float>("beta1", "The initial value of Beta1Pow."); AddAttr<float>("beta1", "The initial value of Beta1Pow.");
AddAttr<float>("beta2", "The initial value of Beta2Pow."); AddAttr<float>("beta2", "The initial value of Beta2Pow.");
......
...@@ -698,6 +698,10 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T> ...@@ -698,6 +698,10 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T>
TensorFillConstant<float>(dev_ctx, global_scale, {1}, 1.0f); TensorFillConstant<float>(dev_ctx, global_scale, {1}, 1.0f);
} }
VLOG(10) << "Init global scale ends"; VLOG(10) << "Init global scale ends";
TensorFillConstant<int64_t>(dev_ctx, ctx.Output<framework::Tensor>("Step"),
{1}, static_cast<int64_t>(0));
dev_ctx.Wait(); dev_ctx.Wait();
VLOG(10) << "Wait for H2D copy"; VLOG(10) << "Wait for H2D copy";
} }
......
...@@ -110,6 +110,7 @@ class DistributedFusedLambOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -110,6 +110,7 @@ class DistributedFusedLambOpMaker : public framework::OpProtoAndCheckerMaker {
.AsDuplicable(); .AsDuplicable();
AddOutput("FoundInf", "Whether there is NaN/Inf"); AddOutput("FoundInf", "Whether there is NaN/Inf");
AddOutput("Step", "The global step which excludes the NaN/Inf step.");
AddAttr<float>("beta1", "The initial Beta1Pow value."); AddAttr<float>("beta1", "The initial Beta1Pow value.");
AddAttr<float>("beta2", "The initial Beta2Pow value."); AddAttr<float>("beta2", "The initial Beta2Pow value.");
......
...@@ -381,8 +381,9 @@ static __global__ void UpdateLambMomentAndTrustRatioDivCUDAKernel( ...@@ -381,8 +381,9 @@ static __global__ void UpdateLambMomentAndTrustRatioDivCUDAKernel(
const T *__restrict__ square_grad_norm_p, const T *__restrict__ square_grad_norm_p,
const T *__restrict__ global_scale, const T *__restrict__ beta1pow_p, const T *__restrict__ global_scale, const T *__restrict__ beta1pow_p,
const T *__restrict__ beta2pow_p, T *__restrict__ mom1_p, const T *__restrict__ beta2pow_p, T *__restrict__ mom1_p,
T *__restrict__ mom2_p, T *__restrict__ trust_ratio_div_p, bool *found_inf, T *__restrict__ mom2_p, T *__restrict__ trust_ratio_div_p,
T weight_decay, int weight_decay_end_numel, T beta1, T beta2, T epsilon, bool *__restrict__ found_inf, int64_t *__restrict__ step, T weight_decay,
int weight_decay_end_numel, T beta1, T beta2, T epsilon,
T max_global_grad_norm, int num, T rescale_grad) { T max_global_grad_norm, int num, T rescale_grad) {
T square_grad_norm = *square_grad_norm_p; T square_grad_norm = *square_grad_norm_p;
bool need_update_found_inf = bool need_update_found_inf =
...@@ -392,6 +393,7 @@ static __global__ void UpdateLambMomentAndTrustRatioDivCUDAKernel( ...@@ -392,6 +393,7 @@ static __global__ void UpdateLambMomentAndTrustRatioDivCUDAKernel(
return; return;
} else if (need_update_found_inf) { } else if (need_update_found_inf) {
*found_inf = false; *found_inf = false;
++(*step);
} }
T scale = rescale_grad / global_scale[0]; T scale = rescale_grad / global_scale[0];
...@@ -467,8 +469,8 @@ static void MultiTensorUpdateLambMomentAndTrustRatioDiv( ...@@ -467,8 +469,8 @@ static void MultiTensorUpdateLambMomentAndTrustRatioDiv(
const platform::CUDADeviceContext &dev_ctx, const int *offsets, int n, const platform::CUDADeviceContext &dev_ctx, const int *offsets, int n,
const T *param_p, const GradT *grad_p, const T *square_grad_norm_p, const T *param_p, const GradT *grad_p, const T *square_grad_norm_p,
const T *global_scale, const T *beta1pow_p, const T *beta2pow_p, T *mom1_p, const T *global_scale, const T *beta1pow_p, const T *beta2pow_p, T *mom1_p,
T *mom2_p, T *trust_ratio_div_p, bool *found_inf_p, T weight_decay, T *mom2_p, T *trust_ratio_div_p, bool *found_inf_p, int64_t *step,
int weight_decay_end_idx, T beta1, T beta2, T epsilon, T weight_decay, int weight_decay_end_idx, T beta1, T beta2, T epsilon,
T max_global_grad_norm, T rescale_grad) { T max_global_grad_norm, T rescale_grad) {
if (n <= 0) return; if (n <= 0) return;
int numel = offsets[n] - offsets[0]; int numel = offsets[n] - offsets[0];
...@@ -496,13 +498,22 @@ static void MultiTensorUpdateLambMomentAndTrustRatioDiv( ...@@ -496,13 +498,22 @@ static void MultiTensorUpdateLambMomentAndTrustRatioDiv(
auto stream = dev_ctx.stream(); auto stream = dev_ctx.stream();
auto config = platform::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size); auto config = platform::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size);
if (found_inf_p == nullptr) {
PADDLE_ENFORCE_EQ(
step, nullptr,
platform::errors::InvalidArgument(
"Output(Step) cannot be updated twice in one mini-batch."));
} else {
PADDLE_ENFORCE_NOT_NULL(step, platform::errors::InvalidArgument(
"Output(Step) cannot be nullptr."));
}
#define PD_LAUNCH_LAMB_MOM_TRUST_RATIO_DIV_KERNEL \ #define PD_LAUNCH_LAMB_MOM_TRUST_RATIO_DIV_KERNEL \
do { \ do { \
UpdateLambMomentAndTrustRatioDivCUDAKernel<T, GradT, kVecSize><<< \ UpdateLambMomentAndTrustRatioDivCUDAKernel<T, GradT, kVecSize><<< \
config.block_per_grid, config.thread_per_block, 0, stream>>>( \ config.block_per_grid, config.thread_per_block, 0, stream>>>( \
param_p, grad_p, square_grad_norm_p, global_scale, beta1pow_p, \ param_p, grad_p, square_grad_norm_p, global_scale, beta1pow_p, \
beta2pow_p, mom1_p, mom2_p, trust_ratio_div_p, found_inf_p, \ beta2pow_p, mom1_p, mom2_p, trust_ratio_div_p, found_inf_p, step, \
weight_decay, weight_decay_end_numel, beta1, beta2, epsilon, \ weight_decay, weight_decay_end_numel, beta1, beta2, epsilon, \
max_global_grad_norm, numel, rescale_grad); \ max_global_grad_norm, numel, rescale_grad); \
} while (0) } while (0)
...@@ -1315,6 +1326,8 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T> ...@@ -1315,6 +1326,8 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
const auto *fp16_partial_fused_offsets = const auto *fp16_partial_fused_offsets =
fp16_partial_fused_offsets_t->data<int>(); fp16_partial_fused_offsets_t->data<int>();
auto *step = ctx.Output<framework::Tensor>("Step")->data<int64_t>();
VLOG(1) << "FusedParamOffsets: " VLOG(1) << "FusedParamOffsets: "
<< FlattenToString(fused_offsets, fused_offsets_t->numel(), << FlattenToString(fused_offsets, fused_offsets_t->numel(),
fused_offsets_t->place()); fused_offsets_t->place());
...@@ -1337,8 +1350,8 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T> ...@@ -1337,8 +1350,8 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
dev_ctx, fp32_partial_fused_offsets, fp32_local_param_num, dev_ctx, fp32_partial_fused_offsets, fp32_local_param_num,
fp32_param + fp32_offset, fp32_sum_grad, fp32_square_grad_norm, fp32_param + fp32_offset, fp32_sum_grad, fp32_square_grad_norm,
global_scale, beta1pow, beta2pow, moment1, moment2, trust_ratio_div, global_scale, beta1pow, beta2pow, moment1, moment2, trust_ratio_div,
found_inf, weight_decay, fp32_weight_decay_end_idx, beta1, beta2, found_inf, step, weight_decay, fp32_weight_decay_end_idx, beta1,
epsilon, max_global_grad_norm, rescale_grad); beta2, epsilon, max_global_grad_norm, rescale_grad);
VLOG(10) << "Update FP32 Moment and TrustRatioDiv done"; VLOG(10) << "Update FP32 Moment and TrustRatioDiv done";
} }
float *master_param = nullptr; float *master_param = nullptr;
...@@ -1346,13 +1359,14 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T> ...@@ -1346,13 +1359,14 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
master_param = fp32_param + fp32_numel; master_param = fp32_param + fp32_numel;
VLOG(10) << "Update FP16 Moment and TrustRatioDiv starts"; VLOG(10) << "Update FP16 Moment and TrustRatioDiv starts";
auto tmp_found_inf = has_fp32_param ? nullptr : found_inf; auto tmp_found_inf = has_fp32_param ? nullptr : found_inf;
auto tmp_step = has_fp32_param ? nullptr : step;
MultiTensorUpdateLambMomentAndTrustRatioDiv( MultiTensorUpdateLambMomentAndTrustRatioDiv(
dev_ctx, fp16_partial_fused_offsets, fp16_local_param_num, dev_ctx, fp16_partial_fused_offsets, fp16_local_param_num,
master_param + fp16_offset, fp16_sum_grad, fp32_square_grad_norm, master_param + fp16_offset, fp16_sum_grad, fp32_square_grad_norm,
global_scale, beta1pow, beta2pow, moment1 + fp32_numel_each_device, global_scale, beta1pow, beta2pow, moment1 + fp32_numel_each_device,
moment2 + fp32_numel_each_device, moment2 + fp32_numel_each_device,
trust_ratio_div + fp32_numel_each_device, tmp_found_inf, weight_decay, trust_ratio_div + fp32_numel_each_device, tmp_found_inf, tmp_step,
fp16_weight_decay_end_idx, beta1, beta2, epsilon, weight_decay, fp16_weight_decay_end_idx, beta1, beta2, epsilon,
max_global_grad_norm, rescale_grad); max_global_grad_norm, rescale_grad);
VLOG(10) << "Update FP16 Moment and TrustRatioDiv done"; VLOG(10) << "Update FP16 Moment and TrustRatioDiv done";
} }
......
...@@ -75,9 +75,18 @@ class DistributedFusedLamb(Optimizer): ...@@ -75,9 +75,18 @@ class DistributedFusedLamb(Optimizer):
name=unique_name.generate('found_inf'), name=unique_name.generate('found_inf'),
shape=[1], shape=[1],
dtype=core.VarDesc.VarType.BOOL) dtype=core.VarDesc.VarType.BOOL)
self._step = None
self._param_to_master_param = {} self._param_to_master_param = {}
def _set_step(self, step):
self._step = step
def _get_or_create_step(self):
if self._step is None:
self._step = self._create_persistable_var('step', dtype='int64')
return self._step
def _set_scale(self, scale): def _set_scale(self, scale):
assert scale is not None assert scale is not None
if not isinstance(scale, Variable): if not isinstance(scale, Variable):
...@@ -189,6 +198,8 @@ class DistributedFusedLamb(Optimizer): ...@@ -189,6 +198,8 @@ class DistributedFusedLamb(Optimizer):
param_order = self._create_persistable_var('param_order', dtype='int32') param_order = self._create_persistable_var('param_order', dtype='int32')
param_order.is_distributed = True param_order.is_distributed = True
step = self._get_or_create_step()
rank = get_rank() rank = get_rank()
nranks = get_world_size() nranks = get_world_size()
scale = self._get_or_create_scale() scale = self._get_or_create_scale()
...@@ -234,6 +245,7 @@ class DistributedFusedLamb(Optimizer): ...@@ -234,6 +245,7 @@ class DistributedFusedLamb(Optimizer):
'FP16ShardFusedParamOffsets': [fp16_partial_fused_offsets], 'FP16ShardFusedParamOffsets': [fp16_partial_fused_offsets],
'FusedParamOffsets': [fused_offsets], 'FusedParamOffsets': [fused_offsets],
'ParamOrder': [param_order], 'ParamOrder': [param_order],
'Step': [step],
}, },
attrs={ attrs={
'alignment': self._alignment, 'alignment': self._alignment,
...@@ -290,6 +302,7 @@ class DistributedFusedLamb(Optimizer): ...@@ -290,6 +302,7 @@ class DistributedFusedLamb(Optimizer):
'ParamOut': params, 'ParamOut': params,
'GradOut': grads, 'GradOut': grads,
'FoundInf': [self._found_inf], 'FoundInf': [self._found_inf],
'Step': [step],
}, },
attrs={ attrs={
'weight_decay': self._weight_decay, 'weight_decay': self._weight_decay,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册