未验证 提交 da803415 编写于 作者: G Guo Sheng 提交者: GitHub

Make layers.reshape/expand/slice in dygraph support var inputs. (#22920)

* Make layers.reshape/expand/slice in dygraph support var inputs.
Make transpose support size 0.
test=develop

* Update layers.expand and layers.slice to support var inputs.
test=develop
上级 a3b02e44
...@@ -5948,20 +5948,13 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None): ...@@ -5948,20 +5948,13 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
warnings.warn( warnings.warn(
"Inplace on reshape is not allowed and will be discarded in dygraph mode currently." "Inplace on reshape is not allowed and will be discarded in dygraph mode currently."
) )
attrs = {}
if isinstance(shape, (list, tuple)): if isinstance(shape, (list, tuple)):
if utils._contain_var(shape): shape = [
raise TypeError( item.numpy()[0] if isinstance(item, Variable) else item
"The type of 'shape' in reshape must be list[int] or tuple(int) in Dygraph mode, but " for item in shape
"received %s, which contains Variable." % type(shape)) ]
attrs['shape'] = shape out, _ = core.ops.reshape2(x, 'shape', shape)
else: return dygraph_utils._append_activation_in_dygraph(out, act)
raise TypeError(
"The type of 'shape' in reshape must be list[int] or tuple(int) in Dygraph mode, but "
"received %s." % type(shape))
out, _ = core.ops.reshape2(x, 'shape', shape)
return dygraph_utils._append_activation_in_dygraph(out, act)
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], 'reshape') x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], 'reshape')
...@@ -9770,16 +9763,12 @@ def expand(x, expand_times, name=None): ...@@ -9770,16 +9763,12 @@ def expand(x, expand_times, name=None):
""" """
if in_dygraph_mode(): if in_dygraph_mode():
if isinstance(expand_times, (list, tuple)): if isinstance(expand_times, (list, tuple)):
if utils._contain_var(expand_times): expand_times = [
raise TypeError( item.numpy()[0] if isinstance(item, Variable) else item
"The type of 'expand_times' in expand must be list[int] or tuple(int) in Dygraph mode, but " for item in expand_times
"received %s, which contains Variable." % type(shape)) ]
else:
raise TypeError(
"The type of 'expand_times' in expand must be list[int] or tuple(int) in Dygraph mode, but "
"received %s." % type(shape))
return core.ops.expand(x, 'expand_times', expand_times) return core.ops.expand(x, 'expand_times', expand_times)
inputs = {"X": [x]} inputs = {"X": [x]}
attrs = {} attrs = {}
...@@ -10318,28 +10307,19 @@ def slice(input, axes, starts, ends): ...@@ -10318,28 +10307,19 @@ def slice(input, axes, starts, ends):
""" """
if in_dygraph_mode(): if in_dygraph_mode():
infer_flags = list(1 for i in range(len(axes))) infer_flags = list(1 for i in range(len(axes)))
if isinstance(starts, (list, tuple)): if isinstance(starts, (list, tuple)) and isinstance(ends,
if utils._contain_var(starts): (list, tuple)):
raise TypeError( starts = [
"The type of 'starts' in slice must be list[int] or tuple(int) in Dygraph mode, but " item.numpy()[0] if isinstance(item, Variable) else item
"received %s, which contains Variable." % type(shape)) for item in starts
else: ]
raise TypeError( ends = [
"The type of 'starts' in slice must be list[int] or tuple(int) in Dygraph mode, but " item.numpy()[0] if isinstance(item, Variable) else item
"received %s." % type(shape)) for item in ends
]
if isinstance(ends, (list, tuple)):
if utils._contain_var(ends):
raise TypeError(
"The type of 'ends' in slice must be list[int] or tuple(int) in Dygraph mode, but "
"received %s, which contains Variable." % type(shape))
else:
raise TypeError(
"The type of 'ends' in slice must be list[int] or tuple(int) in Dygraph mode, but "
"received %s." % type(shape))
return core.ops.slice(input, 'axes', axes, 'starts', starts, 'ends', return core.ops.slice(input, 'axes', axes, 'starts', starts, 'ends',
ends, 'infer_flags', infer_flags) ends, 'infer_flags', infer_flags)
if not isinstance(starts, (list, tuple, Variable)): if not isinstance(starts, (list, tuple, Variable)):
raise ValueError( raise ValueError(
......
...@@ -725,7 +725,7 @@ class BeamSearchDecoder(Decoder): ...@@ -725,7 +725,7 @@ class BeamSearchDecoder(Decoder):
data type is same as `x`. data type is same as `x`.
""" """
# TODO: avoid fake shape in compile-time like tile_beam_merge_with_batch # TODO: avoid fake shape in compile-time like tile_beam_merge_with_batch
return nn.reshape(x, shape=(-1, self.beam_size) + x.shape[1:]) return nn.reshape(x, shape=[-1, self.beam_size] + list(x.shape[1:]))
def _merge_batch_beams(self, x): def _merge_batch_beams(self, x):
""" """
...@@ -741,7 +741,7 @@ class BeamSearchDecoder(Decoder): ...@@ -741,7 +741,7 @@ class BeamSearchDecoder(Decoder):
data type is same as `x`. data type is same as `x`.
""" """
# TODO: avoid fake shape in compile-time like tile_beam_merge_with_batch # TODO: avoid fake shape in compile-time like tile_beam_merge_with_batch
return nn.reshape(x, shape=(-1, ) + x.shape[2:]) return nn.reshape(x, shape=[-1] + list(x.shape[2:]))
def _expand_to_beam_size(self, x): def _expand_to_beam_size(self, x):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册