提交 768f9242 编写于 作者: W WangXi 提交者: gongweibao

Fix dgc clip & rampup step, test=develop (#21491)

上级 9144ae42
...@@ -28,7 +28,7 @@ inline float get_period_sparcity(const std::vector<float>& sparsity, ...@@ -28,7 +28,7 @@ inline float get_period_sparcity(const std::vector<float>& sparsity,
size_t idx = static_cast<int>(cur_step * sparsity.size() / rampup_steps); size_t idx = static_cast<int>(cur_step * sparsity.size() / rampup_steps);
if (idx >= sparsity.size()) { if (idx >= sparsity.size()) {
return 0.999; idx = sparsity.size() - 1;
} }
PADDLE_ENFORCE_LT(idx, sparsity.size()); PADDLE_ENFORCE_LT(idx, sparsity.size());
...@@ -102,7 +102,8 @@ class DGCOpKernel : public framework::OpKernel<T> { ...@@ -102,7 +102,8 @@ class DGCOpKernel : public framework::OpKernel<T> {
} }
float ratio = float ratio =
1 - get_period_sparcity(sparsity, static_cast<float>(*current_step), 1 - get_period_sparcity(
sparsity, static_cast<float>(*current_step - rampup_begin_step),
rampup_step); rampup_step);
PADDLE_ENFORCE_GE(ratio, 0.0); PADDLE_ENFORCE_GE(ratio, 0.0);
PADDLE_ENFORCE_LT(ratio, 1.0); PADDLE_ENFORCE_LT(ratio, 1.0);
......
...@@ -949,6 +949,7 @@ class DGCMomentumOptimizer(Optimizer): ...@@ -949,6 +949,7 @@ class DGCMomentumOptimizer(Optimizer):
self._momentum = momentum self._momentum = momentum
self._use_nesterov = bool(use_nesterov) self._use_nesterov = bool(use_nesterov)
assert rampup_begin_step >= 0, "rampup_begin_step must >= 0"
self._rampup_begin_step = rampup_begin_step self._rampup_begin_step = rampup_begin_step
self._rampup_step = rampup_step self._rampup_step = rampup_step
self._sparsity = sparsity self._sparsity = sparsity
...@@ -965,8 +966,7 @@ class DGCMomentumOptimizer(Optimizer): ...@@ -965,8 +966,7 @@ class DGCMomentumOptimizer(Optimizer):
self._local_grad_clip_norm = local_grad_clip_norm self._local_grad_clip_norm = local_grad_clip_norm
self._num_trainers = num_trainers self._num_trainers = num_trainers
self._clip_norm = local_grad_clip_norm / (num_trainers * self._clip_norm = local_grad_clip_norm * (num_trainers**-0.5)
num_trainers)
self._get_dgc_regularization_param() self._get_dgc_regularization_param()
......
...@@ -67,6 +67,8 @@ class TestDGCMomentumOptimizer(unittest.TestCase): ...@@ -67,6 +67,8 @@ class TestDGCMomentumOptimizer(unittest.TestCase):
learning_rate=learning_rate, learning_rate=learning_rate,
momentum=0.2, momentum=0.2,
rampup_begin_step=0, rampup_begin_step=0,
local_grad_clip_norm=1.0,
num_trainers=2,
regularization=regularization) regularization=regularization)
mean_out = block.create_var( mean_out = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="mean.out") dtype="float32", shape=[1], lod_level=0, name="mean.out")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册