From f4b2ce44ba3ae4a9bfd8e69002b21417c60758e5 Mon Sep 17 00:00:00 2001 From: Thomas Young <35565423+HexToString@users.noreply.github.com> Date: Wed, 14 Apr 2021 10:50:43 +0800 Subject: [PATCH] fix expand op lack of float16 (#32238) --- python/paddle/fluid/layers/nn.py | 3 ++- python/paddle/tensor/manipulation.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index c4f4754cc7..565c134ae9 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -10332,7 +10332,8 @@ def expand(x, expand_times, name=None): inputs = {"X": [x]} attrs = {} check_variable_and_dtype( - x, 'x', ['bool', 'float32', 'float64', 'int32', 'int64'], 'expand') + x, 'x', ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], + 'expand') check_type(expand_times, 'expand_times', (list, tuple, Variable), 'expand') if convert_dtype(x.dtype) == 'bool' and x.stop_gradient == True: raise ValueError( diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 377435a500..696775434b 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1432,7 +1432,8 @@ def expand(x, shape, name=None): 'Elements in shape must be 1-D Tensors or integers.') check_variable_and_dtype( - x, 'x', ['bool', 'float32', 'float64', 'int32', 'int64'], 'expand') + x, 'x', ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], + 'expand') check_type(shape, 'shape', (list, tuple, Variable), 'expand') if convert_dtype(x.dtype) == 'bool' and x.stop_gradient == False: raise ValueError("When the data type of input 'x' for expand is bool, " -- GitLab