diff --git a/paddle/fluid/operators/math/padding.h b/paddle/fluid/operators/math/padding.h new file mode 100644 index 0000000000000000000000000000000000000000..3ae25eae98b25bca015ec4383c7126eb81e52b8a --- /dev/null +++ b/paddle/fluid/operators/math/padding.h @@ -0,0 +1,124 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/fluid/framework/tensor.h" + +namespace paddle { +namespace operators { +namespace math { + +template +using EigenTensor = framework::EigenTensor; + +template +void PadFunction(const framework::ExecutionContext& context, + const std::vector& pads, const framework::Tensor& src, + T pad_value, framework::Tensor* out) { + Eigen::array, D> paddings; + + for (size_t i = 0; i < paddings.size(); ++i) { + paddings[i].first = pads[i * 2]; + paddings[i].second = pads[i * 2 + 1]; + } + + auto src_tensor = EigenTensor::From(src); + auto out_tensor = EigenTensor::From(*out); + + auto& place = + *context.template device_context().eigen_device(); + out_tensor.device(place) = src_tensor.pad(paddings, pad_value); +} + +template +void PadGradFunction(const framework::ExecutionContext& context, + const std::vector& pads, const framework::Tensor& src, + framework::Tensor* d_out) { + Eigen::array, D> paddings; + for (size_t i = 0; i < paddings.size(); ++i) { + paddings[i].first = -pads[i * 2]; + paddings[i].second = -pads[i * 2 + 1]; + } + + auto d_out_tensor = EigenTensor::From(*d_out); + auto src_tensor = EigenTensor::From(src); + auto& place = + *context.template device_context().eigen_device(); + d_out_tensor.device(place) = src_tensor.pad(paddings, 0); +} + +template +void PaddingFunctor(int rank, const framework::ExecutionContext& context, + const std::vector& pads, T pad_value, + const framework::Tensor& src, framework::Tensor* out) { + switch (rank) { + case 1: + PadFunction(context, pads, src, pad_value, out); + break; + case 2: + PadFunction(context, pads, src, pad_value, out); + break; + case 3: + PadFunction(context, pads, src, pad_value, out); + break; + case 4: + PadFunction(context, pads, src, pad_value, out); + break; + case 5: + PadFunction(context, pads, src, pad_value, out); + break; + case 6: + PadFunction(context, pads, src, pad_value, out); + break; + default: + PADDLE_THROW( + "PadOp only support tensors with no more than 6 dimensions."); + } +} + +template +void PaddingGradFunctor(int rank, const framework::ExecutionContext& context, + const std::vector& pads, + const framework::Tensor& src, framework::Tensor* out) { + switch (rank) { + case 1: + PadGradFunction(context, pads, src, out); + break; + case 2: + PadGradFunction(context, pads, src, out); + break; + case 3: + PadGradFunction(context, pads, src, out); + break; + case 4: + PadGradFunction(context, pads, src, out); + break; + case 5: + PadGradFunction(context, pads, src, out); + break; + case 6: + PadGradFunction(context, pads, src, out); + break; + default: + PADDLE_THROW( + "PadOp only support tensors with no more than 6 dimensions."); + } +} + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/pad_constant_like_op.cc b/paddle/fluid/operators/pad_constant_like_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..5958811d38f4f772264d2c837c42a704062649f7 --- /dev/null +++ b/paddle/fluid/operators/pad_constant_like_op.cc @@ -0,0 +1,196 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/pad_constant_like_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class PadConstantLikeOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of PadConstantLikeOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), + "Input(Y) of PadConstantLikeOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of PadConstantLikeOp should not be null."); + + auto x_dim = ctx->GetInputDim("X"); + auto y_dim = ctx->GetInputDim("Y"); + + PADDLE_ENFORCE_EQ(x_dim.size(), y_dim.size(), + "The dimention of X and Y should be the same."); + + for (int i = 0; i < x_dim.size(); ++i) { + PADDLE_ENFORCE_GE(x_dim[i], y_dim[i]); + } + ctx->SetOutputDim("Out", x_dim); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +class PadConstantLikeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "The input of pad_constant_like op. " + "The input should be a k-D tensor(k > 0 and k < 7)"); + AddInput("Y", + "The input of pad_constant_like op. " + "The input should be a k-D tensor(k > 0 and k < 7)"); + AddOutput("Out", + "The output of pad_constant_like op. " + "A tensor with the same shape as X."); + AddAttr("pad_value", + "(float, default 0.0) " + "The value to fill the padded areas.") + .SetDefault(0.0f); + AddComment(R"DOC( +PadConstantLikeOp Operator. + +Pad input(Y) with a pad_value, the number of values padded to the edges of each +axis is specified by the difference of the shape of X and Y. +((0, shape_x_0 - shape_y_0), … (0, shape_x_n - shape_y_n)) unique pad widths for +each axis. +The input should be a k-D tensor(k > 0 and k < 7). As an example: + +case1: + Given: + X = [[1, 2], + [3, 4], + [1, 2], + [3, 4]]], + X.shape = (4, 2) + + Y = [[5, 6], + [7, 8]], + Y.shape = (2, 2) + + And + pad_value = 0, + + Return: + Out = [[5, 6], + [7, 8], + [0, 0], + [0, 0]] + Out.shape = (4, 2) + +case2: + Given: + X = [[[[ 0, 1, 2], + [ 3, 4, 5]], + [[ 6, 7, 8], + [ 9, 10, 11]], + [[12, 13, 14], + [15, 16, 17]]], + [[[18, 19, 20], + [21, 22, 23]], + [[24, 25, 26], + [27, 28, 29]], + [[30, 31, 32], + [33, 34, 35]]]] + X.shape = (2, 3, 2, 3) + + Y = [[[[35, 36, 37]], + [[38, 39, 40]], + [[41, 42, 43]]]] + Y.shape = (1, 3, 1, 3) + + And + pad_value = -1, + + Return: + + Out = [[[[35, 36, 37], + [-1, -1, -1]], + [[38, 39, 40], + [-1, -1, -1]], + [[41, 42, 43], + [-1, -1, -1]]], + [[[-1, -1, -1], + [-1, -1, -1]], + [[-1, -1, -1], + [-1, -1, -1]], + [[-1, -1, -1], + [-1, -1, -1]]]] + Out.shape = (2, 3, 2, 3) +)DOC"); + } +}; + +class PadConstantLikeOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto y_dim = ctx->GetInputDim("Y"); + auto dout_dim = ctx->GetInputDim(framework::GradVarName("Out")); + + PADDLE_ENFORCE_EQ(dout_dim.size(), y_dim.size(), + "The dimention of X and Y should be the same."); + + auto y_grad_name = framework::GradVarName("Y"); + if (ctx->HasOutput(y_grad_name)) { + ctx->SetOutputDim(y_grad_name, y_dim); + ctx->ShareLoD("Y", /*->*/ y_grad_name); + + for (int i = 0; i < y_dim.size(); ++i) { + PADDLE_ENFORCE_GE(dout_dim[i], y_dim[i]); + } + } + } +}; + +class PadConstantLikeOpGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto *bind = new framework::OpDesc(); + bind->SetType("pad_constant_like_grad"); + bind->SetInput("Y", Input("Y")); + bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + bind->SetOutput(framework::GradVarName("Y"), InputGrad("Y")); + bind->SetAttrMap(Attrs()); + return std::unique_ptr(bind); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(pad_constant_like, ops::PadConstantLikeOp, + ops::PadConstantLikeOpMaker, ops::PadConstantLikeOpGradMaker); +REGISTER_OPERATOR(pad_constant_like_grad, ops::PadConstantLikeOpGrad); + +REGISTER_OP_CPU_KERNEL( + pad_constant_like, + ops::PadConstantLikeKernel, + ops::PadConstantLikeKernel); +REGISTER_OP_CPU_KERNEL( + pad_constant_like_grad, + ops::PadConstantLikeGradKernel, + ops::PadConstantLikeGradKernel); diff --git a/paddle/fluid/operators/pad_constant_like_op.cu b/paddle/fluid/operators/pad_constant_like_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..ea69577904577de353b63491973bf74b7724e18e --- /dev/null +++ b/paddle/fluid/operators/pad_constant_like_op.cu @@ -0,0 +1,27 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/fluid/operators/pad_constant_like_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + pad_constant_like, + ops::PadConstantLikeKernel, + ops::PadConstantLikeKernel); +REGISTER_OP_CUDA_KERNEL( + pad_constant_like_grad, + ops::PadConstantLikeGradKernel, + ops::PadConstantLikeGradKernel); diff --git a/paddle/fluid/operators/pad_constant_like_op.h b/paddle/fluid/operators/pad_constant_like_op.h new file mode 100644 index 0000000000000000000000000000000000000000..01d66901afc49a487c344b039b65f547967e95ff --- /dev/null +++ b/paddle/fluid/operators/pad_constant_like_op.h @@ -0,0 +1,93 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/math/padding.h" + +namespace paddle { +namespace operators { + +template +class PadConstantLikeKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto in_x = context.Input("X"); + auto in_y = context.Input("Y"); + auto* out = context.Output("Out"); + + if (in_x->dims() == in_y->dims()) { + // TensorCopy(in_y, context.GetPlace(), context, out); + out->ShareDataWith(*in_y); + return; + } + + T pad_value = context.Attr("pad_value"); + out->mutable_data(context.GetPlace()); + + int rank = context.Input("X")->dims().size(); + + std::vector pads(rank * 2, 0); + + for (int j = 0; j < rank; ++j) { + pads[j * 2] = 0; + pads[j * 2 + 1] = static_cast(in_x->dims()[j] - in_y->dims()[j]); + } + + math::PaddingFunctor(rank, context, pads, pad_value, + *in_y, out); + } +}; + +template +class PadConstantLikeGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto in_y = context.Input("Y"); + auto in_dout = + context.Input(framework::GradVarName("Out")); + auto* d_y = context.Output(framework::GradVarName("Y")); + + if (d_y == nullptr) { + return; + } + + if (in_dout->dims() == in_y->dims()) { + // TensorCopy(in_dout, context.GetPlace(), context, d_y); + d_y->ShareDataWith(*in_dout); + return; + } + + d_y->mutable_data(context.GetPlace()); + int rank = in_dout->dims().size(); + + std::vector pads(static_cast(rank) * 2, 0); + for (int j = 0; j < rank; ++j) { + pads[j * 2] = 0; + pads[j * 2 + 1] = static_cast(in_dout->dims()[j] - in_y->dims()[j]); + } + + math::PaddingGradFunctor(rank, context, pads, *in_dout, + d_y); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/pad_op.h b/paddle/fluid/operators/pad_op.h index c93c096575a30dd9344894ead4b81acc16930e21..32698dac4917e183cfe36c831787b049985b19b3 100644 --- a/paddle/fluid/operators/pad_op.h +++ b/paddle/fluid/operators/pad_op.h @@ -18,117 +18,44 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/padding.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; -template -using EigenTensor = framework::EigenTensor; - -template -void PadFunction(const framework::ExecutionContext& context) { - auto pads = context.Attr>("paddings"); - Eigen::array, D> paddings; - for (size_t i = 0; i < paddings.size(); ++i) { - paddings[i].first = pads[i * 2]; - paddings[i].second = pads[i * 2 + 1]; - } - T pad_value = context.Attr("pad_value"); - - auto* x = context.Input("X"); - auto* out = context.Output("Out"); - out->mutable_data(context.GetPlace()); - - auto x_tensor = EigenTensor::From(*x); - auto out_tensor = EigenTensor::From(*out); - auto& place = - *context.template device_context().eigen_device(); - out_tensor.device(place) = x_tensor.pad(paddings, pad_value); -} - template class PadKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - int rank = context.Input("X")->dims().size(); - switch (rank) { - case 1: - PadFunction(context); - break; - case 2: - PadFunction(context); - break; - case 3: - PadFunction(context); - break; - case 4: - PadFunction(context); - break; - case 5: - PadFunction(context); - break; - case 6: - PadFunction(context); - break; - default: - PADDLE_THROW( - "PadOp only support tensors with no more than 6 dimensions."); - } + auto pads = context.Attr>("paddings"); + T pad_value = context.Attr("pad_value"); + auto* x = context.Input("X"); + auto* out = context.Output("Out"); + out->mutable_data(context.GetPlace()); + + int rank = x->dims().size(); + math::PaddingFunctor(rank, context, pads, pad_value, *x, + out); } }; -template -void PadGradFunction(const framework::ExecutionContext& context) { - auto pads = context.Attr>("paddings"); - Eigen::array, D> paddings; - for (size_t i = 0; i < paddings.size(); ++i) { - paddings[i].first = -pads[i * 2]; - paddings[i].second = -pads[i * 2 + 1]; - } - auto* d_out = context.Input(framework::GradVarName("Out")); - auto* d_x = context.Output(framework::GradVarName("X")); - if (d_x != nullptr) { - d_x->mutable_data(context.GetPlace()); - auto d_x_tensor = EigenTensor::From(*d_x); - auto d_out_tensor = EigenTensor::From(*d_out); - auto& place = - *context.template device_context().eigen_device(); - d_x_tensor.device(place) = d_out_tensor.pad(paddings, 0); - } -} - template class PadGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - size_t rank = - context.Input(framework::GradVarName("Out"))->dims().size(); - switch (rank) { - case 1: - PadGradFunction(context); - break; - case 2: - PadGradFunction(context); - break; - case 3: - PadGradFunction(context); - break; - case 4: - PadGradFunction(context); - break; - case 5: - PadGradFunction(context); - break; - case 6: - PadGradFunction(context); - break; - default: - PADDLE_THROW( - "PadOp only support tensors with no more than 6 dimensions."); + auto pads = context.Attr>("paddings"); + auto* d_out = context.Input(framework::GradVarName("Out")); + auto* d_x = context.Output(framework::GradVarName("X")); + if (d_x == nullptr) { + return; } + + d_x->mutable_data(context.GetPlace()); + int rank = d_out->dims().size(); + math::PaddingGradFunctor(rank, context, pads, *d_out, + d_x); } }; diff --git a/python/paddle/fluid/tests/unittests/test_pad_constant_like.py b/python/paddle/fluid/tests/unittests/test_pad_constant_like.py new file mode 100644 index 0000000000000000000000000000000000000000..6b733fd8fa023f07013909502dbbd5371297216e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_pad_constant_like.py @@ -0,0 +1,69 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest + + +class TestPadOp(OpTest): + def setUp(self): + self.initTestCase() + self.op_type = "pad_constant_like" + self.inputs = { + 'X': np.random.random(self.x_shape).astype("float32"), + 'Y': np.random.random(self.y_shape).astype("float32") + } + self.attrs = {} + self.attrs['pad_value'] = self.pad_value + self.outputs = { + 'Out': np.pad(self.inputs['Y'], + self.paddings, + mode='constant', + constant_values=self.pad_value) + } + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['Y'], 'Out', max_relative_error=0.006) + + def initTestCase(self): + self.x_shape = (16, 16) + self.y_shape = (3, 16) + self.pad_value = 0.1 + self.paddings = [(0, 13), (0, 0)] + + +class TestCase1(TestPadOp): + def initTestCase(self): + self.x_shape = (4, 3, 4, 4) + self.y_shape = (2, 3, 4, 4) + self.paddings = [(0, 2), (0, 0), (0, 0), (0, 0)] + self.pad_value = 0.5 + + +class TestCase2(TestPadOp): + def initTestCase(self): + self.x_shape = (4, 3, 4, 4) + self.y_shape = (2, 3, 2, 4) + self.paddings = [(0, 2), (0, 0), (0, 2), (0, 0)] + self.pad_value = 0.5 + + +if __name__ == '__main__': + unittest.main()