提交 b3888941 编写于 作者: Z Zhang Ting 提交者: Aurelius84

add crop_tensor_op, test=develop, test=document_preview (#19314)

add crop_tensor op. The main difference with crop is :

1. If the argument shape is a list, each element is an integer or a tensor variable with shape: [1]. This way is suitable for the case that the shape may be changed each iteration.

2. If the argument shape is a variable. Its rank must be 1. In crop op, the rank of shape must be the same as x

offsets can be a list, in which each element is an integer or a tensor variavle with shape: [1].
上级 bf836736
......@@ -204,7 +204,8 @@ paddle.fluid.layers.mean_iou (ArgSpec(args=['input', 'label', 'num_classes'], va
paddle.fluid.layers.relu (ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '0942c174f4f6fb274976d4357356f6a2'))
paddle.fluid.layers.selu (ArgSpec(args=['x', 'scale', 'alpha', 'name'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', 'f93c61f5b0bf933cd425a64dca2c4fdd'))
paddle.fluid.layers.log (ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '02f668664e3bfc4df6c00d7363467140'))
paddle.fluid.layers.crop (ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', 'ddf9837ee83e549119210a3d714d5f44'))
paddle.fluid.layers.crop (ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', 'ba3621917d5beffd3d022b88fbf6dc46'))
paddle.fluid.layers.crop_tensor (ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', 'cb855453e3506bf54c5c013616ffddfb'))
paddle.fluid.layers.rank_loss (ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '8eb36596bb43d7a907d3397c7aedbdb3'))
paddle.fluid.layers.margin_rank_loss (ArgSpec(args=['label', 'left', 'right', 'margin', 'name'], varargs=None, keywords=None, defaults=(0.1, None)), ('document', '6fc86ed23b420c8a0f6c043563cf3937'))
paddle.fluid.layers.elu (ArgSpec(args=['x', 'alpha', 'name'], varargs=None, keywords=None, defaults=(1.0, None)), ('document', '9af1926c06711eacef9e82d7a9e4d308'))
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/crop_tensor_op.h"
#include <memory>
#include <string>
#include <vector>
namespace paddle {
namespace operators {
using framework::Tensor;
class CropTensorOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of Op(crop_tensor) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of Op(crop_tensor) should not be null.");
auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
if (ctx->HasInputs("ShapeTensor")) {
// top prority shape
auto inputs_name = ctx->Inputs("ShapeTensor");
PADDLE_ENFORCE_GT(
inputs_name.size(), 0,
"Input(ShapeTensor)'size of Op(crop_tensor) can't be zero. "
"Please check the Attr(shape)'s size of "
"Op(fluid.layers.crop_tensor).");
auto out_dims = std::vector<int>(inputs_name.size(), -1);
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] != -1) {
out_dims[i] = static_cast<int64_t>(shape[i]);
}
}
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
return;
}
auto x_dim = ctx->GetInputDim("X");
if (ctx->HasInput("Shape")) {
auto shape_dim = ctx->GetInputDim("Shape");
PADDLE_ENFORCE_EQ(
shape_dim.size(), 1,
"Input(Shape)'s dimension size of Op(crop_tensor) must be 1. "
"Please check the Attr(shape)'s dimension size of "
"Op(fluid.layers.crop_tensor).");
PADDLE_ENFORCE_EQ(shape_dim[0], x_dim.size(),
"Input(Shape)'s size of Op(crop_tensor) must be equal "
"to dimension size of input tensor. "
"Please check the Attr(shape)'s size of "
"Op(fluid.layers.crop_tensor).");
if (ctx->IsRuntime()) {
// If true, set the shape of Output(Out) according to Input(Shape) in
// CropTensorKernel with ExecutionContext. Also check LoD in
// CropTensorKernel.
ctx->ShareLoD("X", /*->*/ "Out");
} else {
auto out_dims = std::vector<int>(shape_dim[0], -1);
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
}
return;
}
PADDLE_ENFORCE_EQ(int64_t(shape.size()), x_dim.size(),
"Attr(shape)'size of Op(crop_tensor) should be equal to "
"dimention size of input tensor.");
std::vector<int64_t> tensor_shape(shape.size());
for (size_t i = 0; i < shape.size(); ++i) {
tensor_shape[i] = static_cast<int64_t>(shape[i]);
}
ctx->SetOutputDim("Out", framework::make_ddim(tensor_shape));
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
if (var_name == "ShapeTensor" || var_name == "OffsetsTensor" ||
var_name == "Shape" || var_name == "Offsets") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
class CropTensorOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"The input of pad op. "
"The input should be a k-D tensor(k > 0 and k < 7).");
AddInput("Shape",
"The input used to describe shape of output, which is a "
"1-D vector whose size equals to the rank of input 'X'. The "
"elements data type must be int. It has a higher priority than "
"the shape attribute")
.AsDispensable();
AddInput("Offsets",
"The input used to describe offsets in runtime, which is a "
"1-D vector whose size equals to the rank of input 'X'. The "
"elements data type must be int. It has a higher priority than "
"the offsets attribute")
.AsDispensable();
AddInput("ShapeTensor",
"(vector<Tensor<int32>>, optional). If provided, crop_tensor will "
"use this. The shape of the tensor in vector MUST BE [1]. "
"It has the highest priority compare with Input(Shape) and "
"attr(shape).")
.AsDuplicable()
.AsDispensable();
AddInput("OffsetsTensor",
"(vector<Tensor<int32>>, optional). If provided, crop_tensor will "
"use this. The shape of the tensor in vector MUST BE [1]. "
"It has the highest priority compare with Input(Offsets) and "
"attr(offsets).")
.AsDuplicable()
.AsDispensable();
AddOutput("Out",
"The output of crop_tensor op, "
"which is of the same dimensions as X.");
AddAttr<std::vector<int>>("offsets",
"A list<int> describing offsets to be cropped. "
"The size of offsets list should be the same as "
"the dimension size of input X.")
.SetDefault(std::vector<int>());
AddAttr<std::vector<int>>("shape",
"A list<int> describing the shape of output. "
"The size of shape list should be the same as "
"the dimension size of input X.")
.SetDefault(std::vector<int>());
AddComment(R"DOC(
CropTensor Operator.
Crop input into output, as specified by offsets and shape.
There are three ways to set the offsets:
1. Input 'OffsetsTensor: It is a tensor list. It should be set as a list that
contains tensor variable in python configure script.
This way is suitable for dynamic offsets.
2. Input 'Offsets': It is a variable and can be output of other operators.
This way is suitable for dynamic offsets.
3. Attribute 'offsets': It will be set in python configure script. This way
is suitable for fixed offsets.
You CANNOT use these three ways at the same time. An exception will be raised
if input 'OffsetsTensor' or 'Offset' is configured and meanwhile the attribute 'offsets' is
not empty.
There are three ways to set shape:
1. Input 'ShapeTensor': It is a tensor list. It should be set as a list that contains
tensor variable in python configure script. This way is suitable
for dynamic shape.
2. Input 'Shape': It is a Variable and can be output of other operators. This way is suitable
for dynamic shape.
2. Attribute 'shape': crop input X into the shape described by a list<int>. The size of shape
list should be the same as the dimension size of input X. This way is
suitable for fixed shape.
The input should be a k-D tensor(k > 0 and k < 7). As an example:
Case 1:
Given
X = [[0, 1, 2, 0, 0]
[0, 3, 4, 0, 0]
[0, 0, 0, 0, 0]],
and
offsets = [0, 1],
and
shape = [2, 2],
we get:
Out = [[1, 2],
[3, 4]].
Case 2:
Given
X = [[0, 1, 2, 5, 0]
[0, 3, 4, 6, 0]
[0, 0, 0, 0, 0]],
and offsets is a list that contains tensor variable,
in runtime offses_var' s value is 1.
offsets = [0, offsets_var],
and shape is a list that contains tensor variable,
in runtime dim's value is 2.
shape = [dim, 3]
we get:
Out = [[1, 2, 5],
[3, 4, 6]].
)DOC");
}
};
class CropTensorOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of Op(crop_tensor) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
"Input(Out@GRAD) of Op(crop_tensor) should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
if (var_name == "ShapeTensor" || var_name == "OffsetsTensor" ||
var_name == "Shape" || var_name == "Offsets") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
class CropTensorGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("crop_tensor_grad");
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetInput("X", Input("X"));
if (ForwardOp().Inputs().count("OffsetsTensor") > 0) {
op->SetInput("OffsetsTensor", Input("OffsetsTensor"));
}
if (ForwardOp().Inputs().count("Offsets") > 0) {
op->SetInput("Offsets", Input("Offsets"));
}
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(crop_tensor, ops::CropTensorOp, ops::CropTensorOpMaker,
ops::CropTensorGradOpDescMaker);
REGISTER_OPERATOR(crop_tensor_grad, ops::CropTensorOpGrad);
REGISTER_OP_CPU_KERNEL(
crop_tensor,
ops::CropTensorKernel<paddle::platform::CPUDeviceContext, float>,
ops::CropTensorKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
crop_tensor_grad,
ops::CropTensorGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::CropTensorGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/crop_tensor_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
crop_tensor,
ops::CropTensorKernel<paddle::platform::CUDADeviceContext, float>,
ops::CropTensorKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
crop_tensor_grad,
ops::CropTensorGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::CropTensorGradKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <utility>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/strided_memcpy.h"
namespace paddle {
namespace operators { // Internal
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using framework::Tensor;
inline std::vector<int> get_new_data(
const std::vector<const Tensor*>& list_new_tensor) {
// get tensor from
std::vector<int> vec_new_data;
for (size_t i = 0; i < list_new_tensor.size(); ++i) {
auto tensor = list_new_tensor[i];
PADDLE_ENFORCE_EQ(
tensor->dims(), framework::make_ddim({1}),
"The tensor's shape in list of Op(crop_tensor) should be [1].");
if (platform::is_gpu_place(tensor->place())) {
framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
vec_new_data.push_back(static_cast<int32_t>(*temp.data<int32_t>()));
} else {
vec_new_data.push_back(static_cast<int32_t>(*tensor->data<int32_t>()));
}
}
return vec_new_data;
}
static framework::DDim ValidateShape(const std::vector<int> shape,
const framework::DDim& in_dims) {
auto in_dim_size = in_dims.size();
auto shape_size = shape.size();
PADDLE_ENFORCE_EQ(
in_dim_size, shape_size,
"Input(ShapeTensor)'s dimension size of Op(crop_tensor) should be equal "
"to that of input tensor. "
"Please check the Attr(shape)'s size of Op(fluid.layers.crop_tensor).");
const int64_t unk_dim_val = -1;
int unk_dim_idx = -1;
std::vector<int64_t> output_shape(shape.size(), 0);
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] == unk_dim_val) {
PADDLE_ENFORCE_EQ(unk_dim_idx, -1,
"Only one element of shape can be unknown.");
PADDLE_ENFORCE_EQ(i, 0, "Only the first element of shape can be -1.");
unk_dim_idx = i;
} else {
PADDLE_ENFORCE_GT(shape[i], 0,
"Each element of shape must be greater than 0 "
"except the first element.");
}
output_shape[i] = static_cast<int64_t>(shape[i]);
}
return framework::make_ddim(output_shape);
}
static std::vector<int> GetShape(const framework::ExecutionContext& ctx) {
std::vector<int> res;
int rank = ctx.Input<Tensor>("X")->dims().size();
auto list_new_shape_tensor = ctx.MultiInput<framework::Tensor>("ShapeTensor");
if (list_new_shape_tensor.size() > 0) {
// have offsets tensor list
PADDLE_ENFORCE_EQ(list_new_shape_tensor.size(), rank,
"Input(ShapeTensor)'s length of Op(crop_tensor) should "
"be equal to dimension size of input tensor.");
res = get_new_data(list_new_shape_tensor);
return res;
}
auto* shape_tensor = ctx.HasInput("Shape")
? ctx.Input<framework::LoDTensor>("Shape")
: nullptr;
if (shape_tensor) {
auto* shape_data = shape_tensor->data<int>();
framework::Tensor cpu_shape_tensor;
if (platform::is_gpu_place(shape_tensor->place())) {
TensorCopySync(*shape_tensor, platform::CPUPlace(), &cpu_shape_tensor);
shape_data = cpu_shape_tensor.data<int>();
}
res = std::vector<int>(shape_data, shape_data + shape_tensor->numel());
}
return res;
}
static std::vector<int> GetOffsets(const framework::ExecutionContext& ctx) {
std::vector<int> res;
int rank = ctx.Input<Tensor>("X")->dims().size();
auto list_new_offsets_tensor =
ctx.MultiInput<framework::Tensor>("OffsetsTensor");
if (list_new_offsets_tensor.size() > 0) {
// have offsets tensor list
res = get_new_data(list_new_offsets_tensor);
return res;
}
if (ctx.HasInput("Offsets")) {
PADDLE_ENFORCE_EQ(
ctx.Attr<std::vector<int>>("offsets").empty(), true,
"Input 'Offsets' and attribute 'offsets' should not be used "
"at the same time.");
const auto* offsets_tensor = ctx.Input<Tensor>("Offsets");
PADDLE_ENFORCE_EQ(offsets_tensor->dims().size(), 1);
PADDLE_ENFORCE_EQ(
rank, offsets_tensor->dims()[0],
"Offsets size should be equal to dimension size of input tensor.");
const int* offsets_data;
framework::Tensor cpu_tmp_tensor;
if (platform::is_cpu_place(offsets_tensor->place())) {
offsets_data = offsets_tensor->data<int>();
} else {
framework::TensorCopySync(*offsets_tensor, platform::CPUPlace(),
&cpu_tmp_tensor);
offsets_data = cpu_tmp_tensor.data<int>();
}
res = std::vector<int>(offsets_data, offsets_data + rank);
} else {
res = ctx.Attr<std::vector<int>>("offsets");
PADDLE_ENFORCE_EQ(
rank, static_cast<int>(res.size()),
"Offsets size should be equal to dimension size of input tensor.");
}
return res;
}
template <typename DeviceContext, typename T, size_t D>
void CropTensorFunction(const framework::ExecutionContext& context) {
auto* x = context.Input<Tensor>("X");
auto* out = context.Output<Tensor>("Out");
auto x_dims = x->dims();
auto out_dims = out->dims();
// get shape from Input(ShapeTensor) of Input(Shape)
std::vector<int> shape = GetShape(context);
// out_dims setted by arrt(shape)
if (shape.size() == 0) {
for (size_t i = 0; i < out_dims.size(); ++i) {
shape.push_back(out_dims[i]);
}
}
out_dims = ValidateShape(shape, x->dims());
if (out_dims[0] == -1) {
out_dims[0] = x->dims()[0];
}
out->mutable_data<T>(out_dims, context.GetPlace());
auto x_stride = framework::stride(x->dims());
auto offsets = GetOffsets(context);
int64_t offset = 0;
for (size_t i = 0; i < offsets.size(); ++i) {
PADDLE_ENFORCE_LE(
offsets[i] + shape[i], x_dims[i],
"The sum of the Attr(offsets) and Attr(shape) of Op(crop_tensor) "
"should be less than or equal to corresponding input dimension size.");
offset += (x_stride[i] * offsets[i]);
}
auto x_tensor = EigenTensor<T, D>::From(*x);
auto out_tensor = EigenTensor<T, D>::From(*out);
Eigen::array<int, D> e_offsets;
Eigen::array<int, D> e_shape;
for (size_t i = 0; i < D; ++i) {
e_offsets[i] = offsets[i];
e_shape[i] = out->dims()[i];
}
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
out_tensor.device(place) = x_tensor.slice(e_offsets, e_shape);
}
template <typename DeviceContext, typename T>
class CropTensorKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
int rank = context.Input<Tensor>("X")->dims().size();
switch (rank) {
case 1:
CropTensorFunction<DeviceContext, T, 1>(context);
break;
case 2:
CropTensorFunction<DeviceContext, T, 2>(context);
break;
case 3:
CropTensorFunction<DeviceContext, T, 3>(context);
break;
case 4:
CropTensorFunction<DeviceContext, T, 4>(context);
break;
case 5:
CropTensorFunction<DeviceContext, T, 5>(context);
break;
case 6:
CropTensorFunction<DeviceContext, T, 6>(context);
break;
default:
PADDLE_THROW(
"CropTensorOp only support tensors with no more than 6 "
"dimensions.");
}
}
};
template <typename DeviceContext, typename T, size_t D>
void CropTensorGradFunction(const framework::ExecutionContext& context) {
auto* d_x = context.Output<Tensor>(framework::GradVarName("X"));
auto* x = context.Input<Tensor>("X");
if (d_x != nullptr) {
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
d_x->mutable_data<T>(x->dims(), context.GetPlace());
auto offsets = GetOffsets(context);
Eigen::array<std::pair<int, int>, D> paddings;
for (size_t i = 0; i < D; ++i) {
paddings[i].first = offsets[i];
paddings[i].second = d_x->dims()[i] - d_out->dims()[i] - offsets[i];
}
auto d_x_tensor = EigenTensor<T, D>::From(*d_x);
auto d_out_tensor = EigenTensor<T, D>::From(*d_out);
d_x_tensor.device(
*context.template device_context<DeviceContext>().eigen_device()) =
d_out_tensor.pad(paddings, 0);
}
}
template <typename DeviceContext, typename T>
class CropTensorGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
size_t rank =
context.Input<Tensor>(framework::GradVarName("Out"))->dims().size();
switch (rank) {
case 1:
CropTensorGradFunction<DeviceContext, T, 1>(context);
break;
case 2:
CropTensorGradFunction<DeviceContext, T, 2>(context);
break;
case 3:
CropTensorGradFunction<DeviceContext, T, 3>(context);
break;
case 4:
CropTensorGradFunction<DeviceContext, T, 4>(context);
break;
case 5:
CropTensorGradFunction<DeviceContext, T, 5>(context);
break;
case 6:
CropTensorGradFunction<DeviceContext, T, 6>(context);
break;
default:
PADDLE_THROW(
"CropTensorOp only support tensors with no more than 6 "
"dimensions.");
}
}
};
} // namespace operators
} // namespace paddle
......@@ -133,6 +133,7 @@ __all__ = [
'selu',
'log',
'crop',
'crop_tensor',
'rank_loss',
'margin_rank_loss',
'elu',
......@@ -9119,6 +9120,11 @@ def crop(x, shape=None, offsets=None, name=None):
"""
Crop input into output, as specified by offsets and shape.
**Warning:** THIS FUNCTION IS DEPRECATED. It will be removed in a future version.
Instructions for updating: Use `fluid.layers.crop_tensor
<https://www.paddlepaddle.org.cn/documentation/docs/en/api/layers/nn.html#crop_tensor>`_
instead.
.. code-block:: text
* Case 1:
......@@ -9150,16 +9156,16 @@ def crop(x, shape=None, offsets=None, name=None):
Args:
x (Variable): The input tensor variable.
shape (Variable|list/tuple of integer): The output shape is specified
by `shape`, which can a Variable or a list/tupe of integer.
by `shape`, which can be a Variable or a list/tuple of integer.
If a tensor Variable, it's rank must be the same as `x`. This way
is suitable for the case that the output shape may be changed each
iteration. If a list/tupe of integer, it's length must be the same
iteration. If a list/tuple of integer, it's length must be the same
as the rank of `x`
offsets (Variable|list/tuple of integer|None): Specifies the cropping
offsets at each dimension. It can be a Variable or or a list/tupe
offsets at each dimension. It can be a Variable or a list/tuple
of integers. If a tensor Variable, it's rank must be the same as `x`.
This way is suitable for the case that the offsets may be changed
each iteration. If a list/tupe of integer, it's length must be the
each iteration. If a list/tuple of integer, it's length must be the
same as the rank of `x`. If None, the offsets are 0 at each
dimension.
name(str|None): A name for this layer(optional). If set None, the layer
......@@ -9214,6 +9220,188 @@ def crop(x, shape=None, offsets=None, name=None):
return out
def crop_tensor(x, shape=None, offsets=None, name=None):
"""
Crop input into output, as specified by offsets and shape.
.. code-block:: text
* Case 1:
Given
X = [[0, 1, 2, 0, 0]
[0, 3, 4, 0, 0]
[0, 0, 0, 0, 0]],
and
shape = [2, 2],
offsets = [0, 1],
output is:
Out = [[1, 2],
[3, 4]].
* Case 2:
Given
X = [[[0, 1, 2, 3]
[0, 5, 6, 7]
[0, 0, 0, 0]],
[[0, 3, 4, 5]
[0, 6, 7, 8]
[0, 0, 0, 0]]].
and
shape = [2, 2, 3],
offsets = [0, 0, 1],
output is:
Out = [[[1, 2, 3]
[5, 6, 7]],
[[3, 4, 5]
[6, 7, 8]]].
Args:
x (Variable): The input tensor variable.
shape (Variable|list|tuple of integer): The output shape is specified
by `shape`. It can be a 1-D tensor Variable or a list/tuple. If a
1-D tensor Variable, it's rank must be the same as `x`. If a
list/tuple, it's length must be the same as the rank of `x`. Each
element of list can be an integer or a tensor Variable of shape: [1].
If Variable contained, it is suitable for the case that the shape may
be changed each iteration. Only the first element of list/tuple can be
set to -1, it means that the first dimension of the output is the same
as the input.
offsets (Variable|list|tuple of integer|None): Specifies the cropping
offsets at each dimension. It can be a 1-D tensor Variable or a list/tuple.
If a 1-D tensor Variable, it's rank must be the same as `x`. If a list/tuple,
it's length must be the same as the rank of `x`. Each element of list can be
an integer or a tensor Variable of shape: [1]. If Variable contained, it is
suitable for the case that the offsets may be changed each iteration. If None,
the offsets are 0 at each dimension.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
Variable: The cropped tensor variable.
Raises:
ValueError: If shape is not a list, tuple or Variable.
ValueError: If offsets is not None and not a list, tuple or Variable.
Examples:
.. code-block:: python
import paddle.fluid as fluid
x = fluid.layers.data(name="x", shape=[3, 5], dtype="float32")
# x.shape = [-1, 3, 5], where -1 indicates batch size, and it will get the exact value in runtime.
# shape is a 1-D tensor variable
crop_shape = fluid.layers.data(name="crop_shape", shape=[3], dtype="int32", append_batch_size=False)
crop0 = fluid.layers.crop_tensor(x, shape=crop_shape)
# crop0.shape = [-1, -1, -1], it means crop0.shape[0] = x.shape[0] in runtime.
# or shape is a list in which each element is a constant
crop1 = fluid.layers.crop_tensor(x, shape=[-1, 2, 3])
# crop1.shape = [-1, 2, 3]
# or shape is a list in which each element is a constant or variable
y = fluid.layers.data(name="y", shape=[3, 8, 8], dtype="float32")
dim1 = fluid.layers.data(name="dim1", shape=[1], dtype="int32", append_batch_size=False)
crop2 = fluid.layers.crop_tensor(y, shape=[-1, 3, dim1, 4])
# crop2.shape = [-1, 3, -1, 4]
# offsets is a 1-D tensor variable
crop_offsets = fluid.layers.data(name="crop_offsets", shape=[3], dtype="int32", append_batch_size=False)
crop3 = fluid.layers.crop_tensor(x, shape=[-1, 2, 3], offsets=crop_offsets)
# crop3.shape = [-1, 2, 3]
# offsets is a list in which each element is a constant or variable
offsets_var = fluid.layers.data(name="dim1", shape=[1], dtype="int32", append_batch_size=False)
crop4 = fluid.layers.crop_tensor(x, shape=[-1, 2, 3], offsets=[0, 1, offsets_var])
# crop4.shape = [-1, 2, 3]
"""
helper = LayerHelper('crop_tensor', **locals())
if not (isinstance(shape, list) or isinstance(shape, tuple) or \
isinstance(shape, Variable)):
raise ValueError("The shape should be a list, tuple or Variable.")
if offsets is None:
offsets = [0] * len(x.shape)
if not (isinstance(offsets, list) or isinstance(offsets, tuple) or \
isinstance(offsets, Variable)):
raise ValueError("The offsets should be a list, tuple or Variable.")
out = helper.create_variable_for_type_inference(x.dtype)
ipts = {'X': x}
attrs = {}
def contain_var(input_list):
for ele in input_list:
if isinstance(ele, Variable):
return True
return False
if isinstance(offsets, Variable):
offsets.stop_gradient = True
ipts['Offsets'] = offsets
elif contain_var(offsets):
new_offsets_tensor = []
for dim in offsets:
if isinstance(dim, Variable):
dim.stop_gradient = True
new_offsets_tensor.append(dim)
else:
assert (isinstance(dim, int))
assert dim >= 0, ("offsets should be greater or equal to zero.")
temp_out = helper.create_variable_for_type_inference('int32')
fill_constant([1], 'int32', dim, force_cpu=True, out=temp_out)
new_offsets_tensor.append(temp_out)
ipts['OffsetsTensor'] = new_offsets_tensor
else:
attrs['offsets'] = offsets
unk_dim_idx = -1
if isinstance(shape, Variable):
shape.stop_gradient = True
ipts['Shape'] = shape
elif contain_var(shape):
new_shape_tensor = []
shape_attr = []
for dim_idx, dim_size in enumerate(shape):
if isinstance(dim_size, Variable):
dim_size.stop_gradient = True
new_shape_tensor.append(dim_size)
shape_attr.append(-1)
else:
assert (isinstance(dim_size, int))
if dim_size == -1:
assert unk_dim_idx == -1, (
"Only one element in shape can be unknown.")
assert dim_idx == 0, (
"Only the first element in shape can be -1.")
unk_dim_idx = dim_idx
else:
assert dim_size > 0, (
"Each dimension size given in shape must be greater than zero."
)
temp_out = helper.create_variable_for_type_inference('int32')
fill_constant(
[1], 'int32', dim_size, force_cpu=True, out=temp_out)
new_shape_tensor.append(temp_out)
shape_attr.append(dim_size)
ipts['ShapeTensor'] = new_shape_tensor
attrs['shape'] = shape_attr
else:
attrs['shape'] = shape
helper.append_op(
type='crop_tensor',
inputs=ipts,
outputs={'Out': out},
attrs=None if len(attrs) == 0 else attrs)
return out
def affine_grid(theta, out_shape, name=None):
"""
It generates a grid of (x,y) coordinates using the parameters of
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
def crop(data, offsets, crop_shape):
def indexOf(shape, index):
result = []
for dim in reversed(shape):
result.append(index % dim)
index = index / dim
return result[::-1]
result = []
for i, value in enumerate(data.flatten()):
index = indexOf(data.shape, i)
selected = True
if len(index) == len(offsets):
for j, offset in enumerate(offsets):
selected = selected and index[j] >= offset and index[
j] < crop_shape[j] + offset
if selected:
result.append(value)
return np.array(result).reshape(crop_shape)
class TestCropTensorOp(OpTest):
def setUp(self):
self.op_type = "crop_tensor"
self.crop_by_1D_shape = False
self.offset_by_input = False
self.unk_dim_idx = -1
self.attrs = {}
self.initTestCase()
if self.crop_by_1D_shape:
self.inputs = {
'X': np.random.random(self.x_shape).astype("float32"),
'Shape': np.array(self.crop_shape).astype("int32")
}
else:
self.attrs['shape'] = self.crop_shape
self.inputs = {
'X': np.random.random(self.x_shape).astype("float32"),
}
if self.offset_by_input:
self.inputs['Offsets'] = np.array(self.offsets).astype('int32')
else:
self.attrs['offsets'] = self.offsets
if self.unk_dim_idx != -1:
self.crop_shape[self.unk_dim_idx] = self.x_shape[self.unk_dim_idx]
self.outputs = {
'Out': crop(self.inputs['X'], self.offsets, self.crop_shape)
}
def initTestCase(self):
self.x_shape = (8, 8)
self.crop_shape = [2, 2]
self.offsets = [1, 2]
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X'], 'Out', max_relative_error=0.006)
class TestCase1(TestCropTensorOp):
def initTestCase(self):
self.x_shape = (100)
self.crop_shape = [64]
self.offsets = [13]
class TestCase2(TestCropTensorOp):
def initTestCase(self):
self.x_shape = (12, 24)
self.crop_shape = [-1, 8] #only the first dimension (batch) can be -1
self.offsets = [0, 0]
self.unk_dim_idx = 0
class TestCase3(TestCropTensorOp):
def initTestCase(self):
self.x_shape = (4, 8, 16)
self.crop_shape = [2, 2, 3]
self.offsets = [1, 5, 3]
self.crop_by_1D_shape = True
class TestCase4(TestCropTensorOp):
def initTestCase(self):
self.x_shape = (8, 3, 6, 6)
self.crop_shape = [-1, 3, 4, 4]
self.offsets = [0, 0, 0, 0]
self.crop_by_1D_shape = True
self.unk_dim_idx = 0
class TestCase5(TestCropTensorOp):
def initTestCase(self):
self.x_shape = (2, 4, 5, 8, 8)
self.crop_shape = [1, 1, 2, 4, 4]
self.offsets = [1, 0, 0, 2, 2]
self.offset_by_input = True
class TestCase6(TestCropTensorOp):
def initTestCase(self):
self.x_shape = (2, 2, 4, 4, 4, 2)
self.crop_shape = [1, 1, 4, 2, 2, 2]
self.offsets = [0, 0, 0, 0, 0, 0]
self.crop_by_1D_shape = True
self.offset_by_input = True
class TestCropTensorOp_attr_tensor(OpTest):
def setUp(self):
self.op_type = "crop_tensor"
self.mixed_type = False
self.OffsetsTensor = False
self.ShapeTensor = True
self.attrs = {}
self.initTestCase()
if self.ShapeTensor:
shape_tensor = []
for index, ele in enumerate(self.crop_shape):
shape_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs = {
'X': np.random.random(self.x_shape).astype("float32"),
'ShapeTensor': shape_tensor
}
if self.mixed_type:
self.attrs['shape'] = self.shape_attr
if self.OffsetsTensor:
offsets_tensor = []
for index, ele in enumerate(self.offsets):
offsets_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs = {
'X': np.random.random(self.x_shape).astype("float32"),
'OffsetsTensor': offsets_tensor
}
else:
self.attrs['offsets'] = self.offsets
self.outputs = {
'Out': crop(self.inputs['X'], self.offsets, self.crop_shape)
}
def initTestCase(self):
self.x_shape = (8, 8)
self.crop_shape = (2, 2)
self.offsets = [1, 2]
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(["X"], "Out", max_relative_error=0.006)
class TestCropTensorOp_attr_tensor_case1(TestCropTensorOp_attr_tensor):
def init_data(self):
self.x_shape = (16, 8, 32)
self.crop_shape = [2, 2, 3]
self.offsets = [1, 5, 3]
class TestCropTensorOp_attr_tensor_case2(TestCropTensorOp_attr_tensor):
def init_data(self):
self.x_shape = (4, 8, 16, 8)
self.crop_shape = [2, 2, 3, 4]
self.offsets = [1, 5, 3, 0]
self.shape_attr = [-1, -1, 3, 4]
self.mixed_type = True
class TestCropTensorOp_attr_tensor_case3(TestCropTensorOp_attr_tensor):
def init_data(self):
self.x_shape = (16, 8, 32)
self.crop_shape = [2, 2, 3]
self.offsets = [1, 5, 3]
self.ShapeTensor = False
self.OffsetsTensor = True
class TestCropTensorOp_attr_tensor_case4(TestCropTensorOp_attr_tensor):
def init_data(self):
self.x_shape = (16, 8, 32)
self.crop_shape = [2, 2, 3]
self.offsets = [1, 5, 3]
self.OffsetsTensor = True
if __name__ == '__main__':
unittest.main()
......@@ -1100,6 +1100,34 @@ class TestLayer(LayerTest):
for i in range(len(static_ret5)):
self.assertTrue(dcond5.numpy()[i] == static_ret5[i])
def test_crop_tensor(self):
with self.static_graph():
x = fluid.layers.data(name="x1", shape=[6, 5, 8])
dim1 = fluid.layers.data(
name="dim1", shape=[1], append_batch_size=False)
dim2 = fluid.layers.data(
name="dim2", shape=[1], append_batch_size=False)
crop_shape1 = (1, 2, 4, 4)
crop_shape2 = fluid.layers.data(
name="crop_shape", shape=[4], append_batch_size=False)
crop_shape3 = [-1, dim1, dim2, 4]
crop_offsets1 = [0, 0, 1, 0]
crop_offsets2 = fluid.layers.data(
name="crop_offset", shape=[4], append_batch_size=False)
crop_offsets3 = [0, dim1, dim2, 0]
out1 = fluid.layers.crop_tensor(
x, shape=crop_shape1, offsets=crop_offsets1)
out2 = fluid.layers.crop_tensor(
x, shape=crop_shape2, offsets=crop_offsets2)
out3 = fluid.layers.crop_tensor(
x, shape=crop_shape3, offsets=crop_offsets3)
self.assertIsNotNone(out1)
self.assertIsNotNone(out2)
self.assertIsNotNone(out3)
class TestBook(LayerTest):
def test_all_layers(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册