diff --git a/paddle/fluid/platform/dynload/dynamic_loader.cc b/paddle/fluid/platform/dynload/dynamic_loader.cc index 03cd5814afdb52dc8d84a3c8ac0ce0f7497e5f34..d95b812e930c7cf2d2135fd4f9d098fa0e76a7c8 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.cc +++ b/paddle/fluid/platform/dynload/dynamic_loader.cc @@ -267,7 +267,7 @@ void* GetCUDNNDsoHandle() { "For instance, download cudnn-10.0-windows10-x64-v7.6.5.32.zip from " "NVIDIA's official website, \n" "then, unzip it and copy it into C:\\Program Files\\NVIDIA GPU Computing " - "Toolkit\\CUDA/v10.0\n" + "Toolkit\\CUDA\\v10.0\n" "You should do this according to your CUDA installation directory and " "CUDNN version."); return GetDsoHandleFromSearchPath(FLAGS_cudnn_dir, win_cudnn_lib, true, diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index a9237e108049360dfe4a8f5fd7f80b103f9bf92a..9da12a9116854ce2e72781a2c47cb4010b453580 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -46,6 +46,17 @@ def _convert_camel_to_snake(name): return _all_cap_re.sub(r'\1_\2', s1).lower() +def _addindent(string, indent): + s1 = string.split('\n') + if len(s1) == 1: + return string + s2 = [] + for idx, line in enumerate(s1): + if idx > 0: + s2.append(str((indent * ' ') + line)) + return s1[0] + '\n' + '\n'.join(s2) + + class HookRemoveHelper(object): """ A HookRemoveHelper that can be used to remove hook. """ @@ -1166,6 +1177,35 @@ class Layer(core.Layer): return keys + def extra_repr(self): + """ + Extra representation of this layer, you can have custom implementation + of your own layer. + """ + return '' + + def __repr__(self): + extra_lines = [] + extra_repr = self.extra_repr() + extra_lines = extra_repr.split('\n') + sublayer_lines = [] + for name, layer in self._sub_layers.items(): + sublayer_str = repr(layer) + sublayer_str = _addindent(sublayer_str, 2) + sublayer_lines.append('(' + name + '): ' + sublayer_str) + + final_str = self.__class__.__name__ + '(' + if extra_lines: + if len(extra_lines) > 1: + final_str += '\n ' + '\n '.join(extra_lines) + '\n' + elif len(extra_lines) == 1: + final_str += extra_lines[0] + if sublayer_lines: + final_str += '\n ' + '\n '.join(sublayer_lines) + '\n' + + final_str += ')' + return final_str + def state_dict(self, destination=None, include_sublayers=True, diff --git a/python/paddle/fluid/tests/unittests/test_imperative_layers.py b/python/paddle/fluid/tests/unittests/test_imperative_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..214339c50d60dd626e9b7eaf931d24114e13705b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_imperative_layers.py @@ -0,0 +1,347 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import paddle.nn as nn + + +class TestLayerPrint(unittest.TestCase): + def test_layer_str(self): + module = nn.ELU(0.2) + self.assertEqual(str(module), 'ELU(alpha=0.2)') + + module = nn.GELU(True) + self.assertEqual(str(module), 'GELU(approximate=True)') + + module = nn.Hardshrink() + self.assertEqual(str(module), 'Hardshrink(threshold=0.5)') + + module = nn.Hardswish(name="Hardswish") + self.assertEqual(str(module), 'Hardswish(name=Hardswish)') + + module = nn.Tanh(name="Tanh") + self.assertEqual(str(module), 'Tanh(name=Tanh)') + + module = nn.Hardtanh(name="Hardtanh") + self.assertEqual( + str(module), 'Hardtanh(min=-1.0, max=1.0, name=Hardtanh)') + + module = nn.PReLU(1, 0.25, name="PReLU") + self.assertEqual( + str(module), + 'PReLU(num_parameters=1, init=0.25, dtype=float32, name=PReLU)') + + module = nn.ReLU() + self.assertEqual(str(module), 'ReLU()') + + module = nn.ReLU6() + self.assertEqual(str(module), 'ReLU6()') + + module = nn.SELU() + self.assertEqual( + str(module), + 'SELU(scale=1.0507009873554805, alpha=1.6732632423543772)') + + module = nn.LeakyReLU() + self.assertEqual(str(module), 'LeakyReLU(negative_slope=0.01)') + + module = nn.Sigmoid() + self.assertEqual(str(module), 'Sigmoid()') + + module = nn.Hardsigmoid() + self.assertEqual(str(module), 'Hardsigmoid()') + + module = nn.Softplus() + self.assertEqual(str(module), 'Softplus(beta=1, threshold=20)') + + module = nn.Softshrink() + self.assertEqual(str(module), 'Softshrink(threshold=0.5)') + + module = nn.Softsign() + self.assertEqual(str(module), 'Softsign()') + + module = nn.Swish() + self.assertEqual(str(module), 'Swish()') + + module = nn.Tanhshrink() + self.assertEqual(str(module), 'Tanhshrink()') + + module = nn.ThresholdedReLU() + self.assertEqual(str(module), 'ThresholdedReLU(threshold=1.0)') + + module = nn.LogSigmoid() + self.assertEqual(str(module), 'LogSigmoid()') + + module = nn.Softmax() + self.assertEqual(str(module), 'Softmax(axis=-1)') + + module = nn.LogSoftmax() + self.assertEqual(str(module), 'LogSoftmax(axis=-1)') + + module = nn.Maxout(groups=2) + self.assertEqual(str(module), 'Maxout(groups=2, axis=1)') + + module = nn.Linear(2, 4, name='linear') + self.assertEqual( + str(module), + 'Linear(in_features=2, out_features=4, dtype=float32, name=linear)') + + module = nn.Upsample(size=[12, 12]) + self.assertEqual( + str(module), + 'Upsample(size=[12, 12], mode=nearest, align_corners=False, align_mode=0, data_format=NCHW)' + ) + + module = nn.UpsamplingNearest2D(size=[12, 12]) + self.assertEqual( + str(module), 'UpsamplingNearest2D(size=[12, 12], data_format=NCHW)') + + module = nn.UpsamplingBilinear2D(size=[12, 12]) + self.assertEqual( + str(module), + 'UpsamplingBilinear2D(size=[12, 12], data_format=NCHW)') + + module = nn.Bilinear(in1_features=5, in2_features=4, out_features=1000) + self.assertEqual( + str(module), + 'Bilinear(in1_features=5, in2_features=4, out_features=1000, dtype=float32)' + ) + + module = nn.Dropout(p=0.5) + self.assertEqual( + str(module), 'Dropout(p=0.5, axis=None, mode=upscale_in_train)') + + module = nn.Dropout2D(p=0.5) + self.assertEqual(str(module), 'Dropout2D(p=0.5, data_format=NCHW)') + + module = nn.Dropout3D(p=0.5) + self.assertEqual(str(module), 'Dropout3D(p=0.5, data_format=NCDHW)') + + module = nn.AlphaDropout(p=0.5) + self.assertEqual(str(module), 'AlphaDropout(p=0.5)') + + module = nn.Pad1D(padding=[1, 2], mode='constant') + self.assertEqual( + str(module), + 'Pad1D(padding=[1, 2], mode=constant, value=0.0, data_format=NCL)') + + module = nn.Pad2D(padding=[1, 0, 1, 2], mode='constant') + self.assertEqual( + str(module), + 'Pad2D(padding=[1, 0, 1, 2], mode=constant, value=0.0, data_format=NCHW)' + ) + + module = nn.Pad3D(padding=[1, 0, 1, 2, 0, 0], mode='constant') + self.assertEqual( + str(module), + 'Pad3D(padding=[1, 0, 1, 2, 0, 0], mode=constant, value=0.0, data_format=NCDHW)' + ) + + module = nn.CosineSimilarity(axis=0) + self.assertEqual(str(module), 'CosineSimilarity(axis=0, eps=1e-08)') + + module = nn.Embedding(10, 3, sparse=True) + self.assertEqual(str(module), 'Embedding(10, 3, sparse=True)') + + module = nn.Conv1D(3, 2, 3) + self.assertEqual( + str(module), 'Conv1D(3, 2, kernel_size=[3], data_format=NCL)') + + module = nn.Conv1DTranspose(2, 1, 2) + self.assertEqual( + str(module), + 'Conv1DTranspose(2, 1, kernel_size=[2], data_format=NCL)') + + module = nn.Conv2D(4, 6, (3, 3)) + self.assertEqual( + str(module), 'Conv2D(4, 6, kernel_size=[3, 3], data_format=NCHW)') + + module = nn.Conv2DTranspose(4, 6, (3, 3)) + self.assertEqual( + str(module), + 'Conv2DTranspose(4, 6, kernel_size=[3, 3], data_format=NCHW)') + + module = nn.Conv3D(4, 6, (3, 3, 3)) + self.assertEqual( + str(module), + 'Conv3D(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW)') + + module = nn.Conv3DTranspose(4, 6, (3, 3, 3)) + self.assertEqual( + str(module), + 'Conv3DTranspose(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW)') + + module = nn.PairwiseDistance() + self.assertEqual(str(module), 'PairwiseDistance(p=2.0)') + + module = nn.InstanceNorm1D(2) + self.assertEqual( + str(module), 'InstanceNorm1D(num_features=2, epsilon=1e-05)') + + module = nn.InstanceNorm2D(2) + self.assertEqual( + str(module), 'InstanceNorm2D(num_features=2, epsilon=1e-05)') + + module = nn.InstanceNorm3D(2) + self.assertEqual( + str(module), 'InstanceNorm3D(num_features=2, epsilon=1e-05)') + + module = nn.GroupNorm(num_channels=6, num_groups=6) + self.assertEqual( + str(module), + 'GroupNorm(num_groups=6, num_channels=6, epsilon=1e-05)') + + module = nn.LayerNorm([2, 2, 3]) + self.assertEqual( + str(module), 'LayerNorm(normalized_shape=[2, 2, 3], epsilon=1e-05)') + + module = nn.BatchNorm1D(1) + self.assertEqual( + str(module), + 'BatchNorm1D(num_features=1, momentum=0.9, epsilon=1e-05)') + + module = nn.BatchNorm2D(1) + self.assertEqual( + str(module), + 'BatchNorm2D(num_features=1, momentum=0.9, epsilon=1e-05)') + + module = nn.BatchNorm3D(1) + self.assertEqual( + str(module), + 'BatchNorm3D(num_features=1, momentum=0.9, epsilon=1e-05)') + + module = nn.SyncBatchNorm(2) + self.assertEqual( + str(module), + 'SyncBatchNorm(num_features=2, momentum=0.9, epsilon=1e-05)') + + module = nn.LocalResponseNorm(size=5) + self.assertEqual( + str(module), + 'LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=1.0)') + + module = nn.AvgPool1D(kernel_size=2, stride=2, padding=0) + self.assertEqual( + str(module), 'AvgPool1D(kernel_size=2, stride=2, padding=0)') + + module = nn.AvgPool2D(kernel_size=2, stride=2, padding=0) + self.assertEqual( + str(module), 'AvgPool2D(kernel_size=2, stride=2, padding=0)') + + module = nn.AvgPool3D(kernel_size=2, stride=2, padding=0) + self.assertEqual( + str(module), 'AvgPool3D(kernel_size=2, stride=2, padding=0)') + + module = nn.MaxPool1D(kernel_size=2, stride=2, padding=0) + self.assertEqual( + str(module), 'MaxPool1D(kernel_size=2, stride=2, padding=0)') + + module = nn.MaxPool2D(kernel_size=2, stride=2, padding=0) + self.assertEqual( + str(module), 'MaxPool2D(kernel_size=2, stride=2, padding=0)') + + module = nn.MaxPool3D(kernel_size=2, stride=2, padding=0) + self.assertEqual( + str(module), 'MaxPool3D(kernel_size=2, stride=2, padding=0)') + + module = nn.AdaptiveAvgPool1D(output_size=16) + self.assertEqual(str(module), 'AdaptiveAvgPool1D(output_size=16)') + + module = nn.AdaptiveAvgPool2D(output_size=3) + self.assertEqual(str(module), 'AdaptiveAvgPool2D(output_size=3)') + + module = nn.AdaptiveAvgPool3D(output_size=3) + self.assertEqual(str(module), 'AdaptiveAvgPool3D(output_size=3)') + + module = nn.AdaptiveMaxPool1D(output_size=16, return_mask=True) + self.assertEqual( + str(module), 'AdaptiveMaxPool1D(output_size=16, return_mask=True)') + + module = nn.AdaptiveMaxPool2D(output_size=3, return_mask=True) + self.assertEqual( + str(module), 'AdaptiveMaxPool2D(output_size=3, return_mask=True)') + + module = nn.AdaptiveMaxPool3D(output_size=3, return_mask=True) + self.assertEqual( + str(module), 'AdaptiveMaxPool3D(output_size=3, return_mask=True)') + + module = nn.SimpleRNNCell(16, 32) + self.assertEqual(str(module), 'SimpleRNNCell(16, 32)') + + module = nn.LSTMCell(16, 32) + self.assertEqual(str(module), 'LSTMCell(16, 32)') + + module = nn.GRUCell(16, 32) + self.assertEqual(str(module), 'GRUCell(16, 32)') + + module = nn.PixelShuffle(3) + self.assertEqual(str(module), 'PixelShuffle(upscale_factor=3)') + + module = nn.SimpleRNN(16, 32, 2) + self.assertEqual( + str(module), + 'SimpleRNN(16, 32, num_layers=2\n (0): RNN(\n (cell): SimpleRNNCell(16, 32)\n )\n (1): RNN(\n (cell): SimpleRNNCell(32, 32)\n )\n)' + ) + + module = nn.LSTM(16, 32, 2) + self.assertEqual( + str(module), + 'LSTM(16, 32, num_layers=2\n (0): RNN(\n (cell): LSTMCell(16, 32)\n )\n (1): RNN(\n (cell): LSTMCell(32, 32)\n )\n)' + ) + + module = nn.GRU(16, 32, 2) + self.assertEqual( + str(module), + 'GRU(16, 32, num_layers=2\n (0): RNN(\n (cell): GRUCell(16, 32)\n )\n (1): RNN(\n (cell): GRUCell(32, 32)\n )\n)' + ) + + module1 = nn.Sequential( + ('conv1', nn.Conv2D(1, 20, 5)), ('relu1', nn.ReLU()), + ('conv2', nn.Conv2D(20, 64, 5)), ('relu2', nn.ReLU())) + self.assertEqual( + str(module1), + 'Sequential(\n '\ + '(conv1): Conv2D(1, 20, kernel_size=[5, 5], data_format=NCHW)\n '\ + '(relu1): ReLU()\n '\ + '(conv2): Conv2D(20, 64, kernel_size=[5, 5], data_format=NCHW)\n '\ + '(relu2): ReLU()\n)' + ) + + module2 = nn.Sequential( + nn.Conv3DTranspose(4, 6, (3, 3, 3)), + nn.AvgPool3D( + kernel_size=2, stride=2, padding=0), + nn.Tanh(name="Tanh"), + module1, + nn.Conv3D(4, 6, (3, 3, 3)), + nn.MaxPool3D( + kernel_size=2, stride=2, padding=0), + nn.GELU(True)) + self.assertEqual( + str(module2), + 'Sequential(\n '\ + '(0): Conv3DTranspose(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW)\n '\ + '(1): AvgPool3D(kernel_size=2, stride=2, padding=0)\n '\ + '(2): Tanh(name=Tanh)\n '\ + '(3): Sequential(\n (conv1): Conv2D(1, 20, kernel_size=[5, 5], data_format=NCHW)\n (relu1): ReLU()\n'\ + ' (conv2): Conv2D(20, 64, kernel_size=[5, 5], data_format=NCHW)\n (relu2): ReLU()\n )\n '\ + '(4): Conv3D(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW)\n '\ + '(5): MaxPool3D(kernel_size=2, stride=2, padding=0)\n '\ + '(6): GELU(approximate=True)\n)' + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index 3350ab64057a3832b835070c244dcd7d1ef12164..69cdb7381716b5a1866a8519b4ee7662dfb7b2bb 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -86,6 +86,10 @@ class ELU(layers.Layer): def forward(self, x): return F.elu(x, self._alpha, self._name) + def extra_repr(self): + name_str = ', name={}'.format(self._name) if self._name else '' + return 'alpha={}{}'.format(self._alpha, name_str) + class GELU(layers.Layer): r""" @@ -135,6 +139,10 @@ class GELU(layers.Layer): def forward(self, x): return F.gelu(x, self._approximate, self._name) + def extra_repr(self): + name_str = ', name={}'.format(self._name) if self._name else '' + return 'approximate={}{}'.format(self._approximate, name_str) + class Hardshrink(layers.Layer): r""" @@ -179,6 +187,10 @@ class Hardshrink(layers.Layer): def forward(self, x): return F.hardshrink(x, self._threshold, self._name) + def extra_repr(self): + name_str = ', name={}'.format(self._name) if self._name else '' + return 'threshold={}{}'.format(self._threshold, name_str) + class Hardswish(layers.Layer): r""" @@ -225,6 +237,10 @@ class Hardswish(layers.Layer): def forward(self, x): return F.hardswish(x, self._name) + def extra_repr(self): + name_str = 'name={}'.format(self._name) if self._name else '' + return name_str + class Tanh(layers.Layer): r""" @@ -262,6 +278,10 @@ class Tanh(layers.Layer): def forward(self, x): return F.tanh(x, self._name) + def extra_repr(self): + name_str = 'name={}'.format(self._name) if self._name else '' + return name_str + class Hardtanh(layers.Layer): r""" @@ -304,6 +324,10 @@ class Hardtanh(layers.Layer): def forward(self, x): return F.hardtanh(x, self._min, self._max, self._name) + def extra_repr(self): + name_str = ', name={}'.format(self._name) if self._name else '' + return 'min={}, max={}{}'.format(self._min, self._max, name_str) + class PReLU(layers.Layer): """ @@ -371,6 +395,11 @@ class PReLU(layers.Layer): def forward(self, x): return F.prelu(x, self._weight) + def extra_repr(self): + name_str = ', name={}'.format(self._name) if self._name else '' + return 'num_parameters={}, init={}, dtype={}{}'.format( + self._num_parameters, self._init, self._dtype, name_str) + class ReLU(layers.Layer): """ @@ -405,6 +434,10 @@ class ReLU(layers.Layer): def forward(self, x): return F.relu(x, self._name) + def extra_repr(self): + name_str = 'name={}'.format(self._name) if self._name else '' + return name_str + class ReLU6(layers.Layer): """ @@ -440,6 +473,10 @@ class ReLU6(layers.Layer): def forward(self, x): return F.relu6(x, self._name) + def extra_repr(self): + name_str = 'name={}'.format(self._name) if self._name else '' + return name_str + class SELU(layers.Layer): r""" @@ -486,6 +523,11 @@ class SELU(layers.Layer): def forward(self, x): return F.selu(x, self._scale, self._alpha, self._name) + def extra_repr(self): + name_str = ', name={}'.format(self._name) if self._name else '' + return 'scale={:.16f}, alpha={:.16f}{}'.format(self._scale, self._alpha, + name_str) + class LeakyReLU(layers.Layer): r""" @@ -530,6 +572,10 @@ class LeakyReLU(layers.Layer): def forward(self, x): return F.leaky_relu(x, self._negative_slope, self._name) + def extra_repr(self): + name_str = ', name={}'.format(self._name) if self._name else '' + return 'negative_slope={}{}'.format(self._negative_slope, name_str) + class Sigmoid(layers.Layer): """ @@ -566,6 +612,10 @@ class Sigmoid(layers.Layer): def forward(self, x): return F.sigmoid(x, self.name) + def extra_repr(self): + name_str = 'name={}'.format(self.name) if self.name else '' + return name_str + class Hardsigmoid(layers.Layer): r""" @@ -613,6 +663,10 @@ class Hardsigmoid(layers.Layer): def forward(self, x): return F.hardsigmoid(x, name=self.name) + def extra_repr(self): + name_str = 'name={}'.format(self.name) if self.name else '' + return name_str + class Softplus(layers.Layer): r""" @@ -653,6 +707,11 @@ class Softplus(layers.Layer): def forward(self, x): return F.softplus(x, self._beta, self._threshold, self._name) + def extra_repr(self): + name_str = ', name={}'.format(self._name) if self._name else '' + return 'beta={}, threshold={}{}'.format(self._beta, self._threshold, + name_str) + class Softshrink(layers.Layer): r""" @@ -694,6 +753,10 @@ class Softshrink(layers.Layer): def forward(self, x): return F.softshrink(x, self._threshold, self._name) + def extra_repr(self): + name_str = ', name={}'.format(self._name) if self._name else '' + return 'threshold={}{}'.format(self._threshold, name_str) + class Softsign(layers.Layer): r""" @@ -729,6 +792,10 @@ class Softsign(layers.Layer): def forward(self, x): return F.softsign(x, self._name) + def extra_repr(self): + name_str = 'name={}'.format(self._name) if self._name else '' + return name_str + class Swish(layers.Layer): r""" @@ -764,6 +831,10 @@ class Swish(layers.Layer): def forward(self, x): return F.swish(x, self._name) + def extra_repr(self): + name_str = 'name={}'.format(self._name) if self._name else '' + return name_str + class Tanhshrink(layers.Layer): """ @@ -799,6 +870,10 @@ class Tanhshrink(layers.Layer): def forward(self, x): return F.tanhshrink(x, self._name) + def extra_repr(self): + name_str = 'name={}'.format(self._name) if self._name else '' + return name_str + class ThresholdedReLU(layers.Layer): r""" @@ -839,6 +914,10 @@ class ThresholdedReLU(layers.Layer): def forward(self, x): return F.thresholded_relu(x, self._threshold, self._name) + def extra_repr(self): + name_str = ', name={}'.format(self._name) if self._name else '' + return 'threshold={}{}'.format(self._threshold, name_str) + class LogSigmoid(layers.Layer): r""" @@ -874,6 +953,10 @@ class LogSigmoid(layers.Layer): def forward(self, x): return F.log_sigmoid(x, self._name) + def extra_repr(self): + name_str = 'name={}'.format(self._name) if self._name else '' + return name_str + class Softmax(layers.Layer): r""" @@ -997,6 +1080,10 @@ class Softmax(layers.Layer): def forward(self, x): return F.softmax(x, self._axis, self._dtype, self._name) + def extra_repr(self): + name_str = ', name={}'.format(self._name) if self._name else '' + return 'axis={}{}'.format(self._axis, name_str) + class LogSoftmax(layers.Layer): r""" @@ -1051,6 +1138,10 @@ class LogSoftmax(layers.Layer): def forward(self, x): return F.log_softmax(x, self._axis) + def extra_repr(self): + name_str = ', name={}'.format(self._name) if self._name else '' + return 'axis={}{}'.format(self._axis, name_str) + class Maxout(layers.Layer): r""" @@ -1111,3 +1202,7 @@ class Maxout(layers.Layer): def forward(self, x): return F.maxout(x, self._groups, self._axis, self._name) + + def extra_repr(self): + name_str = ', name={}'.format(self._name) if self._name else '' + return 'groups={}, axis={}{}'.format(self._groups, self._axis, name_str) diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index 9675524f938e544d0e6dbcc94bf8e7a979b070ae..05d619bd729d8754b36ecf8aa8e386ca839860cc 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -120,7 +120,6 @@ class Linear(layers.Layer): self._dtype = self._helper.get_default_dtype() self._weight_attr = weight_attr self._bias_attr = bias_attr - self.name = name self.weight = self.create_parameter( shape=[in_features, out_features], attr=self._weight_attr, @@ -138,6 +137,11 @@ class Linear(layers.Layer): x=input, weight=self.weight, bias=self.bias, name=self.name) return out + def extra_repr(self): + name_str = ', name={}'.format(self.name) if self.name else '' + return 'in_features={}, out_features={}, dtype={}{}'.format( + self.weight.shape[0], self.weight.shape[1], self._dtype, name_str) + class Upsample(layers.Layer): """ @@ -378,6 +382,16 @@ class Upsample(layers.Layer): return out + def extra_repr(self): + if self.scale_factor is not None: + main_str = 'scale_factor={}'.format(self.scale_factor) + else: + main_str = 'size={}'.format(self.size) + name_str = ', name={}'.format(self.name) if self.name else '' + return '{}, mode={}, align_corners={}, align_mode={}, data_format={}{}'.format( + main_str, self.mode, self.align_corners, self.align_mode, + self.data_format, name_str) + class UpsamplingNearest2D(layers.Layer): """ @@ -454,6 +468,15 @@ class UpsamplingNearest2D(layers.Layer): return out + def extra_repr(self): + if self.scale_factor is not None: + main_str = 'scale_factor={}'.format(self.scale_factor) + else: + main_str = 'size={}'.format(self.size) + name_str = ', name={}'.format(self.name) if self.name else '' + return '{}, data_format={}{}'.format(main_str, self.data_format, + name_str) + class UpsamplingBilinear2D(layers.Layer): """ @@ -531,6 +554,15 @@ class UpsamplingBilinear2D(layers.Layer): return out + def extra_repr(self): + if self.scale_factor is not None: + main_str = 'scale_factor={}'.format(self.scale_factor) + else: + main_str = 'size={}'.format(self.size) + name_str = ', name={}'.format(self.name) if self.name else '' + return '{}, data_format={}{}'.format(main_str, self.data_format, + name_str) + class Bilinear(layers.Layer): r""" @@ -620,6 +652,12 @@ class Bilinear(layers.Layer): def forward(self, x1, x2): return F.bilinear(x1, x2, self.weight, self.bias, self._name) + def extra_repr(self): + name_str = ', name={}'.format(self._name) if self._name else '' + return 'in1_features={}, in2_features={}, out_features={}, dtype={}{}'.format( + self._in1_features, self._in2_features, self._out_features, + self._dtype, name_str) + class Dropout(layers.Layer): """ @@ -689,6 +727,11 @@ class Dropout(layers.Layer): name=self.name) return out + def extra_repr(self): + name_str = ', name={}'.format(self.name) if self.name else '' + return 'p={}, axis={}, mode={}{}'.format(self.p, self.axis, self.mode, + name_str) + class Dropout2D(layers.Layer): """ @@ -745,6 +788,11 @@ class Dropout2D(layers.Layer): name=self.name) return out + def extra_repr(self): + name_str = ', name={}'.format(self.name) if self.name else '' + return 'p={}, data_format={}{}'.format(self.p, self.data_format, + name_str) + class Dropout3D(layers.Layer): """ @@ -801,6 +849,11 @@ class Dropout3D(layers.Layer): name=self.name) return out + def extra_repr(self): + name_str = ', name={}'.format(self.name) if self.name else '' + return 'p={}, data_format={}{}'.format(self.p, self.data_format, + name_str) + class AlphaDropout(layers.Layer): """ @@ -850,6 +903,10 @@ class AlphaDropout(layers.Layer): input, p=self.p, training=self.training, name=self.name) return out + def extra_repr(self): + name_str = ', name={}'.format(self.name) if self.name else '' + return 'p={}{}'.format(self.p, name_str) + class Pad1D(layers.Layer): """ @@ -925,6 +982,11 @@ class Pad1D(layers.Layer): data_format=self._data_format, name=self._name) + def extra_repr(self): + name_str = ', name={}'.format(self._name) if self._name else '' + return 'padding={}, mode={}, value={}, data_format={}{}'.format( + self._pad, self._mode, self._value, self._data_format, name_str) + class Pad2D(layers.Layer): """ @@ -1003,6 +1065,11 @@ class Pad2D(layers.Layer): data_format=self._data_format, name=self._name) + def extra_repr(self): + name_str = ', name={}'.format(self._name) if self._name else '' + return 'padding={}, mode={}, value={}, data_format={}{}'.format( + self._pad, self._mode, self._value, self._data_format, name_str) + class Pad3D(layers.Layer): """ @@ -1081,6 +1148,11 @@ class Pad3D(layers.Layer): data_format=self._data_format, name=self._name) + def extra_repr(self): + name_str = ', name={}'.format(self._name) if self._name else '' + return 'padding={}, mode={}, value={}, data_format={}{}'.format( + self._pad, self._mode, self._value, self._data_format, name_str) + class CosineSimilarity(layers.Layer): """ @@ -1135,6 +1207,9 @@ class CosineSimilarity(layers.Layer): def forward(self, x1, x2): return F.cosine_similarity(x1, x2, axis=self._axis, eps=self._eps) + def extra_repr(self): + return 'axis={_axis}, eps={_eps}'.format(**self.__dict__) + class Embedding(layers.Layer): r""" @@ -1288,3 +1363,12 @@ class Embedding(layers.Layer): padding_idx=self._padding_idx, sparse=self._sparse, name=self._name) + + def extra_repr(self): + main_str = '{_num_embeddings}, {_embedding_dim}' + if self._padding_idx is not None: + main_str += ', padding_idx={_padding_idx}' + main_str += ', sparse={_sparse}' + if self._name is not None: + main_str += ', name={_name}' + return main_str.format(**self.__dict__) diff --git a/python/paddle/nn/layer/conv.py b/python/paddle/nn/layer/conv.py index da76f0f11e52cb8bff6fe418cf4cd6e9b964c220..2c6308d11292563bcc512a6cee23b1e0a33d6a3c 100644 --- a/python/paddle/nn/layer/conv.py +++ b/python/paddle/nn/layer/conv.py @@ -148,6 +148,23 @@ class _ConvNd(layers.Layer): self._op_type = 'depthwise_conv2d' self._use_cudnn = False + def extra_repr(self): + main_str = '{_in_channels}, {_out_channels}, kernel_size={_kernel_size}' + if self._stride != [1] * len(self._stride): + main_str += ', stride={_stride}' + if self._padding != 0: + main_str += ', padding={_padding}' + if self._padding_mode is not 'zeros': + main_str += ', padding_mode={_padding_mode}' + if self.output_padding != 0: + main_str += ', output_padding={_output_padding}' + if self._dilation != [1] * len(self._dilation): + main_str += ', dilation={_dilation}' + if self._groups != 1: + main_str += ', groups={_groups}' + main_str += ', data_format={_data_format}' + return main_str.format(**self.__dict__) + class Conv1D(_ConvNd): r""" diff --git a/python/paddle/nn/layer/distance.py b/python/paddle/nn/layer/distance.py index 5a3c611b3c447edcfa34f49c0a38c36b9e7d3ec9..72e0a1b2d6d2009e0edb2674b13299460996c104 100644 --- a/python/paddle/nn/layer/distance.py +++ b/python/paddle/nn/layer/distance.py @@ -100,3 +100,13 @@ class PairwiseDistance(layers.Layer): type='p_norm', inputs={'X': sub}, outputs={'Out': out}, attrs=attrs) return out + + def extra_repr(self): + main_str = 'p={p}' + if self.epsilon != 1e-6: + main_str += ', epsilon={epsilon}' + if self.keepdim != False: + main_str += ', keepdim={keepdim}' + if self.name != None: + main_str += ', name={name}' + return main_str.format(**self.__dict__) diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index 256694c7af67dd23a5056d929b8788676cddcb56..a1cc41f39120ca9918c470488ff90ae443cd92c1 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -107,6 +107,10 @@ class _InstanceNormBase(layers.Layer): return instance_norm( input, weight=self.scale, bias=self.bias, eps=self._epsilon) + def extra_repr(self): + return 'num_features={}, epsilon={}'.format(self.scale.shape[0], + self._epsilon) + class InstanceNorm1D(_InstanceNormBase): r""" @@ -433,6 +437,10 @@ class GroupNorm(layers.Layer): return self._helper.append_activation(group_norm_out, None) + def extra_repr(self): + return 'num_groups={}, num_channels={}, epsilon={}'.format( + self._num_groups, self._num_channels, self._epsilon) + class LayerNorm(layers.Layer): r""" @@ -537,6 +545,10 @@ class LayerNorm(layers.Layer): bias=self.bias, epsilon=self._epsilon) + def extra_repr(self): + return 'normalized_shape={}, epsilon={}'.format(self._normalized_shape, + self._epsilon) + class _BatchNormBase(layers.Layer): """ @@ -647,6 +659,15 @@ class _BatchNormBase(layers.Layer): data_format=self._data_format, use_global_stats=self._use_global_stats) + def extra_repr(self): + main_str = 'num_features={}, momentum={}, epsilon={}'.format( + self._num_features, self._momentum, self._epsilon) + if self._data_format is not 'NCHW': + main_str += ', data_format={}'.format(self._data_format) + if self._name is not None: + main_str += ', name={}'.format(self._name) + return main_str + class BatchNorm1D(_BatchNormBase): r""" @@ -1186,3 +1207,12 @@ class LocalResponseNorm(layers.Layer): out = F.local_response_norm(input, self.size, self.alpha, self.beta, self.k, self.data_format, self.name) return out + + def extra_repr(self): + main_str = 'size={}, alpha={}, beta={}, k={}'.format( + self.size, self.alpha, self.beta, self.k) + if self.data_format is not 'NCHW': + main_str += ', data_format={}'.format(self.data_format) + if self.name is not None: + main_str += ', name={}'.format(self.name) + return main_str diff --git a/python/paddle/nn/layer/pooling.py b/python/paddle/nn/layer/pooling.py index 1d9875d45b40ffbf46d08573775cd8becb1c8eb6..0f3c4449a3f20de271c61be68c0c78b39b19f676 100755 --- a/python/paddle/nn/layer/pooling.py +++ b/python/paddle/nn/layer/pooling.py @@ -119,6 +119,10 @@ class AvgPool1D(layers.Layer): self.exclusive, self.ceil_mode, self.name) return out + def extra_repr(self): + return 'kernel_size={kernel_size}, stride={stride}, padding={padding}'.format( + **self.__dict__) + class AvgPool2D(layers.Layer): r""" @@ -222,6 +226,10 @@ class AvgPool2D(layers.Layer): data_format=self.data_format, name=self.name) + def extra_repr(self): + return 'kernel_size={ksize}, stride={stride}, padding={padding}'.format( + **self.__dict__) + class AvgPool3D(layers.Layer): """ @@ -313,6 +321,10 @@ class AvgPool3D(layers.Layer): data_format=self.data_format, name=self.name) + def extra_repr(self): + return 'kernel_size={ksize}, stride={stride}, padding={padding}'.format( + **self.__dict__) + class MaxPool1D(layers.Layer): """ @@ -401,6 +413,10 @@ class MaxPool1D(layers.Layer): self.return_mask, self.ceil_mode, self.name) return out + def extra_repr(self): + return 'kernel_size={kernel_size}, stride={stride}, padding={padding}'.format( + **self.__dict__) + class MaxPool2D(layers.Layer): r""" @@ -504,6 +520,10 @@ class MaxPool2D(layers.Layer): data_format=self.data_format, name=self.name) + def extra_repr(self): + return 'kernel_size={ksize}, stride={stride}, padding={padding}'.format( + **self.__dict__) + class MaxPool3D(layers.Layer): """ @@ -595,6 +615,10 @@ class MaxPool3D(layers.Layer): data_format=self.data_format, name=self.name) + def extra_repr(self): + return 'kernel_size={ksize}, stride={stride}, padding={padding}'.format( + **self.__dict__) + class AdaptiveAvgPool1D(layers.Layer): r""" @@ -664,6 +688,9 @@ class AdaptiveAvgPool1D(layers.Layer): def forward(self, input): return F.adaptive_avg_pool1d(input, self.output_size, self.name) + def extra_repr(self): + return 'output_size={}'.format(self.output_size) + class AdaptiveAvgPool2D(layers.Layer): r""" @@ -746,6 +773,9 @@ class AdaptiveAvgPool2D(layers.Layer): data_format=self._data_format, name=self._name) + def extra_repr(self): + return 'output_size={}'.format(self._output_size) + class AdaptiveAvgPool3D(layers.Layer): r""" @@ -834,6 +864,9 @@ class AdaptiveAvgPool3D(layers.Layer): data_format=self._data_format, name=self._name) + def extra_repr(self): + return 'output_size={}'.format(self._output_size) + class AdaptiveMaxPool1D(layers.Layer): """ @@ -913,6 +946,10 @@ class AdaptiveMaxPool1D(layers.Layer): return F.adaptive_max_pool1d(input, self.output_size, self.return_mask, self.name) + def extra_repr(self): + return 'output_size={}, return_mask={}'.format(self.output_size, + self.return_mask) + class AdaptiveMaxPool2D(layers.Layer): """ @@ -985,6 +1022,10 @@ class AdaptiveMaxPool2D(layers.Layer): return_mask=self._return_mask, name=self._name) + def extra_repr(self): + return 'output_size={}, return_mask={}'.format(self._output_size, + self._return_mask) + class AdaptiveMaxPool3D(layers.Layer): """ @@ -1067,3 +1108,7 @@ class AdaptiveMaxPool3D(layers.Layer): output_size=self._output_size, return_mask=self._return_mask, name=self._name) + + def extra_repr(self): + return 'output_size={}, return_mask={}'.format(self._output_size, + self._return_mask) diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index fefef52ba6b1988f5942c4258ff85d35421eee3a..c9bb4d245a655d7ab59aa5c2bbcd562d92579209 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -390,6 +390,12 @@ class SimpleRNNCell(RNNCellBase): def state_shape(self): return (self.hidden_size, ) + def extra_repr(self): + s = '{input_size}, {hidden_size}' + if self.activation is not "tanh": + s += ', activation={activation}' + return s.format(**self.__dict__) + class LSTMCell(RNNCellBase): r""" @@ -540,6 +546,9 @@ class LSTMCell(RNNCellBase): """ return ((self.hidden_size, ), (self.hidden_size, )) + def extra_repr(self): + return '{input_size}, {hidden_size}'.format(**self.__dict__) + class GRUCell(RNNCellBase): r""" @@ -684,6 +693,9 @@ class GRUCell(RNNCellBase): """ return (self.hidden_size, ) + def extra_repr(self): + return '{input_size}, {hidden_size}'.format(**self.__dict__) + class RNN(Layer): r""" @@ -1053,6 +1065,16 @@ class RNNBase(LayerList): self.state_components) return outputs, final_states + def extra_repr(self): + main_str = '{input_size}, {hidden_size}' + if self.num_layers != 1: + main_str += ', num_layers={num_layers}' + if self.time_major != False: + main_str += ', time_major={time_major}' + if self.dropout != 0: + main_str += ', dropout={dropout}' + return main_str.format(**self.__dict__) + class SimpleRNN(RNNBase): r""" diff --git a/python/paddle/nn/layer/vision.py b/python/paddle/nn/layer/vision.py index dc1402a4e737aaa992f597c67c608fa78eb87efc..d9c948a848a939c0427c14aee793e2c9c439c47b 100644 --- a/python/paddle/nn/layer/vision.py +++ b/python/paddle/nn/layer/vision.py @@ -79,3 +79,11 @@ class PixelShuffle(layers.Layer): def forward(self, x): return functional.pixel_shuffle(x, self._upscale_factor, self._data_format, self._name) + + def extra_repr(self): + main_str = 'upscale_factor={}'.format(self._upscale_factor) + if self._data_format is not 'NCHW': + main_str += ', data_format={}'.format(self._data_format) + if self._name is not None: + main_str += ', name={}'.format(self._name) + return main_str