未验证 提交 b888a4c5 编写于 作者: H Hongyu Liu 提交者: GitHub

fix regularizer lod bug (#17848)

* fix regularizer lod bug; test=develop

* fix exception bug and one_hot expand; test=develop
上级 8062bd51
......@@ -21,6 +21,7 @@ import functools
from . import layers
from . import framework
from . import core
from .dygraph import not_support
__all__ = [
'ErrorClipByValue',
......@@ -335,6 +336,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
return param, new_grad
@not_support
def set_gradient_clip(clip, param_list=None, program=None):
"""
To specify parameters that require gradient clip.
......
......@@ -654,6 +654,8 @@ class Variable(object):
@property
def lod_level(self):
# TODO(minqiyang): Support lod_level in dygraph mode
if in_dygraph_mode():
raise Exception("Dygraph model DO NOT supprt lod")
return self.desc.lod_level()
@property
......
......@@ -6576,6 +6576,7 @@ def one_hot(input, depth):
inputs = {'X': input}
attrs = {'depth': depth}
else:
depth.stop_gradient = True
inputs = {'X': input, 'depth_tensor': depth}
attrs = {}
helper.append_op(
......@@ -9383,6 +9384,7 @@ def expand(x, expand_times, name=None):
new_expand_times = []
for ele in expand_times:
if isinstance(ele, Variable):
ele.stop_gradient = True
new_expand_times.append(ele)
else:
assert (isinstance(ele, int))
......
......@@ -162,8 +162,11 @@ class L2DecayRegularizer(WeightDecayRegularizer):
assert isinstance(param, framework.Parameter)
assert isinstance(block, framework.Block)
decay = block.create_var(
dtype=param.dtype, shape=param.shape, lod_level=param.lod_level)
if framework.in_dygraph_mode():
decay = block.create_var(dtype=param.dtype, shape=param.shape)
else:
decay = block.create_var(
dtype=param.dtype, shape=param.shape, lod_level=param.lod_level)
# Append Op to calculate decay
block.append_op(
......@@ -231,8 +234,11 @@ class L1DecayRegularizer(WeightDecayRegularizer):
assert isinstance(param, framework.Parameter)
assert isinstance(block, framework.Block)
decay = block.create_var(
dtype=param.dtype, shape=param.shape, lod_level=param.lod_level)
if framework.in_dygraph_mode():
decay = block.create_var(dtype=param.dtype, shape=param.shape)
else:
decay = block.create_var(
dtype=param.dtype, shape=param.shape, lod_level=param.lod_level)
# Append sign op
block.append_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册