diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index bf0a2a28364998bacfe43162a337a86938328b9a..a0aed3ee1f1f88f0ac4e0c1fee24b0dd3b74a9d1 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -83,10 +83,14 @@ def select_input(inputs, mask): if isinstance(inputs, list) or isinstance(inputs, tuple): input_dtype = inputs[0].dtype input_shape = inputs[0].shape + input_type = inputs[0].type else: input_dtype = inputs.dtype input_shape = inputs.shape - out = helper.create_variable(dtype=input_dtype, shape=input_shape) + input_type = inputs.type + + out = helper.create_variable( + dtype=input_dtype, shape=input_shape, type=input_type) helper.append_op( type='select_input', inputs={'X': inputs,