diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index b19d50a6ad6afa312f5e695583174e56bf490755..d71d792b4e0147afc0ec09808d2004cb4b48ba46 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -235,6 +235,7 @@ paddle.fluid.layers.huber_loss (ArgSpec(args=['input', 'label', 'delta'], vararg paddle.fluid.layers.kldiv_loss (ArgSpec(args=['x', 'target', 'reduction', 'name'], varargs=None, keywords=None, defaults=('mean', None)), ('document', '776d536cac47c89073abc7ee524d5aec')) paddle.fluid.layers.tree_conv (ArgSpec(args=['nodes_vector', 'edge_set', 'output_size', 'num_filters', 'max_depth', 'act', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1, 2, 'tanh', None, None, None)), ('document', '34ea12ac9f10a65dccbc50100d12e607')) paddle.fluid.layers.npair_loss (ArgSpec(args=['anchor', 'positive', 'labels', 'l2_reg'], varargs=None, keywords=None, defaults=(0.002,)), ('document', '46994d10276dd4cb803b4062b5d14329')) +paddle.fluid.layers.pixel_shuffle (ArgSpec(args=['x', 'upscale_factor'], varargs=None, keywords=None, defaults=None), ('document', 'ad669cdf83e72a69ebc5ed79e36486de')) paddle.fluid.layers.fsp_matrix (ArgSpec(args=['x', 'y'], varargs=None, keywords=None, defaults=None), ('document', 'b76ccca3735bea4a58a0dbf0d77c5393')) paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', '33bbd42027d872b3818b3d64ec52e139')) paddle.fluid.layers.open_files (ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)), ('document', 'b1ae2e1cc0750e58726374061ea90ecc')) diff --git a/paddle/fluid/operators/pixel_shuffle_op.cc b/paddle/fluid/operators/pixel_shuffle_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..59ba660af79bff02cd350afb3eb7675bfe8ac498 --- /dev/null +++ b/paddle/fluid/operators/pixel_shuffle_op.cc @@ -0,0 +1,135 @@ +/*Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/pixel_shuffle_op.h" +#include + +namespace paddle { +namespace operators { + +class PixelShuffleOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of PixelShuffleOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of PixelShuffleOp should not be null."); + + auto input_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW."); + auto upscale_factor = ctx->Attrs().Get("upscale_factor"); + + PADDLE_ENFORCE(input_dims[1] % (upscale_factor * upscale_factor) == 0, + "Upscale_factor should devide the number of channel"); + + auto output_dims = input_dims; + output_dims[0] = input_dims[0]; + output_dims[1] = input_dims[1] / (upscale_factor * upscale_factor); + output_dims[2] = input_dims[2] * upscale_factor; + output_dims[3] = input_dims[3] * upscale_factor; + ctx->SetOutputDim("Out", output_dims); + } +}; + +class PixelShuffleOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput( + "X", + "(Tensor, default Tensor), " + "the input feature data of PixelShuffleOp, the layout is [N C H W]."); + AddOutput( + "Out", + "(Tensor, default Tensor), the output of " + "PixelShuffleOp. The layout is [N,C/factor^2,H*factor,W*factor]."); + AddAttr("upscale_factor", + "the factor to increase spatial resolution by.") + .SetDefault(1) + .AddCustomChecker([](const int& upscale_factor) { + PADDLE_ENFORCE_GE(upscale_factor, 1, + "upscale_factor should be larger than 0."); + }); + + AddComment(R"DOC( + Pixel Shuffle operator + This operator rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` + to a tensor of shape :math:`(C, H \times r, W \times r)`. + + This is useful for implementing efficient sub-pixel convolution + with a stride of :math:`1/r`. + + Please refer to the paper: + `Real-Time Single Image and Video Super-Resolution Using an Efficient + Sub-Pixel Convolutional Neural Network `_ + by Shi et. al (2016) for more details. + + )DOC"); + } +}; + +class PixelShuffleGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + std::unique_ptr Apply() const override { + auto* op = new framework::OpDesc(); + op->SetType("pixel_shuffle_grad"); + op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + op->SetAttrMap(Attrs()); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + return std::unique_ptr(op); + } +}; + +class PixelShuffleGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@Grad) should not be null"); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Output(X@Grad) should not be null"); + + auto do_dims = ctx->GetInputDim(framework::GradVarName("Out")); + PADDLE_ENFORCE(do_dims.size() == 4, "The layout of input is NCHW."); + + auto upscale_factor = ctx->Attrs().Get("upscale_factor"); + + auto dx_dims = do_dims; + dx_dims[0] = do_dims[0]; + dx_dims[1] = do_dims[1] * (upscale_factor * upscale_factor); + dx_dims[2] = do_dims[2] / upscale_factor; + dx_dims[3] = do_dims[3] / upscale_factor; + ctx->SetOutputDim(framework::GradVarName("X"), dx_dims); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(pixel_shuffle, ops::PixelShuffleOp, ops::PixelShuffleOpMaker, + ops::PixelShuffleGradMaker); + +REGISTER_OPERATOR(pixel_shuffle_grad, ops::PixelShuffleGradOp); + +REGISTER_OP_CPU_KERNEL( + pixel_shuffle, + ops::PixelShuffleOpKernel, + ops::PixelShuffleOpKernel); + +REGISTER_OP_CPU_KERNEL( + pixel_shuffle_grad, + ops::PixelShuffleGradOpKernel, + ops::PixelShuffleGradOpKernel); diff --git a/paddle/fluid/operators/pixel_shuffle_op.cu b/paddle/fluid/operators/pixel_shuffle_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..6faf91079e1dac00b3516ccde8dc82cec73a79e6 --- /dev/null +++ b/paddle/fluid/operators/pixel_shuffle_op.cu @@ -0,0 +1,26 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/pixel_shuffle_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL( + pixel_shuffle, ops::PixelShuffleOpKernel, + ops::PixelShuffleOpKernel); +REGISTER_OP_CUDA_KERNEL( + pixel_shuffle_grad, + ops::PixelShuffleGradOpKernel, + ops::PixelShuffleGradOpKernel); diff --git a/paddle/fluid/operators/pixel_shuffle_op.h b/paddle/fluid/operators/pixel_shuffle_op.h new file mode 100644 index 0000000000000000000000000000000000000000..1ae1c7e9d50cb9d701fd0e79337a1906f2f5d545 --- /dev/null +++ b/paddle/fluid/operators/pixel_shuffle_op.h @@ -0,0 +1,82 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +class PixelShuffleOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + + int factor = ctx.Attr("upscale_factor"); + + auto in_dims = in->dims(); + auto o_dims = out->dims(); + + framework::Tensor t; + t.ShareDataWith(*in); + t.Resize({in_dims[0], o_dims[1], factor, factor, in_dims[2], in_dims[3]}); + + std::vector axis = {0, 1, 4, 2, 5, 3}; + + framework::Tensor o; + o.ShareDataWith(*out); + o.Resize({in_dims[0], o_dims[1], in_dims[2], factor, in_dims[3], factor}); + + math::Transpose trans; + auto& dev_ctx = ctx.template device_context(); + trans(dev_ctx, t, &o, axis); + out->Resize(o_dims); + } +}; + +template +class PixelShuffleGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + dx->mutable_data(ctx.GetPlace()); + + int factor = ctx.Attr("upscale_factor"); + + auto do_dims = dout->dims(); + auto dx_dims = dx->dims(); + + framework::Tensor t; + t.ShareDataWith(*dout); + t.Resize({do_dims[0], do_dims[1], dx_dims[2], factor, dx_dims[3], factor}); + + std::vector axis = {0, 1, 3, 5, 2, 4}; + + framework::Tensor o; + o.ShareDataWith(*dx); + o.Resize({do_dims[0], do_dims[1], factor, factor, dx_dims[2], dx_dims[3]}); + + math::Transpose trans; + auto& dev_ctx = ctx.template device_context(); + trans(dev_ctx, t, &o, axis); + dx->Resize(dx_dims); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 91414fdeb207781afd5e28afa5a3fa6e1018efb1..b51e218c49abd424937be1216e7a15f800e8f02c 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -191,6 +191,7 @@ __all__ = [ 'kldiv_loss', 'tree_conv', 'npair_loss', + 'pixel_shuffle', 'fsp_matrix', ] @@ -10923,6 +10924,65 @@ def npair_loss(anchor, positive, labels, l2_reg=0.002): return l2loss + celoss +def pixel_shuffle(x, upscale_factor): + """ + + **Pixel Shuffle Layer** + + This layer rearranges elements in a tensor of shape [N, C, H, W] + to a tensor of shape [N, C/r**2, H*r, W*r]. + This is useful for implementing efficient sub-pixel convolution + with a stride of 1/r. + Please refer to the paper: `Real-Time Single Image and Video Super-Resolution + Using an Efficient Sub-Pixel Convolutional Neural Network `_ . + by Shi et. al (2016) for more details. + + .. code-block:: text + + Given a 4-D tensor with the shape: + x.shape = [1, 9, 4, 4] + Given upscale_factor: + upscale_factor= 3 + output shape is: + [1, 1, 12, 12] + + Args: + + x(Variable): The input tensor variable. + upscale_factor(int): factor to increase spatial resolution + + Returns: + + Out(Variable): the pixel shuffle result is a tensor variable with the same shape and the same type as the input. + + Raises: + + ValueError: If the square of upscale_factor cannot divide the channels of input. + + Examples: + + .. code-block:: python + + input = fluid.layers.data(shape=[9,4,4]) + output = fluid.layers.pixel_shuffle(x=input, upscale_factor=3) + + """ + + helper = LayerHelper("pixel_shuffle", **locals()) + + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + if not isinstance(upscale_factor, int): + raise TypeError("upscale factor must be int type") + + helper.append_op( + type="pixel_shuffle", + inputs={"X": x}, + outputs={"Out": out}, + attrs={"upscale_factor": upscale_factor}) + return out + + def fsp_matrix(x, y): """ diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 89ba17f8b940d8f34e0df3ce6980ce7ddced607c..65bc9a78a121ade33c7d059544fff136d9cac696 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -1927,6 +1927,14 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(out) print(str(program)) + def test_pixel_shuffle(self): + program = Program() + with program_guard(program): + x = layers.data(name="X", shape=[9, 4, 4], dtype="float32") + out = layers.pixel_shuffle(x, upscale_factor=3) + self.assertIsNotNone(out) + print(str(program)) + def test_fsp(self): program = Program() with program_guard(program): diff --git a/python/paddle/fluid/tests/unittests/test_pixel_shuffle.py b/python/paddle/fluid/tests/unittests/test_pixel_shuffle.py new file mode 100644 index 0000000000000000000000000000000000000000..cc3ae2b3b9d4c40a7ee992c04cac79f518acac6d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_pixel_shuffle.py @@ -0,0 +1,50 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest + + +class TestPixelShuffle(OpTest): + def setUp(self): + self.op_type = "pixel_shuffle" + n, c, h, w = 2, 9, 4, 4 + up_factor = 3 + shape = [n, c, h, w] + x = np.random.random(shape).astype("float32") + new_shape = (n, c // (up_factor * up_factor), up_factor, up_factor, h, + w) + # reshape to (num,output_channel,upscale_factor,upscale_factor,h,w) + npresult = np.reshape(x, new_shape) + # transpose to (num,output_channel,h,upscale_factor,w,upscale_factor) + npresult = npresult.transpose(0, 1, 4, 2, 5, 3) + oshape = [n, c // (up_factor * up_factor), h * up_factor, w * up_factor] + npresult = np.reshape(npresult, oshape) + + self.inputs = {'X': x} + self.outputs = {'Out': npresult} + self.attrs = {'upscale_factor': up_factor} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +if __name__ == '__main__': + unittest.main()