diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index 3da448522740d36d2fe4b90e2d29d7b735a0c1b4..a5d504c79aa1ad6eb57b2a106ef234853df87f9a 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -2989,6 +2989,8 @@ class SpectralNorm(layers.Layer): self.weight_v.stop_gradient = True def forward(self, weight): + check_variable_and_dtype(weight, "weight", ['float32', 'float64'], + 'SpectralNorm') inputs = {'Weight': weight, 'U': self.weight_u, 'V': self.weight_v} out = self._helper.create_variable_for_type_inference(self._dtype) self._helper.append_op( diff --git a/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py b/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py index 14c2bf50892546866d20fc25c96dac98a36bdf65..7dd0c7625983ed01ceb9e803b886c45f91840e5d 100644 --- a/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py @@ -16,9 +16,11 @@ from __future__ import division import unittest import numpy as np +import paddle.fluid as fluid from op_test import OpTest, skip_check_grad_ci from paddle.fluid import core +from paddle.fluid.framework import program_guard, Program def spectral_norm(weight, u, v, dim, power_iters, eps): @@ -125,5 +127,46 @@ class TestSpectralNormOp2(TestSpectralNormOp): self.eps = 1e-12 +class TestSpectralNormOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + + def test_Variable(): + weight_1 = np.random.random((2, 4)).astype("float32") + fluid.layers.spectral_norm(weight_1, dim=1, power_iters=2) + + # the weight type of spectral_norm must be Variable + self.assertRaises(TypeError, test_Variable) + + def test_weight_dtype(): + weight_2 = np.random.random((2, 4)).astype("int32") + fluid.layers.spectral_norm(weight_2, dim=1, power_iters=2) + + # the data type of type must be float32 or float64 + self.assertRaises(TypeError, test_weight_dtype) + + +class TestDygraphSpectralNormOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + shape = (2, 4, 3, 3) + spectralNorm = fluid.dygraph.nn.SpectralNorm( + shape, dim=1, power_iters=2) + + def test_Variable(): + weight_1 = np.random.random((2, 4)).astype("float32") + spectralNorm(weight_1) + + # the weight type of SpectralNorm must be Variable + self.assertRaises(TypeError, test_Variable) + + def test_weight_dtype(): + weight_2 = np.random.random((2, 4)).astype("int32") + spectralNorm(weight_2) + + # the data type of type must be float32 or float64 + self.assertRaises(TypeError, test_weight_dtype) + + if __name__ == "__main__": unittest.main()