未验证 提交 da803415 编写于 作者: G Guo Sheng 提交者: GitHub

Make layers.reshape/expand/slice in dygraph support var inputs. (#22920)

* Make layers.reshape/expand/slice in dygraph support var inputs.
Make transpose support size 0.
test=develop

* Update layers.expand and layers.slice to support var inputs.
test=develop
上级 a3b02e44
......@@ -5948,18 +5948,11 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
warnings.warn(
"Inplace on reshape is not allowed and will be discarded in dygraph mode currently."
)
attrs = {}
if isinstance(shape, (list, tuple)):
if utils._contain_var(shape):
raise TypeError(
"The type of 'shape' in reshape must be list[int] or tuple(int) in Dygraph mode, but "
"received %s, which contains Variable." % type(shape))
attrs['shape'] = shape
else:
raise TypeError(
"The type of 'shape' in reshape must be list[int] or tuple(int) in Dygraph mode, but "
"received %s." % type(shape))
shape = [
item.numpy()[0] if isinstance(item, Variable) else item
for item in shape
]
out, _ = core.ops.reshape2(x, 'shape', shape)
return dygraph_utils._append_activation_in_dygraph(out, act)
......@@ -9770,14 +9763,10 @@ def expand(x, expand_times, name=None):
"""
if in_dygraph_mode():
if isinstance(expand_times, (list, tuple)):
if utils._contain_var(expand_times):
raise TypeError(
"The type of 'expand_times' in expand must be list[int] or tuple(int) in Dygraph mode, but "
"received %s, which contains Variable." % type(shape))
else:
raise TypeError(
"The type of 'expand_times' in expand must be list[int] or tuple(int) in Dygraph mode, but "
"received %s." % type(shape))
expand_times = [
item.numpy()[0] if isinstance(item, Variable) else item
for item in expand_times
]
return core.ops.expand(x, 'expand_times', expand_times)
......@@ -10318,25 +10307,16 @@ def slice(input, axes, starts, ends):
"""
if in_dygraph_mode():
infer_flags = list(1 for i in range(len(axes)))
if isinstance(starts, (list, tuple)):
if utils._contain_var(starts):
raise TypeError(
"The type of 'starts' in slice must be list[int] or tuple(int) in Dygraph mode, but "
"received %s, which contains Variable." % type(shape))
else:
raise TypeError(
"The type of 'starts' in slice must be list[int] or tuple(int) in Dygraph mode, but "
"received %s." % type(shape))
if isinstance(ends, (list, tuple)):
if utils._contain_var(ends):
raise TypeError(
"The type of 'ends' in slice must be list[int] or tuple(int) in Dygraph mode, but "
"received %s, which contains Variable." % type(shape))
else:
raise TypeError(
"The type of 'ends' in slice must be list[int] or tuple(int) in Dygraph mode, but "
"received %s." % type(shape))
if isinstance(starts, (list, tuple)) and isinstance(ends,
(list, tuple)):
starts = [
item.numpy()[0] if isinstance(item, Variable) else item
for item in starts
]
ends = [
item.numpy()[0] if isinstance(item, Variable) else item
for item in ends
]
return core.ops.slice(input, 'axes', axes, 'starts', starts, 'ends',
ends, 'infer_flags', infer_flags)
......
......@@ -725,7 +725,7 @@ class BeamSearchDecoder(Decoder):
data type is same as `x`.
"""
# TODO: avoid fake shape in compile-time like tile_beam_merge_with_batch
return nn.reshape(x, shape=(-1, self.beam_size) + x.shape[1:])
return nn.reshape(x, shape=[-1, self.beam_size] + list(x.shape[1:]))
def _merge_batch_beams(self, x):
"""
......@@ -741,7 +741,7 @@ class BeamSearchDecoder(Decoder):
data type is same as `x`.
"""
# TODO: avoid fake shape in compile-time like tile_beam_merge_with_batch
return nn.reshape(x, shape=(-1, ) + x.shape[2:])
return nn.reshape(x, shape=[-1] + list(x.shape[2:]))
def _expand_to_beam_size(self, x):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册