未验证 提交 4283e19e 编写于 作者: W wangna11BD 提交者: GitHub

support to_static for SpectralNorm (#51622)

* support to_static for SpectralNorm
上级 e5616448
...@@ -23,7 +23,7 @@ __all__ = [] ...@@ -23,7 +23,7 @@ __all__ = []
def normal_(x, mean=0.0, std=1.0): def normal_(x, mean=0.0, std=1.0):
temp_value = paddle.normal(mean, std, shape=x.shape) temp_value = paddle.normal(mean, std, shape=x.shape)
x.set_value(temp_value) paddle.assign(temp_value, x)
return x return x
...@@ -61,7 +61,7 @@ class SpectralNorm: ...@@ -61,7 +61,7 @@ class SpectralNorm:
if do_power_iteration: if do_power_iteration:
with paddle.no_grad(): with paddle.no_grad():
for _ in range(self.n_power_iterations): for _ in range(self.n_power_iterations):
v.set_value( paddle.assign(
F.normalize( F.normalize(
paddle.matmul( paddle.matmul(
weight_mat, weight_mat,
...@@ -71,15 +71,17 @@ class SpectralNorm: ...@@ -71,15 +71,17 @@ class SpectralNorm:
), ),
axis=0, axis=0,
epsilon=self.eps, epsilon=self.eps,
) ),
v,
) )
u.set_value( paddle.assign(
F.normalize( F.normalize(
paddle.matmul(weight_mat, v), paddle.matmul(weight_mat, v),
axis=0, axis=0,
epsilon=self.eps, epsilon=self.eps,
) ),
u,
) )
if self.n_power_iterations > 0: if self.n_power_iterations > 0:
u = u.clone() u = u.clone()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册