未验证 提交 7780badb 编写于 作者: Z Zhou Wei 提交者: GitHub

【cherry-pick2.0】Polish and Optimize the print/repr information of Layer (#29998) (#30893)

cherry-pick #29998
* Polish and Optimize the print/repr message of all layer
* fix some code format
上级 963d54d1
...@@ -267,7 +267,7 @@ void* GetCUDNNDsoHandle() { ...@@ -267,7 +267,7 @@ void* GetCUDNNDsoHandle() {
"For instance, download cudnn-10.0-windows10-x64-v7.6.5.32.zip from " "For instance, download cudnn-10.0-windows10-x64-v7.6.5.32.zip from "
"NVIDIA's official website, \n" "NVIDIA's official website, \n"
"then, unzip it and copy it into C:\\Program Files\\NVIDIA GPU Computing " "then, unzip it and copy it into C:\\Program Files\\NVIDIA GPU Computing "
"Toolkit\\CUDA/v10.0\n" "Toolkit\\CUDA\\v10.0\n"
"You should do this according to your CUDA installation directory and " "You should do this according to your CUDA installation directory and "
"CUDNN version."); "CUDNN version.");
return GetDsoHandleFromSearchPath(FLAGS_cudnn_dir, win_cudnn_lib, true, return GetDsoHandleFromSearchPath(FLAGS_cudnn_dir, win_cudnn_lib, true,
......
...@@ -46,6 +46,17 @@ def _convert_camel_to_snake(name): ...@@ -46,6 +46,17 @@ def _convert_camel_to_snake(name):
return _all_cap_re.sub(r'\1_\2', s1).lower() return _all_cap_re.sub(r'\1_\2', s1).lower()
def _addindent(string, indent):
s1 = string.split('\n')
if len(s1) == 1:
return string
s2 = []
for idx, line in enumerate(s1):
if idx > 0:
s2.append(str((indent * ' ') + line))
return s1[0] + '\n' + '\n'.join(s2)
class HookRemoveHelper(object): class HookRemoveHelper(object):
""" A HookRemoveHelper that can be used to remove hook. """ """ A HookRemoveHelper that can be used to remove hook. """
...@@ -1166,6 +1177,35 @@ class Layer(core.Layer): ...@@ -1166,6 +1177,35 @@ class Layer(core.Layer):
return keys return keys
def extra_repr(self):
"""
Extra representation of this layer, you can have custom implementation
of your own layer.
"""
return ''
def __repr__(self):
extra_lines = []
extra_repr = self.extra_repr()
extra_lines = extra_repr.split('\n')
sublayer_lines = []
for name, layer in self._sub_layers.items():
sublayer_str = repr(layer)
sublayer_str = _addindent(sublayer_str, 2)
sublayer_lines.append('(' + name + '): ' + sublayer_str)
final_str = self.__class__.__name__ + '('
if extra_lines:
if len(extra_lines) > 1:
final_str += '\n ' + '\n '.join(extra_lines) + '\n'
elif len(extra_lines) == 1:
final_str += extra_lines[0]
if sublayer_lines:
final_str += '\n ' + '\n '.join(sublayer_lines) + '\n'
final_str += ')'
return final_str
def state_dict(self, def state_dict(self,
destination=None, destination=None,
include_sublayers=True, include_sublayers=True,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import paddle.nn as nn
class TestLayerPrint(unittest.TestCase):
def test_layer_str(self):
module = nn.ELU(0.2)
self.assertEqual(str(module), 'ELU(alpha=0.2)')
module = nn.GELU(True)
self.assertEqual(str(module), 'GELU(approximate=True)')
module = nn.Hardshrink()
self.assertEqual(str(module), 'Hardshrink(threshold=0.5)')
module = nn.Hardswish(name="Hardswish")
self.assertEqual(str(module), 'Hardswish(name=Hardswish)')
module = nn.Tanh(name="Tanh")
self.assertEqual(str(module), 'Tanh(name=Tanh)')
module = nn.Hardtanh(name="Hardtanh")
self.assertEqual(
str(module), 'Hardtanh(min=-1.0, max=1.0, name=Hardtanh)')
module = nn.PReLU(1, 0.25, name="PReLU")
self.assertEqual(
str(module),
'PReLU(num_parameters=1, init=0.25, dtype=float32, name=PReLU)')
module = nn.ReLU()
self.assertEqual(str(module), 'ReLU()')
module = nn.ReLU6()
self.assertEqual(str(module), 'ReLU6()')
module = nn.SELU()
self.assertEqual(
str(module),
'SELU(scale=1.0507009873554805, alpha=1.6732632423543772)')
module = nn.LeakyReLU()
self.assertEqual(str(module), 'LeakyReLU(negative_slope=0.01)')
module = nn.Sigmoid()
self.assertEqual(str(module), 'Sigmoid()')
module = nn.Hardsigmoid()
self.assertEqual(str(module), 'Hardsigmoid()')
module = nn.Softplus()
self.assertEqual(str(module), 'Softplus(beta=1, threshold=20)')
module = nn.Softshrink()
self.assertEqual(str(module), 'Softshrink(threshold=0.5)')
module = nn.Softsign()
self.assertEqual(str(module), 'Softsign()')
module = nn.Swish()
self.assertEqual(str(module), 'Swish()')
module = nn.Tanhshrink()
self.assertEqual(str(module), 'Tanhshrink()')
module = nn.ThresholdedReLU()
self.assertEqual(str(module), 'ThresholdedReLU(threshold=1.0)')
module = nn.LogSigmoid()
self.assertEqual(str(module), 'LogSigmoid()')
module = nn.Softmax()
self.assertEqual(str(module), 'Softmax(axis=-1)')
module = nn.LogSoftmax()
self.assertEqual(str(module), 'LogSoftmax(axis=-1)')
module = nn.Maxout(groups=2)
self.assertEqual(str(module), 'Maxout(groups=2, axis=1)')
module = nn.Linear(2, 4, name='linear')
self.assertEqual(
str(module),
'Linear(in_features=2, out_features=4, dtype=float32, name=linear)')
module = nn.Upsample(size=[12, 12])
self.assertEqual(
str(module),
'Upsample(size=[12, 12], mode=nearest, align_corners=False, align_mode=0, data_format=NCHW)'
)
module = nn.UpsamplingNearest2D(size=[12, 12])
self.assertEqual(
str(module), 'UpsamplingNearest2D(size=[12, 12], data_format=NCHW)')
module = nn.UpsamplingBilinear2D(size=[12, 12])
self.assertEqual(
str(module),
'UpsamplingBilinear2D(size=[12, 12], data_format=NCHW)')
module = nn.Bilinear(in1_features=5, in2_features=4, out_features=1000)
self.assertEqual(
str(module),
'Bilinear(in1_features=5, in2_features=4, out_features=1000, dtype=float32)'
)
module = nn.Dropout(p=0.5)
self.assertEqual(
str(module), 'Dropout(p=0.5, axis=None, mode=upscale_in_train)')
module = nn.Dropout2D(p=0.5)
self.assertEqual(str(module), 'Dropout2D(p=0.5, data_format=NCHW)')
module = nn.Dropout3D(p=0.5)
self.assertEqual(str(module), 'Dropout3D(p=0.5, data_format=NCDHW)')
module = nn.AlphaDropout(p=0.5)
self.assertEqual(str(module), 'AlphaDropout(p=0.5)')
module = nn.Pad1D(padding=[1, 2], mode='constant')
self.assertEqual(
str(module),
'Pad1D(padding=[1, 2], mode=constant, value=0.0, data_format=NCL)')
module = nn.Pad2D(padding=[1, 0, 1, 2], mode='constant')
self.assertEqual(
str(module),
'Pad2D(padding=[1, 0, 1, 2], mode=constant, value=0.0, data_format=NCHW)'
)
module = nn.Pad3D(padding=[1, 0, 1, 2, 0, 0], mode='constant')
self.assertEqual(
str(module),
'Pad3D(padding=[1, 0, 1, 2, 0, 0], mode=constant, value=0.0, data_format=NCDHW)'
)
module = nn.CosineSimilarity(axis=0)
self.assertEqual(str(module), 'CosineSimilarity(axis=0, eps=1e-08)')
module = nn.Embedding(10, 3, sparse=True)
self.assertEqual(str(module), 'Embedding(10, 3, sparse=True)')
module = nn.Conv1D(3, 2, 3)
self.assertEqual(
str(module), 'Conv1D(3, 2, kernel_size=[3], data_format=NCL)')
module = nn.Conv1DTranspose(2, 1, 2)
self.assertEqual(
str(module),
'Conv1DTranspose(2, 1, kernel_size=[2], data_format=NCL)')
module = nn.Conv2D(4, 6, (3, 3))
self.assertEqual(
str(module), 'Conv2D(4, 6, kernel_size=[3, 3], data_format=NCHW)')
module = nn.Conv2DTranspose(4, 6, (3, 3))
self.assertEqual(
str(module),
'Conv2DTranspose(4, 6, kernel_size=[3, 3], data_format=NCHW)')
module = nn.Conv3D(4, 6, (3, 3, 3))
self.assertEqual(
str(module),
'Conv3D(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW)')
module = nn.Conv3DTranspose(4, 6, (3, 3, 3))
self.assertEqual(
str(module),
'Conv3DTranspose(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW)')
module = nn.PairwiseDistance()
self.assertEqual(str(module), 'PairwiseDistance(p=2.0)')
module = nn.InstanceNorm1D(2)
self.assertEqual(
str(module), 'InstanceNorm1D(num_features=2, epsilon=1e-05)')
module = nn.InstanceNorm2D(2)
self.assertEqual(
str(module), 'InstanceNorm2D(num_features=2, epsilon=1e-05)')
module = nn.InstanceNorm3D(2)
self.assertEqual(
str(module), 'InstanceNorm3D(num_features=2, epsilon=1e-05)')
module = nn.GroupNorm(num_channels=6, num_groups=6)
self.assertEqual(
str(module),
'GroupNorm(num_groups=6, num_channels=6, epsilon=1e-05)')
module = nn.LayerNorm([2, 2, 3])
self.assertEqual(
str(module), 'LayerNorm(normalized_shape=[2, 2, 3], epsilon=1e-05)')
module = nn.BatchNorm1D(1)
self.assertEqual(
str(module),
'BatchNorm1D(num_features=1, momentum=0.9, epsilon=1e-05)')
module = nn.BatchNorm2D(1)
self.assertEqual(
str(module),
'BatchNorm2D(num_features=1, momentum=0.9, epsilon=1e-05)')
module = nn.BatchNorm3D(1)
self.assertEqual(
str(module),
'BatchNorm3D(num_features=1, momentum=0.9, epsilon=1e-05)')
module = nn.SyncBatchNorm(2)
self.assertEqual(
str(module),
'SyncBatchNorm(num_features=2, momentum=0.9, epsilon=1e-05)')
module = nn.LocalResponseNorm(size=5)
self.assertEqual(
str(module),
'LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=1.0)')
module = nn.AvgPool1D(kernel_size=2, stride=2, padding=0)
self.assertEqual(
str(module), 'AvgPool1D(kernel_size=2, stride=2, padding=0)')
module = nn.AvgPool2D(kernel_size=2, stride=2, padding=0)
self.assertEqual(
str(module), 'AvgPool2D(kernel_size=2, stride=2, padding=0)')
module = nn.AvgPool3D(kernel_size=2, stride=2, padding=0)
self.assertEqual(
str(module), 'AvgPool3D(kernel_size=2, stride=2, padding=0)')
module = nn.MaxPool1D(kernel_size=2, stride=2, padding=0)
self.assertEqual(
str(module), 'MaxPool1D(kernel_size=2, stride=2, padding=0)')
module = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
self.assertEqual(
str(module), 'MaxPool2D(kernel_size=2, stride=2, padding=0)')
module = nn.MaxPool3D(kernel_size=2, stride=2, padding=0)
self.assertEqual(
str(module), 'MaxPool3D(kernel_size=2, stride=2, padding=0)')
module = nn.AdaptiveAvgPool1D(output_size=16)
self.assertEqual(str(module), 'AdaptiveAvgPool1D(output_size=16)')
module = nn.AdaptiveAvgPool2D(output_size=3)
self.assertEqual(str(module), 'AdaptiveAvgPool2D(output_size=3)')
module = nn.AdaptiveAvgPool3D(output_size=3)
self.assertEqual(str(module), 'AdaptiveAvgPool3D(output_size=3)')
module = nn.AdaptiveMaxPool1D(output_size=16, return_mask=True)
self.assertEqual(
str(module), 'AdaptiveMaxPool1D(output_size=16, return_mask=True)')
module = nn.AdaptiveMaxPool2D(output_size=3, return_mask=True)
self.assertEqual(
str(module), 'AdaptiveMaxPool2D(output_size=3, return_mask=True)')
module = nn.AdaptiveMaxPool3D(output_size=3, return_mask=True)
self.assertEqual(
str(module), 'AdaptiveMaxPool3D(output_size=3, return_mask=True)')
module = nn.SimpleRNNCell(16, 32)
self.assertEqual(str(module), 'SimpleRNNCell(16, 32)')
module = nn.LSTMCell(16, 32)
self.assertEqual(str(module), 'LSTMCell(16, 32)')
module = nn.GRUCell(16, 32)
self.assertEqual(str(module), 'GRUCell(16, 32)')
module = nn.PixelShuffle(3)
self.assertEqual(str(module), 'PixelShuffle(upscale_factor=3)')
module = nn.SimpleRNN(16, 32, 2)
self.assertEqual(
str(module),
'SimpleRNN(16, 32, num_layers=2\n (0): RNN(\n (cell): SimpleRNNCell(16, 32)\n )\n (1): RNN(\n (cell): SimpleRNNCell(32, 32)\n )\n)'
)
module = nn.LSTM(16, 32, 2)
self.assertEqual(
str(module),
'LSTM(16, 32, num_layers=2\n (0): RNN(\n (cell): LSTMCell(16, 32)\n )\n (1): RNN(\n (cell): LSTMCell(32, 32)\n )\n)'
)
module = nn.GRU(16, 32, 2)
self.assertEqual(
str(module),
'GRU(16, 32, num_layers=2\n (0): RNN(\n (cell): GRUCell(16, 32)\n )\n (1): RNN(\n (cell): GRUCell(32, 32)\n )\n)'
)
module1 = nn.Sequential(
('conv1', nn.Conv2D(1, 20, 5)), ('relu1', nn.ReLU()),
('conv2', nn.Conv2D(20, 64, 5)), ('relu2', nn.ReLU()))
self.assertEqual(
str(module1),
'Sequential(\n '\
'(conv1): Conv2D(1, 20, kernel_size=[5, 5], data_format=NCHW)\n '\
'(relu1): ReLU()\n '\
'(conv2): Conv2D(20, 64, kernel_size=[5, 5], data_format=NCHW)\n '\
'(relu2): ReLU()\n)'
)
module2 = nn.Sequential(
nn.Conv3DTranspose(4, 6, (3, 3, 3)),
nn.AvgPool3D(
kernel_size=2, stride=2, padding=0),
nn.Tanh(name="Tanh"),
module1,
nn.Conv3D(4, 6, (3, 3, 3)),
nn.MaxPool3D(
kernel_size=2, stride=2, padding=0),
nn.GELU(True))
self.assertEqual(
str(module2),
'Sequential(\n '\
'(0): Conv3DTranspose(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW)\n '\
'(1): AvgPool3D(kernel_size=2, stride=2, padding=0)\n '\
'(2): Tanh(name=Tanh)\n '\
'(3): Sequential(\n (conv1): Conv2D(1, 20, kernel_size=[5, 5], data_format=NCHW)\n (relu1): ReLU()\n'\
' (conv2): Conv2D(20, 64, kernel_size=[5, 5], data_format=NCHW)\n (relu2): ReLU()\n )\n '\
'(4): Conv3D(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW)\n '\
'(5): MaxPool3D(kernel_size=2, stride=2, padding=0)\n '\
'(6): GELU(approximate=True)\n)'
)
if __name__ == '__main__':
unittest.main()
...@@ -86,6 +86,10 @@ class ELU(layers.Layer): ...@@ -86,6 +86,10 @@ class ELU(layers.Layer):
def forward(self, x): def forward(self, x):
return F.elu(x, self._alpha, self._name) return F.elu(x, self._alpha, self._name)
def extra_repr(self):
name_str = ', name={}'.format(self._name) if self._name else ''
return 'alpha={}{}'.format(self._alpha, name_str)
class GELU(layers.Layer): class GELU(layers.Layer):
r""" r"""
...@@ -135,6 +139,10 @@ class GELU(layers.Layer): ...@@ -135,6 +139,10 @@ class GELU(layers.Layer):
def forward(self, x): def forward(self, x):
return F.gelu(x, self._approximate, self._name) return F.gelu(x, self._approximate, self._name)
def extra_repr(self):
name_str = ', name={}'.format(self._name) if self._name else ''
return 'approximate={}{}'.format(self._approximate, name_str)
class Hardshrink(layers.Layer): class Hardshrink(layers.Layer):
r""" r"""
...@@ -179,6 +187,10 @@ class Hardshrink(layers.Layer): ...@@ -179,6 +187,10 @@ class Hardshrink(layers.Layer):
def forward(self, x): def forward(self, x):
return F.hardshrink(x, self._threshold, self._name) return F.hardshrink(x, self._threshold, self._name)
def extra_repr(self):
name_str = ', name={}'.format(self._name) if self._name else ''
return 'threshold={}{}'.format(self._threshold, name_str)
class Hardswish(layers.Layer): class Hardswish(layers.Layer):
r""" r"""
...@@ -225,6 +237,10 @@ class Hardswish(layers.Layer): ...@@ -225,6 +237,10 @@ class Hardswish(layers.Layer):
def forward(self, x): def forward(self, x):
return F.hardswish(x, self._name) return F.hardswish(x, self._name)
def extra_repr(self):
name_str = 'name={}'.format(self._name) if self._name else ''
return name_str
class Tanh(layers.Layer): class Tanh(layers.Layer):
r""" r"""
...@@ -262,6 +278,10 @@ class Tanh(layers.Layer): ...@@ -262,6 +278,10 @@ class Tanh(layers.Layer):
def forward(self, x): def forward(self, x):
return F.tanh(x, self._name) return F.tanh(x, self._name)
def extra_repr(self):
name_str = 'name={}'.format(self._name) if self._name else ''
return name_str
class Hardtanh(layers.Layer): class Hardtanh(layers.Layer):
r""" r"""
...@@ -304,6 +324,10 @@ class Hardtanh(layers.Layer): ...@@ -304,6 +324,10 @@ class Hardtanh(layers.Layer):
def forward(self, x): def forward(self, x):
return F.hardtanh(x, self._min, self._max, self._name) return F.hardtanh(x, self._min, self._max, self._name)
def extra_repr(self):
name_str = ', name={}'.format(self._name) if self._name else ''
return 'min={}, max={}{}'.format(self._min, self._max, name_str)
class PReLU(layers.Layer): class PReLU(layers.Layer):
""" """
...@@ -371,6 +395,11 @@ class PReLU(layers.Layer): ...@@ -371,6 +395,11 @@ class PReLU(layers.Layer):
def forward(self, x): def forward(self, x):
return F.prelu(x, self._weight) return F.prelu(x, self._weight)
def extra_repr(self):
name_str = ', name={}'.format(self._name) if self._name else ''
return 'num_parameters={}, init={}, dtype={}{}'.format(
self._num_parameters, self._init, self._dtype, name_str)
class ReLU(layers.Layer): class ReLU(layers.Layer):
""" """
...@@ -405,6 +434,10 @@ class ReLU(layers.Layer): ...@@ -405,6 +434,10 @@ class ReLU(layers.Layer):
def forward(self, x): def forward(self, x):
return F.relu(x, self._name) return F.relu(x, self._name)
def extra_repr(self):
name_str = 'name={}'.format(self._name) if self._name else ''
return name_str
class ReLU6(layers.Layer): class ReLU6(layers.Layer):
""" """
...@@ -440,6 +473,10 @@ class ReLU6(layers.Layer): ...@@ -440,6 +473,10 @@ class ReLU6(layers.Layer):
def forward(self, x): def forward(self, x):
return F.relu6(x, self._name) return F.relu6(x, self._name)
def extra_repr(self):
name_str = 'name={}'.format(self._name) if self._name else ''
return name_str
class SELU(layers.Layer): class SELU(layers.Layer):
r""" r"""
...@@ -486,6 +523,11 @@ class SELU(layers.Layer): ...@@ -486,6 +523,11 @@ class SELU(layers.Layer):
def forward(self, x): def forward(self, x):
return F.selu(x, self._scale, self._alpha, self._name) return F.selu(x, self._scale, self._alpha, self._name)
def extra_repr(self):
name_str = ', name={}'.format(self._name) if self._name else ''
return 'scale={:.16f}, alpha={:.16f}{}'.format(self._scale, self._alpha,
name_str)
class LeakyReLU(layers.Layer): class LeakyReLU(layers.Layer):
r""" r"""
...@@ -530,6 +572,10 @@ class LeakyReLU(layers.Layer): ...@@ -530,6 +572,10 @@ class LeakyReLU(layers.Layer):
def forward(self, x): def forward(self, x):
return F.leaky_relu(x, self._negative_slope, self._name) return F.leaky_relu(x, self._negative_slope, self._name)
def extra_repr(self):
name_str = ', name={}'.format(self._name) if self._name else ''
return 'negative_slope={}{}'.format(self._negative_slope, name_str)
class Sigmoid(layers.Layer): class Sigmoid(layers.Layer):
""" """
...@@ -566,6 +612,10 @@ class Sigmoid(layers.Layer): ...@@ -566,6 +612,10 @@ class Sigmoid(layers.Layer):
def forward(self, x): def forward(self, x):
return F.sigmoid(x, self.name) return F.sigmoid(x, self.name)
def extra_repr(self):
name_str = 'name={}'.format(self.name) if self.name else ''
return name_str
class Hardsigmoid(layers.Layer): class Hardsigmoid(layers.Layer):
r""" r"""
...@@ -613,6 +663,10 @@ class Hardsigmoid(layers.Layer): ...@@ -613,6 +663,10 @@ class Hardsigmoid(layers.Layer):
def forward(self, x): def forward(self, x):
return F.hardsigmoid(x, name=self.name) return F.hardsigmoid(x, name=self.name)
def extra_repr(self):
name_str = 'name={}'.format(self.name) if self.name else ''
return name_str
class Softplus(layers.Layer): class Softplus(layers.Layer):
r""" r"""
...@@ -653,6 +707,11 @@ class Softplus(layers.Layer): ...@@ -653,6 +707,11 @@ class Softplus(layers.Layer):
def forward(self, x): def forward(self, x):
return F.softplus(x, self._beta, self._threshold, self._name) return F.softplus(x, self._beta, self._threshold, self._name)
def extra_repr(self):
name_str = ', name={}'.format(self._name) if self._name else ''
return 'beta={}, threshold={}{}'.format(self._beta, self._threshold,
name_str)
class Softshrink(layers.Layer): class Softshrink(layers.Layer):
r""" r"""
...@@ -694,6 +753,10 @@ class Softshrink(layers.Layer): ...@@ -694,6 +753,10 @@ class Softshrink(layers.Layer):
def forward(self, x): def forward(self, x):
return F.softshrink(x, self._threshold, self._name) return F.softshrink(x, self._threshold, self._name)
def extra_repr(self):
name_str = ', name={}'.format(self._name) if self._name else ''
return 'threshold={}{}'.format(self._threshold, name_str)
class Softsign(layers.Layer): class Softsign(layers.Layer):
r""" r"""
...@@ -729,6 +792,10 @@ class Softsign(layers.Layer): ...@@ -729,6 +792,10 @@ class Softsign(layers.Layer):
def forward(self, x): def forward(self, x):
return F.softsign(x, self._name) return F.softsign(x, self._name)
def extra_repr(self):
name_str = 'name={}'.format(self._name) if self._name else ''
return name_str
class Swish(layers.Layer): class Swish(layers.Layer):
r""" r"""
...@@ -764,6 +831,10 @@ class Swish(layers.Layer): ...@@ -764,6 +831,10 @@ class Swish(layers.Layer):
def forward(self, x): def forward(self, x):
return F.swish(x, self._name) return F.swish(x, self._name)
def extra_repr(self):
name_str = 'name={}'.format(self._name) if self._name else ''
return name_str
class Tanhshrink(layers.Layer): class Tanhshrink(layers.Layer):
""" """
...@@ -799,6 +870,10 @@ class Tanhshrink(layers.Layer): ...@@ -799,6 +870,10 @@ class Tanhshrink(layers.Layer):
def forward(self, x): def forward(self, x):
return F.tanhshrink(x, self._name) return F.tanhshrink(x, self._name)
def extra_repr(self):
name_str = 'name={}'.format(self._name) if self._name else ''
return name_str
class ThresholdedReLU(layers.Layer): class ThresholdedReLU(layers.Layer):
r""" r"""
...@@ -839,6 +914,10 @@ class ThresholdedReLU(layers.Layer): ...@@ -839,6 +914,10 @@ class ThresholdedReLU(layers.Layer):
def forward(self, x): def forward(self, x):
return F.thresholded_relu(x, self._threshold, self._name) return F.thresholded_relu(x, self._threshold, self._name)
def extra_repr(self):
name_str = ', name={}'.format(self._name) if self._name else ''
return 'threshold={}{}'.format(self._threshold, name_str)
class LogSigmoid(layers.Layer): class LogSigmoid(layers.Layer):
r""" r"""
...@@ -874,6 +953,10 @@ class LogSigmoid(layers.Layer): ...@@ -874,6 +953,10 @@ class LogSigmoid(layers.Layer):
def forward(self, x): def forward(self, x):
return F.log_sigmoid(x, self._name) return F.log_sigmoid(x, self._name)
def extra_repr(self):
name_str = 'name={}'.format(self._name) if self._name else ''
return name_str
class Softmax(layers.Layer): class Softmax(layers.Layer):
r""" r"""
...@@ -997,6 +1080,10 @@ class Softmax(layers.Layer): ...@@ -997,6 +1080,10 @@ class Softmax(layers.Layer):
def forward(self, x): def forward(self, x):
return F.softmax(x, self._axis, self._dtype, self._name) return F.softmax(x, self._axis, self._dtype, self._name)
def extra_repr(self):
name_str = ', name={}'.format(self._name) if self._name else ''
return 'axis={}{}'.format(self._axis, name_str)
class LogSoftmax(layers.Layer): class LogSoftmax(layers.Layer):
r""" r"""
...@@ -1051,6 +1138,10 @@ class LogSoftmax(layers.Layer): ...@@ -1051,6 +1138,10 @@ class LogSoftmax(layers.Layer):
def forward(self, x): def forward(self, x):
return F.log_softmax(x, self._axis) return F.log_softmax(x, self._axis)
def extra_repr(self):
name_str = ', name={}'.format(self._name) if self._name else ''
return 'axis={}{}'.format(self._axis, name_str)
class Maxout(layers.Layer): class Maxout(layers.Layer):
r""" r"""
...@@ -1111,3 +1202,7 @@ class Maxout(layers.Layer): ...@@ -1111,3 +1202,7 @@ class Maxout(layers.Layer):
def forward(self, x): def forward(self, x):
return F.maxout(x, self._groups, self._axis, self._name) return F.maxout(x, self._groups, self._axis, self._name)
def extra_repr(self):
name_str = ', name={}'.format(self._name) if self._name else ''
return 'groups={}, axis={}{}'.format(self._groups, self._axis, name_str)
...@@ -120,7 +120,6 @@ class Linear(layers.Layer): ...@@ -120,7 +120,6 @@ class Linear(layers.Layer):
self._dtype = self._helper.get_default_dtype() self._dtype = self._helper.get_default_dtype()
self._weight_attr = weight_attr self._weight_attr = weight_attr
self._bias_attr = bias_attr self._bias_attr = bias_attr
self.name = name
self.weight = self.create_parameter( self.weight = self.create_parameter(
shape=[in_features, out_features], shape=[in_features, out_features],
attr=self._weight_attr, attr=self._weight_attr,
...@@ -138,6 +137,11 @@ class Linear(layers.Layer): ...@@ -138,6 +137,11 @@ class Linear(layers.Layer):
x=input, weight=self.weight, bias=self.bias, name=self.name) x=input, weight=self.weight, bias=self.bias, name=self.name)
return out return out
def extra_repr(self):
name_str = ', name={}'.format(self.name) if self.name else ''
return 'in_features={}, out_features={}, dtype={}{}'.format(
self.weight.shape[0], self.weight.shape[1], self._dtype, name_str)
class Upsample(layers.Layer): class Upsample(layers.Layer):
""" """
...@@ -378,6 +382,16 @@ class Upsample(layers.Layer): ...@@ -378,6 +382,16 @@ class Upsample(layers.Layer):
return out return out
def extra_repr(self):
if self.scale_factor is not None:
main_str = 'scale_factor={}'.format(self.scale_factor)
else:
main_str = 'size={}'.format(self.size)
name_str = ', name={}'.format(self.name) if self.name else ''
return '{}, mode={}, align_corners={}, align_mode={}, data_format={}{}'.format(
main_str, self.mode, self.align_corners, self.align_mode,
self.data_format, name_str)
class UpsamplingNearest2D(layers.Layer): class UpsamplingNearest2D(layers.Layer):
""" """
...@@ -454,6 +468,15 @@ class UpsamplingNearest2D(layers.Layer): ...@@ -454,6 +468,15 @@ class UpsamplingNearest2D(layers.Layer):
return out return out
def extra_repr(self):
if self.scale_factor is not None:
main_str = 'scale_factor={}'.format(self.scale_factor)
else:
main_str = 'size={}'.format(self.size)
name_str = ', name={}'.format(self.name) if self.name else ''
return '{}, data_format={}{}'.format(main_str, self.data_format,
name_str)
class UpsamplingBilinear2D(layers.Layer): class UpsamplingBilinear2D(layers.Layer):
""" """
...@@ -531,6 +554,15 @@ class UpsamplingBilinear2D(layers.Layer): ...@@ -531,6 +554,15 @@ class UpsamplingBilinear2D(layers.Layer):
return out return out
def extra_repr(self):
if self.scale_factor is not None:
main_str = 'scale_factor={}'.format(self.scale_factor)
else:
main_str = 'size={}'.format(self.size)
name_str = ', name={}'.format(self.name) if self.name else ''
return '{}, data_format={}{}'.format(main_str, self.data_format,
name_str)
class Bilinear(layers.Layer): class Bilinear(layers.Layer):
r""" r"""
...@@ -620,6 +652,12 @@ class Bilinear(layers.Layer): ...@@ -620,6 +652,12 @@ class Bilinear(layers.Layer):
def forward(self, x1, x2): def forward(self, x1, x2):
return F.bilinear(x1, x2, self.weight, self.bias, self._name) return F.bilinear(x1, x2, self.weight, self.bias, self._name)
def extra_repr(self):
name_str = ', name={}'.format(self._name) if self._name else ''
return 'in1_features={}, in2_features={}, out_features={}, dtype={}{}'.format(
self._in1_features, self._in2_features, self._out_features,
self._dtype, name_str)
class Dropout(layers.Layer): class Dropout(layers.Layer):
""" """
...@@ -689,6 +727,11 @@ class Dropout(layers.Layer): ...@@ -689,6 +727,11 @@ class Dropout(layers.Layer):
name=self.name) name=self.name)
return out return out
def extra_repr(self):
name_str = ', name={}'.format(self.name) if self.name else ''
return 'p={}, axis={}, mode={}{}'.format(self.p, self.axis, self.mode,
name_str)
class Dropout2D(layers.Layer): class Dropout2D(layers.Layer):
""" """
...@@ -745,6 +788,11 @@ class Dropout2D(layers.Layer): ...@@ -745,6 +788,11 @@ class Dropout2D(layers.Layer):
name=self.name) name=self.name)
return out return out
def extra_repr(self):
name_str = ', name={}'.format(self.name) if self.name else ''
return 'p={}, data_format={}{}'.format(self.p, self.data_format,
name_str)
class Dropout3D(layers.Layer): class Dropout3D(layers.Layer):
""" """
...@@ -801,6 +849,11 @@ class Dropout3D(layers.Layer): ...@@ -801,6 +849,11 @@ class Dropout3D(layers.Layer):
name=self.name) name=self.name)
return out return out
def extra_repr(self):
name_str = ', name={}'.format(self.name) if self.name else ''
return 'p={}, data_format={}{}'.format(self.p, self.data_format,
name_str)
class AlphaDropout(layers.Layer): class AlphaDropout(layers.Layer):
""" """
...@@ -850,6 +903,10 @@ class AlphaDropout(layers.Layer): ...@@ -850,6 +903,10 @@ class AlphaDropout(layers.Layer):
input, p=self.p, training=self.training, name=self.name) input, p=self.p, training=self.training, name=self.name)
return out return out
def extra_repr(self):
name_str = ', name={}'.format(self.name) if self.name else ''
return 'p={}{}'.format(self.p, name_str)
class Pad1D(layers.Layer): class Pad1D(layers.Layer):
""" """
...@@ -925,6 +982,11 @@ class Pad1D(layers.Layer): ...@@ -925,6 +982,11 @@ class Pad1D(layers.Layer):
data_format=self._data_format, data_format=self._data_format,
name=self._name) name=self._name)
def extra_repr(self):
name_str = ', name={}'.format(self._name) if self._name else ''
return 'padding={}, mode={}, value={}, data_format={}{}'.format(
self._pad, self._mode, self._value, self._data_format, name_str)
class Pad2D(layers.Layer): class Pad2D(layers.Layer):
""" """
...@@ -1003,6 +1065,11 @@ class Pad2D(layers.Layer): ...@@ -1003,6 +1065,11 @@ class Pad2D(layers.Layer):
data_format=self._data_format, data_format=self._data_format,
name=self._name) name=self._name)
def extra_repr(self):
name_str = ', name={}'.format(self._name) if self._name else ''
return 'padding={}, mode={}, value={}, data_format={}{}'.format(
self._pad, self._mode, self._value, self._data_format, name_str)
class Pad3D(layers.Layer): class Pad3D(layers.Layer):
""" """
...@@ -1081,6 +1148,11 @@ class Pad3D(layers.Layer): ...@@ -1081,6 +1148,11 @@ class Pad3D(layers.Layer):
data_format=self._data_format, data_format=self._data_format,
name=self._name) name=self._name)
def extra_repr(self):
name_str = ', name={}'.format(self._name) if self._name else ''
return 'padding={}, mode={}, value={}, data_format={}{}'.format(
self._pad, self._mode, self._value, self._data_format, name_str)
class CosineSimilarity(layers.Layer): class CosineSimilarity(layers.Layer):
""" """
...@@ -1135,6 +1207,9 @@ class CosineSimilarity(layers.Layer): ...@@ -1135,6 +1207,9 @@ class CosineSimilarity(layers.Layer):
def forward(self, x1, x2): def forward(self, x1, x2):
return F.cosine_similarity(x1, x2, axis=self._axis, eps=self._eps) return F.cosine_similarity(x1, x2, axis=self._axis, eps=self._eps)
def extra_repr(self):
return 'axis={_axis}, eps={_eps}'.format(**self.__dict__)
class Embedding(layers.Layer): class Embedding(layers.Layer):
r""" r"""
...@@ -1288,3 +1363,12 @@ class Embedding(layers.Layer): ...@@ -1288,3 +1363,12 @@ class Embedding(layers.Layer):
padding_idx=self._padding_idx, padding_idx=self._padding_idx,
sparse=self._sparse, sparse=self._sparse,
name=self._name) name=self._name)
def extra_repr(self):
main_str = '{_num_embeddings}, {_embedding_dim}'
if self._padding_idx is not None:
main_str += ', padding_idx={_padding_idx}'
main_str += ', sparse={_sparse}'
if self._name is not None:
main_str += ', name={_name}'
return main_str.format(**self.__dict__)
...@@ -148,6 +148,23 @@ class _ConvNd(layers.Layer): ...@@ -148,6 +148,23 @@ class _ConvNd(layers.Layer):
self._op_type = 'depthwise_conv2d' self._op_type = 'depthwise_conv2d'
self._use_cudnn = False self._use_cudnn = False
def extra_repr(self):
main_str = '{_in_channels}, {_out_channels}, kernel_size={_kernel_size}'
if self._stride != [1] * len(self._stride):
main_str += ', stride={_stride}'
if self._padding != 0:
main_str += ', padding={_padding}'
if self._padding_mode is not 'zeros':
main_str += ', padding_mode={_padding_mode}'
if self.output_padding != 0:
main_str += ', output_padding={_output_padding}'
if self._dilation != [1] * len(self._dilation):
main_str += ', dilation={_dilation}'
if self._groups != 1:
main_str += ', groups={_groups}'
main_str += ', data_format={_data_format}'
return main_str.format(**self.__dict__)
class Conv1D(_ConvNd): class Conv1D(_ConvNd):
r""" r"""
......
...@@ -100,3 +100,13 @@ class PairwiseDistance(layers.Layer): ...@@ -100,3 +100,13 @@ class PairwiseDistance(layers.Layer):
type='p_norm', inputs={'X': sub}, outputs={'Out': out}, attrs=attrs) type='p_norm', inputs={'X': sub}, outputs={'Out': out}, attrs=attrs)
return out return out
def extra_repr(self):
main_str = 'p={p}'
if self.epsilon != 1e-6:
main_str += ', epsilon={epsilon}'
if self.keepdim != False:
main_str += ', keepdim={keepdim}'
if self.name != None:
main_str += ', name={name}'
return main_str.format(**self.__dict__)
...@@ -107,6 +107,10 @@ class _InstanceNormBase(layers.Layer): ...@@ -107,6 +107,10 @@ class _InstanceNormBase(layers.Layer):
return instance_norm( return instance_norm(
input, weight=self.scale, bias=self.bias, eps=self._epsilon) input, weight=self.scale, bias=self.bias, eps=self._epsilon)
def extra_repr(self):
return 'num_features={}, epsilon={}'.format(self.scale.shape[0],
self._epsilon)
class InstanceNorm1D(_InstanceNormBase): class InstanceNorm1D(_InstanceNormBase):
r""" r"""
...@@ -433,6 +437,10 @@ class GroupNorm(layers.Layer): ...@@ -433,6 +437,10 @@ class GroupNorm(layers.Layer):
return self._helper.append_activation(group_norm_out, None) return self._helper.append_activation(group_norm_out, None)
def extra_repr(self):
return 'num_groups={}, num_channels={}, epsilon={}'.format(
self._num_groups, self._num_channels, self._epsilon)
class LayerNorm(layers.Layer): class LayerNorm(layers.Layer):
r""" r"""
...@@ -537,6 +545,10 @@ class LayerNorm(layers.Layer): ...@@ -537,6 +545,10 @@ class LayerNorm(layers.Layer):
bias=self.bias, bias=self.bias,
epsilon=self._epsilon) epsilon=self._epsilon)
def extra_repr(self):
return 'normalized_shape={}, epsilon={}'.format(self._normalized_shape,
self._epsilon)
class _BatchNormBase(layers.Layer): class _BatchNormBase(layers.Layer):
""" """
...@@ -647,6 +659,15 @@ class _BatchNormBase(layers.Layer): ...@@ -647,6 +659,15 @@ class _BatchNormBase(layers.Layer):
data_format=self._data_format, data_format=self._data_format,
use_global_stats=self._use_global_stats) use_global_stats=self._use_global_stats)
def extra_repr(self):
main_str = 'num_features={}, momentum={}, epsilon={}'.format(
self._num_features, self._momentum, self._epsilon)
if self._data_format is not 'NCHW':
main_str += ', data_format={}'.format(self._data_format)
if self._name is not None:
main_str += ', name={}'.format(self._name)
return main_str
class BatchNorm1D(_BatchNormBase): class BatchNorm1D(_BatchNormBase):
r""" r"""
...@@ -1186,3 +1207,12 @@ class LocalResponseNorm(layers.Layer): ...@@ -1186,3 +1207,12 @@ class LocalResponseNorm(layers.Layer):
out = F.local_response_norm(input, self.size, self.alpha, self.beta, out = F.local_response_norm(input, self.size, self.alpha, self.beta,
self.k, self.data_format, self.name) self.k, self.data_format, self.name)
return out return out
def extra_repr(self):
main_str = 'size={}, alpha={}, beta={}, k={}'.format(
self.size, self.alpha, self.beta, self.k)
if self.data_format is not 'NCHW':
main_str += ', data_format={}'.format(self.data_format)
if self.name is not None:
main_str += ', name={}'.format(self.name)
return main_str
...@@ -119,6 +119,10 @@ class AvgPool1D(layers.Layer): ...@@ -119,6 +119,10 @@ class AvgPool1D(layers.Layer):
self.exclusive, self.ceil_mode, self.name) self.exclusive, self.ceil_mode, self.name)
return out return out
def extra_repr(self):
return 'kernel_size={kernel_size}, stride={stride}, padding={padding}'.format(
**self.__dict__)
class AvgPool2D(layers.Layer): class AvgPool2D(layers.Layer):
r""" r"""
...@@ -222,6 +226,10 @@ class AvgPool2D(layers.Layer): ...@@ -222,6 +226,10 @@ class AvgPool2D(layers.Layer):
data_format=self.data_format, data_format=self.data_format,
name=self.name) name=self.name)
def extra_repr(self):
return 'kernel_size={ksize}, stride={stride}, padding={padding}'.format(
**self.__dict__)
class AvgPool3D(layers.Layer): class AvgPool3D(layers.Layer):
""" """
...@@ -313,6 +321,10 @@ class AvgPool3D(layers.Layer): ...@@ -313,6 +321,10 @@ class AvgPool3D(layers.Layer):
data_format=self.data_format, data_format=self.data_format,
name=self.name) name=self.name)
def extra_repr(self):
return 'kernel_size={ksize}, stride={stride}, padding={padding}'.format(
**self.__dict__)
class MaxPool1D(layers.Layer): class MaxPool1D(layers.Layer):
""" """
...@@ -401,6 +413,10 @@ class MaxPool1D(layers.Layer): ...@@ -401,6 +413,10 @@ class MaxPool1D(layers.Layer):
self.return_mask, self.ceil_mode, self.name) self.return_mask, self.ceil_mode, self.name)
return out return out
def extra_repr(self):
return 'kernel_size={kernel_size}, stride={stride}, padding={padding}'.format(
**self.__dict__)
class MaxPool2D(layers.Layer): class MaxPool2D(layers.Layer):
r""" r"""
...@@ -504,6 +520,10 @@ class MaxPool2D(layers.Layer): ...@@ -504,6 +520,10 @@ class MaxPool2D(layers.Layer):
data_format=self.data_format, data_format=self.data_format,
name=self.name) name=self.name)
def extra_repr(self):
return 'kernel_size={ksize}, stride={stride}, padding={padding}'.format(
**self.__dict__)
class MaxPool3D(layers.Layer): class MaxPool3D(layers.Layer):
""" """
...@@ -595,6 +615,10 @@ class MaxPool3D(layers.Layer): ...@@ -595,6 +615,10 @@ class MaxPool3D(layers.Layer):
data_format=self.data_format, data_format=self.data_format,
name=self.name) name=self.name)
def extra_repr(self):
return 'kernel_size={ksize}, stride={stride}, padding={padding}'.format(
**self.__dict__)
class AdaptiveAvgPool1D(layers.Layer): class AdaptiveAvgPool1D(layers.Layer):
r""" r"""
...@@ -664,6 +688,9 @@ class AdaptiveAvgPool1D(layers.Layer): ...@@ -664,6 +688,9 @@ class AdaptiveAvgPool1D(layers.Layer):
def forward(self, input): def forward(self, input):
return F.adaptive_avg_pool1d(input, self.output_size, self.name) return F.adaptive_avg_pool1d(input, self.output_size, self.name)
def extra_repr(self):
return 'output_size={}'.format(self.output_size)
class AdaptiveAvgPool2D(layers.Layer): class AdaptiveAvgPool2D(layers.Layer):
r""" r"""
...@@ -746,6 +773,9 @@ class AdaptiveAvgPool2D(layers.Layer): ...@@ -746,6 +773,9 @@ class AdaptiveAvgPool2D(layers.Layer):
data_format=self._data_format, data_format=self._data_format,
name=self._name) name=self._name)
def extra_repr(self):
return 'output_size={}'.format(self._output_size)
class AdaptiveAvgPool3D(layers.Layer): class AdaptiveAvgPool3D(layers.Layer):
r""" r"""
...@@ -834,6 +864,9 @@ class AdaptiveAvgPool3D(layers.Layer): ...@@ -834,6 +864,9 @@ class AdaptiveAvgPool3D(layers.Layer):
data_format=self._data_format, data_format=self._data_format,
name=self._name) name=self._name)
def extra_repr(self):
return 'output_size={}'.format(self._output_size)
class AdaptiveMaxPool1D(layers.Layer): class AdaptiveMaxPool1D(layers.Layer):
""" """
...@@ -913,6 +946,10 @@ class AdaptiveMaxPool1D(layers.Layer): ...@@ -913,6 +946,10 @@ class AdaptiveMaxPool1D(layers.Layer):
return F.adaptive_max_pool1d(input, self.output_size, self.return_mask, return F.adaptive_max_pool1d(input, self.output_size, self.return_mask,
self.name) self.name)
def extra_repr(self):
return 'output_size={}, return_mask={}'.format(self.output_size,
self.return_mask)
class AdaptiveMaxPool2D(layers.Layer): class AdaptiveMaxPool2D(layers.Layer):
""" """
...@@ -985,6 +1022,10 @@ class AdaptiveMaxPool2D(layers.Layer): ...@@ -985,6 +1022,10 @@ class AdaptiveMaxPool2D(layers.Layer):
return_mask=self._return_mask, return_mask=self._return_mask,
name=self._name) name=self._name)
def extra_repr(self):
return 'output_size={}, return_mask={}'.format(self._output_size,
self._return_mask)
class AdaptiveMaxPool3D(layers.Layer): class AdaptiveMaxPool3D(layers.Layer):
""" """
...@@ -1067,3 +1108,7 @@ class AdaptiveMaxPool3D(layers.Layer): ...@@ -1067,3 +1108,7 @@ class AdaptiveMaxPool3D(layers.Layer):
output_size=self._output_size, output_size=self._output_size,
return_mask=self._return_mask, return_mask=self._return_mask,
name=self._name) name=self._name)
def extra_repr(self):
return 'output_size={}, return_mask={}'.format(self._output_size,
self._return_mask)
...@@ -390,6 +390,12 @@ class SimpleRNNCell(RNNCellBase): ...@@ -390,6 +390,12 @@ class SimpleRNNCell(RNNCellBase):
def state_shape(self): def state_shape(self):
return (self.hidden_size, ) return (self.hidden_size, )
def extra_repr(self):
s = '{input_size}, {hidden_size}'
if self.activation is not "tanh":
s += ', activation={activation}'
return s.format(**self.__dict__)
class LSTMCell(RNNCellBase): class LSTMCell(RNNCellBase):
r""" r"""
...@@ -540,6 +546,9 @@ class LSTMCell(RNNCellBase): ...@@ -540,6 +546,9 @@ class LSTMCell(RNNCellBase):
""" """
return ((self.hidden_size, ), (self.hidden_size, )) return ((self.hidden_size, ), (self.hidden_size, ))
def extra_repr(self):
return '{input_size}, {hidden_size}'.format(**self.__dict__)
class GRUCell(RNNCellBase): class GRUCell(RNNCellBase):
r""" r"""
...@@ -684,6 +693,9 @@ class GRUCell(RNNCellBase): ...@@ -684,6 +693,9 @@ class GRUCell(RNNCellBase):
""" """
return (self.hidden_size, ) return (self.hidden_size, )
def extra_repr(self):
return '{input_size}, {hidden_size}'.format(**self.__dict__)
class RNN(Layer): class RNN(Layer):
r""" r"""
...@@ -1053,6 +1065,16 @@ class RNNBase(LayerList): ...@@ -1053,6 +1065,16 @@ class RNNBase(LayerList):
self.state_components) self.state_components)
return outputs, final_states return outputs, final_states
def extra_repr(self):
main_str = '{input_size}, {hidden_size}'
if self.num_layers != 1:
main_str += ', num_layers={num_layers}'
if self.time_major != False:
main_str += ', time_major={time_major}'
if self.dropout != 0:
main_str += ', dropout={dropout}'
return main_str.format(**self.__dict__)
class SimpleRNN(RNNBase): class SimpleRNN(RNNBase):
r""" r"""
......
...@@ -79,3 +79,11 @@ class PixelShuffle(layers.Layer): ...@@ -79,3 +79,11 @@ class PixelShuffle(layers.Layer):
def forward(self, x): def forward(self, x):
return functional.pixel_shuffle(x, self._upscale_factor, return functional.pixel_shuffle(x, self._upscale_factor,
self._data_format, self._name) self._data_format, self._name)
def extra_repr(self):
main_str = 'upscale_factor={}'.format(self._upscale_factor)
if self._data_format is not 'NCHW':
main_str += ', data_format={}'.format(self._data_format)
if self._name is not None:
main_str += ', name={}'.format(self._name)
return main_str
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册