From 26cec83901dc443a60aef911c1ad2baf882eb474 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 30 Aug 2017 19:54:14 +0800 Subject: [PATCH] Add pad op --- paddle/operators/CMakeLists.txt | 1 + paddle/operators/pad_op.cc | 77 ++++++++++++++++++ paddle/operators/pad_op.cu | 21 +++++ paddle/operators/pad_op.h | 81 +++++++++++++++++++ paddle/pybind/CMakeLists.txt | 3 +- paddle/pybind/pybind.cc | 1 + .../paddle/v2/framework/tests/test_pad_op.py | 32 ++++++++ 7 files changed, 215 insertions(+), 1 deletion(-) create mode 100644 paddle/operators/pad_op.cc create mode 100644 paddle/operators/pad_op.cu create mode 100644 paddle/operators/pad_op.h create mode 100644 python/paddle/v2/framework/tests/test_pad_op.py diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index f466dbc79a2..1a759133e14 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -72,3 +72,4 @@ op_library(uniform_random_op SRCS uniform_random_op.cc uniform_random_op.cu) op_library(lookup_table_op SRCS lookup_table_op.cc lookup_table_op.cu) op_library(scale_op SRCS scale_op.cc scale_op.cu DEPS net_op) op_library(minus_op SRCS minus_op.cc minus_op.cu DEPS scale_op) +op_library(pad_op SRCS pad_op.cc pad_op.cu) diff --git a/paddle/operators/pad_op.cc b/paddle/operators/pad_op.cc new file mode 100644 index 00000000000..f96d61669bd --- /dev/null +++ b/paddle/operators/pad_op.cc @@ -0,0 +1,77 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/pad_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class PadOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + auto dim0 = ctx.Input("X")->dims(); + auto dim1 = ctx.Output("Out")->dims(); + auto paddings = GetAttr>>("paddings"); + for (int i = 0; i < dim0.size(); ++i) { + dim1[i] = dim0[i] + paddings[i][0] + paddings[i][1]; + } + ctx.Output("Out")->Resize(dim1); + } +}; + +class MulOpMaker : public framework::OpProtoAndCheckerMaker { + public: + MulOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input of pad op"); + AddOutput("Out", "The output of pad op"); + AddComment(R"DOC( +Pad Operator. +)DOC"); + AddAttr>>( + "paddings", "The padding rules for each dimension"); + AddAttr("pad_value", "The value to be padded into tensor") + .SetDefault(0.0f); + } +}; + +class PadOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx.Input("X")->dims(); + auto *x_grad = ctx.Output(framework::GradVarName("X")); + + x_grad->Resize(x_dims); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(pad, ops::PadOp, ops::PadOpMaker, pad_grad, ops::PadOpGrad); +REGISTER_OP_CPU_KERNEL(pad, ops::PadKernel); +REGISTER_OP_CPU_KERNEL(pad_grad, + ops::PadGradKernel); diff --git a/paddle/operators/pad_op.cu b/paddle/operators/pad_op.cu new file mode 100644 index 00000000000..555a7dba23c --- /dev/null +++ b/paddle/operators/pad_op.cu @@ -0,0 +1,21 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/pad_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(pad, ops::PadKernel); +REGISTER_OP_GPU_KERNEL(pad_grad, + ops::PadGradKernel); diff --git a/paddle/operators/pad_op.h b/paddle/operators/pad_op.h new file mode 100644 index 00000000000..6a743bd31c8 --- /dev/null +++ b/paddle/operators/pad_op.h @@ -0,0 +1,81 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/operators/math/math_function.h" + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +using EigenTensor = framework::EigenTensor; + +template +class PadKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto paddings = + context.op_.GetAttr>>("paddings"); + T pad_value = context.op_.GetAttr("pad_value"); + + auto* X = context.Input("X"); + auto* Out = context.Output("Out"); + Out->mutable_data(context.GetPlace()); + auto dims = X->dims(); + + // Eigen::TensorMap> X_tensor = EigenTensor::From(*X); + // Eigen::TensorMap> + // Out_tensor = EigenTensor::From(*Out); + EigenTensor::ConstType X_tensor = + EigenTensor::From(*X); + EigenTensor::Type Out_tensor = + EigenTensor::From(*Out); + Out_tensor = X_tensor.pad(paddings, pad_value); + } +}; + +template +class PadGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + std::vector> paddings = + context.op_.GetAttr>>("paddings"); + for (int i = 0; i < paddings.size(); ++i) { + paddings[0].first = -paddings[0].first; + paddings[1].second = -paddings[1].second; + } + auto* dOut = ctx.Input(framework::GradVarName("Out")); + auto dims = dOut->dims(); + + auto* dX = ctx.Output(framework::GradVarName("X")); + dX->mutable_data(ctx.GetPlace()); + + EigenTensor::Type dX_tensor = + EigenTensor::From(*dX); + EigenTensor::ConstType dOut_tensor = + EigenTensor::From(*dOut); + dX_tensor = dOut_tensor.pad(paddings, 0); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index abb9c248eee..17ef1e8291a 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -17,5 +17,6 @@ cc_library(paddle_pybind SHARED fill_zeros_like_op lookup_table_op scale_op - minus_op) + minus_op + pad_op) endif(WITH_PYTHON) diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 8fa8be2cef5..0176eb7a88a 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -47,6 +47,7 @@ USE_OP(scale); USE_OP_ITSELF(identity); USE_OP(minus); USE_CPU_ONLY_OP(gather); +USE_OP(pad); namespace paddle { namespace framework { diff --git a/python/paddle/v2/framework/tests/test_pad_op.py b/python/paddle/v2/framework/tests/test_pad_op.py new file mode 100644 index 00000000000..89ac7e7e1d0 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_pad_op.py @@ -0,0 +1,32 @@ +import unittest +import numpy as np +from gradient_checker import GradientChecker, create_op +from op_test_util import OpTestMeta + + +class TestPadOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "pad" + self.inputs = {'X': np.random.random((16, 16)).astype("float32"), } + self.attrs['paddings'] = ((0, 1), (2, 3)) + self.attrs['pad_value'] = 0 + self.outputs = { + 'Out': np.pad(self.inputs['X'], + self.attrs['paddings'], + mode='constant', + constant_value=0) + } + + +class PadGradOpTest(GradientChecker): + def test_pad(self): + op = Operator("pad", paddings=((0, 1), (2, 3)), pad_value=0) + inputs = {'X': np.random.random((16, 16)).astype("float32"), } + + self.check_grad(op, inputs, set(["X"]), "Out", max_relative_error=0.5) + + +if __name__ == '__main__': + unittest.main() -- GitLab