提交 5c057f95 编写于 作者: S sweetsky0901

add spp op only can test ok

上级 d89061c3
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/spp_op.h"
namespace paddle {
namespace operators {
class SppOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SppOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
"(Tensor) The input tensor of spp operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of feature.");
AddOutput("Out",
"(Tensor) The output tensor of spp operator."
"N * M."
"M = C * H * W");
AddAttr<int>("pyramid_height", ">= 1");
AddComment(R"DOC(
"Input shape: $(N, C_{in}, H_{in}, W_{in})$
Output shape: $(H_{out}, W_{out})$
Where
$$
H_{out} = (H_{in}−1) * strides[0] − 2 * paddings[0] + ksize[0] \\
W_{out} = (W_{in}−1) * strides[1] − 2 * paddings[1] + ksize[1]
$$
)DOC");
}
};
int OutputSize(int pyramid_level, int input_size) {
int bins = std::pow(2, pyramid_level);
int ksize = std::ceil(input_size / static_cast<double>(bins));
int padding = (ksize * bins - input_size + 1) / 2;
int output_size = (input_size - ksize + 2 * padding) / ksize + 1;
// output_size = bins
return output_size;
}
class SppOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SppOp"
"should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SppOp should not be null.");
auto in_x_dims = ctx->GetInputDim("X");
int pyramid_height = ctx->Attrs().Get<int>("pyramid_height");
PADDLE_ENFORCE(in_x_dims.size() == 4,
"Spping intput must be of 4-dimensional.");
int outlen = 0;
for (int p = 0; p < pyramid_height; ++p) {
int outh = OutputSize(p, in_x_dims[2]);
int outw = OutputSize(p, in_x_dims[3]);
int p_level_outlen = outh * outw * in_x_dims[1];
outlen += p_level_outlen;
}
std::vector<int64_t> output_shape({in_x_dims[0], outlen});
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
}
};
class SppOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Input(X@GRAD) should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(spp, ops::SppOp, ops::SppOpMaker, spp_grad, ops::SppOpGrad);
REGISTER_OP_CPU_KERNEL(spp, ops::SppKernel<paddle::platform::CPUPlace, float>,
ops::SppKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(spp_grad,
ops::SppGradKernel<paddle::platform::CPUPlace, float>,
ops::SppGradKernel<paddle::platform::CPUPlace, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/pooling.h"
#include "paddle/operators/strided_memcpy.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class SppKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
int pyramid_height = context.template Attr<int>("pyramid_height");
out->mutable_data<T>(context.GetPlace());
auto out_stride = framework::stride(out->dims());
int input_h = in_x->dims()[2];
int input_w = in_x->dims()[3];
size_t output_offset = 0;
for (int p = 0; p < pyramid_height; ++p) {
int bins = std::pow(2, p);
int ksize_h = std::ceil(input_h / static_cast<double>(bins));
int ksize_w = std::ceil(input_w / static_cast<double>(bins));
int padding_h = (ksize_h * bins - input_h + 1) / 2;
int padding_w = (ksize_w * bins - input_w + 1) / 2;
std::vector<int> ksize({ksize_h, ksize_w});
std::vector<int> strides({ksize_h, ksize_w});
std::vector<int> paddings({padding_h, padding_w});
// pooling output shape
std::vector<int64_t> output_shape_vec({in_x->dims()[0], in_x->dims()[1]});
output_shape_vec.push_back((input_h - ksize_h + 2 * padding_h) / ksize_h +
1);
output_shape_vec.push_back((input_w - ksize_w + 2 * padding_w) / ksize_w +
1);
framework::DDim output_shape(framework::make_ddim(output_shape_vec));
// flatten pooling output shape
int output_flatten_w = in_x->dims()[1] * bins * bins;
std::vector<int64_t> output_flatten_shape_vec(
{in_x->dims()[0], output_flatten_w});
framework::DDim output_flatten_shape(
framework::make_ddim(output_flatten_shape_vec));
framework::Tensor out_level;
framework::Tensor out_flatten_level;
out_level.mutable_data<T>(output_shape, context.GetPlace());
// pooling
math::Pool2dFunctor<Place, math::MaxPool<T>, T> pool_forward;
math::MaxPool<T> max_process;
pool_forward(context.device_context(), *in_x, ksize, strides, paddings,
max_process, &out_level);
out_flatten_level.ShareDataWith(out_level);
out_flatten_level.Resize(output_flatten_shape);
auto in_stride = framework::stride(out_flatten_level.dims());
const T* src_data = out_flatten_level.data<T>();
StridedMemcpy<T>(context.device_context(), src_data, in_stride,
out_flatten_level.dims(), out_stride,
out->data<T>() + output_offset);
output_offset += out_flatten_level.dims()[1] * in_stride[1];
}
}
};
template <typename Place, typename T>
class SppGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
const framework::Tensor* out = context.Input<framework::Tensor>("Out");
const framework::Tensor* out_grad =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
framework::Tensor* in_x_grad =
context.Output<framework::Tensor>(framework::GradVarName("X"));
auto& device_ctx = context.device_context();
math::SetConstant<Place, T> zero;
in_x_grad->mutable_data<T>(context.GetPlace());
zero(device_ctx, in_x_grad, static_cast<T>(0));
int pyramid_height = context.template Attr<int>("pyramid_height");
auto outgrad_stride = framework::stride(out_grad->dims());
auto out_stride = framework::stride(out->dims());
int input_h = in_x->dims()[2];
int input_w = in_x->dims()[3];
size_t out_offset = 0;
for (int p = 0; p < pyramid_height; ++p) {
int bins = std::pow(2, p);
int ksize_h = std::ceil(input_h / static_cast<double>(bins));
int ksize_w = std::ceil(input_w / static_cast<double>(bins));
int padding_h = (ksize_h * bins - input_h + 1) / 2;
int padding_w = (ksize_w * bins - input_w + 1) / 2;
std::vector<int> ksize({ksize_h, ksize_w});
std::vector<int> strides({ksize_h, ksize_w});
std::vector<int> paddings({padding_h, padding_w});
// split outgrad and get flatten
std::vector<int64_t> out_shape_vec({in_x->dims()[0], in_x->dims()[1]});
out_shape_vec.push_back((input_h - ksize_h + 2 * padding_h) / ksize_h +
1);
out_shape_vec.push_back((input_w - ksize_w + 2 * padding_w) / ksize_w +
1);
framework::DDim out_shape(framework::make_ddim(out_shape_vec));
int out_flatten_w = in_x->dims()[1] * bins * bins;
std::vector<int64_t> out_flatten_shape_vec(
{in_x->dims()[0], out_flatten_w});
framework::DDim out_flatten_shape(
framework::make_ddim(out_flatten_shape_vec));
framework::Tensor out_level;
framework::Tensor outgrad_level;
framework::Tensor out_flatten_level;
framework::Tensor outgrad_flatten_level;
out_flatten_level.mutable_data<T>(out_flatten_shape, context.GetPlace());
outgrad_flatten_level.mutable_data<T>(out_flatten_shape,
context.GetPlace());
auto flatten_stride = framework::stride(out_flatten_level.dims());
// memcpy
StridedMemcpy<T>(context.device_context(), out->data<T>() + out_offset,
out_stride, out_flatten_level.dims(), flatten_stride,
out_flatten_level.data<T>());
StridedMemcpy<T>(context.device_context(),
out_grad->data<T>() + out_offset, outgrad_stride,
outgrad_flatten_level.dims(), flatten_stride,
outgrad_flatten_level.data<T>());
out_offset += out_flatten_level.dims()[1] * out_stride[1];
// flatten backward
out_level.ShareDataWith(out_flatten_level);
out_level.Resize(out_shape);
outgrad_level.ShareDataWith(outgrad_flatten_level);
outgrad_level.Resize(out_shape);
math::MaxPool2dGradFunctor<Place, T> pool2d_backward;
pool2d_backward(context.device_context(), *in_x, *&out_level,
*&outgrad_level, ksize, strides, paddings, in_x_grad);
}
}
};
} // namespace operators
} // namespace paddle
import unittest
import numpy as np
from op_test import OpTest
from test_pool2d_op import max_pool2D_forward_naive
class TestSppOp(OpTest):
def setUp(self):
self.op_type = "spp"
self.init_test_case()
input = np.random.random(self.shape).astype("float32")
nsize, csize, hsize, wsize = input.shape
out_level_flatten = []
for i in xrange(self.pyramid_height):
bins = np.power(2, i)
ksize = [0, 0]
padding = [0, 0]
ksize[0] = np.ceil(hsize / bins.astype("double")).astype("int32")
padding[0] = ((ksize[0] * bins - hsize + 1) / 2).astype("int32")
ksize[1] = np.ceil(wsize / bins.astype("double")).astype("int32")
padding[1] = ((ksize[1] * bins - wsize + 1) / 2).astype("int32")
out_level = max_pool2D_forward_naive(input, ksize, ksize, padding)
out_level_flatten.append(
out_level.reshape(nsize, bins * bins * csize))
if i == 0:
output = out_level_flatten[i]
else:
output = np.concatenate((output, out_level_flatten[i]), 1)
# output = np.concatenate(out_level_flatten.tolist(), 0);
self.inputs = {'X': input.astype('float32'), }
self.attrs = {'pyramid_height': self.pyramid_height}
self.outputs = {'Out': output.astype('float32')}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
def init_test_case(self):
self.shape = [1, 1, 2, 2]
self.pyramid_height = 2
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册