提交 06b42e9e 编写于 作者: W wanghaoshuang

Add crop op.

上级 f2f839af
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/crop_op.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class CropOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto dim0 = ctx.Input<Tensor>("X")->dims();
auto Y = ctx.Input<Tensor>("Y");
if (Y == nullptr) {
auto shape = GetAttr<std::vector<int>>("shape");
PADDLE_ENFORCE_EQ(
shape.size(), dim0.size(),
"Shape size should be equal to dimention size of input tensor.");
ctx.Output<Tensor>("Out")->Resize(paddle::framework::make_ddim(shape));
} else {
ctx.Output<Tensor>("Out")->Resize(Y->dims());
}
}
};
class CropOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CropOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of crop op");
AddInput("Y", "The input used as reference for cropping. ");
AddOutput("Out", "The output of crop op.");
AddComment(R"DOC(
Crop Operator.
)DOC");
AddAttr<std::vector<int>>("offsets", "The offsets for cropping.");
AddAttr<std::vector<int>>("shape", "The shape for cropping.");
}
};
class CropOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
x_grad->Resize(x_dims);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(crop, ops::CropOp, ops::CropOpMaker, crop_grad, ops::CropOpGrad);
REGISTER_OP_CPU_KERNEL(crop,
ops::CropKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(crop_grad,
ops::CropGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/crop_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(crop,
ops::CropKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(crop_grad,
ops::CropGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 CropdleCropdle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using Tensor = framework::Tensor;
template <typename Place, typename T, size_t D>
void CropFunction(const framework::ExecutionContext& context) {
auto* x = context.Input<Tensor>("X");
auto* out = context.Output<Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
auto x_dims = x->dims();
auto out_dims = out->dims();
auto offsets = context.op().GetAttr<std::vector<int>>("offsets");
PADDLE_ENFORCE_EQ(
x_dims.size(), offsets.size(),
"Offsets size should be equal to dimension size of input tensor.");
Eigen::array<std::pair<int, int>, D> paddings;
for (size_t i = 0; i < D; ++i) {
paddings[i].first = -(offsets[i]);
paddings[i].second = -(x_dims[i] - out_dims[i] - offsets[i]);
}
auto x_tensor = EigenTensor<T, D>::From(*x);
auto out_tensor = EigenTensor<T, D>::From(*out);
auto place = context.GetEigenDevice<Place>();
out_tensor.device(place) = x_tensor.pad(paddings, 0);
}
template <typename Place, typename T>
class CropKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
int dim = context.Input<Tensor>("X")->dims().size();
switch (dim) {
case 1:
CropFunction<Place, T, 1>(context);
break;
case 2:
CropFunction<Place, T, 2>(context);
break;
case 3:
CropFunction<Place, T, 3>(context);
break;
case 4:
CropFunction<Place, T, 4>(context);
break;
case 5:
CropFunction<Place, T, 5>(context);
break;
case 6:
CropFunction<Place, T, 6>(context);
break;
default:
LOG(ERROR) << "Only ranks up to 6 supported.";
}
}
};
template <typename Place, typename T, size_t D>
void CropGradFunction(const framework::ExecutionContext& context) {
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
auto* d_x = context.Output<Tensor>(framework::GradVarName("X"));
d_x->mutable_data<T>(context.GetPlace());
auto d_x_dims = d_x->dims();
auto d_out_dims = d_out->dims();
auto offsets = context.op().GetAttr<std::vector<int>>("offsets");
Eigen::array<std::pair<int, int>, D> paddings;
for (int i = 0; i < d_out_dims.size(); ++i) {
paddings[i].first = offsets[i];
paddings[i].second = d_x_dims[i] - d_out_dims[i] - offsets[i];
}
auto d_x_tensor = EigenTensor<T, D>::From(*d_x);
auto d_out_tensor = EigenTensor<T, D>::From(*d_out);
auto place = context.GetEigenDevice<Place>();
d_x_tensor.device(place) = d_out_tensor.pad(paddings, 0);
}
template <typename Place, typename T>
class CropGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
size_t dim =
context.Input<Tensor>(framework::GradVarName("Out"))->dims().size();
switch (dim) {
case 1:
CropGradFunction<Place, T, 1>(context);
break;
case 2:
CropGradFunction<Place, T, 2>(context);
break;
case 3:
CropGradFunction<Place, T, 3>(context);
break;
case 4:
CropGradFunction<Place, T, 4>(context);
break;
case 5:
CropGradFunction<Place, T, 5>(context);
break;
case 6:
CropGradFunction<Place, T, 6>(context);
break;
default:
LOG(ERROR) << "Only ranks up to 6 supported.";
}
}
};
} // namespace operators
} // namespace paddle
......@@ -48,6 +48,7 @@ USE_NO_KERNEL_OP(identity);
USE_OP(minus);
USE_CPU_ONLY_OP(gather);
USE_CPU_ONLY_OP(scatter);
USE_OP(crop);
namespace paddle {
namespace framework {
......
import unittest
import numpy as np
from paddle.v2.framework.op import Operator
from gradient_checker import GradientChecker
from op_test_util import OpTestMeta
class TestCropOp(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
self.type = "crop"
self.inputs = {'X': np.random.random((16, 16)).astype("float32"), }
self.attrs = {}
self.attrs['offsets'] = [2, 3]
self.attrs['shape'] = [8, 8]
self.outputs = {'Out': self.inputs['X'][2:10, 3:11]}
class TestCropGradOp(GradientChecker):
def setUp(self):
self.op = Operator(
type="crop", X="X", Out="Out", offsets=[2, 3], shape=[8, 8])
self.inputs = {'X': np.random.random((16, 16)).astype("float32"), }
def test_normal(self):
self.check_grad(
self.op, self.inputs, set(["X"]), "Out", max_relative_error=0.5)
def test_cpu_gpu_compare(self):
self.compare_grad(self.op, self.inputs)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册