未验证 提交 d6e888ca 编写于 作者: Y yaoxuefeng 提交者: GitHub

fix Flatten api test=develop (#26346)

上级 0d71cffd
......@@ -14,6 +14,7 @@
from __future__ import print_function
import paddle
from six.moves import reduce
from .. import core
from ..layers import utils
......@@ -3457,19 +3458,6 @@ class Flatten(layers.Layer):
self.stop_axis = stop_axis
def forward(self, input):
out = self._helper.create_variable_for_type_inference(input.dtype)
x_shape = self._helper.create_variable_for_type_inference(input.dtype)
if in_dygraph_mode():
dy_out, _ = core.ops.flatten_contiguous_range(
input, 'start_axis', self.start_axis, 'stop_axis',
self.stop_axis)
return dy_out
self._helper.append_op(
type="flatten_contiguous_range",
inputs={"X": input},
outputs={"Out": out,
"XShape": x_shape},
attrs={"start_axis": self.start_axis,
"stop_axis": self.stop_axis})
out = paddle.tensor.manipulation.flatten(
input, start_axis=self.start_axis, stop_axis=self.stop_axis)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册