未验证 提交 555c3463 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat] fix unittest failed (#33438)

上级 23290929
...@@ -256,6 +256,7 @@ class PartialProgramLayer(layers.Layer): ...@@ -256,6 +256,7 @@ class PartialProgramLayer(layers.Layer):
place=framework._current_expected_place(), place=framework._current_expected_place(),
zero_copy=True) zero_copy=True)
elif isinstance(value, core.VarBase): elif isinstance(value, core.VarBase):
value.name = self._inputs[i].desc.name()
if value.stop_gradient: if value.stop_gradient:
# NOTE(Aurelius84): If var is on CPUPlace, it will be transformed multi times # NOTE(Aurelius84): If var is on CPUPlace, it will be transformed multi times
# into CUDAPlace when it's as input of multi Ops. so we move it in advance # into CUDAPlace when it's as input of multi Ops. so we move it in advance
...@@ -265,9 +266,9 @@ class PartialProgramLayer(layers.Layer): ...@@ -265,9 +266,9 @@ class PartialProgramLayer(layers.Layer):
dtype=value.dtype, dtype=value.dtype,
place=framework._current_expected_place(), place=framework._current_expected_place(),
stop_gradient=True) stop_gradient=True)
var.name = value.name
else: else:
var = value var = value
var.name = self._inputs[i].desc.name()
else: else:
continue continue
input_vars.append(var) input_vars.append(var)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册