未验证 提交 b888a4c5 编写于 作者: H Hongyu Liu 提交者: GitHub

fix regularizer lod bug (#17848)

* fix regularizer lod bug; test=develop

* fix exception bug and one_hot expand; test=develop
上级 8062bd51
...@@ -21,6 +21,7 @@ import functools ...@@ -21,6 +21,7 @@ import functools
from . import layers from . import layers
from . import framework from . import framework
from . import core from . import core
from .dygraph import not_support
__all__ = [ __all__ = [
'ErrorClipByValue', 'ErrorClipByValue',
...@@ -335,6 +336,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr): ...@@ -335,6 +336,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
return param, new_grad return param, new_grad
@not_support
def set_gradient_clip(clip, param_list=None, program=None): def set_gradient_clip(clip, param_list=None, program=None):
""" """
To specify parameters that require gradient clip. To specify parameters that require gradient clip.
......
...@@ -654,6 +654,8 @@ class Variable(object): ...@@ -654,6 +654,8 @@ class Variable(object):
@property @property
def lod_level(self): def lod_level(self):
# TODO(minqiyang): Support lod_level in dygraph mode # TODO(minqiyang): Support lod_level in dygraph mode
if in_dygraph_mode():
raise Exception("Dygraph model DO NOT supprt lod")
return self.desc.lod_level() return self.desc.lod_level()
@property @property
......
...@@ -6576,6 +6576,7 @@ def one_hot(input, depth): ...@@ -6576,6 +6576,7 @@ def one_hot(input, depth):
inputs = {'X': input} inputs = {'X': input}
attrs = {'depth': depth} attrs = {'depth': depth}
else: else:
depth.stop_gradient = True
inputs = {'X': input, 'depth_tensor': depth} inputs = {'X': input, 'depth_tensor': depth}
attrs = {} attrs = {}
helper.append_op( helper.append_op(
...@@ -9383,6 +9384,7 @@ def expand(x, expand_times, name=None): ...@@ -9383,6 +9384,7 @@ def expand(x, expand_times, name=None):
new_expand_times = [] new_expand_times = []
for ele in expand_times: for ele in expand_times:
if isinstance(ele, Variable): if isinstance(ele, Variable):
ele.stop_gradient = True
new_expand_times.append(ele) new_expand_times.append(ele)
else: else:
assert (isinstance(ele, int)) assert (isinstance(ele, int))
......
...@@ -162,6 +162,9 @@ class L2DecayRegularizer(WeightDecayRegularizer): ...@@ -162,6 +162,9 @@ class L2DecayRegularizer(WeightDecayRegularizer):
assert isinstance(param, framework.Parameter) assert isinstance(param, framework.Parameter)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
if framework.in_dygraph_mode():
decay = block.create_var(dtype=param.dtype, shape=param.shape)
else:
decay = block.create_var( decay = block.create_var(
dtype=param.dtype, shape=param.shape, lod_level=param.lod_level) dtype=param.dtype, shape=param.shape, lod_level=param.lod_level)
...@@ -231,6 +234,9 @@ class L1DecayRegularizer(WeightDecayRegularizer): ...@@ -231,6 +234,9 @@ class L1DecayRegularizer(WeightDecayRegularizer):
assert isinstance(param, framework.Parameter) assert isinstance(param, framework.Parameter)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
if framework.in_dygraph_mode():
decay = block.create_var(dtype=param.dtype, shape=param.shape)
else:
decay = block.create_var( decay = block.create_var(
dtype=param.dtype, shape=param.shape, lod_level=param.lod_level) dtype=param.dtype, shape=param.shape, lod_level=param.lod_level)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册