未验证 提交 1f3f02e3 编写于 作者: P PuQing 提交者: GitHub

fix crop shape None error (#46813)

* fix crop shape None error

* add test case

* fix testcase

* fix import

* fix testcase

* fix codestyle
上级 ceea5d02
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle
import paddle.fluid as fluid
def crop(data, offsets, crop_shape): def crop(data, offsets, crop_shape):
...@@ -61,6 +63,11 @@ class TestCropOp(OpTest): ...@@ -61,6 +63,11 @@ class TestCropOp(OpTest):
self.inputs['Offsets'] = np.array(self.offsets).astype('int32') self.inputs['Offsets'] = np.array(self.offsets).astype('int32')
else: else:
self.attrs['offsets'] = self.offsets self.attrs['offsets'] = self.offsets
if self.offsets is None:
self.offsets = [0] * len(self.crop_shape)
if self.crop_shape is None:
self.crop_shape = self.x_shape
self.outputs = { self.outputs = {
'Out': crop(self.inputs['X'], self.offsets, self.crop_shape) 'Out': crop(self.inputs['X'], self.offsets, self.crop_shape)
} }
...@@ -130,6 +137,23 @@ class TestCase6(TestCropOp): ...@@ -130,6 +137,23 @@ class TestCase6(TestCropOp):
self.offset_by_input = True self.offset_by_input = True
class TestCropNoneOffset(unittest.TestCase):
def test_crop_none_offset(self):
x = fluid.data(name="input1", shape=[3, 6, 6], dtype="float32")
crop_shape = [2, 2, 2]
crop = paddle.crop(x, crop_shape, None)
self.assertEqual(crop.shape, (2, 2, 2))
class TestCropNoneShape(unittest.TestCase):
def test_crop_none_shape(self):
x = fluid.data(name="input1", shape=[3, 6, 6], dtype="float32")
crop = paddle.crop(x)
self.assertEqual(crop.shape, (3, 6, 6))
if __name__ == '__main__': if __name__ == '__main__':
import paddle import paddle
paddle.enable_static() paddle.enable_static()
......
...@@ -634,13 +634,17 @@ def crop(x, shape=None, offsets=None, name=None): ...@@ -634,13 +634,17 @@ def crop(x, shape=None, offsets=None, name=None):
helper = LayerHelper('crop_tensor', **locals()) helper = LayerHelper('crop_tensor', **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'], check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'],
'crop_tensor') 'crop_tensor')
check_type(shape, 'shape', (list, tuple, Variable), 'crop_tensor') check_type(shape, 'shape', (list, tuple, Variable, type(None)),
'crop_tensor')
check_type(offsets, 'offsets', (list, tuple, Variable, type(None)), check_type(offsets, 'offsets', (list, tuple, Variable, type(None)),
'crop_tensor') 'crop_tensor')
if offsets is None: if offsets is None:
offsets = [0] * len(x.shape) offsets = [0] * len(x.shape)
if shape is None:
shape = x.shape
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.crop_tensor(x, shape, offsets) return _C_ops.crop_tensor(x, shape, offsets)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册