提交 26cec839 编写于 作者: W wanghaoshuang

Add pad op

上级 818a64f4
......@@ -72,3 +72,4 @@ op_library(uniform_random_op SRCS uniform_random_op.cc uniform_random_op.cu)
op_library(lookup_table_op SRCS lookup_table_op.cc lookup_table_op.cu)
op_library(scale_op SRCS scale_op.cc scale_op.cu DEPS net_op)
op_library(minus_op SRCS minus_op.cc minus_op.cu DEPS scale_op)
op_library(pad_op SRCS pad_op.cc pad_op.cu)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/pad_op.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class PadOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto dim0 = ctx.Input<Tensor>("X")->dims();
auto dim1 = ctx.Output<Tensor>("Out")->dims();
auto paddings = GetAttr<std::vector<std::pair<int32, int32>>>("paddings");
for (int i = 0; i < dim0.size(); ++i) {
dim1[i] = dim0[i] + paddings[i][0] + paddings[i][1];
}
ctx.Output<Tensor>("Out")->Resize(dim1);
}
};
class MulOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MulOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of pad op");
AddOutput("Out", "The output of pad op");
AddComment(R"DOC(
Pad Operator.
)DOC");
AddAttr<std::vector<std::pair<int, int>>>(
"paddings", "The padding rules for each dimension");
AddAttr<float>("pad_value", "The value to be padded into tensor")
.SetDefault(0.0f);
}
};
class PadOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
x_grad->Resize(x_dims);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(pad, ops::PadOp, ops::PadOpMaker, pad_grad, ops::PadOpGrad);
REGISTER_OP_CPU_KERNEL(pad, ops::PadKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(pad_grad,
ops::PadGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/pad_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(pad, ops::PadKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(pad_grad,
ops::PadGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/operators/math/math_function.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
template <typename Place, typename T>
class PadKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto paddings =
context.op_.GetAttr<std::vector<std::pair<int, int>>>("paddings");
T pad_value = context.op_.GetAttr<T>("pad_value");
auto* X = context.Input<Tensor>("X");
auto* Out = context.Output<Tensor>("Out");
Out->mutable_data<T>(context.GetPlace());
auto dims = X->dims();
// Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor,
// Eigen::DenseIndex>> X_tensor = EigenTensor<T, 2>::From(*X);
// Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, Eigen::DenseIndex>>
// Out_tensor = EigenTensor<T, 2>::From(*Out);
EigenTensor<T, dims.size()>::ConstType X_tensor =
EigenTensor<T, dims.size()>::From(*X);
EigenTensor<T, dims.size()>::Type Out_tensor =
EigenTensor<T, dims.size()>::From(*Out);
Out_tensor = X_tensor.pad(paddings, pad_value);
}
};
template <typename Place, typename T>
class PadGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
std::vector<std::pair<int, int>> paddings =
context.op_.GetAttr<std::vector<std::pair<int, int>>>("paddings");
for (int i = 0; i < paddings.size(); ++i) {
paddings[0].first = -paddings[0].first;
paddings[1].second = -paddings[1].second;
}
auto* dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto dims = dOut->dims();
auto* dX = ctx.Output<Tensor>(framework::GradVarName("X"));
dX->mutable_data<T>(ctx.GetPlace());
EigenTensor<T, dims.size()>::Type dX_tensor =
EigenTensor<T, dims.size()>::From(*dX);
EigenTensor<T, dims.size()>::ConstType dOut_tensor =
EigenTensor<T, dims.size()>::From(*dOut);
dX_tensor = dOut_tensor.pad(paddings, 0);
}
};
} // namespace operators
} // namespace paddle
......@@ -17,5 +17,6 @@ cc_library(paddle_pybind SHARED
fill_zeros_like_op
lookup_table_op
scale_op
minus_op)
minus_op
pad_op)
endif(WITH_PYTHON)
......@@ -47,6 +47,7 @@ USE_OP(scale);
USE_OP_ITSELF(identity);
USE_OP(minus);
USE_CPU_ONLY_OP(gather);
USE_OP(pad);
namespace paddle {
namespace framework {
......
import unittest
import numpy as np
from gradient_checker import GradientChecker, create_op
from op_test_util import OpTestMeta
class TestPadOp(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
self.type = "pad"
self.inputs = {'X': np.random.random((16, 16)).astype("float32"), }
self.attrs['paddings'] = ((0, 1), (2, 3))
self.attrs['pad_value'] = 0
self.outputs = {
'Out': np.pad(self.inputs['X'],
self.attrs['paddings'],
mode='constant',
constant_value=0)
}
class PadGradOpTest(GradientChecker):
def test_pad(self):
op = Operator("pad", paddings=((0, 1), (2, 3)), pad_value=0)
inputs = {'X': np.random.random((16, 16)).astype("float32"), }
self.check_grad(op, inputs, set(["X"]), "Out", max_relative_error=0.5)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册