From d6e888cace07bbe0c8f7290b7b8fccfaf67345a5 Mon Sep 17 00:00:00 2001 From: yaoxuefeng Date: Mon, 24 Aug 2020 16:13:44 +0800 Subject: [PATCH] fix Flatten api test=develop (#26346) --- python/paddle/fluid/dygraph/nn.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index 45744841fc5..69d27e2c234 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -14,6 +14,7 @@ from __future__ import print_function +import paddle from six.moves import reduce from .. import core from ..layers import utils @@ -3457,19 +3458,6 @@ class Flatten(layers.Layer): self.stop_axis = stop_axis def forward(self, input): - out = self._helper.create_variable_for_type_inference(input.dtype) - x_shape = self._helper.create_variable_for_type_inference(input.dtype) - - if in_dygraph_mode(): - dy_out, _ = core.ops.flatten_contiguous_range( - input, 'start_axis', self.start_axis, 'stop_axis', - self.stop_axis) - return dy_out - self._helper.append_op( - type="flatten_contiguous_range", - inputs={"X": input}, - outputs={"Out": out, - "XShape": x_shape}, - attrs={"start_axis": self.start_axis, - "stop_axis": self.stop_axis}) + out = paddle.tensor.manipulation.flatten( + input, start_axis=self.start_axis, stop_axis=self.stop_axis) return out -- GitLab