未验证 提交 945f777f 编写于 作者: R Roc 提交者: GitHub

Revert params in paddle.nn.SpectralNorm and paddle.nnFlatten.forward (#49311)

上级 73aa98cf
...@@ -942,7 +942,7 @@ class TestLayer(LayerTest): ...@@ -942,7 +942,7 @@ class TestLayer(LayerTest):
lod_level=1, lod_level=1,
append_batch_size=False, append_batch_size=False,
) )
spectralNorm = paddle.nn.SpectralNorm(shape, axis=1, power_iters=2) spectralNorm = paddle.nn.SpectralNorm(shape, dim=1, power_iters=2)
ret = spectralNorm(Weight) ret = spectralNorm(Weight)
static_ret2 = self.get_static_graph_result( static_ret2 = self.get_static_graph_result(
feed={ feed={
...@@ -955,7 +955,7 @@ class TestLayer(LayerTest): ...@@ -955,7 +955,7 @@ class TestLayer(LayerTest):
)[0] )[0]
with self.dynamic_graph(): with self.dynamic_graph():
spectralNorm = paddle.nn.SpectralNorm(shape, axis=1, power_iters=2) spectralNorm = paddle.nn.SpectralNorm(shape, dim=1, power_iters=2)
dy_ret = spectralNorm(base.to_variable(input)) dy_ret = spectralNorm(base.to_variable(input))
dy_rlt_value = dy_ret.numpy() dy_rlt_value = dy_ret.numpy()
......
...@@ -154,7 +154,7 @@ class TestDygraphSpectralNormOpError(unittest.TestCase): ...@@ -154,7 +154,7 @@ class TestDygraphSpectralNormOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
shape = (2, 4, 3, 3) shape = (2, 4, 3, 3)
spectralNorm = paddle.nn.SpectralNorm(shape, axis=1, power_iters=2) spectralNorm = paddle.nn.SpectralNorm(shape, dim=1, power_iters=2)
def test_Variable(): def test_Variable():
weight_1 = np.random.random((2, 4)).astype("float32") weight_1 = np.random.random((2, 4)).astype("float32")
......
...@@ -1737,8 +1737,8 @@ class Flatten(Layer): ...@@ -1737,8 +1737,8 @@ class Flatten(Layer):
self.start_axis = start_axis self.start_axis = start_axis
self.stop_axis = stop_axis self.stop_axis = stop_axis
def forward(self, x): def forward(self, input):
out = paddle.flatten( out = paddle.flatten(
x, start_axis=self.start_axis, stop_axis=self.stop_axis input, start_axis=self.start_axis, stop_axis=self.stop_axis
) )
return out return out
...@@ -1812,7 +1812,7 @@ class SpectralNorm(Layer): ...@@ -1812,7 +1812,7 @@ class SpectralNorm(Layer):
Step 1: Step 1:
Generate vector U in shape of [H], and V in shape of [W]. Generate vector U in shape of [H], and V in shape of [W].
While H is the :attr:`axis` th dimension of the input weights, While H is the :attr:`dim` th dimension of the input weights,
and W is the product result of remaining dimensions. and W is the product result of remaining dimensions.
Step 2: Step 2:
...@@ -1839,9 +1839,9 @@ class SpectralNorm(Layer): ...@@ -1839,9 +1839,9 @@ class SpectralNorm(Layer):
Parameters: Parameters:
weight_shape(list or tuple): The shape of weight parameter. weight_shape(list or tuple): The shape of weight parameter.
axis(int, optional): The index of dimension which should be permuted to the first before reshaping Input(Weight) to matrix, it should be set as 0 if Input(Weight) is the weight of fc layer, and should be set as 1 if Input(Weight) is the weight of conv layer. Default: 0. dim(int, optional): The index of dimension which should be permuted to the first before reshaping Input(Weight) to matrix, it should be set as 0 if Input(Weight) is the weight of fc layer, and should be set as 1 if Input(Weight) is the weight of conv layer. Default: 0.
power_iters(int, optional): The number of power iterations to calculate spectral norm. Default: 1. power_iters(int, optional): The number of power iterations to calculate spectral norm. Default: 1.
epsilon(float, optional): The epsilon for numerical stability in calculating norms. Default: 1e-12. eps(float, optional): The epsilon for numerical stability in calculating norms. Default: 1e-12.
name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` . name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` .
dtype (str, optional): Data type, it can be "float32" or "float64". Default: "float32". dtype (str, optional): Data type, it can be "float32" or "float64". Default: "float32".
...@@ -1854,7 +1854,7 @@ class SpectralNorm(Layer): ...@@ -1854,7 +1854,7 @@ class SpectralNorm(Layer):
import paddle import paddle
x = paddle.rand((2,8,32,32)) x = paddle.rand((2,8,32,32))
spectral_norm = paddle.nn.SpectralNorm(x.shape, axis=1, power_iters=2) spectral_norm = paddle.nn.SpectralNorm(x.shape, dim=1, power_iters=2)
spectral_norm_out = spectral_norm(x) spectral_norm_out = spectral_norm(x)
print(spectral_norm_out.shape) # [2, 8, 32, 32] print(spectral_norm_out.shape) # [2, 8, 32, 32]
...@@ -1864,25 +1864,25 @@ class SpectralNorm(Layer): ...@@ -1864,25 +1864,25 @@ class SpectralNorm(Layer):
def __init__( def __init__(
self, self,
weight_shape, weight_shape,
axis=0, dim=0,
power_iters=1, power_iters=1,
epsilon=1e-12, eps=1e-12,
dtype='float32', dtype='float32',
): ):
super().__init__() super().__init__()
self._power_iters = power_iters self._power_iters = power_iters
self._epsilon = epsilon self._epsilon = eps
self._dim = axis self._dim = dim
self._dtype = dtype self._dtype = dtype
self._weight_shape = list(weight_shape) self._weight_shape = list(weight_shape)
assert ( assert (
np.prod(self._weight_shape) > 0 np.prod(self._weight_shape) > 0
), "Any dimension of `weight_shape` cannot be equal to 0." ), "Any dimension of `weight_shape` cannot be equal to 0."
assert axis < len(self._weight_shape), ( assert dim < len(self._weight_shape), (
"The input `axis` should be less than the " "The input `dim` should be less than the "
"length of `weight_shape`, but received axis=" "length of `weight_shape`, but received dim="
"{}".format(axis) "{}".format(dim)
) )
h = self._weight_shape[self._dim] h = self._weight_shape[self._dim]
w = np.prod(self._weight_shape) // h w = np.prod(self._weight_shape) // h
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册