未验证 提交 1f82c0cd 编写于 作者: R ruri 提交者: GitHub

[Api2.0] add pixel shuffle (#26071)

上级 1ed74aae
......@@ -28,25 +28,44 @@ class PixelShuffleOp : public framework::OperatorWithKernel {
"Output(Out) of PixelShuffleOp should not be null."));
auto input_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(
input_dims.size(), 4,
PADDLE_ENFORCE_EQ(input_dims.size(), 4,
platform::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W], but got %u.",
"Input should be a 4-D tensor of format [N, C, H, W] "
"or [N, H, W, C], but got %u.",
input_dims.size()));
auto upscale_factor = ctx->Attrs().Get<int>("upscale_factor");
PADDLE_ENFORCE_EQ(input_dims[1] % (upscale_factor * upscale_factor), 0,
const std::string data_format =
ctx->Attrs().Get<std::string>("data_format");
const bool channel_last = (data_format == "NHWC");
if (!channel_last) {
PADDLE_ENFORCE_EQ(
input_dims[1] % (upscale_factor * upscale_factor), 0,
platform::errors::InvalidArgument(
"The square of upscale_factor[%u] should divide the "
"number of channel[%u]",
input_dims[1], upscale_factor * upscale_factor));
} else {
PADDLE_ENFORCE_EQ(
input_dims[3] % (upscale_factor * upscale_factor), 0,
platform::errors::InvalidArgument(
"The square of upscale_factor[%u] should divide the "
"number of channel[%u]",
input_dims[3], upscale_factor * upscale_factor));
}
auto output_dims = input_dims;
output_dims[0] = input_dims[0];
if (!channel_last) {
output_dims[1] = input_dims[1] / (upscale_factor * upscale_factor);
output_dims[2] = input_dims[2] * upscale_factor;
output_dims[3] = input_dims[3] * upscale_factor;
} else {
output_dims[1] = input_dims[1] * upscale_factor;
output_dims[2] = input_dims[2] * upscale_factor;
output_dims[3] = input_dims[3] / (upscale_factor * upscale_factor);
}
ctx->SetOutputDim("Out", output_dims);
}
};
......@@ -54,14 +73,14 @@ class PixelShuffleOp : public framework::OperatorWithKernel {
class PixelShuffleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(
"X",
AddInput("X",
"(Tensor, default Tensor<float>), "
"the input feature data of PixelShuffleOp, the layout is [N C H W].");
AddOutput(
"Out",
"the input feature data of PixelShuffleOp, the layout is [N, C, "
"H, W] or [N, H, W, C].");
AddOutput("Out",
"(Tensor, default Tensor<float>), the output of "
"PixelShuffleOp. The layout is [N,C/factor^2,H*factor,W*factor].");
"PixelShuffleOp. The layout is [N, C/factor^2, H*factor, "
"W*factor] or [N, H*factor, W*factor, C/factor^2].");
AddAttr<int>("upscale_factor",
"the factor to increase spatial resolution by.")
.SetDefault(1)
......@@ -70,6 +89,11 @@ class PixelShuffleOpMaker : public framework::OpProtoAndCheckerMaker {
platform::errors::InvalidArgument(
"upscale_factor should be larger than 0."));
});
AddAttr<std::string>(
"data_format",
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\", Specify the data format of the input data.")
.SetDefault("NCHW");
AddComment(R"DOC(
Pixel Shuffle operator
......@@ -114,19 +138,30 @@ class PixelShuffleGradOp : public framework::OperatorWithKernel {
platform::errors::NotFound("Output(X@Grad) should not be null"));
auto do_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(
do_dims.size(), 4,
PADDLE_ENFORCE_EQ(do_dims.size(), 4,
platform::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W], but got %u.",
"Input should be a 4-D tensor of format [N, C, H, W] "
"or [N, H, W, C], but got %u.",
do_dims.size()));
auto upscale_factor = ctx->Attrs().Get<int>("upscale_factor");
const std::string data_format =
ctx->Attrs().Get<std::string>("data_format");
const bool channel_last = (data_format == "NHWC");
auto dx_dims = do_dims;
dx_dims[0] = do_dims[0];
if (!channel_last) {
dx_dims[1] = do_dims[1] * (upscale_factor * upscale_factor);
dx_dims[2] = do_dims[2] / upscale_factor;
dx_dims[3] = do_dims[3] / upscale_factor;
} else {
dx_dims[1] = do_dims[1] / upscale_factor;
dx_dims[2] = do_dims[2] / upscale_factor;
dx_dims[3] = do_dims[3] * (upscale_factor * upscale_factor);
}
ctx->SetOutputDim(framework::GradVarName("X"), dx_dims);
}
};
......
......@@ -11,6 +11,7 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
......@@ -24,23 +25,33 @@ class PixelShuffleOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
int factor = ctx.Attr<int>("upscale_factor");
std::string data_format = ctx.Attr<std::string>("data_format");
bool channel_last = (data_format == "NHWC");
auto in_dims = in->dims();
auto o_dims = out->dims();
framework::Tensor t;
t.ShareDataWith(*in);
if (!channel_last) {
t.Resize({in_dims[0], o_dims[1], factor, factor, in_dims[2], in_dims[3]});
} else {
t.Resize({in_dims[0], in_dims[1], in_dims[2], o_dims[3], factor, factor});
}
std::vector<int> axis = {0, 1, 4, 2, 5, 3};
framework::Tensor o;
o.ShareDataWith(*out);
if (!channel_last) {
o.Resize({in_dims[0], o_dims[1], in_dims[2], factor, in_dims[3], factor});
} else {
o.Resize({in_dims[0], in_dims[1], factor, in_dims[2], factor, o_dims[3]});
}
math::Transpose<DeviceContext, T, 6> trans;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
trans(dev_ctx, t, &o, axis);
......@@ -58,19 +69,32 @@ class PixelShuffleGradOpKernel : public framework::OpKernel<T> {
int factor = ctx.Attr<int>("upscale_factor");
std::string data_format = ctx.Attr<std::string>("data_format");
bool channel_last = (data_format == "NHWC");
auto do_dims = dout->dims();
auto dx_dims = dx->dims();
framework::Tensor t;
t.ShareDataWith(*dout);
t.Resize({do_dims[0], do_dims[1], dx_dims[2], factor, dx_dims[3], factor});
if (!channel_last) {
t.Resize(
{do_dims[0], do_dims[1], dx_dims[2], factor, dx_dims[3], factor});
} else {
t.Resize(
{do_dims[0], dx_dims[1], factor, dx_dims[2], factor, do_dims[3]});
}
std::vector<int> axis = {0, 1, 3, 5, 2, 4};
framework::Tensor o;
o.ShareDataWith(*dx);
o.Resize({do_dims[0], do_dims[1], factor, factor, dx_dims[2], dx_dims[3]});
if (!channel_last) {
o.Resize(
{do_dims[0], do_dims[1], factor, factor, dx_dims[2], dx_dims[3]});
} else {
o.Resize(
{do_dims[0], dx_dims[1], dx_dims[2], do_dims[3], factor, factor});
}
math::Transpose<DeviceContext, T, 6> trans;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
trans(dev_ctx, t, &o, axis);
......
......@@ -16,16 +16,17 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.nn.functional as F
import paddle.fluid.core as core
import paddle.fluid as fluid
class TestPixelShuffle(OpTest):
def setUp(self):
self.op_type = "pixel_shuffle"
n, c, h, w = 2, 9, 4, 4
up_factor = 3
shape = [n, c, h, w]
x = np.random.random(shape).astype("float64")
def pixel_shuffle_np(x, up_factor, data_format="NCHW"):
if data_format == "NCHW":
n, c, h, w = x.shape
new_shape = (n, c // (up_factor * up_factor), up_factor, up_factor, h,
w)
# reshape to (num,output_channel,upscale_factor,upscale_factor,h,w)
......@@ -34,10 +35,42 @@ class TestPixelShuffle(OpTest):
npresult = npresult.transpose(0, 1, 4, 2, 5, 3)
oshape = [n, c // (up_factor * up_factor), h * up_factor, w * up_factor]
npresult = np.reshape(npresult, oshape)
return npresult
else:
n, h, w, c = x.shape
new_shape = (n, h, w, c // (up_factor * up_factor), up_factor,
up_factor)
# reshape to (num,h,w,output_channel,upscale_factor,upscale_factor)
npresult = np.reshape(x, new_shape)
# transpose to (num,h,upscale_factor,w,upscale_factor,output_channel)
npresult = npresult.transpose(0, 1, 4, 2, 5, 3)
oshape = [n, h * up_factor, w * up_factor, c // (up_factor * up_factor)]
npresult = np.reshape(npresult, oshape)
return npresult
class TestPixelShuffleOp(OpTest):
def setUp(self):
self.op_type = "pixel_shuffle"
self.init_data_format()
n, c, h, w = 2, 9, 4, 4
if self.format == "NCHW":
shape = [n, c, h, w]
if self.format == "NHWC":
shape = [n, h, w, c]
up_factor = 3
x = np.random.random(shape).astype("float64")
npresult = pixel_shuffle_np(x, up_factor, self.format)
self.inputs = {'X': x}
self.outputs = {'Out': npresult}
self.attrs = {'upscale_factor': up_factor}
self.attrs = {'upscale_factor': up_factor, "data_format": self.format}
def init_data_format(self):
self.format = "NCHW"
def test_check_output(self):
self.check_output()
......@@ -46,5 +79,141 @@ class TestPixelShuffle(OpTest):
self.check_grad(['X'], 'Out')
class TestChannelLast(TestPixelShuffleOp):
def init_data_format(self):
self.format = "NHWC"
class TestPixelShuffleAPI(unittest.TestCase):
def setUp(self):
self.x_1_np = np.random.random([2, 9, 4, 4]).astype("float64")
self.x_2_np = np.random.random([2, 4, 4, 9]).astype("float64")
self.out_1_np = pixel_shuffle_np(self.x_1_np, 3)
self.out_2_np = pixel_shuffle_np(self.x_2_np, 3, "NHWC")
def test_static_graph_functional(self):
for use_cuda in ([False, True]
if core.is_compiled_with_cuda() else [False]):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
paddle.enable_static()
x_1 = paddle.data(name="x", shape=[2, 9, 4, 4], dtype="float64")
x_2 = paddle.data(name="x2", shape=[2, 4, 4, 9], dtype="float64")
out_1 = F.pixel_shuffle(x_1, 3)
out_2 = F.pixel_shuffle(x_2, 3, "NHWC")
exe = paddle.static.Executor(place=place)
res_1 = exe.run(fluid.default_main_program(),
feed={"x": self.x_1_np},
fetch_list=out_1,
use_prune=True)
res_2 = exe.run(fluid.default_main_program(),
feed={"x2": self.x_2_np},
fetch_list=out_2,
use_prune=True)
assert np.allclose(res_1, self.out_1_np)
assert np.allclose(res_2, self.out_2_np)
# same test between layer and functional in this op.
def test_static_graph_layer(self):
for use_cuda in ([False, True]
if core.is_compiled_with_cuda() else [False]):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
paddle.enable_static()
x_1 = paddle.data(name="x", shape=[2, 9, 4, 4], dtype="float64")
x_2 = paddle.data(name="x2", shape=[2, 4, 4, 9], dtype="float64")
# init instance
ps_1 = paddle.nn.PixelShuffle(3)
ps_2 = paddle.nn.PixelShuffle(3, "NHWC")
out_1 = ps_1(x_1)
out_2 = ps_2(x_2)
out_1_np = pixel_shuffle_np(self.x_1_np, 3)
out_2_np = pixel_shuffle_np(self.x_2_np, 3, "NHWC")
exe = paddle.static.Executor(place=place)
res_1 = exe.run(fluid.default_main_program(),
feed={"x": self.x_1_np},
fetch_list=out_1,
use_prune=True)
res_2 = exe.run(fluid.default_main_program(),
feed={"x2": self.x_2_np},
fetch_list=out_2,
use_prune=True)
assert np.allclose(res_1, out_1_np)
assert np.allclose(res_2, out_2_np)
def run_dygraph(self, up_factor, data_format):
n, c, h, w = 2, 9, 4, 4
if data_format == "NCHW":
shape = [n, c, h, w]
if data_format == "NHWC":
shape = [n, h, w, c]
x = np.random.random(shape).astype("float64")
npresult = pixel_shuffle_np(x, up_factor, data_format)
for use_cuda in ([False, True]
if core.is_compiled_with_cuda() else [False]):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
paddle.disable_static(place=place)
pixel_shuffle = paddle.nn.PixelShuffle(
up_factor, data_format=data_format)
result = pixel_shuffle(paddle.to_tensor(x))
self.assertTrue(np.allclose(result.numpy(), npresult))
result_functional = F.pixel_shuffle(
paddle.to_tensor(x), 3, data_format)
self.assertTrue(np.allclose(result_functional.numpy(), npresult))
def test_dygraph1(self):
self.run_dygraph(3, "NCHW")
def test_dygraph2(self):
self.run_dygraph(3, "NHWC")
class TestPixelShuffleError(unittest.TestCase):
def test_error_functional(self):
def error_upscale_factor():
with paddle.fluid.dygraph.guard():
x = np.random.random([2, 9, 4, 4]).astype("float64")
pixel_shuffle = F.pixel_shuffle(paddle.to_tensor(x), 3.33)
self.assertRaises(TypeError, error_upscale_factor)
def error_data_format():
with paddle.fluid.dygraph.guard():
x = np.random.random([2, 9, 4, 4]).astype("float64")
pixel_shuffle = F.pixel_shuffle(paddle.to_tensor(x), 3, "WOW")
self.assertRaises(ValueError, error_data_format)
def test_error_layer(self):
def error_upscale_factor_layer():
with paddle.fluid.dygraph.guard():
x = np.random.random([2, 9, 4, 4]).astype("float64")
ps = paddle.nn.PixelShuffle(3.33)
self.assertRaises(TypeError, error_upscale_factor_layer)
def error_data_format_layer():
with paddle.fluid.dygraph.guard():
x = np.random.random([2, 9, 4, 4]).astype("float64")
ps = paddle.nn.PixelShuffle(3, "MEOW")
self.assertRaises(ValueError, error_data_format_layer)
if __name__ == '__main__':
unittest.main()
......@@ -139,7 +139,10 @@ from .layer.transformer import TransformerDecoder
from .layer.transformer import Transformer
from .layer.distance import PairwiseDistance #DEFINE_ALIAS
from .layer.vision import PixelShuffle
from .layer import loss #DEFINE_ALIAS
from .layer import conv #DEFINE_ALIAS
from .layer import vision #DEFINE_ALIAS
from ..fluid.dygraph.layers import Layer #DEFINE_ALIAS
from ..fluid.dygraph.container import LayerList, ParameterList, Sequential #DEFINE_ALIAS
......@@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: define specitial functions used in computer vision task
from ...fluid.data_feeder import check_variable_and_dtype
from ...fluid.layer_helper import LayerHelper
from ...fluid.framework import core, in_dygraph_mode
from ...fluid.layers import affine_channel #DEFINE_ALIAS
from ...fluid.layers import affine_grid #DEFINE_ALIAS
from ...fluid.layers import anchor_generator #DEFINE_ALIAS
......@@ -28,6 +31,7 @@ from ...fluid.layers import distribute_fpn_proposals #DEFINE_ALIAS
from ...fluid.layers import generate_mask_labels #DEFINE_ALIAS
from ...fluid.layers import generate_proposal_labels #DEFINE_ALIAS
from ...fluid.layers import generate_proposals #DEFINE_ALIAS
from ...fluid.layers import grid_sampler #DEFINE_ALIAS
from ...fluid.layers import image_resize #DEFINE_ALIAS
from ...fluid.layers import prior_box #DEFINE_ALIAS
from ...fluid.layers import prroi_pool #DEFINE_ALIAS
......@@ -43,7 +47,7 @@ from ...fluid.layers import yolov3_loss #DEFINE_ALIAS
from ...fluid.layers import fsp_matrix #DEFINE_ALIAS
from ...fluid.layers import image_resize_short #DEFINE_ALIAS
from ...fluid.layers import pixel_shuffle #DEFINE_ALIAS
# from ...fluid.layers import pixel_shuffle #DEFINE_ALIAS
from ...fluid.layers import retinanet_detection_output #DEFINE_ALIAS
from ...fluid.layers import retinanet_target_assign #DEFINE_ALIAS
from ...fluid.layers import roi_perspective_transform #DEFINE_ALIAS
......@@ -67,7 +71,7 @@ __all__ = [
'generate_mask_labels',
'generate_proposal_labels',
'generate_proposals',
'grid_sample',
'grid_sampler',
'image_resize',
'image_resize_short',
# 'multi_box_head',
......@@ -89,8 +93,6 @@ __all__ = [
'yolov3_loss'
]
from ...fluid.layer_helper import LayerHelper
from ...fluid.data_feeder import check_variable_and_dtype
from ...fluid import core, dygraph_utils
from ...fluid.framework import Variable, in_dygraph_mode
from ...device import get_cudnn_version
......@@ -112,22 +114,16 @@ def grid_sample(x,
data x and y is indexing the 3rd dimension (in height dimension),
finally results is the bilinear interpolation or nearest value of 4 nearest corner
points. The output tensor shape will be [N, C, H, W].
.. code-block:: text
Step 1:
Get (x, y) grid coordinates and scale to [0, H-1/W-1].
.. code-block:: text
grid_x = 0.5 * (grid[:, :, :, 0] + 1) * (W - 1)
grid_y = 0.5 * (grid[:, :, :, 1] + 1) * (H - 1)
Step 2:
Indices input data X with grid (x, y) in each [H, W] area, and bilinear
interpolate point value by 4 nearest points or nearest interpolate point value
by nearest point.
wn ------- y_n ------- en
| | |
| d_n |
......@@ -137,27 +133,21 @@ def grid_sample(x,
| d_s |
| | |
ws ------- y_s ------- wn
For bilinear interpolation:
x_w = floor(x) // west side x coord
x_e = x_w + 1 // east side x coord
y_n = floor(y) // north side y coord
y_s = y_s + 1 // south side y coord
d_w = grid_x - x_w // distance to west side
d_e = x_e - grid_x // distance to east side
d_n = grid_y - y_n // distance to north side
d_s = y_s - grid_y // distance to south side
wn = X[:, :, y_n, x_w] // north-west point value
en = X[:, :, y_n, x_e] // north-east point value
ws = X[:, :, y_s, x_w] // south-east point value
es = X[:, :, y_s, x_w] // north-east point value
output = wn * d_e * d_s + en * d_w * d_s
+ ws * d_e * d_n + es * d_w * d_n
Args:
x(Tensor): The input tensor, which is a 4-d tensor with shape
[N, C, H, W], N is the batch size, C is the channel
......@@ -176,14 +166,10 @@ def grid_sample(x,
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns: Tensor, The shape of output is [N, C, grid_H, grid_W] in which `grid_H` is the height of grid
and `grid_W` is the width of grid. The data type is same as input tensor.
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
import numpy as np
......@@ -272,3 +258,57 @@ def grid_sample(x,
attrs=attrs,
outputs={'Output': out})
return out
def pixel_shuffle(x, upscale_factor, data_format="NCHW", name=None):
"""
This API implements pixel shuffle operation.
See more details in :ref:`api_nn_vision_PixelShuffle` .
Parameters:
x(Tensor): 4-D tensor, the data type should be float32 or float64.
upscale_factor(int): factor to increase spatial resolution.
data_format (str): The data format of the input and output data. An optional string from: "NCHW", "NHWC". The default is "NCHW". When it is "NCHW", the data is stored in the order of: [batch_size, input_channels, input_height, input_width].
name (str, optional): The default value is None. Normally there is no need for user to set this property.
Returns:
Out(tensor): Reshaped tensor according to the new dimension.
Raises:
ValueError: If the square of upscale_factor cannot divide the channels of input.
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
import numpy as np
x = np.random.randn(2, 9, 4, 4).astype(np.float32)
paddle.disable_static()
x_var = paddle.to_tensor(x)
out_var = F.pixel_shuffle(x_var, 3)
out = out_var.numpy()
print(out.shape)
# (2, 1, 12, 12)
"""
if not in_dygraph_mode():
check_variable_and_dtype(x, 'x', ['float32', 'float64'],
'pixel_shuffle')
if not isinstance(upscale_factor, int):
raise TypeError("upscale factor must be int type")
if data_format not in ["NCHW", "NHWC"]:
raise ValueError("Attr(data_format) should be 'NCHW' or 'NHWC'."
"But recevie Attr(data_format): {} ".format(
data_format))
if in_dygraph_mode():
return core.ops.pixel_shuffle(x, "upscale_factor", upscale_factor,
"data_format", data_format)
helper = LayerHelper("pixel_shuffle", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="pixel_shuffle",
inputs={"X": x},
outputs={"Out": out},
attrs={"upscale_factor": upscale_factor,
"data_format": data_format})
return out
......@@ -20,6 +20,7 @@ from . import conv
from . import extension
from . import activation
from . import norm
from . import vision
from . import distance
from . import transformer
......@@ -29,6 +30,8 @@ from .conv import *
from .extension import *
from .activation import *
from .norm import *
from .vision import *
from .transformer import *
# from .activation import PReLU #DEFINE_ALIAS
from .activation import ReLU #DEFINE_ALIAS
......@@ -104,4 +107,6 @@ from .norm import InstanceNorm #DEFINE_ALIAS
# from .rnn import RNNCell #DEFINE_ALIAS
# from .rnn import GRUCell #DEFINE_ALIAS
# from .rnn import LSTMCell #DEFINE_ALIAS
from .vision import PixelShuffle #DEFINE_ALIAS
from .distance import PairwiseDistance #DEFINE_ALIAS
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: define specitial functions used in computer vision task
from ...fluid.dygraph import layers
from .. import functional
__all__ = ['PixelShuffle']
class PixelShuffle(layers.Layer):
"""
PixelShuffle Layer
This operator rearranges elements in a tensor of shape [N, C, H, W]
to a tensor of shape [N, C/upscale_factor**2, H*upscale_factor, W*upscale_factor],
or from shape [N, H, W, C] to [N, H*upscale_factor, W*upscale_factor, C/upscale_factor**2].
This is useful for implementing efficient sub-pixel convolution
with a stride of 1/upscale_factor.
Please refer to the paper: `Real-Time Single Image and Video Super-Resolution
Using an Efficient Sub-Pixel Convolutional Neural Network <https://arxiv.org/abs/1609.05158v2>`_ .
by Shi et. al (2016) for more details.
Parameters:
upscale_factor(int): factor to increase spatial resolution.
data_format (str): The data format of the input and output data. An optional string from: "NCHW", "NHWC". The default is "NCHW". When it is "NCHW", the data is stored in the order of: [batch_size, input_channels, input_height, input_width].
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Shape:
- x: 4-D tensor with shape: (N, C, H, W) or (N, H, W, C).
- out: 4-D tensor with shape: (N, C/upscale_factor**2, H*upscale_factor, W*upscale_factor) or (N, H*upscale_factor, W*upscale_factor, C/upscale_factor^2).
Examples:
.. code-block:: python
import paddle
import paddle.nn as nn
import numpy as np
paddle.disable_static()
x = np.random.randn(2, 9, 4, 4).astype(np.float32)
x_var = paddle.to_tensor(x)
pixel_shuffle = nn.PixelShuffle(3)
out_var = pixel_shuffle(x_var)
out = out_var.numpy()
print(out.shape)
# (2, 1, 12, 12)
"""
def __init__(self, upscale_factor, data_format="NCHW", name=None):
super(PixelShuffle, self).__init__()
if not isinstance(upscale_factor, int):
raise TypeError("upscale factor must be int type")
if data_format not in ["NCHW", "NHWC"]:
raise ValueError("Data format should be 'NCHW' or 'NHWC'."
"But recevie data format: {}".format(data_format))
self._upscale_factor = upscale_factor
self._data_format = data_format
self._name = name
def forward(self, x):
return functional.pixel_shuffle(x, self._upscale_factor,
self._data_format, self._name)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册