未验证 提交 7277df47 编写于 作者: S silingtong123 提交者: GitHub

error message of NCE API enhancement (#23544)

* error message of NCE API enhancement
上级 f10100eb
...@@ -26,6 +26,7 @@ from ..param_attr import ParamAttr ...@@ -26,6 +26,7 @@ from ..param_attr import ParamAttr
from ..initializer import Normal, Constant, NumpyArrayInitializer from ..initializer import Normal, Constant, NumpyArrayInitializer
from .. import unique_name from .. import unique_name
from .layer_object_helper import LayerObjectHelper from .layer_object_helper import LayerObjectHelper
from ..data_feeder import check_variable_and_dtype, check_type
import numpy as np import numpy as np
import numbers import numbers
import logging import logging
...@@ -2019,6 +2020,10 @@ class NCE(layers.Layer): ...@@ -2019,6 +2020,10 @@ class NCE(layers.Layer):
self._inputs['Weight'] = self.weight self._inputs['Weight'] = self.weight
def forward(self, input, label, sample_weight=None): def forward(self, input, label, sample_weight=None):
check_variable_and_dtype(input, "input", ['float32', 'float64'], "NCE")
check_variable_and_dtype(label, "label", ['int64'], "NCE")
check_type(sample_weight, 'sample_weight', (Variable, type(None)),
'NCE')
assert isinstance(input, Variable) assert isinstance(input, Variable)
assert isinstance(label, Variable) assert isinstance(label, Variable)
......
...@@ -252,5 +252,47 @@ class TestNCE_OpError(unittest.TestCase): ...@@ -252,5 +252,47 @@ class TestNCE_OpError(unittest.TestCase):
self.assertRaises(TypeError, fluid.layers.nce, input4, label4, 5) self.assertRaises(TypeError, fluid.layers.nce, input4, label4, 5)
class TestDygraphNCE_OpError(unittest.TestCase):
def test_NCE_errors(self):
with program_guard(Program(), Program()):
nce = fluid.NCE(20, 5)
input1 = fluid.create_lod_tensor(
np.array([0.0, 3.0, 2.0, 4.0]), [[1, 1, 2]], fluid.CPUPlace())
label1 = fluid.layers.data(
name='label1', shape=[-1, 4], dtype="int64")
# the input(input) of NCE layer must be Variable.
self.assertRaises(TypeError, nce, input1, label1)
input2 = fluid.layers.data(
name='input2', shape=[-1, 4], dtype="float32")
label2 = fluid.create_lod_tensor(
np.array([0.0, 3.0, 2.0, 4.0]), [[1, 1, 2]], fluid.CPUPlace())
# the input(label) of NCE layer must be Variable.
self.assertRaises(TypeError, nce, input2, label2)
input3 = fluid.layers.data(
name='input3', shape=[-1, 4], dtype="float16")
label3 = fluid.layers.data(
name='label3', shape=[-1, 1], dtype="int64")
# the data type of input(input) must be float32 or float64.
self.assertRaises(TypeError, nce, input3, label3)
input4 = fluid.layers.data(
name='input4', shape=[-1, 4], dtype="float32")
label4 = fluid.layers.data(
name='label4', shape=[-1, 1], dtype="int32")
# the data type of input(label) must be int64.
self.assertRaises(TypeError, nce, input4, label4)
input5 = fluid.layers.data(
name='input5', shape=[-1, 4], dtype="float32")
label5 = fluid.layers.data(
name='label5', shape=[-1, 1], dtype="int64")
sample_weight = fluid.create_lod_tensor(
np.array([0.0, 3.0, 2.0, 4.0]), [[1, 1, 2]], fluid.CPUPlace())
# the sample_weight of nce must be Variable or None.
self.assertRaises(TypeError, nce, input5, label5, sample_weight)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册