未验证 提交 b587a7f6 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #11293 from JiayiFeng/update_crop_op

Update crop op
...@@ -48,6 +48,13 @@ class CropOp : public framework::OperatorWithKernel { ...@@ -48,6 +48,13 @@ class CropOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", y_dim); ctx->SetOutputDim("Out", y_dim);
} }
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
}; };
class CropOpMaker : public framework::OpProtoAndCheckerMaker { class CropOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -60,13 +67,19 @@ class CropOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -60,13 +67,19 @@ class CropOpMaker : public framework::OpProtoAndCheckerMaker {
"The input used as reference for cropping, " "The input used as reference for cropping, "
"which is of the same dimensions as X.") "which is of the same dimensions as X.")
.AsDispensable(); .AsDispensable();
AddInput("Offsets",
"The input used to describe offsets in runtime, which is a "
"1-D vector whose size equals to the rank of input 'X'. The "
"elements data type must be int.")
.AsDispensable();
AddOutput("Out", AddOutput("Out",
"The output of crop op, " "The output of crop op, "
"which is of the same dimensions as X."); "which is of the same dimensions as X.");
AddAttr<std::vector<int>>("offsets", AddAttr<std::vector<int>>("offsets",
"A list<int> describing offsets to be cropped. " "A list<int> describing offsets to be cropped. "
"The size of offsets list should be the same as " "The size of offsets list should be the same as "
"the dimension size of input X."); "the dimension size of input X.")
.SetDefault(std::vector<int>());
AddAttr<std::vector<int>>("shape", AddAttr<std::vector<int>>("shape",
"A list<int> describing the shape of output. " "A list<int> describing the shape of output. "
"The size of shape list should be the same as " "The size of shape list should be the same as "
...@@ -77,6 +90,17 @@ Crop Operator. ...@@ -77,6 +90,17 @@ Crop Operator.
Crop input into output, as specified by offsets and shape. Crop input into output, as specified by offsets and shape.
There are two ways to set the offsets:
1. In runtime: Using the input 'Offsets', which is a Vairbale and can be
output of other operators. This way is suitable for
dynamic offsets.
2. In network configuration: Using the attribute 'offsets', which will be
set in Python configure script. This way is
suitable for fixed offsets.
You CANNOT use these two ways at the same time. An exception will be raised
if input 'Offset' is configured and meanwhile the attribute 'offsets' is
not empty.
There are two ways to set shape: There are two ways to set shape:
1. reference input: crop input X into the same shape as reference input. 1. reference input: crop input X into the same shape as reference input.
The dimension of reference input should The dimension of reference input should
...@@ -146,6 +170,15 @@ class CropOpGrad : public framework::OperatorWithKernel { ...@@ -146,6 +170,15 @@ class CropOpGrad : public framework::OperatorWithKernel {
ctx->SetOutputDim(x_grad_name, x_dims); ctx->SetOutputDim(x_grad_name, x_dims);
} }
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))
->type()),
ctx.device_context());
}
}; };
} // namespace operators } // namespace operators
......
...@@ -27,6 +27,37 @@ template <typename T, size_t D, int MajorType = Eigen::RowMajor, ...@@ -27,6 +27,37 @@ template <typename T, size_t D, int MajorType = Eigen::RowMajor,
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>; using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using framework::Tensor; using framework::Tensor;
static std::vector<int> GetOffsets(const framework::ExecutionContext& ctx) {
std::vector<int> res;
int rank = ctx.Input<Tensor>("X")->dims().size();
if (ctx.HasInput("Offsets")) {
PADDLE_ENFORCE(ctx.Attr<std::vector<int>>("offsets").empty(),
"Input 'Offsets' and attribute 'offsets' should not be used "
"at the same time.");
const auto* offsets_tensor = ctx.Input<Tensor>("Offsets");
PADDLE_ENFORCE_EQ(offsets_tensor->dims().size(), 1);
PADDLE_ENFORCE_EQ(
rank, offsets_tensor->dims()[0],
"Offsets size should be equal to dimension size of input tensor.");
const int* offsets_data;
framework::Tensor cpu_tmp_tensor;
if (platform::is_cpu_place(offsets_tensor->place())) {
offsets_data = offsets_tensor->data<int>();
} else {
framework::TensorCopySync(*offsets_tensor, platform::CPUPlace(),
&cpu_tmp_tensor);
offsets_data = cpu_tmp_tensor.data<int>();
}
res = std::vector<int>(offsets_data, offsets_data + rank);
} else {
res = ctx.Attr<std::vector<int>>("offsets");
PADDLE_ENFORCE_EQ(
rank, res.size(),
"Offsets size should be equal to dimension size of input tensor.");
}
return res;
}
template <typename T> template <typename T>
class CropKernel : public framework::OpKernel<T> { class CropKernel : public framework::OpKernel<T> {
public: public:
...@@ -37,10 +68,7 @@ class CropKernel : public framework::OpKernel<T> { ...@@ -37,10 +68,7 @@ class CropKernel : public framework::OpKernel<T> {
T* out_data = out->mutable_data<T>(context.GetPlace()); T* out_data = out->mutable_data<T>(context.GetPlace());
auto x_stride = framework::stride(x->dims()); auto x_stride = framework::stride(x->dims());
auto out_stride = framework::stride(out->dims()); auto out_stride = framework::stride(out->dims());
auto offsets = context.Attr<std::vector<int>>("offsets"); auto offsets = GetOffsets(context);
PADDLE_ENFORCE_EQ(
x->dims().size(), static_cast<int64_t>(offsets.size()),
"Offsets size should be equal to dimension size of input tensor.");
int64_t offset = 0; int64_t offset = 0;
for (size_t i = 0; i < offsets.size(); ++i) { for (size_t i = 0; i < offsets.size(); ++i) {
offset += (x_stride[i] * offsets[i]); offset += (x_stride[i] * offsets[i]);
...@@ -56,7 +84,7 @@ void CropGradFunction(const framework::ExecutionContext& context) { ...@@ -56,7 +84,7 @@ void CropGradFunction(const framework::ExecutionContext& context) {
if (d_x != nullptr) { if (d_x != nullptr) {
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out")); auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
d_x->mutable_data<T>(context.GetPlace()); d_x->mutable_data<T>(context.GetPlace());
auto offsets = context.Attr<std::vector<int>>("offsets"); auto offsets = GetOffsets(context);
Eigen::array<std::pair<int, int>, D> paddings; Eigen::array<std::pair<int, int>, D> paddings;
for (size_t i = 0; i < D; ++i) { for (size_t i = 0; i < D; ++i) {
paddings[i].first = offsets[i]; paddings[i].first = offsets[i];
......
...@@ -20,7 +20,6 @@ class RandomCropOp : public framework::OperatorWithKernel { ...@@ -20,7 +20,6 @@ class RandomCropOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
......
...@@ -42,9 +42,9 @@ class TestCropOp(OpTest): ...@@ -42,9 +42,9 @@ class TestCropOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "crop" self.op_type = "crop"
self.crop_by_input = False self.crop_by_input = False
self.offset_by_input = False
self.attrs = {} self.attrs = {}
self.initTestCase() self.initTestCase()
self.attrs['offsets'] = self.offsets
if self.crop_by_input: if self.crop_by_input:
self.inputs = { self.inputs = {
'X': np.random.random(self.x_shape).astype("float32"), 'X': np.random.random(self.x_shape).astype("float32"),
...@@ -55,6 +55,10 @@ class TestCropOp(OpTest): ...@@ -55,6 +55,10 @@ class TestCropOp(OpTest):
self.inputs = { self.inputs = {
'X': np.random.random(self.x_shape).astype("float32"), 'X': np.random.random(self.x_shape).astype("float32"),
} }
if self.offset_by_input:
self.inputs['Offsets'] = np.array(self.offsets).astype('int32')
else:
self.attrs['offsets'] = self.offsets
self.outputs = { self.outputs = {
'Out': crop(self.inputs['X'], self.offsets, self.crop_shape) 'Out': crop(self.inputs['X'], self.offsets, self.crop_shape)
} }
...@@ -101,5 +105,22 @@ class TestCase4(TestCropOp): ...@@ -101,5 +105,22 @@ class TestCase4(TestCropOp):
self.crop_by_input = True self.crop_by_input = True
class TestCase5(TestCropOp):
def initTestCase(self):
self.x_shape = (3, 4, 5)
self.crop_shape = [2, 2, 3]
self.offsets = [1, 0, 2]
self.offset_by_input = True
class TestCase6(TestCropOp):
def initTestCase(self):
self.x_shape = (10, 9, 14)
self.crop_shape = [3, 3, 5]
self.offsets = [3, 5, 4]
self.crop_by_input = True
self.offset_by_input = True
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部