未验证 提交 4283e19e 编写于 作者: W wangna11BD 提交者: GitHub

support to_static for SpectralNorm (#51622)

* support to_static for SpectralNorm
上级 e5616448
......@@ -23,7 +23,7 @@ __all__ = []
def normal_(x, mean=0.0, std=1.0):
temp_value = paddle.normal(mean, std, shape=x.shape)
x.set_value(temp_value)
paddle.assign(temp_value, x)
return x
......@@ -61,7 +61,7 @@ class SpectralNorm:
if do_power_iteration:
with paddle.no_grad():
for _ in range(self.n_power_iterations):
v.set_value(
paddle.assign(
F.normalize(
paddle.matmul(
weight_mat,
......@@ -71,15 +71,17 @@ class SpectralNorm:
),
axis=0,
epsilon=self.eps,
)
),
v,
)
u.set_value(
paddle.assign(
F.normalize(
paddle.matmul(weight_mat, v),
axis=0,
epsilon=self.eps,
)
),
u,
)
if self.n_power_iterations > 0:
u = u.clone()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册