未验证 提交 d8a09e25 编写于 作者: C Charles-hit 提交者: GitHub

replace reshape op with reshape2 op (#45735)

上级 6df93364
...@@ -163,10 +163,14 @@ class LayerHelperBase(object): ...@@ -163,10 +163,14 @@ class LayerHelperBase(object):
[self.name, 'weight_norm_reshape'])), [self.name, 'weight_norm_reshape'])),
dtype=dtype, dtype=dtype,
persistable=False) persistable=False)
block.append_op(type='reshape', x_shape = block.create_var(name="Xshape", dtype=x.dtype)
block.append_op(type="reshape2",
inputs={'X': x}, inputs={'X': x},
outputs={'Out': out}, attrs={'shape': shape},
attrs={'shape': shape}) outputs={
"Out": out,
"XShape": x_shape
})
return out return out
def __transpose_op(x, def __transpose_op(x,
......
...@@ -1833,7 +1833,7 @@ def eye(num_rows, ...@@ -1833,7 +1833,7 @@ def eye(num_rows,
re_shape = re_shape + [num_rows, num_columns] re_shape = re_shape + [num_rows, num_columns]
expand_times = batch_shape + [1, 1] expand_times = batch_shape + [1, 1]
if _non_static_mode(): if _non_static_mode():
out = _legacy_C_ops.reshape(out, 'shape', re_shape) out, _ = _legacy_C_ops.reshape2(out, None, 'shape', re_shape)
return _legacy_C_ops.expand(out, None, 'expand_times', expand_times) return _legacy_C_ops.expand(out, None, 'expand_times', expand_times)
if not isinstance(batch_shape, list): if not isinstance(batch_shape, list):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册