提交 5a8d885d 编写于 作者: Z Zhang Ting 提交者: Aurelius84

All elements in attr(shape) of crop_tensor can be -1 and int32/64 kernel registered (#20756)

* All elements in attr(shape) of crop_tensor can be -1, test=develop, test=document_preview

* fix the bug that attr(offsets) should be initialized, test=develop
上级 9171f737
......@@ -31,8 +31,9 @@ class CropTensorOp : public framework::OperatorWithKernel {
"Input(X) of Op(crop_tensor) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of Op(crop_tensor) should not be null.");
auto x_dim = ctx->GetInputDim("X");
auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
auto offsets = ctx->Attrs().Get<std::vector<int>>("offsets");
if (ctx->HasInputs("ShapeTensor")) {
// top prority shape
auto inputs_name = ctx->Inputs("ShapeTensor");
......@@ -43,15 +44,19 @@ class CropTensorOp : public framework::OperatorWithKernel {
"Op(fluid.layers.crop_tensor).");
auto out_dims = std::vector<int>(inputs_name.size(), -1);
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] != -1) {
if (shape[i] > 0) {
out_dims[i] = static_cast<int64_t>(shape[i]);
} else {
if (shape[i] == -1 && offsets[i] != -1 && x_dim[i] != -1) {
out_dims[i] = x_dim[i] - static_cast<int64_t>(offsets[i]);
}
}
}
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
return;
}
auto x_dim = ctx->GetInputDim("X");
if (ctx->HasInput("Shape")) {
auto shape_dim = ctx->GetInputDim("Shape");
PADDLE_ENFORCE_EQ(
......@@ -78,11 +83,17 @@ class CropTensorOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(int64_t(shape.size()), x_dim.size(),
"Attr(shape)'size of Op(crop_tensor) should be equal to "
"dimention size of input tensor.");
std::vector<int64_t> tensor_shape(shape.size());
std::vector<int64_t> out_shape(shape.size(), -1);
for (size_t i = 0; i < shape.size(); ++i) {
tensor_shape[i] = static_cast<int64_t>(shape[i]);
if (shape[i] > 0) {
out_shape[i] = static_cast<int64_t>(shape[i]);
} else {
if (shape[i] == -1 && offsets[i] != -1 && x_dim[i] != -1) {
out_shape[i] = x_dim[i] - static_cast<int64_t>(offsets[i]);
}
}
}
ctx->SetOutputDim("Out", framework::make_ddim(tensor_shape));
ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
}
framework::OpKernelType GetExpectedKernelType(
......@@ -293,8 +304,12 @@ REGISTER_OPERATOR(crop_tensor_grad, ops::CropTensorOpGrad);
REGISTER_OP_CPU_KERNEL(
crop_tensor,
ops::CropTensorKernel<paddle::platform::CPUDeviceContext, float>,
ops::CropTensorKernel<paddle::platform::CPUDeviceContext, double>);
ops::CropTensorKernel<paddle::platform::CPUDeviceContext, double>,
ops::CropTensorKernel<paddle::platform::CPUDeviceContext, int>,
ops::CropTensorKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
crop_tensor_grad,
ops::CropTensorGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::CropTensorGradKernel<paddle::platform::CPUDeviceContext, double>);
ops::CropTensorGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::CropTensorGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::CropTensorGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
......@@ -17,8 +17,12 @@ namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
crop_tensor,
ops::CropTensorKernel<paddle::platform::CUDADeviceContext, float>,
ops::CropTensorKernel<paddle::platform::CUDADeviceContext, double>);
ops::CropTensorKernel<paddle::platform::CUDADeviceContext, double>,
ops::CropTensorKernel<paddle::platform::CUDADeviceContext, int>,
ops::CropTensorKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
crop_tensor_grad,
ops::CropTensorGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::CropTensorGradKernel<paddle::platform::CUDADeviceContext, double>);
ops::CropTensorGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::CropTensorGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::CropTensorGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
......@@ -50,29 +50,28 @@ inline std::vector<int> get_new_data(
}
static framework::DDim ValidateShape(const std::vector<int> shape,
const std::vector<int> offsets,
const framework::DDim& in_dims) {
auto in_dim_size = in_dims.size();
auto shape_size = shape.size();
PADDLE_ENFORCE_EQ(
in_dim_size, shape_size,
"Input(ShapeTensor)'s dimension size of Op(crop_tensor) should be equal "
"to that of input tensor. "
"Attr(shape)'s size of Op(crop_tensor) should be equal "
"to that of input Tensor. "
"Please check the Attr(shape)'s size of Op(fluid.layers.crop_tensor).");
const int64_t unk_dim_val = -1;
int unk_dim_idx = -1;
std::vector<int64_t> output_shape(shape.size(), 0);
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] == unk_dim_val) {
PADDLE_ENFORCE_EQ(unk_dim_idx, -1,
"Only one element of shape can be unknown.");
PADDLE_ENFORCE_EQ(i, 0, "Only the first element of shape can be -1.");
unk_dim_idx = i;
if (shape[i] <= 0 && in_dims[i] > 0) {
PADDLE_ENFORCE_NE(
shape[i], 0,
"The element in Attr(shape) of Op(crop_tensor) should not be zero.");
PADDLE_ENFORCE_EQ(shape[i], -1,
"When the element in Attr(shape) of Op(crop_tensor) is "
"negative, only -1 is supported.");
output_shape[i] = in_dims[i] - offsets[i];
} else {
PADDLE_ENFORCE_GT(shape[i], 0,
"Each element of shape must be greater than 0 "
"except the first element.");
output_shape[i] = static_cast<int64_t>(shape[i]);
}
output_shape[i] = static_cast<int64_t>(shape[i]);
}
return framework::make_ddim(output_shape);
......@@ -164,21 +163,15 @@ void CropTensorFunction(const framework::ExecutionContext& context) {
shape.push_back(out_dims[i]);
}
}
out_dims = ValidateShape(shape, x->dims());
if (out_dims[0] == -1) {
out_dims[0] = x->dims()[0];
}
out->mutable_data<T>(out_dims, context.GetPlace());
auto x_stride = framework::stride(x->dims());
auto offsets = GetOffsets(context);
int64_t offset = 0;
out_dims = ValidateShape(shape, offsets, x->dims());
out->mutable_data<T>(out_dims, context.GetPlace());
for (size_t i = 0; i < offsets.size(); ++i) {
PADDLE_ENFORCE_LE(
offsets[i] + shape[i], x_dims[i],
"The sum of the Attr(offsets) and Attr(shape) of Op(crop_tensor) "
"should be less than or equal to corresponding input dimension size.");
offset += (x_stride[i] * offsets[i]);
}
auto x_tensor = EigenTensor<T, D>::From(*x);
......
......@@ -11391,7 +11391,7 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
* Case 1 (input is a 2-D Tensor):
Input:
X.shape = [3. 5]
X.shape = [3, 5]
X.data = [[0, 1, 2, 0, 0],
[0, 3, 4, 0, 0],
[0, 0, 0, 0, 0]]
......@@ -11399,8 +11399,9 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
shape = [2, 2]
offsets = [0, 1]
Output:
Out = [[1, 2],
[3, 4]]
Out.shape = [2, 2]
Out.data = [[1, 2],
[3, 4]]
* Case 2 (input is a 3-D Tensor):
Input:
X.shape = [2, 3, 4]
......@@ -11411,24 +11412,23 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
[0, 6, 7, 8],
[0, 0, 0, 0]]]
Parameters:
shape = [2, 2, 3]
shape = [2, 2, -1]
offsets = [0, 0, 1]
Output:
Out = [[[1, 2, 3],
[5, 6, 7]],
[[3, 4, 5],
[6, 7, 8]]]
Out.shape = [2, 2, 3]
Out.data = [[[1, 2, 3],
[5, 6, 7]],
[[3, 4, 5],
[6, 7, 8]]]
Parameters:
x (Variable): 1-D to 6-D Tensor, the data type is float32 or float64.
x (Variable): 1-D to 6-D Tensor, the data type is float32, float64, int32 or int64.
shape (list|tuple|Variable): The output shape is specified
by `shape`. Its data type is int32. If a list/tuple, it's length must be
the same as the dimension size of `x`. If a Variable, it shoule be a 1-D Tensor.
When it is a list, each element can be an integer or a Tensor of shape: [1].
If Variable contained, it is suitable for the case that the shape may
be changed each iteration. Only the first element of list/tuple can be
set to -1, it means that the first dimension's size of the output is the same
as the input.
If Variable contained, it is suitable for the case that the shape may
be changed each iteration.
offsets (list|tuple|Variable, optional): Specifies the cropping
offsets at each dimension. Its data type is int32. If a list/tuple, it's length
must be the same as the dimension size of `x`. If a Variable, it shoule be a 1-D
......@@ -11442,8 +11442,12 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
Variable: The cropped Tensor has same data type with `x`.
Raises:
ValueError: If shape is not a list, tuple or Variable.
ValueError: If offsets is not None and not a list, tuple or Variable.
TypeError: If the data type of `x` is not in: float32, float64, int32, int64.
TypeError: If `shape` is not a list, tuple or Variable.
TypeError: If the data type of `shape` is not int32.
TypeError: If `offsets` is not None and not a list, tuple or Variable.
TypeError: If the data type of `offsets` is not int32.
ValueError: If the element in `offsets` is less than zero.
Examples:
......@@ -11459,7 +11463,7 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
# crop0.shape = [-1, -1, -1], it means crop0.shape[0] = x.shape[0] in runtime.
# or shape is a list in which each element is a constant
crop1 = fluid.layers.crop_tensor(x, shape=[-1, 2, 3])
crop1 = fluid.layers.crop_tensor(x, shape=[-1, -1, 3], offsets=[0, 1, 0])
# crop1.shape = [-1, 2, 3]
# or shape is a list in which each element is a constant or Variable
......@@ -11481,70 +11485,98 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
"""
helper = LayerHelper('crop_tensor', **locals())
if convert_dtype(x.dtype) not in ['float32', 'float64', 'int32', 'int64']:
raise TypeError(
"Input(x)'s dtype of Op(crop_tensor) must be float32, float64, int32 or int64, "
"but received %s." % (convert_dtype(x.dtype)))
if not (isinstance(shape, list) or isinstance(shape, tuple) or \
isinstance(shape, Variable)):
raise ValueError("The shape should be a list, tuple or Variable.")
raise TypeError(
"Attr(shape) of Op(crop_tensor) should be a list, tuple or Variable."
)
if offsets is None:
offsets = [0] * len(x.shape)
if not (isinstance(offsets, list) or isinstance(offsets, tuple) or \
isinstance(offsets, Variable)):
raise ValueError("The offsets should be a list, tuple or Variable.")
raise TypeError(
"Attr(offsets) of Op(crop_tensor) should be a list, tuple or Variable."
)
out = helper.create_variable_for_type_inference(x.dtype)
ipts = {'X': x}
attrs = {}
def contain_var(input_list):
def _contain_var(input_list):
for ele in input_list:
if isinstance(ele, Variable):
return True
return False
def _attr_shape_check(shape_val):
if not isinstance(shape_val, int):
raise TypeError(
"Attr(shape)'s dtype of Op(crop_tensor) should be int32, but received: %s."
% type(shape_val))
if shape_val == 0:
raise ValueError(
"Attr(shape) of Op(crop_tensor) should not be zero, but received: %s."
% str(shape_val))
if shape_val < -1:
raise ValueError(
"When the element in Attr(shape) of Op(crop_tensor) is negative, only -1 is supported, but received: %s."
% str(shape_val))
def _attr_offsets_check(offset_val):
if not isinstance(offset_val, int):
raise TypeError(
"Attr(offsets)'s dtype of Op(crop_tensor) should be int32, but received: %s."
% type(offset_val))
if offset_val < 0:
raise ValueError(
"Attr(offsets) of Op(crop_tensor) should be greater or equal to zero, but received: %s."
% str(offset_val))
if isinstance(offsets, Variable):
offsets.stop_gradient = True
ipts['Offsets'] = offsets
elif contain_var(offsets):
attrs['offsets'] = [-1] * len(x.shape)
elif _contain_var(offsets):
new_offsets_tensor = []
offsets_attr = []
for dim in offsets:
if isinstance(dim, Variable):
dim.stop_gradient = True
new_offsets_tensor.append(dim)
offsets_attr.append(-1)
else:
assert (isinstance(dim, int))
assert dim >= 0, ("offsets should be greater or equal to zero.")
_attr_offsets_check(dim)
temp_out = helper.create_variable_for_type_inference('int32')
fill_constant([1], 'int32', dim, force_cpu=True, out=temp_out)
new_offsets_tensor.append(temp_out)
offsets_attr.append(dim)
ipts['OffsetsTensor'] = new_offsets_tensor
attrs['offsets'] = offsets_attr
else:
for offset in offsets:
_attr_offsets_check(offset)
attrs['offsets'] = offsets
unk_dim_idx = -1
if isinstance(shape, Variable):
shape.stop_gradient = True
ipts['Shape'] = shape
elif contain_var(shape):
elif _contain_var(shape):
new_shape_tensor = []
shape_attr = []
for dim_idx, dim_size in enumerate(shape):
for dim_size in shape:
if isinstance(dim_size, Variable):
dim_size.stop_gradient = True
new_shape_tensor.append(dim_size)
shape_attr.append(-1)
shape_attr.append(0)
else:
assert (isinstance(dim_size, int))
if dim_size == -1:
assert unk_dim_idx == -1, (
"Only one element in shape can be unknown.")
assert dim_idx == 0, (
"Only the first element in shape can be -1.")
unk_dim_idx = dim_idx
else:
assert dim_size > 0, (
"Each dimension size given in shape must be greater than zero."
)
_attr_shape_check(dim_size)
temp_out = helper.create_variable_for_type_inference('int32')
fill_constant(
[1], 'int32', dim_size, force_cpu=True, out=temp_out)
......@@ -11553,6 +11585,8 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
ipts['ShapeTensor'] = new_shape_tensor
attrs['shape'] = shape_attr
else:
for dim_size in shape:
_attr_shape_check(dim_size)
attrs['shape'] = shape
helper.append_op(
......
......@@ -44,13 +44,13 @@ def crop(data, offsets, crop_shape):
class TestCropTensorOp(OpTest):
def setUp(self):
self.op_type = "crop_tensor"
self.crop_by_1D_shape = False
self.shape_by_input = False
self.offset_by_input = False
self.unk_dim_idx = -1
self.attrs = {}
self.initTestCase()
if self.crop_by_1D_shape:
if self.shape_by_input:
self.inputs = {
'X': np.random.random(self.x_shape).astype("float32"),
'Shape': np.array(self.crop_shape).astype("int32")
......@@ -65,11 +65,11 @@ class TestCropTensorOp(OpTest):
else:
self.attrs['offsets'] = self.offsets
if self.unk_dim_idx != -1:
self.crop_shape[self.unk_dim_idx] = self.x_shape[self.unk_dim_idx]
self.outputs = {
'Out': crop(self.inputs['X'], self.offsets, self.crop_shape)
}
crop_shape = [val for val in self.crop_shape]
for i in range(len(self.crop_shape)):
if self.crop_shape[i] == -1:
crop_shape[i] = self.x_shape[i] - self.offsets[i]
self.outputs = {'Out': crop(self.inputs['X'], self.offsets, crop_shape)}
def initTestCase(self):
self.x_shape = (8, 8)
......@@ -93,9 +93,8 @@ class TestCase1(TestCropTensorOp):
class TestCase2(TestCropTensorOp):
def initTestCase(self):
self.x_shape = (12, 24)
self.crop_shape = [-1, 8] #only the first dimension (batch) can be -1
self.crop_shape = [-1, 8]
self.offsets = [0, 0]
self.unk_dim_idx = 0
class TestCase3(TestCropTensorOp):
......@@ -103,16 +102,15 @@ class TestCase3(TestCropTensorOp):
self.x_shape = (4, 8, 16)
self.crop_shape = [2, 2, 3]
self.offsets = [1, 5, 3]
self.crop_by_1D_shape = True
self.shape_by_input = True
class TestCase4(TestCropTensorOp):
def initTestCase(self):
self.x_shape = (8, 3, 6, 6)
self.crop_shape = [-1, 3, 4, 4]
self.offsets = [0, 0, 0, 0]
self.crop_by_1D_shape = True
self.unk_dim_idx = 0
self.crop_shape = [-1, 3, -1, 4]
self.offsets = [0, 0, 1, 0]
self.shape_by_input = True
class TestCase5(TestCropTensorOp):
......@@ -128,14 +126,13 @@ class TestCase6(TestCropTensorOp):
self.x_shape = (2, 2, 4, 4, 4, 2)
self.crop_shape = [1, 1, 4, 2, 2, 2]
self.offsets = [0, 0, 0, 0, 0, 0]
self.crop_by_1D_shape = True
self.shape_by_input = True
self.offset_by_input = True
class TestCropTensorOp_attr_tensor(OpTest):
class TestCropTensorOpTensorAttr(OpTest):
def setUp(self):
self.op_type = "crop_tensor"
self.mixed_type = False
self.OffsetsTensor = False
self.ShapeTensor = True
self.attrs = {}
......@@ -150,8 +147,7 @@ class TestCropTensorOp_attr_tensor(OpTest):
'X': np.random.random(self.x_shape).astype("float32"),
'ShapeTensor': shape_tensor
}
if self.mixed_type:
self.attrs['shape'] = self.shape_attr
self.attrs['shape'] = self.shape_attr
if self.OffsetsTensor:
offsets_tensor = []
......@@ -162,17 +158,21 @@ class TestCropTensorOp_attr_tensor(OpTest):
'X': np.random.random(self.x_shape).astype("float32"),
'OffsetsTensor': offsets_tensor
}
else:
self.attrs['offsets'] = self.offsets
self.attrs['offsets'] = self.offsets_attr
self.outputs = {
'Out': crop(self.inputs['X'], self.offsets, self.crop_shape)
}
self.attrs['shape'] = self.crop_shape
self.attrs['offsets'] = self.offsets
crop_shape = [val for val in self.crop_shape]
for i in range(len(self.crop_shape)):
if self.crop_shape[i] == -1:
crop_shape[i] = self.x_shape[i] - self.offsets[i]
self.outputs = {'Out': crop(self.inputs['X'], self.offsets, crop_shape)}
def initTestCase(self):
self.x_shape = (8, 8)
self.crop_shape = (2, 2)
self.offsets = [1, 2]
self.shape_attr = [0, 0]
def test_check_output(self):
self.check_output()
......@@ -181,38 +181,85 @@ class TestCropTensorOp_attr_tensor(OpTest):
self.check_grad(["X"], "Out", max_relative_error=0.006)
class TestCropTensorOp_attr_tensor_case1(TestCropTensorOp_attr_tensor):
def init_data(self):
class TestCropTensorOpTensorAttrCase1(TestCropTensorOpTensorAttr):
def initTestCase(self):
self.x_shape = (16, 8, 32)
self.crop_shape = [2, 2, 3]
self.crop_shape = [-1, -1, 3]
self.offsets = [1, 5, 3]
self.shape_attr = [-1, -1, 3]
class TestCropTensorOp_attr_tensor_case2(TestCropTensorOp_attr_tensor):
def init_data(self):
class TestCropTensorOpTensorAttrCase2(TestCropTensorOpTensorAttr):
def initTestCase(self):
self.x_shape = (4, 8, 16, 8)
self.crop_shape = [2, 2, 3, 4]
self.offsets = [1, 5, 3, 0]
self.shape_attr = [-1, -1, 3, 4]
self.mixed_type = True
self.shape_attr = [0, 0, 3, 4]
class TestCropTensorOp_attr_tensor_case3(TestCropTensorOp_attr_tensor):
def init_data(self):
class TestCropTensorOpTensorAttrCase3(TestCropTensorOpTensorAttr):
def initTestCase(self):
self.x_shape = (16, 8, 32)
self.crop_shape = [2, 2, 3]
self.offsets = [1, 5, 3]
self.offsets_attr = [-1, -1, 3]
self.ShapeTensor = False
self.OffsetsTensor = True
class TestCropTensorOp_attr_tensor_case4(TestCropTensorOp_attr_tensor):
def init_data(self):
class TestCropTensorOpTensorAttrCase4(TestCropTensorOpTensorAttr):
def initTestCase(self):
self.x_shape = (16, 8, 32)
self.crop_shape = [2, 2, 3]
self.shape_attr = [0, 2, 3]
self.offsets = [1, 5, 3]
self.offsets_attr = [-1, -1, 3]
self.OffsetsTensor = True
class TestCropTensorException(OpTest):
def test_exception(self):
input1 = fluid.data(name="input1", shape=[2, 3, 6, 6], dtype="float32")
input2 = fluid.data(name="input2", shape=[2, 3, 6, 6], dtype="float16")
dim = fluid.data(name='dim', shape=[1], dtype='int32')
offset = fluid.data(name='offset', shape=[1], dtype='int32')
def attr_shape_type():
out = fluid.layers.crop_tensor(input1, shape=3)
def attr_shape_dtype():
out = fluid.layers.crop_tensor(input1, shape=[2, 2.0, 3, 3])
def attr_shape_value1():
out = fluid.layers.crop_tensor(input1, shape=[2, -2, dim, 3])
def attr_shape_value2():
out = fluid.layers.crop_tensor(input1, shape=[2, 0, dim, 3])
def attr_offsets_type():
out = fluid.layers.crop_tensor(
input1, shape=[2, 2, 3, 3], offsets=0)
def attr_offsets_dtype():
out = fluid.layers.crop_tensor(
input1, shape=[2, 2, 3, 3], offsets=[0, 1.0, 0, 0])
def attr_offsets_value():
out = fluid.layers.crop_tensor(
input1, shape=[2, 2, 3, 3], offsets=[0, -1, offset, 0])
def input_dtype():
out = fluid.layers.crop_tensor(input2, shape=[2, 2, 3, 3])
self.assertRaises(TypeError, attr_shape_type)
self.assertRaises(TypeError, attr_shape_dtype)
self.assertRaises(ValueError, attr_shape_value1)
self.assertRaises(ValueError, attr_shape_value2)
self.assertRaises(TypeError, attr_offsets_type)
self.assertRaises(TypeError, attr_offsets_dtype)
self.assertRaises(ValueError, attr_offsets_value)
self.assertRaises(TypeError, input_dtype)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册