diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index db1222fa421ef61e6f68f0d69ad0fe7f5d80f6d5..9de407841fb461713d00f997afdf33a38a531245 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -21,6 +21,7 @@ from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_t from ..fluid.layers.tensor import fill_constant from ..fluid.layers import utils import numpy as np +import six # TODO: define functions to manipulate a tensor from ..fluid.layers import cast #DEFINE_ALIAS from ..fluid.layers import slice #DEFINE_ALIAS @@ -1056,10 +1057,25 @@ def tile(x, repeat_times, name=None): """ if in_dygraph_mode(): return core.ops.tile(x, 'repeat_times', repeat_times) + check_type(repeat_times, 'repeat_times', (list, tuple, Variable), 'tile') + if isinstance(repeat_times, Variable): + assert len(repeat_times.shape) == 1, ( + 'repeat_times must be an 1-D Tensor.') + else: + for elem in repeat_times: + if isinstance(elem, Variable): + assert len(elem.shape) == 1, ( + 'Elements in repeat_times must be 1-D Tensors or integers.') + else: + if six.PY3: + type_tuple = (int, np.int32, np.int64) + elif six.PY2: + type_tuple = (int, long, np.int32, np.int64) + assert isinstance(elem, type_tuple), ( + 'Elements in repeat_times must be 1-D Tensors or integers.') check_variable_and_dtype( x, 'x', ['bool', 'float32', 'float64', 'int32', 'int64'], 'tile') - check_type(repeat_times, 'repeat_times', (list, tuple, Variable), 'tile') if convert_dtype(x.dtype) == 'bool' and x.stop_gradient == False: raise ValueError( "When the date type is bool for the input 'x' of tile op, you " @@ -1181,18 +1197,33 @@ def expand(x, shape, name=None): if in_dygraph_mode(): return core.ops.expand_v2(x, 'shape', shape) + if isinstance(shape, Variable): + assert len(shape.shape) == 1, ('shape must be an 1-D Tensor.') + else: + for elem in shape: + if isinstance(elem, Variable): + assert len(elem.shape) == 1, ( + 'Elements in shape must be 1-D Tensors or integers.') + else: + if six.PY3: + type_tuple = (int, np.int32, np.int64) + elif six.PY2: + type_tuple = (int, long, np.int32, np.int64) + assert isinstance(elem, type_tuple), ( + 'Elements in shape must be 1-D Tensors or integers.') + check_variable_and_dtype( x, 'x', ['bool', 'float32', 'float64', 'int32', 'int64'], 'expand') check_type(shape, 'shape', (list, tuple, Variable), 'expand') - - inputs = {"X": [x]} - attrs = {} if convert_dtype(x.dtype) == 'bool' and x.stop_gradient == False: raise ValueError("When the data type of input 'x' for expand is bool, " "you must set its stop_gradient to be False by " "some_var.stop_gradient = True, supporting " "some_var as the input.") + inputs = {"X": [x]} + attrs = {} + helper = LayerHelper('expand', **locals()) def get_attr_expand_shape(list_expand_shape):