未验证 提交 120e9ed0 编写于 作者: A Aurelius84 提交者: GitHub

[cherry-pick] Fix fp16 in input.dtype check in layers.fc (#20467) test=develop, (#20500)

test=release/1.6

* Add fp16 in input.dtype check test=develop

* Add warning of fp16 in CPU test=develop

* add unittest code for fp16 test=develop

* fix float16 list error test=develop
上级 78adff12
...@@ -355,9 +355,12 @@ def fc(input, ...@@ -355,9 +355,12 @@ def fc(input,
"The type of 'input' in fc must be Variable, but received %s" % "The type of 'input' in fc must be Variable, but received %s" %
(type(input))) (type(input)))
dtype = helper.input_dtype() dtype = helper.input_dtype()
if convert_dtype(dtype) not in ['float32', 'float64']: if convert_dtype(dtype) in ['float16']:
warnings.warn(
"The data type of 'input' in fc only support float16 in GPU now.")
if convert_dtype(dtype) not in ['float16', 'float32', 'float64']:
raise TypeError( raise TypeError(
"The data type of 'input' in fc must be float32 or float64, but received %s." "The data type of 'input' in fc must be float16, float32 or float64, but received %s."
% (convert_dtype(dtype))) % (convert_dtype(dtype)))
mul_results = [] mul_results = []
......
...@@ -148,6 +148,10 @@ class TestFCOpError(OpTest): ...@@ -148,6 +148,10 @@ class TestFCOpError(OpTest):
self.assertRaises(TypeError, test_type) self.assertRaises(TypeError, test_type)
# The input dtype of fc can be float16 in GPU, test for warning
x3 = fluid.layers.data(name='x3', shape=[4], dtype='float16')
fluid.layers.fc(input=x3, size=1)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册