diff --git a/paddle/operators/pad_op.cc b/paddle/operators/pad_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..7e78b6ec133981494a65b5e16316ae8fdbd61a60 --- /dev/null +++ b/paddle/operators/pad_op.cc @@ -0,0 +1,112 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/pad_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class PadOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + auto x_dim = ctx.Input("X")->dims(); + auto paddings = Attr>("paddings"); + PADDLE_ENFORCE_EQ(x_dim.size() * 2, int64_t(paddings.size()), + "Size of paddings should be equal to 2 * dimension size " + "of input tensor."); + std::vector out_dims(x_dim.size()); + for (int i = 0; i < x_dim.size(); ++i) { + out_dims[i] = x_dim[i] + paddings[i * 2] + paddings[i * 2 + 1]; + } + ctx.Output("Out")->Resize(framework::make_ddim(out_dims)); + } +}; + +class PadOpMaker : public framework::OpProtoAndCheckerMaker { + public: + PadOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "The input of pad op. " + "The input should be a k-D tensor(k > 0 and k < 7)"); + AddOutput("Out", + "The output of pad op." + "A tensor with the same shape as X.") + .NotInGradient(); + AddComment(R"DOC( +Pad input into output, as specified by paddings and pad_value. The input should be a k-D tensor(k > 0 and k < 7). As an example: + +Given: + +X = [[1, 2], + [3, 4]] + +and + +paddings = [0, 1, 1, 2] + +and + +pad_value = 0 + +then we get + +Out = [[0, 1, 2, 0, 0] + [0, 3, 4, 0, 0] + [0, 0, 0, 0, 0]] +)DOC"); + AddAttr>( + "paddings", + "A list to describes padding rules for each dimension." + " For 2-D image tensor, paddings=[0, 1, 2, 3] means" + " padding 0 row to top, 1 row to bottom, 2 columns to left" + " and 3 columns to right.Size of paddings should be equal to" + " 2 * dimension size of input tensor."); + AddAttr("pad_value", + "(float) default to 0; " + "The value to fill padded areas.") + .SetDefault(0.0f); + } +}; + +class PadOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx.Input("X")->dims(); + auto *x_grad = ctx.Output(framework::GradVarName("X")); + if (x_grad != nullptr) { + x_grad->Resize(x_dims); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(pad, ops::PadOp, ops::PadOpMaker, pad_grad, ops::PadOpGrad); +REGISTER_OP_CPU_KERNEL(pad, ops::PadKernel); +REGISTER_OP_CPU_KERNEL(pad_grad, + ops::PadGradKernel); diff --git a/paddle/operators/pad_op.cu b/paddle/operators/pad_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..555a7dba23c6fa2659cabf4858b42ff70d74bf18 --- /dev/null +++ b/paddle/operators/pad_op.cu @@ -0,0 +1,21 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/pad_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(pad, ops::PadKernel); +REGISTER_OP_GPU_KERNEL(pad_grad, + ops::PadGradKernel); diff --git a/paddle/operators/pad_op.h b/paddle/operators/pad_op.h new file mode 100644 index 0000000000000000000000000000000000000000..2cc3b945ae5b2e2e93d8531c7f99e4c215d1d806 --- /dev/null +++ b/paddle/operators/pad_op.h @@ -0,0 +1,132 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +using EigenTensor = framework::EigenTensor; + +template +void PadFunction(const framework::ExecutionContext& context) { + auto pads = context.Attr>("paddings"); + Eigen::array, D> paddings; + for (size_t i = 0; i < paddings.size(); ++i) { + paddings[i].first = pads[i * 2]; + paddings[i].second = pads[i * 2 + 1]; + } + T pad_value = context.Attr("pad_value"); + + auto* x = context.Input("X"); + auto* out = context.Output("Out"); + out->mutable_data(context.GetPlace()); + + auto x_tensor = EigenTensor::From(*x); + auto out_tensor = EigenTensor::From(*out); + auto place = context.GetEigenDevice(); + out_tensor.device(place) = x_tensor.pad(paddings, pad_value); +} + +template +class PadKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + int rank = context.Input("X")->dims().size(); + switch (rank) { + case 1: + PadFunction(context); + break; + case 2: + PadFunction(context); + break; + case 3: + PadFunction(context); + break; + case 4: + PadFunction(context); + break; + case 5: + PadFunction(context); + break; + case 6: + PadFunction(context); + break; + default: + PADDLE_THROW( + "PadOp only support tensors with no more than 6 dimensions."); + } + } +}; + +template +void PadGradFunction(const framework::ExecutionContext& context) { + auto pads = context.Attr>("paddings"); + Eigen::array, D> paddings; + for (size_t i = 0; i < paddings.size(); ++i) { + paddings[i].first = -pads[i * 2]; + paddings[i].second = -pads[i * 2 + 1]; + } + auto* d_out = context.Input(framework::GradVarName("Out")); + auto* d_x = context.Output(framework::GradVarName("X")); + if (d_x != nullptr) { + d_x->mutable_data(context.GetPlace()); + auto d_x_tensor = EigenTensor::From(*d_x); + auto d_out_tensor = EigenTensor::From(*d_out); + auto place = context.GetEigenDevice(); + d_x_tensor.device(place) = d_out_tensor.pad(paddings, 0); + } +} + +template +class PadGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + size_t rank = + context.Input(framework::GradVarName("Out"))->dims().size(); + switch (rank) { + case 1: + PadGradFunction(context); + break; + case 2: + PadGradFunction(context); + break; + case 3: + PadGradFunction(context); + break; + case 4: + PadGradFunction(context); + break; + case 5: + PadGradFunction(context); + break; + case 6: + PadGradFunction(context); + break; + default: + PADDLE_THROW( + "PadOp only support tensors with no more than 6 dimensions."); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index ef62d6e997fa0af5c342e0c310c09b840b0b8583..fe1e50927a425018e471ac8065c0349474ee93a5 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -50,6 +50,7 @@ USE_NO_KERNEL_OP(identity); USE_OP(minus); USE_OP(cos_sim); USE_CPU_ONLY_OP(gather); +USE_OP(pad); USE_CPU_ONLY_OP(scatter); USE_CPU_ONLY_OP(concat); USE_OP(top_k); diff --git a/python/paddle/v2/framework/op.py b/python/paddle/v2/framework/op.py index 9e665adad2d3ad91d183c6815fbd7135ac4e8965..15e0d125c495fbc0688d8dc4e66881cb9ab95a90 100644 --- a/python/paddle/v2/framework/op.py +++ b/python/paddle/v2/framework/op.py @@ -97,7 +97,7 @@ class OpDescCreationMethod(object): new_attr.strings.extend(user_defined_attr) elif attr.type == framework_pb2.INT_PAIRS: for p in user_defined_attr: - pair = new_attr.pairs.add() + pair = new_attr.int_pairs.add() pair.first = p[0] pair.second = p[1] else: diff --git a/python/paddle/v2/framework/tests/test_pad_op.py b/python/paddle/v2/framework/tests/test_pad_op.py new file mode 100644 index 0000000000000000000000000000000000000000..456b765e331fc4c80e6fd817c88d7ec533158ecb --- /dev/null +++ b/python/paddle/v2/framework/tests/test_pad_op.py @@ -0,0 +1,55 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestPadOp(OpTest): + def setUp(self): + self.initTestCase() + self.op_type = "pad" + self.inputs = {'X': np.random.random(self.shape).astype("float32"), } + self.attrs = {} + self.attrs['paddings'] = np.array(self.paddings).flatten() + self.attrs['pad_value'] = self.pad_value + self.outputs = { + 'Out': np.pad(self.inputs['X'], + self.paddings, + mode='constant', + constant_values=self.pad_value) + } + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X'], 'Out') + + def initTestCase(self): + self.shape = (16, 16) + self.paddings = [(0, 1), (2, 3)] + self.pad_value = 0 + + +class TestCase1(TestPadOp): + def initTestCase(self): + self.shape = (2, 3, 4, 4) + self.paddings = [(0, 1), (2, 3), (2, 1), (1, 1)] + self.pad_value = 0.5 + + +class TestCase2(TestPadOp): + def initTestCase(self): + self.shape = (2, 2, 2) + self.paddings = [(0, 0), (0, 0), (1, 2)] + self.pad_value = 1 + + +class TestCase3(TestPadOp): + def initTestCase(self): + self.shape = (8) + self.paddings = [(0, 1)] + self.pad_value = 0.9 + + +if __name__ == '__main__': + unittest.main()