未验证 提交 120e9ed0 编写于 作者: A Aurelius84 提交者: GitHub

[cherry-pick] Fix fp16 in input.dtype check in layers.fc (#20467) test=develop, (#20500)

test=release/1.6

* Add fp16 in input.dtype check test=develop

* Add warning of fp16 in CPU test=develop

* add unittest code for fp16 test=develop

* fix float16 list error test=develop
上级 78adff12
......@@ -355,9 +355,12 @@ def fc(input,
"The type of 'input' in fc must be Variable, but received %s" %
(type(input)))
dtype = helper.input_dtype()
if convert_dtype(dtype) not in ['float32', 'float64']:
if convert_dtype(dtype) in ['float16']:
warnings.warn(
"The data type of 'input' in fc only support float16 in GPU now.")
if convert_dtype(dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The data type of 'input' in fc must be float32 or float64, but received %s."
"The data type of 'input' in fc must be float16, float32 or float64, but received %s."
% (convert_dtype(dtype)))
mul_results = []
......
......@@ -148,6 +148,10 @@ class TestFCOpError(OpTest):
self.assertRaises(TypeError, test_type)
# The input dtype of fc can be float16 in GPU, test for warning
x3 = fluid.layers.data(name='x3', shape=[4], dtype='float16')
fluid.layers.fc(input=x3, size=1)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册