提交 5a8d885d 编写于 作者: Z Zhang Ting 提交者: Aurelius84

All elements in attr(shape) of crop_tensor can be -1 and int32/64 kernel registered (#20756)

* All elements in attr(shape) of crop_tensor can be -1, test=develop, test=document_preview

* fix the bug that attr(offsets) should be initialized, test=develop
上级 9171f737
...@@ -31,8 +31,9 @@ class CropTensorOp : public framework::OperatorWithKernel { ...@@ -31,8 +31,9 @@ class CropTensorOp : public framework::OperatorWithKernel {
"Input(X) of Op(crop_tensor) should not be null."); "Input(X) of Op(crop_tensor) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of Op(crop_tensor) should not be null."); "Output(Out) of Op(crop_tensor) should not be null.");
auto x_dim = ctx->GetInputDim("X");
auto shape = ctx->Attrs().Get<std::vector<int>>("shape"); auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
auto offsets = ctx->Attrs().Get<std::vector<int>>("offsets");
if (ctx->HasInputs("ShapeTensor")) { if (ctx->HasInputs("ShapeTensor")) {
// top prority shape // top prority shape
auto inputs_name = ctx->Inputs("ShapeTensor"); auto inputs_name = ctx->Inputs("ShapeTensor");
...@@ -43,15 +44,19 @@ class CropTensorOp : public framework::OperatorWithKernel { ...@@ -43,15 +44,19 @@ class CropTensorOp : public framework::OperatorWithKernel {
"Op(fluid.layers.crop_tensor)."); "Op(fluid.layers.crop_tensor).");
auto out_dims = std::vector<int>(inputs_name.size(), -1); auto out_dims = std::vector<int>(inputs_name.size(), -1);
for (size_t i = 0; i < shape.size(); ++i) { for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] != -1) { if (shape[i] > 0) {
out_dims[i] = static_cast<int64_t>(shape[i]); out_dims[i] = static_cast<int64_t>(shape[i]);
} else {
if (shape[i] == -1 && offsets[i] != -1 && x_dim[i] != -1) {
out_dims[i] = x_dim[i] - static_cast<int64_t>(offsets[i]);
}
} }
} }
ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
return; return;
} }
auto x_dim = ctx->GetInputDim("X");
if (ctx->HasInput("Shape")) { if (ctx->HasInput("Shape")) {
auto shape_dim = ctx->GetInputDim("Shape"); auto shape_dim = ctx->GetInputDim("Shape");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -78,11 +83,17 @@ class CropTensorOp : public framework::OperatorWithKernel { ...@@ -78,11 +83,17 @@ class CropTensorOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(int64_t(shape.size()), x_dim.size(), PADDLE_ENFORCE_EQ(int64_t(shape.size()), x_dim.size(),
"Attr(shape)'size of Op(crop_tensor) should be equal to " "Attr(shape)'size of Op(crop_tensor) should be equal to "
"dimention size of input tensor."); "dimention size of input tensor.");
std::vector<int64_t> tensor_shape(shape.size()); std::vector<int64_t> out_shape(shape.size(), -1);
for (size_t i = 0; i < shape.size(); ++i) { for (size_t i = 0; i < shape.size(); ++i) {
tensor_shape[i] = static_cast<int64_t>(shape[i]); if (shape[i] > 0) {
out_shape[i] = static_cast<int64_t>(shape[i]);
} else {
if (shape[i] == -1 && offsets[i] != -1 && x_dim[i] != -1) {
out_shape[i] = x_dim[i] - static_cast<int64_t>(offsets[i]);
}
}
} }
ctx->SetOutputDim("Out", framework::make_ddim(tensor_shape)); ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
...@@ -293,8 +304,12 @@ REGISTER_OPERATOR(crop_tensor_grad, ops::CropTensorOpGrad); ...@@ -293,8 +304,12 @@ REGISTER_OPERATOR(crop_tensor_grad, ops::CropTensorOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
crop_tensor, crop_tensor,
ops::CropTensorKernel<paddle::platform::CPUDeviceContext, float>, ops::CropTensorKernel<paddle::platform::CPUDeviceContext, float>,
ops::CropTensorKernel<paddle::platform::CPUDeviceContext, double>); ops::CropTensorKernel<paddle::platform::CPUDeviceContext, double>,
ops::CropTensorKernel<paddle::platform::CPUDeviceContext, int>,
ops::CropTensorKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
crop_tensor_grad, crop_tensor_grad,
ops::CropTensorGradKernel<paddle::platform::CPUDeviceContext, float>, ops::CropTensorGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::CropTensorGradKernel<paddle::platform::CPUDeviceContext, double>); ops::CropTensorGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::CropTensorGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::CropTensorGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
...@@ -17,8 +17,12 @@ namespace ops = paddle::operators; ...@@ -17,8 +17,12 @@ namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
crop_tensor, crop_tensor,
ops::CropTensorKernel<paddle::platform::CUDADeviceContext, float>, ops::CropTensorKernel<paddle::platform::CUDADeviceContext, float>,
ops::CropTensorKernel<paddle::platform::CUDADeviceContext, double>); ops::CropTensorKernel<paddle::platform::CUDADeviceContext, double>,
ops::CropTensorKernel<paddle::platform::CUDADeviceContext, int>,
ops::CropTensorKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
crop_tensor_grad, crop_tensor_grad,
ops::CropTensorGradKernel<paddle::platform::CUDADeviceContext, float>, ops::CropTensorGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::CropTensorGradKernel<paddle::platform::CUDADeviceContext, double>); ops::CropTensorGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::CropTensorGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::CropTensorGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
...@@ -50,29 +50,28 @@ inline std::vector<int> get_new_data( ...@@ -50,29 +50,28 @@ inline std::vector<int> get_new_data(
} }
static framework::DDim ValidateShape(const std::vector<int> shape, static framework::DDim ValidateShape(const std::vector<int> shape,
const std::vector<int> offsets,
const framework::DDim& in_dims) { const framework::DDim& in_dims) {
auto in_dim_size = in_dims.size(); auto in_dim_size = in_dims.size();
auto shape_size = shape.size(); auto shape_size = shape.size();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_dim_size, shape_size, in_dim_size, shape_size,
"Input(ShapeTensor)'s dimension size of Op(crop_tensor) should be equal " "Attr(shape)'s size of Op(crop_tensor) should be equal "
"to that of input tensor. " "to that of input Tensor. "
"Please check the Attr(shape)'s size of Op(fluid.layers.crop_tensor)."); "Please check the Attr(shape)'s size of Op(fluid.layers.crop_tensor).");
const int64_t unk_dim_val = -1;
int unk_dim_idx = -1;
std::vector<int64_t> output_shape(shape.size(), 0); std::vector<int64_t> output_shape(shape.size(), 0);
for (size_t i = 0; i < shape.size(); ++i) { for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] == unk_dim_val) { if (shape[i] <= 0 && in_dims[i] > 0) {
PADDLE_ENFORCE_EQ(unk_dim_idx, -1, PADDLE_ENFORCE_NE(
"Only one element of shape can be unknown."); shape[i], 0,
PADDLE_ENFORCE_EQ(i, 0, "Only the first element of shape can be -1."); "The element in Attr(shape) of Op(crop_tensor) should not be zero.");
unk_dim_idx = i; PADDLE_ENFORCE_EQ(shape[i], -1,
"When the element in Attr(shape) of Op(crop_tensor) is "
"negative, only -1 is supported.");
output_shape[i] = in_dims[i] - offsets[i];
} else { } else {
PADDLE_ENFORCE_GT(shape[i], 0, output_shape[i] = static_cast<int64_t>(shape[i]);
"Each element of shape must be greater than 0 "
"except the first element.");
} }
output_shape[i] = static_cast<int64_t>(shape[i]);
} }
return framework::make_ddim(output_shape); return framework::make_ddim(output_shape);
...@@ -164,21 +163,15 @@ void CropTensorFunction(const framework::ExecutionContext& context) { ...@@ -164,21 +163,15 @@ void CropTensorFunction(const framework::ExecutionContext& context) {
shape.push_back(out_dims[i]); shape.push_back(out_dims[i]);
} }
} }
out_dims = ValidateShape(shape, x->dims());
if (out_dims[0] == -1) {
out_dims[0] = x->dims()[0];
}
out->mutable_data<T>(out_dims, context.GetPlace());
auto x_stride = framework::stride(x->dims());
auto offsets = GetOffsets(context); auto offsets = GetOffsets(context);
int64_t offset = 0; out_dims = ValidateShape(shape, offsets, x->dims());
out->mutable_data<T>(out_dims, context.GetPlace());
for (size_t i = 0; i < offsets.size(); ++i) { for (size_t i = 0; i < offsets.size(); ++i) {
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
offsets[i] + shape[i], x_dims[i], offsets[i] + shape[i], x_dims[i],
"The sum of the Attr(offsets) and Attr(shape) of Op(crop_tensor) " "The sum of the Attr(offsets) and Attr(shape) of Op(crop_tensor) "
"should be less than or equal to corresponding input dimension size."); "should be less than or equal to corresponding input dimension size.");
offset += (x_stride[i] * offsets[i]);
} }
auto x_tensor = EigenTensor<T, D>::From(*x); auto x_tensor = EigenTensor<T, D>::From(*x);
......
...@@ -11391,7 +11391,7 @@ def crop_tensor(x, shape=None, offsets=None, name=None): ...@@ -11391,7 +11391,7 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
* Case 1 (input is a 2-D Tensor): * Case 1 (input is a 2-D Tensor):
Input: Input:
X.shape = [3. 5] X.shape = [3, 5]
X.data = [[0, 1, 2, 0, 0], X.data = [[0, 1, 2, 0, 0],
[0, 3, 4, 0, 0], [0, 3, 4, 0, 0],
[0, 0, 0, 0, 0]] [0, 0, 0, 0, 0]]
...@@ -11399,8 +11399,9 @@ def crop_tensor(x, shape=None, offsets=None, name=None): ...@@ -11399,8 +11399,9 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
shape = [2, 2] shape = [2, 2]
offsets = [0, 1] offsets = [0, 1]
Output: Output:
Out = [[1, 2], Out.shape = [2, 2]
[3, 4]] Out.data = [[1, 2],
[3, 4]]
* Case 2 (input is a 3-D Tensor): * Case 2 (input is a 3-D Tensor):
Input: Input:
X.shape = [2, 3, 4] X.shape = [2, 3, 4]
...@@ -11411,24 +11412,23 @@ def crop_tensor(x, shape=None, offsets=None, name=None): ...@@ -11411,24 +11412,23 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
[0, 6, 7, 8], [0, 6, 7, 8],
[0, 0, 0, 0]]] [0, 0, 0, 0]]]
Parameters: Parameters:
shape = [2, 2, 3] shape = [2, 2, -1]
offsets = [0, 0, 1] offsets = [0, 0, 1]
Output: Output:
Out = [[[1, 2, 3], Out.shape = [2, 2, 3]
[5, 6, 7]], Out.data = [[[1, 2, 3],
[[3, 4, 5], [5, 6, 7]],
[6, 7, 8]]] [[3, 4, 5],
[6, 7, 8]]]
Parameters: Parameters:
x (Variable): 1-D to 6-D Tensor, the data type is float32 or float64. x (Variable): 1-D to 6-D Tensor, the data type is float32, float64, int32 or int64.
shape (list|tuple|Variable): The output shape is specified shape (list|tuple|Variable): The output shape is specified
by `shape`. Its data type is int32. If a list/tuple, it's length must be by `shape`. Its data type is int32. If a list/tuple, it's length must be
the same as the dimension size of `x`. If a Variable, it shoule be a 1-D Tensor. the same as the dimension size of `x`. If a Variable, it shoule be a 1-D Tensor.
When it is a list, each element can be an integer or a Tensor of shape: [1]. When it is a list, each element can be an integer or a Tensor of shape: [1].
If Variable contained, it is suitable for the case that the shape may If Variable contained, it is suitable for the case that the shape may
be changed each iteration. Only the first element of list/tuple can be be changed each iteration.
set to -1, it means that the first dimension's size of the output is the same
as the input.
offsets (list|tuple|Variable, optional): Specifies the cropping offsets (list|tuple|Variable, optional): Specifies the cropping
offsets at each dimension. Its data type is int32. If a list/tuple, it's length offsets at each dimension. Its data type is int32. If a list/tuple, it's length
must be the same as the dimension size of `x`. If a Variable, it shoule be a 1-D must be the same as the dimension size of `x`. If a Variable, it shoule be a 1-D
...@@ -11442,8 +11442,12 @@ def crop_tensor(x, shape=None, offsets=None, name=None): ...@@ -11442,8 +11442,12 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
Variable: The cropped Tensor has same data type with `x`. Variable: The cropped Tensor has same data type with `x`.
Raises: Raises:
ValueError: If shape is not a list, tuple or Variable. TypeError: If the data type of `x` is not in: float32, float64, int32, int64.
ValueError: If offsets is not None and not a list, tuple or Variable. TypeError: If `shape` is not a list, tuple or Variable.
TypeError: If the data type of `shape` is not int32.
TypeError: If `offsets` is not None and not a list, tuple or Variable.
TypeError: If the data type of `offsets` is not int32.
ValueError: If the element in `offsets` is less than zero.
Examples: Examples:
...@@ -11459,7 +11463,7 @@ def crop_tensor(x, shape=None, offsets=None, name=None): ...@@ -11459,7 +11463,7 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
# crop0.shape = [-1, -1, -1], it means crop0.shape[0] = x.shape[0] in runtime. # crop0.shape = [-1, -1, -1], it means crop0.shape[0] = x.shape[0] in runtime.
# or shape is a list in which each element is a constant # or shape is a list in which each element is a constant
crop1 = fluid.layers.crop_tensor(x, shape=[-1, 2, 3]) crop1 = fluid.layers.crop_tensor(x, shape=[-1, -1, 3], offsets=[0, 1, 0])
# crop1.shape = [-1, 2, 3] # crop1.shape = [-1, 2, 3]
# or shape is a list in which each element is a constant or Variable # or shape is a list in which each element is a constant or Variable
...@@ -11481,70 +11485,98 @@ def crop_tensor(x, shape=None, offsets=None, name=None): ...@@ -11481,70 +11485,98 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
""" """
helper = LayerHelper('crop_tensor', **locals()) helper = LayerHelper('crop_tensor', **locals())
if convert_dtype(x.dtype) not in ['float32', 'float64', 'int32', 'int64']:
raise TypeError(
"Input(x)'s dtype of Op(crop_tensor) must be float32, float64, int32 or int64, "
"but received %s." % (convert_dtype(x.dtype)))
if not (isinstance(shape, list) or isinstance(shape, tuple) or \ if not (isinstance(shape, list) or isinstance(shape, tuple) or \
isinstance(shape, Variable)): isinstance(shape, Variable)):
raise ValueError("The shape should be a list, tuple or Variable.") raise TypeError(
"Attr(shape) of Op(crop_tensor) should be a list, tuple or Variable."
)
if offsets is None: if offsets is None:
offsets = [0] * len(x.shape) offsets = [0] * len(x.shape)
if not (isinstance(offsets, list) or isinstance(offsets, tuple) or \ if not (isinstance(offsets, list) or isinstance(offsets, tuple) or \
isinstance(offsets, Variable)): isinstance(offsets, Variable)):
raise ValueError("The offsets should be a list, tuple or Variable.") raise TypeError(
"Attr(offsets) of Op(crop_tensor) should be a list, tuple or Variable."
)
out = helper.create_variable_for_type_inference(x.dtype) out = helper.create_variable_for_type_inference(x.dtype)
ipts = {'X': x} ipts = {'X': x}
attrs = {} attrs = {}
def contain_var(input_list): def _contain_var(input_list):
for ele in input_list: for ele in input_list:
if isinstance(ele, Variable): if isinstance(ele, Variable):
return True return True
return False return False
def _attr_shape_check(shape_val):
if not isinstance(shape_val, int):
raise TypeError(
"Attr(shape)'s dtype of Op(crop_tensor) should be int32, but received: %s."
% type(shape_val))
if shape_val == 0:
raise ValueError(
"Attr(shape) of Op(crop_tensor) should not be zero, but received: %s."
% str(shape_val))
if shape_val < -1:
raise ValueError(
"When the element in Attr(shape) of Op(crop_tensor) is negative, only -1 is supported, but received: %s."
% str(shape_val))
def _attr_offsets_check(offset_val):
if not isinstance(offset_val, int):
raise TypeError(
"Attr(offsets)'s dtype of Op(crop_tensor) should be int32, but received: %s."
% type(offset_val))
if offset_val < 0:
raise ValueError(
"Attr(offsets) of Op(crop_tensor) should be greater or equal to zero, but received: %s."
% str(offset_val))
if isinstance(offsets, Variable): if isinstance(offsets, Variable):
offsets.stop_gradient = True offsets.stop_gradient = True
ipts['Offsets'] = offsets ipts['Offsets'] = offsets
elif contain_var(offsets): attrs['offsets'] = [-1] * len(x.shape)
elif _contain_var(offsets):
new_offsets_tensor = [] new_offsets_tensor = []
offsets_attr = []
for dim in offsets: for dim in offsets:
if isinstance(dim, Variable): if isinstance(dim, Variable):
dim.stop_gradient = True dim.stop_gradient = True
new_offsets_tensor.append(dim) new_offsets_tensor.append(dim)
offsets_attr.append(-1)
else: else:
assert (isinstance(dim, int)) _attr_offsets_check(dim)
assert dim >= 0, ("offsets should be greater or equal to zero.")
temp_out = helper.create_variable_for_type_inference('int32') temp_out = helper.create_variable_for_type_inference('int32')
fill_constant([1], 'int32', dim, force_cpu=True, out=temp_out) fill_constant([1], 'int32', dim, force_cpu=True, out=temp_out)
new_offsets_tensor.append(temp_out) new_offsets_tensor.append(temp_out)
offsets_attr.append(dim)
ipts['OffsetsTensor'] = new_offsets_tensor ipts['OffsetsTensor'] = new_offsets_tensor
attrs['offsets'] = offsets_attr
else: else:
for offset in offsets:
_attr_offsets_check(offset)
attrs['offsets'] = offsets attrs['offsets'] = offsets
unk_dim_idx = -1
if isinstance(shape, Variable): if isinstance(shape, Variable):
shape.stop_gradient = True shape.stop_gradient = True
ipts['Shape'] = shape ipts['Shape'] = shape
elif contain_var(shape): elif _contain_var(shape):
new_shape_tensor = [] new_shape_tensor = []
shape_attr = [] shape_attr = []
for dim_idx, dim_size in enumerate(shape): for dim_size in shape:
if isinstance(dim_size, Variable): if isinstance(dim_size, Variable):
dim_size.stop_gradient = True dim_size.stop_gradient = True
new_shape_tensor.append(dim_size) new_shape_tensor.append(dim_size)
shape_attr.append(-1) shape_attr.append(0)
else: else:
assert (isinstance(dim_size, int)) _attr_shape_check(dim_size)
if dim_size == -1:
assert unk_dim_idx == -1, (
"Only one element in shape can be unknown.")
assert dim_idx == 0, (
"Only the first element in shape can be -1.")
unk_dim_idx = dim_idx
else:
assert dim_size > 0, (
"Each dimension size given in shape must be greater than zero."
)
temp_out = helper.create_variable_for_type_inference('int32') temp_out = helper.create_variable_for_type_inference('int32')
fill_constant( fill_constant(
[1], 'int32', dim_size, force_cpu=True, out=temp_out) [1], 'int32', dim_size, force_cpu=True, out=temp_out)
...@@ -11553,6 +11585,8 @@ def crop_tensor(x, shape=None, offsets=None, name=None): ...@@ -11553,6 +11585,8 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
ipts['ShapeTensor'] = new_shape_tensor ipts['ShapeTensor'] = new_shape_tensor
attrs['shape'] = shape_attr attrs['shape'] = shape_attr
else: else:
for dim_size in shape:
_attr_shape_check(dim_size)
attrs['shape'] = shape attrs['shape'] = shape
helper.append_op( helper.append_op(
......
...@@ -44,13 +44,13 @@ def crop(data, offsets, crop_shape): ...@@ -44,13 +44,13 @@ def crop(data, offsets, crop_shape):
class TestCropTensorOp(OpTest): class TestCropTensorOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "crop_tensor" self.op_type = "crop_tensor"
self.crop_by_1D_shape = False self.shape_by_input = False
self.offset_by_input = False self.offset_by_input = False
self.unk_dim_idx = -1 self.unk_dim_idx = -1
self.attrs = {} self.attrs = {}
self.initTestCase() self.initTestCase()
if self.crop_by_1D_shape: if self.shape_by_input:
self.inputs = { self.inputs = {
'X': np.random.random(self.x_shape).astype("float32"), 'X': np.random.random(self.x_shape).astype("float32"),
'Shape': np.array(self.crop_shape).astype("int32") 'Shape': np.array(self.crop_shape).astype("int32")
...@@ -65,11 +65,11 @@ class TestCropTensorOp(OpTest): ...@@ -65,11 +65,11 @@ class TestCropTensorOp(OpTest):
else: else:
self.attrs['offsets'] = self.offsets self.attrs['offsets'] = self.offsets
if self.unk_dim_idx != -1: crop_shape = [val for val in self.crop_shape]
self.crop_shape[self.unk_dim_idx] = self.x_shape[self.unk_dim_idx] for i in range(len(self.crop_shape)):
self.outputs = { if self.crop_shape[i] == -1:
'Out': crop(self.inputs['X'], self.offsets, self.crop_shape) crop_shape[i] = self.x_shape[i] - self.offsets[i]
} self.outputs = {'Out': crop(self.inputs['X'], self.offsets, crop_shape)}
def initTestCase(self): def initTestCase(self):
self.x_shape = (8, 8) self.x_shape = (8, 8)
...@@ -93,9 +93,8 @@ class TestCase1(TestCropTensorOp): ...@@ -93,9 +93,8 @@ class TestCase1(TestCropTensorOp):
class TestCase2(TestCropTensorOp): class TestCase2(TestCropTensorOp):
def initTestCase(self): def initTestCase(self):
self.x_shape = (12, 24) self.x_shape = (12, 24)
self.crop_shape = [-1, 8] #only the first dimension (batch) can be -1 self.crop_shape = [-1, 8]
self.offsets = [0, 0] self.offsets = [0, 0]
self.unk_dim_idx = 0
class TestCase3(TestCropTensorOp): class TestCase3(TestCropTensorOp):
...@@ -103,16 +102,15 @@ class TestCase3(TestCropTensorOp): ...@@ -103,16 +102,15 @@ class TestCase3(TestCropTensorOp):
self.x_shape = (4, 8, 16) self.x_shape = (4, 8, 16)
self.crop_shape = [2, 2, 3] self.crop_shape = [2, 2, 3]
self.offsets = [1, 5, 3] self.offsets = [1, 5, 3]
self.crop_by_1D_shape = True self.shape_by_input = True
class TestCase4(TestCropTensorOp): class TestCase4(TestCropTensorOp):
def initTestCase(self): def initTestCase(self):
self.x_shape = (8, 3, 6, 6) self.x_shape = (8, 3, 6, 6)
self.crop_shape = [-1, 3, 4, 4] self.crop_shape = [-1, 3, -1, 4]
self.offsets = [0, 0, 0, 0] self.offsets = [0, 0, 1, 0]
self.crop_by_1D_shape = True self.shape_by_input = True
self.unk_dim_idx = 0
class TestCase5(TestCropTensorOp): class TestCase5(TestCropTensorOp):
...@@ -128,14 +126,13 @@ class TestCase6(TestCropTensorOp): ...@@ -128,14 +126,13 @@ class TestCase6(TestCropTensorOp):
self.x_shape = (2, 2, 4, 4, 4, 2) self.x_shape = (2, 2, 4, 4, 4, 2)
self.crop_shape = [1, 1, 4, 2, 2, 2] self.crop_shape = [1, 1, 4, 2, 2, 2]
self.offsets = [0, 0, 0, 0, 0, 0] self.offsets = [0, 0, 0, 0, 0, 0]
self.crop_by_1D_shape = True self.shape_by_input = True
self.offset_by_input = True self.offset_by_input = True
class TestCropTensorOp_attr_tensor(OpTest): class TestCropTensorOpTensorAttr(OpTest):
def setUp(self): def setUp(self):
self.op_type = "crop_tensor" self.op_type = "crop_tensor"
self.mixed_type = False
self.OffsetsTensor = False self.OffsetsTensor = False
self.ShapeTensor = True self.ShapeTensor = True
self.attrs = {} self.attrs = {}
...@@ -150,8 +147,7 @@ class TestCropTensorOp_attr_tensor(OpTest): ...@@ -150,8 +147,7 @@ class TestCropTensorOp_attr_tensor(OpTest):
'X': np.random.random(self.x_shape).astype("float32"), 'X': np.random.random(self.x_shape).astype("float32"),
'ShapeTensor': shape_tensor 'ShapeTensor': shape_tensor
} }
if self.mixed_type: self.attrs['shape'] = self.shape_attr
self.attrs['shape'] = self.shape_attr
if self.OffsetsTensor: if self.OffsetsTensor:
offsets_tensor = [] offsets_tensor = []
...@@ -162,17 +158,21 @@ class TestCropTensorOp_attr_tensor(OpTest): ...@@ -162,17 +158,21 @@ class TestCropTensorOp_attr_tensor(OpTest):
'X': np.random.random(self.x_shape).astype("float32"), 'X': np.random.random(self.x_shape).astype("float32"),
'OffsetsTensor': offsets_tensor 'OffsetsTensor': offsets_tensor
} }
else: self.attrs['offsets'] = self.offsets_attr
self.attrs['offsets'] = self.offsets
self.outputs = { self.attrs['shape'] = self.crop_shape
'Out': crop(self.inputs['X'], self.offsets, self.crop_shape) self.attrs['offsets'] = self.offsets
} crop_shape = [val for val in self.crop_shape]
for i in range(len(self.crop_shape)):
if self.crop_shape[i] == -1:
crop_shape[i] = self.x_shape[i] - self.offsets[i]
self.outputs = {'Out': crop(self.inputs['X'], self.offsets, crop_shape)}
def initTestCase(self): def initTestCase(self):
self.x_shape = (8, 8) self.x_shape = (8, 8)
self.crop_shape = (2, 2) self.crop_shape = (2, 2)
self.offsets = [1, 2] self.offsets = [1, 2]
self.shape_attr = [0, 0]
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -181,38 +181,85 @@ class TestCropTensorOp_attr_tensor(OpTest): ...@@ -181,38 +181,85 @@ class TestCropTensorOp_attr_tensor(OpTest):
self.check_grad(["X"], "Out", max_relative_error=0.006) self.check_grad(["X"], "Out", max_relative_error=0.006)
class TestCropTensorOp_attr_tensor_case1(TestCropTensorOp_attr_tensor): class TestCropTensorOpTensorAttrCase1(TestCropTensorOpTensorAttr):
def init_data(self): def initTestCase(self):
self.x_shape = (16, 8, 32) self.x_shape = (16, 8, 32)
self.crop_shape = [2, 2, 3] self.crop_shape = [-1, -1, 3]
self.offsets = [1, 5, 3] self.offsets = [1, 5, 3]
self.shape_attr = [-1, -1, 3]
class TestCropTensorOp_attr_tensor_case2(TestCropTensorOp_attr_tensor): class TestCropTensorOpTensorAttrCase2(TestCropTensorOpTensorAttr):
def init_data(self): def initTestCase(self):
self.x_shape = (4, 8, 16, 8) self.x_shape = (4, 8, 16, 8)
self.crop_shape = [2, 2, 3, 4] self.crop_shape = [2, 2, 3, 4]
self.offsets = [1, 5, 3, 0] self.offsets = [1, 5, 3, 0]
self.shape_attr = [-1, -1, 3, 4] self.shape_attr = [0, 0, 3, 4]
self.mixed_type = True
class TestCropTensorOp_attr_tensor_case3(TestCropTensorOp_attr_tensor): class TestCropTensorOpTensorAttrCase3(TestCropTensorOpTensorAttr):
def init_data(self): def initTestCase(self):
self.x_shape = (16, 8, 32) self.x_shape = (16, 8, 32)
self.crop_shape = [2, 2, 3] self.crop_shape = [2, 2, 3]
self.offsets = [1, 5, 3] self.offsets = [1, 5, 3]
self.offsets_attr = [-1, -1, 3]
self.ShapeTensor = False self.ShapeTensor = False
self.OffsetsTensor = True self.OffsetsTensor = True
class TestCropTensorOp_attr_tensor_case4(TestCropTensorOp_attr_tensor): class TestCropTensorOpTensorAttrCase4(TestCropTensorOpTensorAttr):
def init_data(self): def initTestCase(self):
self.x_shape = (16, 8, 32) self.x_shape = (16, 8, 32)
self.crop_shape = [2, 2, 3] self.crop_shape = [2, 2, 3]
self.shape_attr = [0, 2, 3]
self.offsets = [1, 5, 3] self.offsets = [1, 5, 3]
self.offsets_attr = [-1, -1, 3]
self.OffsetsTensor = True self.OffsetsTensor = True
class TestCropTensorException(OpTest):
def test_exception(self):
input1 = fluid.data(name="input1", shape=[2, 3, 6, 6], dtype="float32")
input2 = fluid.data(name="input2", shape=[2, 3, 6, 6], dtype="float16")
dim = fluid.data(name='dim', shape=[1], dtype='int32')
offset = fluid.data(name='offset', shape=[1], dtype='int32')
def attr_shape_type():
out = fluid.layers.crop_tensor(input1, shape=3)
def attr_shape_dtype():
out = fluid.layers.crop_tensor(input1, shape=[2, 2.0, 3, 3])
def attr_shape_value1():
out = fluid.layers.crop_tensor(input1, shape=[2, -2, dim, 3])
def attr_shape_value2():
out = fluid.layers.crop_tensor(input1, shape=[2, 0, dim, 3])
def attr_offsets_type():
out = fluid.layers.crop_tensor(
input1, shape=[2, 2, 3, 3], offsets=0)
def attr_offsets_dtype():
out = fluid.layers.crop_tensor(
input1, shape=[2, 2, 3, 3], offsets=[0, 1.0, 0, 0])
def attr_offsets_value():
out = fluid.layers.crop_tensor(
input1, shape=[2, 2, 3, 3], offsets=[0, -1, offset, 0])
def input_dtype():
out = fluid.layers.crop_tensor(input2, shape=[2, 2, 3, 3])
self.assertRaises(TypeError, attr_shape_type)
self.assertRaises(TypeError, attr_shape_dtype)
self.assertRaises(ValueError, attr_shape_value1)
self.assertRaises(ValueError, attr_shape_value2)
self.assertRaises(TypeError, attr_offsets_type)
self.assertRaises(TypeError, attr_offsets_dtype)
self.assertRaises(ValueError, attr_offsets_value)
self.assertRaises(TypeError, input_dtype)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册