未验证 提交 0342f012 编写于 作者: G gongweibao 提交者: GitHub

Fix dgc bug. (#16602)

上级 2ca0de3c
...@@ -24,19 +24,22 @@ class DGCClipByNormKernel : public ClipByNormKernel<DeviceContext, T> { ...@@ -24,19 +24,22 @@ class DGCClipByNormKernel : public ClipByNormKernel<DeviceContext, T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto rampup_begin_step = context.Attr<float>("rampup_begin_step"); auto rampup_begin_step = context.Attr<float>("rampup_begin_step");
if (static_cast<int>(rampup_begin_step) >= 0) { if (static_cast<int>(rampup_begin_step) < 0) {
auto current_step_tensor = return;
context.Input<framework::Tensor>("current_step"); }
auto current_step_tensor = context.Input<framework::Tensor>("current_step");
auto* current_step = current_step_tensor->data<T>(); auto* current_step = current_step_tensor->data<T>();
if (static_cast<int>(*current_step) < VLOG(10) << "current_step:" << *current_step
static_cast<int>(rampup_begin_step)) { << ", rampup_begin_step:" << rampup_begin_step;
if (static_cast<int>(*current_step) < static_cast<int>(rampup_begin_step)) {
VLOG(10) << "current_step:" << *current_step VLOG(10) << "current_step:" << *current_step
<< " < rampup_begin_step:" << rampup_begin_step << " < rampup_begin_step:" << rampup_begin_step
<< " so does't use dgc_clip_by_norm"; << " so does't use dgc_clip_by_norm";
return; return;
} }
}
return ClipByNormKernel<DeviceContext, T>::Compute(context); return ClipByNormKernel<DeviceContext, T>::Compute(context);
}; };
......
...@@ -832,7 +832,7 @@ class DGCMomentumOptimizer(MomentumOptimizer): ...@@ -832,7 +832,7 @@ class DGCMomentumOptimizer(MomentumOptimizer):
type=x.type, name=name, dtype=x.dtype, persistable=False) type=x.type, name=name, dtype=x.dtype, persistable=False)
helper.append_op( helper.append_op(
type="clip_by_norm", type="dgc_clip_by_norm",
inputs={"X": x, inputs={"X": x,
"current_step": self._global_step_var}, "current_step": self._global_step_var},
attrs={ attrs={
...@@ -845,7 +845,7 @@ class DGCMomentumOptimizer(MomentumOptimizer): ...@@ -845,7 +845,7 @@ class DGCMomentumOptimizer(MomentumOptimizer):
def _append_clip_norm(self, grad_var, clip_norm): def _append_clip_norm(self, grad_var, clip_norm):
with grad_var.block.program._backward_role_guard(): with grad_var.block.program._backward_role_guard():
return self._clip_by_norm( return self._clip_by_norm(
x=grad_var, max_norm=clip_norm, name=grad_var.name + "@DGC") x=grad_var, max_norm=clip_norm, name=grad_var.name)
def _dgc_op(self, param_var, clip_var, grad_var, u_var, v_var, k_var, def _dgc_op(self, param_var, clip_var, grad_var, u_var, v_var, k_var,
encoded_var): encoded_var):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册