From f9e2a27963c8e1d4f6e137a50a3dac0cda3741c1 Mon Sep 17 00:00:00 2001 From: silingtong123 <35439432+silingtong123@users.noreply.github.com> Date: Fri, 10 Apr 2020 15:05:07 +0800 Subject: [PATCH] error message of SpectralNorm OP enhancement (#23516) --- python/paddle/fluid/dygraph/nn.py | 2 + .../tests/unittests/test_spectral_norm_op.py | 43 +++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index 3da44852274..a5d504c79aa 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -2989,6 +2989,8 @@ class SpectralNorm(layers.Layer): self.weight_v.stop_gradient = True def forward(self, weight): + check_variable_and_dtype(weight, "weight", ['float32', 'float64'], + 'SpectralNorm') inputs = {'Weight': weight, 'U': self.weight_u, 'V': self.weight_v} out = self._helper.create_variable_for_type_inference(self._dtype) self._helper.append_op( diff --git a/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py b/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py index 14c2bf50892..7dd0c762598 100644 --- a/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py @@ -16,9 +16,11 @@ from __future__ import division import unittest import numpy as np +import paddle.fluid as fluid from op_test import OpTest, skip_check_grad_ci from paddle.fluid import core +from paddle.fluid.framework import program_guard, Program def spectral_norm(weight, u, v, dim, power_iters, eps): @@ -125,5 +127,46 @@ class TestSpectralNormOp2(TestSpectralNormOp): self.eps = 1e-12 +class TestSpectralNormOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + + def test_Variable(): + weight_1 = np.random.random((2, 4)).astype("float32") + fluid.layers.spectral_norm(weight_1, dim=1, power_iters=2) + + # the weight type of spectral_norm must be Variable + self.assertRaises(TypeError, test_Variable) + + def test_weight_dtype(): + weight_2 = np.random.random((2, 4)).astype("int32") + fluid.layers.spectral_norm(weight_2, dim=1, power_iters=2) + + # the data type of type must be float32 or float64 + self.assertRaises(TypeError, test_weight_dtype) + + +class TestDygraphSpectralNormOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + shape = (2, 4, 3, 3) + spectralNorm = fluid.dygraph.nn.SpectralNorm( + shape, dim=1, power_iters=2) + + def test_Variable(): + weight_1 = np.random.random((2, 4)).astype("float32") + spectralNorm(weight_1) + + # the weight type of SpectralNorm must be Variable + self.assertRaises(TypeError, test_Variable) + + def test_weight_dtype(): + weight_2 = np.random.random((2, 4)).astype("int32") + spectralNorm(weight_2) + + # the data type of type must be float32 or float64 + self.assertRaises(TypeError, test_weight_dtype) + + if __name__ == "__main__": unittest.main() -- GitLab