未验证 提交 229dc932 编写于 作者: R ruri 提交者: GitHub

Add Pixel shuffle OP (#15782)

* add pixel_shuffle op

* add pixel_shuffle op, test=develop

* rewrite code, test=develop

* delete useless comment, test=develop

* Refine pixel_shuffle_op and unit testing

* refine code,test=develop

* refine .cu,test=develop

* fix unittest,test=develop

* Fix unit testing
test=develop

* resolve conflict, test=develop

* fix test, test=develop

* fix API, test=develop

* fix test datatype bug,test=develop

* polish comments,test=develop

* add API,test=develop

* test=develop

* Add Pixel_Shuffle OP,test=develop

* support python3,test=develop

* add include memory to travis CI bug,test=develop
上级 38382f8e
......@@ -235,6 +235,7 @@ paddle.fluid.layers.huber_loss (ArgSpec(args=['input', 'label', 'delta'], vararg
paddle.fluid.layers.kldiv_loss (ArgSpec(args=['x', 'target', 'reduction', 'name'], varargs=None, keywords=None, defaults=('mean', None)), ('document', '776d536cac47c89073abc7ee524d5aec'))
paddle.fluid.layers.tree_conv (ArgSpec(args=['nodes_vector', 'edge_set', 'output_size', 'num_filters', 'max_depth', 'act', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1, 2, 'tanh', None, None, None)), ('document', '34ea12ac9f10a65dccbc50100d12e607'))
paddle.fluid.layers.npair_loss (ArgSpec(args=['anchor', 'positive', 'labels', 'l2_reg'], varargs=None, keywords=None, defaults=(0.002,)), ('document', '46994d10276dd4cb803b4062b5d14329'))
paddle.fluid.layers.pixel_shuffle (ArgSpec(args=['x', 'upscale_factor'], varargs=None, keywords=None, defaults=None), ('document', 'ad669cdf83e72a69ebc5ed79e36486de'))
paddle.fluid.layers.fsp_matrix (ArgSpec(args=['x', 'y'], varargs=None, keywords=None, defaults=None), ('document', 'b76ccca3735bea4a58a0dbf0d77c5393'))
paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', '33bbd42027d872b3818b3d64ec52e139'))
paddle.fluid.layers.open_files (ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)), ('document', 'b1ae2e1cc0750e58726374061ea90ecc'))
......
/*Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/pixel_shuffle_op.h"
#include <memory>
namespace paddle {
namespace operators {
class PixelShuffleOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of PixelShuffleOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of PixelShuffleOp should not be null.");
auto input_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
auto upscale_factor = ctx->Attrs().Get<int>("upscale_factor");
PADDLE_ENFORCE(input_dims[1] % (upscale_factor * upscale_factor) == 0,
"Upscale_factor should devide the number of channel");
auto output_dims = input_dims;
output_dims[0] = input_dims[0];
output_dims[1] = input_dims[1] / (upscale_factor * upscale_factor);
output_dims[2] = input_dims[2] * upscale_factor;
output_dims[3] = input_dims[3] * upscale_factor;
ctx->SetOutputDim("Out", output_dims);
}
};
class PixelShuffleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(
"X",
"(Tensor, default Tensor<float>), "
"the input feature data of PixelShuffleOp, the layout is [N C H W].");
AddOutput(
"Out",
"(Tensor, default Tensor<float>), the output of "
"PixelShuffleOp. The layout is [N,C/factor^2,H*factor,W*factor].");
AddAttr<int>("upscale_factor",
"the factor to increase spatial resolution by.")
.SetDefault(1)
.AddCustomChecker([](const int& upscale_factor) {
PADDLE_ENFORCE_GE(upscale_factor, 1,
"upscale_factor should be larger than 0.");
});
AddComment(R"DOC(
Pixel Shuffle operator
This operator rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)`
to a tensor of shape :math:`(C, H \times r, W \times r)`.
This is useful for implementing efficient sub-pixel convolution
with a stride of :math:`1/r`.
Please refer to the paper:
`Real-Time Single Image and Video Super-Resolution Using an Efficient
Sub-Pixel Convolutional Neural Network <https://arxiv.org/abs/1609.05158v2>`_
by Shi et. al (2016) for more details.
)DOC");
}
};
class PixelShuffleGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* op = new framework::OpDesc();
op->SetType("pixel_shuffle_grad");
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetAttrMap(Attrs());
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
return std::unique_ptr<framework::OpDesc>(op);
}
};
class PixelShuffleGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@Grad) should not be null");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X@Grad) should not be null");
auto do_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE(do_dims.size() == 4, "The layout of input is NCHW.");
auto upscale_factor = ctx->Attrs().Get<int>("upscale_factor");
auto dx_dims = do_dims;
dx_dims[0] = do_dims[0];
dx_dims[1] = do_dims[1] * (upscale_factor * upscale_factor);
dx_dims[2] = do_dims[2] / upscale_factor;
dx_dims[3] = do_dims[3] / upscale_factor;
ctx->SetOutputDim(framework::GradVarName("X"), dx_dims);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(pixel_shuffle, ops::PixelShuffleOp, ops::PixelShuffleOpMaker,
ops::PixelShuffleGradMaker);
REGISTER_OPERATOR(pixel_shuffle_grad, ops::PixelShuffleGradOp);
REGISTER_OP_CPU_KERNEL(
pixel_shuffle,
ops::PixelShuffleOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::PixelShuffleOpKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
pixel_shuffle_grad,
ops::PixelShuffleGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::PixelShuffleGradOpKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/pixel_shuffle_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
pixel_shuffle, ops::PixelShuffleOpKernel<plat::CUDADeviceContext, float>,
ops::PixelShuffleOpKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
pixel_shuffle_grad,
ops::PixelShuffleGradOpKernel<plat::CUDADeviceContext, float>,
ops::PixelShuffleGradOpKernel<plat::CUDADeviceContext, double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class PixelShuffleOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
int factor = ctx.Attr<int>("upscale_factor");
auto in_dims = in->dims();
auto o_dims = out->dims();
framework::Tensor t;
t.ShareDataWith(*in);
t.Resize({in_dims[0], o_dims[1], factor, factor, in_dims[2], in_dims[3]});
std::vector<int> axis = {0, 1, 4, 2, 5, 3};
framework::Tensor o;
o.ShareDataWith(*out);
o.Resize({in_dims[0], o_dims[1], in_dims[2], factor, in_dims[3], factor});
math::Transpose<DeviceContext, T, 6> trans;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
trans(dev_ctx, t, &o, axis);
out->Resize(o_dims);
}
};
template <typename DeviceContext, typename T>
class PixelShuffleGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(ctx.GetPlace());
int factor = ctx.Attr<int>("upscale_factor");
auto do_dims = dout->dims();
auto dx_dims = dx->dims();
framework::Tensor t;
t.ShareDataWith(*dout);
t.Resize({do_dims[0], do_dims[1], dx_dims[2], factor, dx_dims[3], factor});
std::vector<int> axis = {0, 1, 3, 5, 2, 4};
framework::Tensor o;
o.ShareDataWith(*dx);
o.Resize({do_dims[0], do_dims[1], factor, factor, dx_dims[2], dx_dims[3]});
math::Transpose<DeviceContext, T, 6> trans;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
trans(dev_ctx, t, &o, axis);
dx->Resize(dx_dims);
}
};
} // namespace operators
} // namespace paddle
......@@ -191,6 +191,7 @@ __all__ = [
'kldiv_loss',
'tree_conv',
'npair_loss',
'pixel_shuffle',
'fsp_matrix',
]
......@@ -10923,6 +10924,65 @@ def npair_loss(anchor, positive, labels, l2_reg=0.002):
return l2loss + celoss
def pixel_shuffle(x, upscale_factor):
"""
**Pixel Shuffle Layer**
This layer rearranges elements in a tensor of shape [N, C, H, W]
to a tensor of shape [N, C/r**2, H*r, W*r].
This is useful for implementing efficient sub-pixel convolution
with a stride of 1/r.
Please refer to the paper: `Real-Time Single Image and Video Super-Resolution
Using an Efficient Sub-Pixel Convolutional Neural Network <https://arxiv.org/abs/1609.05158v2>`_ .
by Shi et. al (2016) for more details.
.. code-block:: text
Given a 4-D tensor with the shape:
x.shape = [1, 9, 4, 4]
Given upscale_factor:
upscale_factor= 3
output shape is:
[1, 1, 12, 12]
Args:
x(Variable): The input tensor variable.
upscale_factor(int): factor to increase spatial resolution
Returns:
Out(Variable): the pixel shuffle result is a tensor variable with the same shape and the same type as the input.
Raises:
ValueError: If the square of upscale_factor cannot divide the channels of input.
Examples:
.. code-block:: python
input = fluid.layers.data(shape=[9,4,4])
output = fluid.layers.pixel_shuffle(x=input, upscale_factor=3)
"""
helper = LayerHelper("pixel_shuffle", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
if not isinstance(upscale_factor, int):
raise TypeError("upscale factor must be int type")
helper.append_op(
type="pixel_shuffle",
inputs={"X": x},
outputs={"Out": out},
attrs={"upscale_factor": upscale_factor})
return out
def fsp_matrix(x, y):
"""
......
......@@ -1927,6 +1927,14 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(out)
print(str(program))
def test_pixel_shuffle(self):
program = Program()
with program_guard(program):
x = layers.data(name="X", shape=[9, 4, 4], dtype="float32")
out = layers.pixel_shuffle(x, upscale_factor=3)
self.assertIsNotNone(out)
print(str(program))
def test_fsp(self):
program = Program()
with program_guard(program):
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
class TestPixelShuffle(OpTest):
def setUp(self):
self.op_type = "pixel_shuffle"
n, c, h, w = 2, 9, 4, 4
up_factor = 3
shape = [n, c, h, w]
x = np.random.random(shape).astype("float32")
new_shape = (n, c // (up_factor * up_factor), up_factor, up_factor, h,
w)
# reshape to (num,output_channel,upscale_factor,upscale_factor,h,w)
npresult = np.reshape(x, new_shape)
# transpose to (num,output_channel,h,upscale_factor,w,upscale_factor)
npresult = npresult.transpose(0, 1, 4, 2, 5, 3)
oshape = [n, c // (up_factor * up_factor), h * up_factor, w * up_factor]
npresult = np.reshape(npresult, oshape)
self.inputs = {'X': x}
self.outputs = {'Out': npresult}
self.attrs = {'upscale_factor': up_factor}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册