未验证 提交 065ffcce 编写于 作者: G gongweibao 提交者: GitHub

fix dgcclipnorm bug test=develop (#16629)

上级 79643663
...@@ -24,18 +24,21 @@ class DGCClipByNormKernel : public ClipByNormKernel<DeviceContext, T> { ...@@ -24,18 +24,21 @@ class DGCClipByNormKernel : public ClipByNormKernel<DeviceContext, T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto rampup_begin_step = context.Attr<float>("rampup_begin_step"); auto rampup_begin_step = context.Attr<float>("rampup_begin_step");
if (static_cast<int>(rampup_begin_step) >= 0) { if (static_cast<int>(rampup_begin_step) < 0) {
auto current_step_tensor = return;
context.Input<framework::Tensor>("current_step"); }
auto* current_step = current_step_tensor->data<T>();
auto current_step_tensor = context.Input<framework::Tensor>("current_step");
if (static_cast<int>(*current_step) < auto* current_step = current_step_tensor->data<T>();
static_cast<int>(rampup_begin_step)) {
VLOG(10) << "current_step:" << *current_step VLOG(10) << "current_step:" << *current_step
<< " < rampup_begin_step:" << rampup_begin_step << ", rampup_begin_step:" << rampup_begin_step;
<< " so does't use dgc_clip_by_norm";
return; if (static_cast<int>(*current_step) < static_cast<int>(rampup_begin_step)) {
} VLOG(10) << "current_step:" << *current_step
<< " < rampup_begin_step:" << rampup_begin_step
<< " so does't use dgc_clip_by_norm";
return;
} }
return ClipByNormKernel<DeviceContext, T>::Compute(context); return ClipByNormKernel<DeviceContext, T>::Compute(context);
......
...@@ -832,7 +832,7 @@ class DGCMomentumOptimizer(MomentumOptimizer): ...@@ -832,7 +832,7 @@ class DGCMomentumOptimizer(MomentumOptimizer):
type=x.type, name=name, dtype=x.dtype, persistable=False) type=x.type, name=name, dtype=x.dtype, persistable=False)
helper.append_op( helper.append_op(
type="clip_by_norm", type="dgc_clip_by_norm",
inputs={"X": x, inputs={"X": x,
"current_step": self._global_step_var}, "current_step": self._global_step_var},
attrs={ attrs={
...@@ -845,7 +845,7 @@ class DGCMomentumOptimizer(MomentumOptimizer): ...@@ -845,7 +845,7 @@ class DGCMomentumOptimizer(MomentumOptimizer):
def _append_clip_norm(self, grad_var, clip_norm): def _append_clip_norm(self, grad_var, clip_norm):
with grad_var.block.program._backward_role_guard(): with grad_var.block.program._backward_role_guard():
return self._clip_by_norm( return self._clip_by_norm(
x=grad_var, max_norm=clip_norm, name=grad_var.name + "@DGC") x=grad_var, max_norm=clip_norm, name=grad_var.name)
def _dgc_op(self, param_var, clip_var, grad_var, u_var, v_var, k_var, def _dgc_op(self, param_var, clip_var, grad_var, u_var, v_var, k_var,
encoded_var): encoded_var):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册