diff --git a/paddle/fluid/operators/reorg_op.cc b/paddle/fluid/operators/reorg_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..1f9da1f7977e7d21d24086df81bc5b8dfb0c6445 --- /dev/null +++ b/paddle/fluid/operators/reorg_op.cc @@ -0,0 +1,127 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/reorg_op.h" +#include +#include + +namespace paddle { +namespace operators { + +class ReorgOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of reorgOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of reorgOp should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_EQ(x_dims.size(), 4, "input should be a 4D tensor"); + auto stride = ctx->Attrs().Get("stride"); + + PADDLE_ENFORCE_GT(stride, 0, "The stride should be Greater than 0"); + PADDLE_ENFORCE_GT(x_dims[1], 0, "input channel should be Greater than 0"); + PADDLE_ENFORCE_GT(x_dims[2], 0, "input Height should be Greater than 0"); + PADDLE_ENFORCE_GT(x_dims[3], 0, "input Width should be Greater than 0"); + + PADDLE_ENFORCE_EQ( + x_dims[1] % (stride * stride), 0, + "input channel should be dvisible of the square of reorg stride"); + PADDLE_ENFORCE_EQ( + x_dims[2] % (stride), 0, + "input Height should be dvisible of the square of reorg stride"); + PADDLE_ENFORCE_EQ( + x_dims[3] % (stride), 0, + "input Width should be dvisible of the square of reorg stride"); + + VLOG(3) << "reorg operator x.shape=" << x_dims << "Attribute stride" + << stride << std::endl; + + std::vector output_shape(4, 0); // [B,C,H,W] + output_shape[0] = x_dims[0]; + output_shape[1] = x_dims[1] * stride * stride; + output_shape[2] = x_dims[2] / stride; + output_shape[3] = x_dims[3] / stride; + + auto out_dims = framework::make_ddim(output_shape); + + ctx->SetOutputDim("Out", out_dims); + + if (x_dims[0] == out_dims[0]) { + // Only pass LoD when the first dimension of output and Input(X) + // are the same. + ctx->ShareLoD("X", /*->*/ "Out"); + } + } +}; + +class ReorgOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor). The input should be a 4D tensor B * C * W * H of reorg " + "operator."); + AddOutput("Out", + "(Tensor), The output should be a 4D tensor B * C2 * W2 * H2 of " + "reorg operator."); + AddAttr("stride", + "(int64_t, default 1) stride used to do reorgnization.") + .SetDefault(1) + .EqualGreaterThan(1); + AddComment(R"DOC( + reorg operator used in Yolo v2. + The equation is: C2 = C1/stride * stride, W2 = W1 ∗ stride + offset % stride, H2 = H1 ∗ stride + offset / stride, + + Reshape Input(X) into the shape according to Attr(stride). The + data in Input(X) are unchanged. + + Examples: + + 1. Given a 3-D tensor Input(X) with a shape [2048, 26, 26], and the stride is 2, the reorg operator will transform Input(X) + into a 3-D tensor with shape [2048, 13, 13] and leaving Input(X)'s data unchanged. + + )DOC"); + } +}; + +class ReorgGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) shouldn't be null."); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(reorg, ops::ReorgOp, ops::ReorgOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(reorg_grad, ops::ReorgGradOp); +REGISTER_OP_CPU_KERNEL( + reorg, ops::ReorgKernel, + ops::ReorgKernel, + ops::ReorgKernel); +REGISTER_OP_CPU_KERNEL( + reorg_grad, ops::ReorgGradKernel, + ops::ReorgGradKernel, + ops::ReorgGradKernel); diff --git a/paddle/fluid/operators/reorg_op.cu b/paddle/fluid/operators/reorg_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..de1c7d7468e105afc4f350e036a3e0eabc37a72c --- /dev/null +++ b/paddle/fluid/operators/reorg_op.cu @@ -0,0 +1,29 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reorg_op.h" + +namespace plat = paddle::platform; +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + reorg, ops::ReorgKernel, + ops::ReorgKernel, + ops::ReorgKernel); + +REGISTER_OP_CUDA_KERNEL( + reorg_grad, + ops::ReorgGradKernel, + ops::ReorgGradKernel, + ops::ReorgGradKernel); diff --git a/paddle/fluid/operators/reorg_op.h b/paddle/fluid/operators/reorg_op.h new file mode 100644 index 0000000000000000000000000000000000000000..108437b4d8f895f5951b4d560307548672f7e9d1 --- /dev/null +++ b/paddle/fluid/operators/reorg_op.h @@ -0,0 +1,126 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#ifndef PADDLE_FLUID_OPERATORS_REORG_OP_H_ +#define PADDLE_FLUID_OPERATORS_REORG_OP_H_ +#endif // PADDLE_FLUID_OPERATORS_REORG_OP_H_ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { + +template +class reorg_cpu { + public: + HOSTDEVICE reorg_cpu(const T *x, int64_t w, int64_t h, int64_t c, + int64_t batch, int64_t stride, int64_t forward, T *out) + : x_(x), + w_(w), + h_(h), + c_(c), + batch_(batch), + stride_(stride), + forward_(forward), + out_(out) {} + + HOSTDEVICE void operator()(int64_t in_index) { + int64_t out_c = c_ / (stride_ * stride_); + // calculate each dim position with index of tensor + int64_t b = in_index / (c_ * h_ * w_); + int64_t k = (in_index % (c_ * h_ * w_)) / (h_ * w_); + int64_t j = ((in_index % (c_ * h_ * w_)) % (h_ * w_)) / w_; + int64_t i = ((in_index % (c_ * h_ * w_)) % (h_ * w_)) % w_; + + int64_t c2 = k % out_c; + int64_t offset = k / out_c; + int64_t w2 = i * stride_ + offset % stride_; + int64_t h2 = j * stride_ + offset / stride_; + int64_t out_index = + w2 + w_ * stride_ * (h2 + h_ * stride_ * (c2 + out_c * b)); + if (forward_) + out_[out_index] = x_[in_index]; + else + out_[in_index] = x_[out_index]; + } + + private: + const T *x_; + int64_t w_, h_, c_, batch_, stride_, forward_; + T *out_; +}; + +template +class ReorgKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *out = context.Output("Out"); + auto *x = context.Input("X"); + auto stride = context.Attr("stride"); + auto in_dims = x->dims(); + out->mutable_data(context.GetPlace(), x->type()); + + auto out_dims = out->dims(); + auto B = in_dims[0]; + auto C = in_dims[1]; + auto H = in_dims[2]; + auto W = in_dims[3]; + platform::ForRange for_range( + context.template device_context(), + static_cast(x->numel())); + + auto *x_data = x->data(); + auto *out_data = out->data(); + paddle::operators::reorg_cpu reorg(x_data, W, H, C, B, stride, 1, + out_data); + for_range(reorg); + + out->Resize(out_dims); + } +}; + +template +class ReorgGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *d_out = + context.Input(framework::GradVarName("Out")); + auto *d_x = + context.Output(framework::GradVarName("X")); + auto stride = context.Attr("stride"); + auto in_dims = d_x->dims(); + d_x->mutable_data(context.GetPlace(), d_out->type()); + + auto B = in_dims[0]; + auto C = in_dims[1]; + auto H = in_dims[2]; + auto W = in_dims[3]; + + platform::ForRange for_range( + context.template device_context(), + static_cast(d_x->numel())); + + auto *dx_data = d_x->data(); + auto *dout_data = d_out->data(); + + paddle::operators::reorg_cpu reorg(dout_data, W, H, C, B, stride, 0, + dx_data); + for_range(reorg); + + d_x->Resize(in_dims); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 8c0ef7a82421ffc04bf669e6850e075226c09d27..35a1a899e793b98f47f69a5db189d9d861112684 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -150,6 +150,7 @@ __all__ = [ 'mul', 'sigmoid_cross_entropy_with_logits', 'maxout', + 'reorg', ] @@ -7084,3 +7085,54 @@ def maxout(x, groups, name=None): attrs={"groups": groups}, outputs={"Out": out}) return out + + +def reorg(x, stride, name=None): + """ + Gives a stride to reorg the input tensor + + Here are some example: + + input is 4D LoDtensor with shape [batch, channel, height, width] and has an attrs stride = 2 + + reorg will do some math work to reorder the elements of input according to stride to construt + put with shape [batch, channel * stride * stride, height/stride, width/stride] + + reorg is used to reorgnization the output of pre_layer and change the tensor to fit the shape + + Args: + x(variable): The input tensor. + stride(variable): The stride to reorg + + Returns: + Variable: The output tensor. + + Raises: + TypeError: stride type must be a long. + + Examples: + .. code-block:: python + + data = fluid.layers.data( + name='data', shape=[1, 4, 2, 2], dtype='float32') + reorged = fluid.layers.reorged( + x=data, stride=2) + """ + + if not (isinstance(stride, long)): + raise ValueError("stride must be a python long") + + helper = LayerHelper("reorg", **locals()) + if name is None: + out = helper.create_tmp_variable(dtype=x.dtype) + else: + out = helper.create_variable( + name=name, dtype=x.dtype, persistable=False) + + helper.append_op( + type="reorg", + inputs={"X": x}, + attrs={"stride": stride}, + outputs={"Out": out}) + + return out diff --git a/python/paddle/fluid/op.py b/python/paddle/fluid/op.py index 667db10d3ebdd24ddd9efbe2310ebb331e268ee2..52b169fb3ccaad16b9acfed56eeffbdd324d7bde 100644 --- a/python/paddle/fluid/op.py +++ b/python/paddle/fluid/op.py @@ -108,6 +108,8 @@ class OpDescCreationMethod(object): new_attr.i = user_defined_attr elif attr.type == framework_pb2.FLOAT: new_attr.f = user_defined_attr + elif attr.type == framework_pb2.LONG: + new_attr.l = user_defined_attr elif attr.type == framework_pb2.STRING: new_attr.s = user_defined_attr elif attr.type == framework_pb2.BOOLEAN: diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 1d8d0b55f0c5d7cffa01a100847bdf48b6d7023d..f34c385617c119f0a3dc26ffb40caacfa901fccb 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -240,6 +240,17 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(layers.softmax(hid)) print(str(program)) + def test_reorg(self): + program = Program() + with program_guard(program): + data = layers.data( + name="data", + shape=[32, 9, 6, 6], + append_batch_size=False, + dtype='float32') + self.assertIsNotNone(layers.reorg(data, long(3))) + print(str(program)) + def test_sequence_unsqueeze(self): program = Program() with program_guard(program): diff --git a/python/paddle/fluid/tests/unittests/test_reorg_op.py b/python/paddle/fluid/tests/unittests/test_reorg_op.py new file mode 100644 index 0000000000000000000000000000000000000000..9d4fa4d0ff72d5535da417c2d303b34ff581d855 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_reorg_op.py @@ -0,0 +1,93 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np +import paddle.fluid as fluid +from op_test import OpTest + + +class TestReorgOp(OpTest): + @staticmethod + def helper(in_, width, height, channel, batch, stride, forward, out_): + channel_out = channel / (stride * stride) + for b in range(batch): + for k in range(channel): + for j in range(height): + for i in range(width): + in_index = i + width * (j + height * (k + channel * b)) + channel2 = k % channel_out + offset = k / channel_out + width2 = i * stride + offset % stride + height2 = j * stride + offset / stride + out_index = width2 + width * stride * ( + height2 + height * stride * + (channel2 + channel_out * b)) + if forward: + out_[out_index] = in_[in_index] + else: + out_[in_index] = in_[out_index] + + def setUp(self): + self.init_data() + + self.op_type = "reorg" + self.inputs = {"X": self.x} + self.helper(self.x_1d, self.x.shape[3], self.x.shape[2], + self.x.shape[1], self.x.shape[0], self.stride, self.forward, + self.out_1d) + self.out = np.reshape(self.out_1d, self.infered_shape) + self.attrs = {"stride": long(self.stride)} + self.outputs = {"Out": self.out} + + def init_data(self): + self.ori_shape = (32, 12, 6, 6) + self.infered_shape = (32, 48, 3, 3) + self.one_d_len = 32 * 48 * 3 * 3 + + self.stride = 2 + self.x = np.random.random(self.ori_shape).astype('float32') + self.x_1d = np.reshape(self.x, self.one_d_len) + self.out = np.zeros(self.infered_shape).astype('float32') + self.out_1d = np.reshape(self.out, self.one_d_len) + self.forward = 1 + + def test_check_output(self): + place = fluid.core.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.core.CPUPlace() + self.check_output_with_place(place, 1e-5, None, False) + + def test_check_grad(self): + place = fluid.core.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.core.CPUPlace() + self.check_grad_with_place(place, ['X'], 'Out') + + +class TestReorgOp2(TestReorgOp): + def init_data(self): + self.ori_shape = (32, 9, 6, 6) + self.infered_shape = (32, 81, 2, 2) + self.one_d_len = 32 * 81 * 2 * 2 + + self.stride = 3 + self.x = np.random.random(self.ori_shape).astype('float32') + self.x_1d = np.reshape(self.x, self.one_d_len) + self.out = np.zeros(self.infered_shape).astype('float32') + self.out_1d = np.reshape(self.out, self.one_d_len) + self.forward = 1 + + +if __name__ == '__main__': + unittest.main()