From 5c057f95529656e379ef404a2e388e1be3e88de1 Mon Sep 17 00:00:00 2001 From: sweetsky0901 Date: Sun, 3 Dec 2017 14:40:33 +0800 Subject: [PATCH] add spp op only can test ok --- paddle/operators/spp_op.cc | 98 +++++++++++++ paddle/operators/spp_op.h | 148 ++++++++++++++++++++ python/paddle/v2/fluid/tests/test_spp_op.py | 48 +++++++ 3 files changed, 294 insertions(+) create mode 100644 paddle/operators/spp_op.cc create mode 100644 paddle/operators/spp_op.h create mode 100644 python/paddle/v2/fluid/tests/test_spp_op.py diff --git a/paddle/operators/spp_op.cc b/paddle/operators/spp_op.cc new file mode 100644 index 00000000000..62fc2112a8b --- /dev/null +++ b/paddle/operators/spp_op.cc @@ -0,0 +1,98 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +Indicesou may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/spp_op.h" +namespace paddle { +namespace operators { + +class SppOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SppOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "X", + "(Tensor) The input tensor of spp operator. " + "The format of input tensor is NCHW. Where N is batch size, C is the " + "number of channels, H and W is the height and width of feature."); + AddOutput("Out", + "(Tensor) The output tensor of spp operator." + "N * M." + "M = C * H * W"); + AddAttr("pyramid_height", ">= 1"); + AddComment(R"DOC( + "Input shape: $(N, C_{in}, H_{in}, W_{in})$ + Output shape: $(H_{out}, W_{out})$ + Where + $$ + H_{out} = (H_{in}−1) * strides[0] − 2 * paddings[0] + ksize[0] \\ + W_{out} = (W_{in}−1) * strides[1] − 2 * paddings[1] + ksize[1] + $$ + )DOC"); + } +}; + +int OutputSize(int pyramid_level, int input_size) { + int bins = std::pow(2, pyramid_level); + int ksize = std::ceil(input_size / static_cast(bins)); + int padding = (ksize * bins - input_size + 1) / 2; + int output_size = (input_size - ksize + 2 * padding) / ksize + 1; + // output_size = bins + return output_size; +} + +class SppOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SppOp" + "should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SppOp should not be null."); + auto in_x_dims = ctx->GetInputDim("X"); + int pyramid_height = ctx->Attrs().Get("pyramid_height"); + PADDLE_ENFORCE(in_x_dims.size() == 4, + "Spping intput must be of 4-dimensional."); + int outlen = 0; + for (int p = 0; p < pyramid_height; ++p) { + int outh = OutputSize(p, in_x_dims[2]); + int outw = OutputSize(p, in_x_dims[3]); + int p_level_outlen = outh * outw * in_x_dims[1]; + outlen += p_level_outlen; + } + std::vector output_shape({in_x_dims[0], outlen}); + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + } +}; + +class SppOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Input(X@GRAD) should not be null."); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(spp, ops::SppOp, ops::SppOpMaker, spp_grad, ops::SppOpGrad); +REGISTER_OP_CPU_KERNEL(spp, ops::SppKernel, + ops::SppKernel); +REGISTER_OP_CPU_KERNEL(spp_grad, + ops::SppGradKernel, + ops::SppGradKernel); diff --git a/paddle/operators/spp_op.h b/paddle/operators/spp_op.h new file mode 100644 index 00000000000..2a2824bb310 --- /dev/null +++ b/paddle/operators/spp_op.h @@ -0,0 +1,148 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +Indicesou may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/pooling.h" +#include "paddle/operators/strided_memcpy.h" + +namespace paddle { +namespace operators { +template +class SppKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const framework::Tensor* in_x = context.Input("X"); + auto* out = context.Output("Out"); + int pyramid_height = context.template Attr("pyramid_height"); + out->mutable_data(context.GetPlace()); + auto out_stride = framework::stride(out->dims()); + int input_h = in_x->dims()[2]; + int input_w = in_x->dims()[3]; + size_t output_offset = 0; + for (int p = 0; p < pyramid_height; ++p) { + int bins = std::pow(2, p); + int ksize_h = std::ceil(input_h / static_cast(bins)); + int ksize_w = std::ceil(input_w / static_cast(bins)); + int padding_h = (ksize_h * bins - input_h + 1) / 2; + int padding_w = (ksize_w * bins - input_w + 1) / 2; + std::vector ksize({ksize_h, ksize_w}); + std::vector strides({ksize_h, ksize_w}); + std::vector paddings({padding_h, padding_w}); + // pooling output shape + std::vector output_shape_vec({in_x->dims()[0], in_x->dims()[1]}); + output_shape_vec.push_back((input_h - ksize_h + 2 * padding_h) / ksize_h + + 1); + output_shape_vec.push_back((input_w - ksize_w + 2 * padding_w) / ksize_w + + 1); + framework::DDim output_shape(framework::make_ddim(output_shape_vec)); + // flatten pooling output shape + int output_flatten_w = in_x->dims()[1] * bins * bins; + std::vector output_flatten_shape_vec( + {in_x->dims()[0], output_flatten_w}); + framework::DDim output_flatten_shape( + framework::make_ddim(output_flatten_shape_vec)); + framework::Tensor out_level; + framework::Tensor out_flatten_level; + out_level.mutable_data(output_shape, context.GetPlace()); + // pooling + math::Pool2dFunctor, T> pool_forward; + math::MaxPool max_process; + pool_forward(context.device_context(), *in_x, ksize, strides, paddings, + max_process, &out_level); + out_flatten_level.ShareDataWith(out_level); + out_flatten_level.Resize(output_flatten_shape); + auto in_stride = framework::stride(out_flatten_level.dims()); + const T* src_data = out_flatten_level.data(); + StridedMemcpy(context.device_context(), src_data, in_stride, + out_flatten_level.dims(), out_stride, + out->data() + output_offset); + output_offset += out_flatten_level.dims()[1] * in_stride[1]; + } + } +}; +template +class SppGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const framework::Tensor* in_x = context.Input("X"); + const framework::Tensor* out = context.Input("Out"); + const framework::Tensor* out_grad = + context.Input(framework::GradVarName("Out")); + framework::Tensor* in_x_grad = + context.Output(framework::GradVarName("X")); + auto& device_ctx = context.device_context(); + math::SetConstant zero; + in_x_grad->mutable_data(context.GetPlace()); + zero(device_ctx, in_x_grad, static_cast(0)); + int pyramid_height = context.template Attr("pyramid_height"); + auto outgrad_stride = framework::stride(out_grad->dims()); + auto out_stride = framework::stride(out->dims()); + int input_h = in_x->dims()[2]; + int input_w = in_x->dims()[3]; + size_t out_offset = 0; + for (int p = 0; p < pyramid_height; ++p) { + int bins = std::pow(2, p); + int ksize_h = std::ceil(input_h / static_cast(bins)); + int ksize_w = std::ceil(input_w / static_cast(bins)); + int padding_h = (ksize_h * bins - input_h + 1) / 2; + int padding_w = (ksize_w * bins - input_w + 1) / 2; + std::vector ksize({ksize_h, ksize_w}); + std::vector strides({ksize_h, ksize_w}); + std::vector paddings({padding_h, padding_w}); + // split outgrad and get flatten + std::vector out_shape_vec({in_x->dims()[0], in_x->dims()[1]}); + out_shape_vec.push_back((input_h - ksize_h + 2 * padding_h) / ksize_h + + 1); + out_shape_vec.push_back((input_w - ksize_w + 2 * padding_w) / ksize_w + + 1); + framework::DDim out_shape(framework::make_ddim(out_shape_vec)); + int out_flatten_w = in_x->dims()[1] * bins * bins; + std::vector out_flatten_shape_vec( + {in_x->dims()[0], out_flatten_w}); + framework::DDim out_flatten_shape( + framework::make_ddim(out_flatten_shape_vec)); + framework::Tensor out_level; + framework::Tensor outgrad_level; + framework::Tensor out_flatten_level; + framework::Tensor outgrad_flatten_level; + out_flatten_level.mutable_data(out_flatten_shape, context.GetPlace()); + outgrad_flatten_level.mutable_data(out_flatten_shape, + context.GetPlace()); + + auto flatten_stride = framework::stride(out_flatten_level.dims()); + // memcpy + StridedMemcpy(context.device_context(), out->data() + out_offset, + out_stride, out_flatten_level.dims(), flatten_stride, + out_flatten_level.data()); + + StridedMemcpy(context.device_context(), + out_grad->data() + out_offset, outgrad_stride, + outgrad_flatten_level.dims(), flatten_stride, + outgrad_flatten_level.data()); + out_offset += out_flatten_level.dims()[1] * out_stride[1]; + // flatten backward + out_level.ShareDataWith(out_flatten_level); + out_level.Resize(out_shape); + outgrad_level.ShareDataWith(outgrad_flatten_level); + outgrad_level.Resize(out_shape); + math::MaxPool2dGradFunctor pool2d_backward; + pool2d_backward(context.device_context(), *in_x, *&out_level, + *&outgrad_level, ksize, strides, paddings, in_x_grad); + } + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_spp_op.py b/python/paddle/v2/fluid/tests/test_spp_op.py new file mode 100644 index 00000000000..806d5e7736b --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_spp_op.py @@ -0,0 +1,48 @@ +import unittest +import numpy as np +from op_test import OpTest +from test_pool2d_op import max_pool2D_forward_naive + + +class TestSppOp(OpTest): + def setUp(self): + self.op_type = "spp" + self.init_test_case() + input = np.random.random(self.shape).astype("float32") + nsize, csize, hsize, wsize = input.shape + out_level_flatten = [] + for i in xrange(self.pyramid_height): + bins = np.power(2, i) + ksize = [0, 0] + padding = [0, 0] + ksize[0] = np.ceil(hsize / bins.astype("double")).astype("int32") + padding[0] = ((ksize[0] * bins - hsize + 1) / 2).astype("int32") + + ksize[1] = np.ceil(wsize / bins.astype("double")).astype("int32") + padding[1] = ((ksize[1] * bins - wsize + 1) / 2).astype("int32") + out_level = max_pool2D_forward_naive(input, ksize, ksize, padding) + out_level_flatten.append( + out_level.reshape(nsize, bins * bins * csize)) + if i == 0: + output = out_level_flatten[i] + else: + output = np.concatenate((output, out_level_flatten[i]), 1) + # output = np.concatenate(out_level_flatten.tolist(), 0); + self.inputs = {'X': input.astype('float32'), } + self.attrs = {'pyramid_height': self.pyramid_height} + + self.outputs = {'Out': output.astype('float32')} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + def init_test_case(self): + self.shape = [1, 1, 2, 2] + self.pyramid_height = 2 + + +if __name__ == '__main__': + unittest.main() -- GitLab