diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 15acc9f0ccd30c1a6b5ad7282415a9aa9af71571..15bc5e8281515bcb8b92ca4ef9258e9ccf2f5fde 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -5948,20 +5948,13 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None): warnings.warn( "Inplace on reshape is not allowed and will be discarded in dygraph mode currently." ) - attrs = {} if isinstance(shape, (list, tuple)): - if utils._contain_var(shape): - raise TypeError( - "The type of 'shape' in reshape must be list[int] or tuple(int) in Dygraph mode, but " - "received %s, which contains Variable." % type(shape)) - attrs['shape'] = shape - else: - raise TypeError( - "The type of 'shape' in reshape must be list[int] or tuple(int) in Dygraph mode, but " - "received %s." % type(shape)) - - out, _ = core.ops.reshape2(x, 'shape', shape) - return dygraph_utils._append_activation_in_dygraph(out, act) + shape = [ + item.numpy()[0] if isinstance(item, Variable) else item + for item in shape + ] + out, _ = core.ops.reshape2(x, 'shape', shape) + return dygraph_utils._append_activation_in_dygraph(out, act) check_variable_and_dtype( x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], 'reshape') @@ -9770,16 +9763,12 @@ def expand(x, expand_times, name=None): """ if in_dygraph_mode(): if isinstance(expand_times, (list, tuple)): - if utils._contain_var(expand_times): - raise TypeError( - "The type of 'expand_times' in expand must be list[int] or tuple(int) in Dygraph mode, but " - "received %s, which contains Variable." % type(shape)) - else: - raise TypeError( - "The type of 'expand_times' in expand must be list[int] or tuple(int) in Dygraph mode, but " - "received %s." % type(shape)) + expand_times = [ + item.numpy()[0] if isinstance(item, Variable) else item + for item in expand_times + ] - return core.ops.expand(x, 'expand_times', expand_times) + return core.ops.expand(x, 'expand_times', expand_times) inputs = {"X": [x]} attrs = {} @@ -10318,28 +10307,19 @@ def slice(input, axes, starts, ends): """ if in_dygraph_mode(): infer_flags = list(1 for i in range(len(axes))) - if isinstance(starts, (list, tuple)): - if utils._contain_var(starts): - raise TypeError( - "The type of 'starts' in slice must be list[int] or tuple(int) in Dygraph mode, but " - "received %s, which contains Variable." % type(shape)) - else: - raise TypeError( - "The type of 'starts' in slice must be list[int] or tuple(int) in Dygraph mode, but " - "received %s." % type(shape)) - - if isinstance(ends, (list, tuple)): - if utils._contain_var(ends): - raise TypeError( - "The type of 'ends' in slice must be list[int] or tuple(int) in Dygraph mode, but " - "received %s, which contains Variable." % type(shape)) - else: - raise TypeError( - "The type of 'ends' in slice must be list[int] or tuple(int) in Dygraph mode, but " - "received %s." % type(shape)) + if isinstance(starts, (list, tuple)) and isinstance(ends, + (list, tuple)): + starts = [ + item.numpy()[0] if isinstance(item, Variable) else item + for item in starts + ] + ends = [ + item.numpy()[0] if isinstance(item, Variable) else item + for item in ends + ] - return core.ops.slice(input, 'axes', axes, 'starts', starts, 'ends', - ends, 'infer_flags', infer_flags) + return core.ops.slice(input, 'axes', axes, 'starts', starts, 'ends', + ends, 'infer_flags', infer_flags) if not isinstance(starts, (list, tuple, Variable)): raise ValueError( diff --git a/python/paddle/fluid/layers/rnn.py b/python/paddle/fluid/layers/rnn.py index e4596b1ac2f840e7f3a60154c5d97912a9247271..4c048355d2f73a8d627c48dedff149ab56cd029a 100644 --- a/python/paddle/fluid/layers/rnn.py +++ b/python/paddle/fluid/layers/rnn.py @@ -725,7 +725,7 @@ class BeamSearchDecoder(Decoder): data type is same as `x`. """ # TODO: avoid fake shape in compile-time like tile_beam_merge_with_batch - return nn.reshape(x, shape=(-1, self.beam_size) + x.shape[1:]) + return nn.reshape(x, shape=[-1, self.beam_size] + list(x.shape[1:])) def _merge_batch_beams(self, x): """ @@ -741,7 +741,7 @@ class BeamSearchDecoder(Decoder): data type is same as `x`. """ # TODO: avoid fake shape in compile-time like tile_beam_merge_with_batch - return nn.reshape(x, shape=(-1, ) + x.shape[2:]) + return nn.reshape(x, shape=[-1] + list(x.shape[2:])) def _expand_to_beam_size(self, x): """