未验证 提交 945f777f 编写于 作者: R Roc 提交者: GitHub

Revert params in paddle.nn.SpectralNorm and paddle.nnFlatten.forward (#49311)

上级 73aa98cf
......@@ -942,7 +942,7 @@ class TestLayer(LayerTest):
lod_level=1,
append_batch_size=False,
)
spectralNorm = paddle.nn.SpectralNorm(shape, axis=1, power_iters=2)
spectralNorm = paddle.nn.SpectralNorm(shape, dim=1, power_iters=2)
ret = spectralNorm(Weight)
static_ret2 = self.get_static_graph_result(
feed={
......@@ -955,7 +955,7 @@ class TestLayer(LayerTest):
)[0]
with self.dynamic_graph():
spectralNorm = paddle.nn.SpectralNorm(shape, axis=1, power_iters=2)
spectralNorm = paddle.nn.SpectralNorm(shape, dim=1, power_iters=2)
dy_ret = spectralNorm(base.to_variable(input))
dy_rlt_value = dy_ret.numpy()
......
......@@ -154,7 +154,7 @@ class TestDygraphSpectralNormOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
shape = (2, 4, 3, 3)
spectralNorm = paddle.nn.SpectralNorm(shape, axis=1, power_iters=2)
spectralNorm = paddle.nn.SpectralNorm(shape, dim=1, power_iters=2)
def test_Variable():
weight_1 = np.random.random((2, 4)).astype("float32")
......
......@@ -1737,8 +1737,8 @@ class Flatten(Layer):
self.start_axis = start_axis
self.stop_axis = stop_axis
def forward(self, x):
def forward(self, input):
out = paddle.flatten(
x, start_axis=self.start_axis, stop_axis=self.stop_axis
input, start_axis=self.start_axis, stop_axis=self.stop_axis
)
return out
......@@ -1812,7 +1812,7 @@ class SpectralNorm(Layer):
Step 1:
Generate vector U in shape of [H], and V in shape of [W].
While H is the :attr:`axis` th dimension of the input weights,
While H is the :attr:`dim` th dimension of the input weights,
and W is the product result of remaining dimensions.
Step 2:
......@@ -1839,9 +1839,9 @@ class SpectralNorm(Layer):
Parameters:
weight_shape(list or tuple): The shape of weight parameter.
axis(int, optional): The index of dimension which should be permuted to the first before reshaping Input(Weight) to matrix, it should be set as 0 if Input(Weight) is the weight of fc layer, and should be set as 1 if Input(Weight) is the weight of conv layer. Default: 0.
dim(int, optional): The index of dimension which should be permuted to the first before reshaping Input(Weight) to matrix, it should be set as 0 if Input(Weight) is the weight of fc layer, and should be set as 1 if Input(Weight) is the weight of conv layer. Default: 0.
power_iters(int, optional): The number of power iterations to calculate spectral norm. Default: 1.
epsilon(float, optional): The epsilon for numerical stability in calculating norms. Default: 1e-12.
eps(float, optional): The epsilon for numerical stability in calculating norms. Default: 1e-12.
name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` .
dtype (str, optional): Data type, it can be "float32" or "float64". Default: "float32".
......@@ -1854,7 +1854,7 @@ class SpectralNorm(Layer):
import paddle
x = paddle.rand((2,8,32,32))
spectral_norm = paddle.nn.SpectralNorm(x.shape, axis=1, power_iters=2)
spectral_norm = paddle.nn.SpectralNorm(x.shape, dim=1, power_iters=2)
spectral_norm_out = spectral_norm(x)
print(spectral_norm_out.shape) # [2, 8, 32, 32]
......@@ -1864,25 +1864,25 @@ class SpectralNorm(Layer):
def __init__(
self,
weight_shape,
axis=0,
dim=0,
power_iters=1,
epsilon=1e-12,
eps=1e-12,
dtype='float32',
):
super().__init__()
self._power_iters = power_iters
self._epsilon = epsilon
self._dim = axis
self._epsilon = eps
self._dim = dim
self._dtype = dtype
self._weight_shape = list(weight_shape)
assert (
np.prod(self._weight_shape) > 0
), "Any dimension of `weight_shape` cannot be equal to 0."
assert axis < len(self._weight_shape), (
"The input `axis` should be less than the "
"length of `weight_shape`, but received axis="
"{}".format(axis)
assert dim < len(self._weight_shape), (
"The input `dim` should be less than the "
"length of `weight_shape`, but received dim="
"{}".format(dim)
)
h = self._weight_shape[self._dim]
w = np.prod(self._weight_shape) // h
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册