未验证 提交 f9e2a279 编写于 作者: S silingtong123 提交者: GitHub

error message of SpectralNorm OP enhancement (#23516)

上级 076dcdfd
......@@ -2989,6 +2989,8 @@ class SpectralNorm(layers.Layer):
self.weight_v.stop_gradient = True
def forward(self, weight):
check_variable_and_dtype(weight, "weight", ['float32', 'float64'],
'SpectralNorm')
inputs = {'Weight': weight, 'U': self.weight_u, 'V': self.weight_v}
out = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
......
......@@ -16,9 +16,11 @@ from __future__ import division
import unittest
import numpy as np
import paddle.fluid as fluid
from op_test import OpTest, skip_check_grad_ci
from paddle.fluid import core
from paddle.fluid.framework import program_guard, Program
def spectral_norm(weight, u, v, dim, power_iters, eps):
......@@ -125,5 +127,46 @@ class TestSpectralNormOp2(TestSpectralNormOp):
self.eps = 1e-12
class TestSpectralNormOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
def test_Variable():
weight_1 = np.random.random((2, 4)).astype("float32")
fluid.layers.spectral_norm(weight_1, dim=1, power_iters=2)
# the weight type of spectral_norm must be Variable
self.assertRaises(TypeError, test_Variable)
def test_weight_dtype():
weight_2 = np.random.random((2, 4)).astype("int32")
fluid.layers.spectral_norm(weight_2, dim=1, power_iters=2)
# the data type of type must be float32 or float64
self.assertRaises(TypeError, test_weight_dtype)
class TestDygraphSpectralNormOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
shape = (2, 4, 3, 3)
spectralNorm = fluid.dygraph.nn.SpectralNorm(
shape, dim=1, power_iters=2)
def test_Variable():
weight_1 = np.random.random((2, 4)).astype("float32")
spectralNorm(weight_1)
# the weight type of SpectralNorm must be Variable
self.assertRaises(TypeError, test_Variable)
def test_weight_dtype():
weight_2 = np.random.random((2, 4)).astype("int32")
spectralNorm(weight_2)
# the data type of type must be float32 or float64
self.assertRaises(TypeError, test_weight_dtype)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册