From 1f3f02e3db5e2eb6edf3ff5a8cf767b8f0f5602d Mon Sep 17 00:00:00 2001 From: PuQing Date: Tue, 11 Oct 2022 14:53:12 +0800 Subject: [PATCH] fix crop shape None error (#46813) * fix crop shape None error * add test case * fix testcase * fix import * fix testcase * fix codestyle --- .../fluid/tests/unittests/test_crop_op.py | 24 +++++++++++++++++++ python/paddle/tensor/manipulation.py | 6 ++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_crop_op.py b/python/paddle/fluid/tests/unittests/test_crop_op.py index 38bd1b3194d..32e31e5366f 100644 --- a/python/paddle/fluid/tests/unittests/test_crop_op.py +++ b/python/paddle/fluid/tests/unittests/test_crop_op.py @@ -15,6 +15,8 @@ import unittest import numpy as np from op_test import OpTest +import paddle +import paddle.fluid as fluid def crop(data, offsets, crop_shape): @@ -61,6 +63,11 @@ class TestCropOp(OpTest): self.inputs['Offsets'] = np.array(self.offsets).astype('int32') else: self.attrs['offsets'] = self.offsets + if self.offsets is None: + self.offsets = [0] * len(self.crop_shape) + if self.crop_shape is None: + self.crop_shape = self.x_shape + self.outputs = { 'Out': crop(self.inputs['X'], self.offsets, self.crop_shape) } @@ -130,6 +137,23 @@ class TestCase6(TestCropOp): self.offset_by_input = True +class TestCropNoneOffset(unittest.TestCase): + + def test_crop_none_offset(self): + x = fluid.data(name="input1", shape=[3, 6, 6], dtype="float32") + crop_shape = [2, 2, 2] + crop = paddle.crop(x, crop_shape, None) + self.assertEqual(crop.shape, (2, 2, 2)) + + +class TestCropNoneShape(unittest.TestCase): + + def test_crop_none_shape(self): + x = fluid.data(name="input1", shape=[3, 6, 6], dtype="float32") + crop = paddle.crop(x) + self.assertEqual(crop.shape, (3, 6, 6)) + + if __name__ == '__main__': import paddle paddle.enable_static() diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index a161734b1df..d5285351a29 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -634,13 +634,17 @@ def crop(x, shape=None, offsets=None, name=None): helper = LayerHelper('crop_tensor', **locals()) check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'], 'crop_tensor') - check_type(shape, 'shape', (list, tuple, Variable), 'crop_tensor') + check_type(shape, 'shape', (list, tuple, Variable, type(None)), + 'crop_tensor') check_type(offsets, 'offsets', (list, tuple, Variable, type(None)), 'crop_tensor') if offsets is None: offsets = [0] * len(x.shape) + if shape is None: + shape = x.shape + if in_dygraph_mode(): return _C_ops.crop_tensor(x, shape, offsets) -- GitLab