diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc index fc3d508553c0e966978b28d58127bdbff10d45f1..a3357867530c110df16a5f3ec8c799735206cc71 100644 --- a/paddle/framework/ddim.cc +++ b/paddle/framework/ddim.cc @@ -292,5 +292,13 @@ DDim flatten_to_2d(const DDim& src, int num_col_dims) { DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); } +DDim stride(const DDim& ddim) { + std::vector strides(ddim.size()); + strides[ddim.size() - 1] = 1; + for (int i = ddim.size() - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * ddim[i + 1]; + } + return framework::make_ddim(strides); +} } // namespace framework } // namespace paddle diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index ca29e7e8c7776de6adf3e3b0e8f11f0d4d8487c3..4a871bb0a91ed4050847509cc3f24218bcd57142 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -121,6 +121,7 @@ DDim flatten_to_2d(const DDim& src, int num_col_dims); DDim flatten_to_1d(const DDim& src); +DDim stride(const DDim& ddim); } // namespace framework } // namespace paddle diff --git a/paddle/operators/crop_op.cc b/paddle/operators/crop_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..7ed21f336f69e494f3c4039c609c83407a80cd8c --- /dev/null +++ b/paddle/operators/crop_op.cc @@ -0,0 +1,139 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/crop_op.h" +#include + +namespace paddle { +namespace operators { + +using framework::Tensor; +using framework::LoDTensor; + +class CropOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of CropOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of CropOp should not be null."); + auto x_dim = ctx.Input("X")->dims(); + auto *y = ctx.Input("Y"); + auto *out = ctx.Output("Out"); + if (y == nullptr) { + auto shape = Attr>("shape"); + PADDLE_ENFORCE_EQ( + int64_t(shape.size()), x_dim.size(), + "Shape size should be equal to dimention size of input tensor."); + std::vector tensor_shape(shape.size()); + for (size_t i = 0; i < shape.size(); ++i) { + tensor_shape[i] = static_cast(shape[i]); + } + out->Resize(framework::make_ddim(tensor_shape)); + } else { + PADDLE_ENFORCE_EQ(framework::arity(x_dim), framework::arity(y->dims()), + "Tensor rank of both CropOp's " + "inputs must be same."); + out->Resize(y->dims()); + } + } +}; + +class CropOpMaker : public framework::OpProtoAndCheckerMaker { + public: + CropOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "The input of pad op. " + "The input should be a k-D tensor(k > 0 and k < 7)"); + AddInput("Y", + "The input used as reference for cropping" + " with the same dimension as X. "); + AddOutput("Out", + "The output of crop op " + "with the same dimension as X."); + AddAttr>("offsets", + "A list describing offsets to be cropped." + "The size of offsets list should be as same as " + "dimension size of input X."); + AddAttr>("shape", + "A list describing the shape of output." + "The size of shape list should be as same as " + "dimension size of input X.") + .SetDefault(std::vector()); + AddComment(R"DOC( +Crop Operator. +Crop input into output, as specified by offsets and shape. + +There are two ways to set shape: +1. referenc input: crop input X as shape as reference input. + The dimension of reference input should + be as same as input X. +2. shape list: crop input X by shape described by a list. + The size of shape list should be as same as + dimension size of input X. + +The input should be a k-D tensor(k > 0 and k < 7). As an example: + +Given: + + X = [[0, 1, 2, 0, 0] + [0, 3, 4, 0, 0] + [0, 0, 0, 0, 0]] + +and + + offsets = [0, 1] + +and + + shape = [2, 2] + +then we get + + Out = [[1, 2], + [3, 4]] + +)DOC"); + } +}; + +class CropOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx.Input("X")->dims(); + auto *x_grad = ctx.Output(framework::GradVarName("X")); + if (x_grad != nullptr) { + x_grad->Resize(x_dims); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(crop, ops::CropOp, ops::CropOpMaker, crop_grad, ops::CropOpGrad); +REGISTER_OP_CPU_KERNEL(crop, ops::CropKernel); +REGISTER_OP_CPU_KERNEL(crop_grad, + ops::CropGradKernel); diff --git a/paddle/operators/crop_op.cu b/paddle/operators/crop_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..f8ee18a1d6e894cbb2d71dd4b6b459abeb076817 --- /dev/null +++ b/paddle/operators/crop_op.cu @@ -0,0 +1,21 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/crop_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(crop, ops::CropKernel); +REGISTER_OP_GPU_KERNEL(crop_grad, + ops::CropGradKernel); diff --git a/paddle/operators/crop_op.h b/paddle/operators/crop_op.h new file mode 100644 index 0000000000000000000000000000000000000000..2f40c059033ec649b29f6ecdee4fcedd128a63a6 --- /dev/null +++ b/paddle/operators/crop_op.h @@ -0,0 +1,104 @@ +/* Copyright (c) 2016 CropdleCropdle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/strided_memcpy.h" + +namespace paddle { +namespace operators { // Internal + +template +using EigenTensor = framework::EigenTensor; +using framework::Tensor; + +template +class CropKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* out = context.Output("Out"); + const T* x_data = x->data(); + T* out_data = out->mutable_data(context.GetPlace()); + auto x_stride = framework::stride(x->dims()); + auto out_stride = framework::stride(out->dims()); + auto offsets = context.Attr>("offsets"); + PADDLE_ENFORCE_EQ( + x->dims().size(), offsets.size(), + "Offsets size should be equal to dimension size of input tensor."); + int64_t offset = 0; + for (int i = 0; i < offsets.size(); ++i) { + offset += (x_stride[i] * offsets[i]); + } + StridedMemcpy(context.device_context(), x_data + offset, x_stride, + out->dims(), out_stride, out_data); + } +}; + +template +void CropGradFunction(const framework::ExecutionContext& context) { + auto* d_x = context.Output(framework::GradVarName("X")); + if (d_x != nullptr) { + auto* d_out = context.Input(framework::GradVarName("Out")); + d_x->mutable_data(context.GetPlace()); + auto offsets = context.Attr>("offsets"); + Eigen::array, D> paddings; + for (int i = 0; i < D; ++i) { + paddings[i].first = offsets[i]; + paddings[i].second = d_x->dims()[i] - d_out->dims()[i] - offsets[i]; + } + auto d_x_tensor = EigenTensor::From(*d_x); + auto d_out_tensor = EigenTensor::From(*d_out); + d_x_tensor.device(context.GetEigenDevice()) = + d_out_tensor.pad(paddings, 0); + } +} + +template +class CropGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + size_t rank = + context.Input(framework::GradVarName("Out"))->dims().size(); + switch (rank) { + case 1: + CropGradFunction(context); + break; + case 2: + CropGradFunction(context); + break; + case 3: + CropGradFunction(context); + break; + case 4: + CropGradFunction(context); + break; + case 5: + CropGradFunction(context); + break; + case 6: + CropGradFunction(context); + break; + default: + PADDLE_THROW( + "CropOp only support tensors with no more than 6 dimensions."); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_crop_op.py b/python/paddle/v2/framework/tests/test_crop_op.py new file mode 100644 index 0000000000000000000000000000000000000000..62c883bdc130021d06c33ded9c2865505da0b719 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_crop_op.py @@ -0,0 +1,91 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def crop(data, offsets, crop_shape): + def indexOf(shape, index): + result = [] + for dim in reversed(shape): + result.append(index % dim) + index = index / dim + return result[::-1] + + result = [] + for i, value in enumerate(data.flatten()): + index = indexOf(data.shape, i) + selected = True + if len(index) == len(offsets): + for j, offset in enumerate(offsets): + selected = selected and index[j] >= offset and index[ + j] < crop_shape[j] + offset + if selected: + result.append(value) + return np.array(result).reshape(crop_shape) + + +class TestCropOp(OpTest): + def setUp(self): + self.op_type = "crop" + self.crop_by_input = False + self.attrs = {} + self.initTestCase() + self.attrs['offsets'] = self.offsets + if self.crop_by_input: + self.inputs = { + 'X': np.random.random(self.x_shape).astype("float32"), + 'Y': np.random.random(self.crop_shape).astype("float32") + } + else: + self.attrs['shape'] = self.crop_shape + self.inputs = { + 'X': np.random.random(self.x_shape).astype("float32"), + } + self.outputs = { + 'Out': crop(self.inputs['X'], self.offsets, self.crop_shape) + } + + def initTestCase(self): + self.x_shape = (8, 8) + self.crop_shape = (2, 2) + self.offsets = [1, 2] + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X'], 'Out', max_relative_error=0.006) + + +class TestCase1(TestCropOp): + def initTestCase(self): + self.x_shape = (16, 8, 32) + self.crop_shape = [2, 2, 3] + self.offsets = [1, 5, 3] + + +class TestCase2(TestCropOp): + def initTestCase(self): + self.x_shape = (4, 8) + self.crop_shape = [4, 8] + self.offsets = [0, 0] + + +class TestCase3(TestCropOp): + def initTestCase(self): + self.x_shape = (4, 8, 16) + self.crop_shape = [2, 2, 3] + self.offsets = [1, 5, 3] + self.crop_by_input = True + + +class TestCase4(TestCropOp): + def initTestCase(self): + self.x_shape = (4, 4) + self.crop_shape = [4, 4] + self.offsets = [0, 0] + self.crop_by_input = True + + +if __name__ == '__main__': + unittest.main()