未验证 提交 f4b2ce44 编写于 作者: T Thomas Young 提交者: GitHub

fix expand op lack of float16 (#32238)

上级 4281eb49
...@@ -10332,7 +10332,8 @@ def expand(x, expand_times, name=None): ...@@ -10332,7 +10332,8 @@ def expand(x, expand_times, name=None):
inputs = {"X": [x]} inputs = {"X": [x]}
attrs = {} attrs = {}
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['bool', 'float32', 'float64', 'int32', 'int64'], 'expand') x, 'x', ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
'expand')
check_type(expand_times, 'expand_times', (list, tuple, Variable), 'expand') check_type(expand_times, 'expand_times', (list, tuple, Variable), 'expand')
if convert_dtype(x.dtype) == 'bool' and x.stop_gradient == True: if convert_dtype(x.dtype) == 'bool' and x.stop_gradient == True:
raise ValueError( raise ValueError(
......
...@@ -1432,7 +1432,8 @@ def expand(x, shape, name=None): ...@@ -1432,7 +1432,8 @@ def expand(x, shape, name=None):
'Elements in shape must be 1-D Tensors or integers.') 'Elements in shape must be 1-D Tensors or integers.')
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['bool', 'float32', 'float64', 'int32', 'int64'], 'expand') x, 'x', ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
'expand')
check_type(shape, 'shape', (list, tuple, Variable), 'expand') check_type(shape, 'shape', (list, tuple, Variable), 'expand')
if convert_dtype(x.dtype) == 'bool' and x.stop_gradient == False: if convert_dtype(x.dtype) == 'bool' and x.stop_gradient == False:
raise ValueError("When the data type of input 'x' for expand is bool, " raise ValueError("When the data type of input 'x' for expand is bool, "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册