未验证 提交 bbaaf217 编写于 作者: B BrilliantYuKaimin 提交者: GitHub

【PaddlePaddle Hackathon 2】24、为 Paddle 新增 nn.ChannelShuffle 组网 API (#40743)

* Add infermeta for ChannelShuffle

* Create channel_shuffle_grad_kernel.h

* Create channel_shuffle_kernel.h

* Create channel_shuffle_sig.cc

* Create channel_shuffle_op.cc

ChannelShuffle算子的描述

* Create channel_shuffle_kernel_impl.h

ChannelShuffle核函数的实现

* Create channel_shuffle_grad_kernel_impl.h

ChannelShuffle反向核函数的实现

* Add kernel register of channel shuffle and grad

注册ChannelShuffle及其反向的核函数

* add nn.functional.channel_shuffle

* add nn.ChannelShuffle

* Create test_channel_shuffle.py

* Update example of ChannelShuffle in vision.py

* Update test_channel_shuffle.py

* 修改channel_shuffle核函数的实现位置

* 修正代码格式

* 删除多余空格

* 完善channel_shuffle的错误检查

* Update unary.cc

* Update channel_shuffle_op.cc

* Update test_channel_shuffle.py

* Update unary.cc

* add channel_shuffle

* Update test_channel_shuffle.py

* Update vision.py

* 调整代码格式

* Update channel_shuffle_sig.cc

* 更新ChannelShuffle的文档

* 更新channel_shuffle的文档

* remove ChannelShuffleOpArgumentMapping

* add ChannelShuffleGradInferMeta

* Update channel_shuffle_op.cc

* 调整channel_shuffle及其梯度的核函数的位置
上级 c2a05a90
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class ChannelShuffleOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
class ChannelShuffleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>), "
"the input feature data of ChannelShuffleOp, the layout is "
"[N, C, H, W] or [N, H, W, C].");
AddOutput("Out",
"(Tensor, default Tensor<float>), the output of "
"ChannelShuffleOp. The layout is also [N, C, "
"H, W] or [N, H, W, C].");
AddAttr<int>("groups", "number of groups to divide channels in.");
AddAttr<std::string>(
"data_format",
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\", Specify the data format of the input data.")
.SetDefault("NCHW");
AddComment(R"DOC(
Channel Shuffle operator
This operator divides channels in a tensor of shape :math:`(*, C, H, W)`
into :math:`g` groups and rearranges them as :math:`(*, C/g, g, H, W)`
while keeping the original tensor shape.
Please refer to the paper:
`ShuffleNet: An Extremely Efficient Convolutional Neural Network for
Mobile Devices <https://arxiv.org/abs/1707.01083>`_
by Zhang et. al (2017) for more details.
)DOC");
}
};
class ChannelShuffleGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
template <typename T>
class ChannelShuffleGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("channel_shuffle_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(channel_shuffle, ChannelShuffleInferShapeFunctor,
PD_INFER_META(phi::ChannelShuffleInferMeta));
REGISTER_OPERATOR(channel_shuffle, ops::ChannelShuffleOp,
ops::ChannelShuffleOpMaker,
ops::ChannelShuffleGradOpMaker<paddle::framework::OpDesc>,
ops::ChannelShuffleGradOpMaker<paddle::imperative::OpBase>,
ChannelShuffleInferShapeFunctor);
DECLARE_INFER_SHAPE_FUNCTOR(channel_shuffle_grad,
ChannelShuffleGradInferShapeFunctor,
PD_INFER_META(phi::ChannelShuffleGradInferMeta));
REGISTER_OPERATOR(channel_shuffle_grad, ops::ChannelShuffleGradOp,
ChannelShuffleGradInferShapeFunctor);
......@@ -67,6 +67,22 @@ void BilinearTensorProductGradInferMeta(const MetaTensor& x,
}
}
void ChannelShuffleGradInferMeta(const MetaTensor& out_grad,
int groups,
const std::string& data_format,
MetaTensor* x_grad) {
auto do_dims = out_grad.dims();
PADDLE_ENFORCE_EQ(do_dims.size(),
4,
phi::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W] "
"or [N, H, W, C], but got %u.",
do_dims.size()));
auto dx_dims = do_dims;
x_grad->set_dims(dx_dims);
x_grad->set_dtype(out_grad.dtype());
}
void ConvTransposeGradInferMeta(const MetaTensor& x,
const MetaTensor& filter,
const MetaTensor& dout,
......
......@@ -37,6 +37,11 @@ void BilinearTensorProductGradInferMeta(const MetaTensor& x,
MetaTensor* dweight,
MetaTensor* dbias);
void ChannelShuffleGradInferMeta(const MetaTensor& out_grad,
int groups,
const std::string& data_format,
MetaTensor* x_grad);
void ConvTransposeGradInferMeta(const MetaTensor& x,
const MetaTensor& filter,
const MetaTensor& dout,
......
......@@ -2999,6 +2999,52 @@ void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out) {
out->set_dtype(DataType::INT64);
}
void ChannelShuffleInferMeta(const MetaTensor& x,
int groups,
const std::string& data_format,
MetaTensor* out) {
auto input_dims = x.dims();
PADDLE_ENFORCE_EQ(input_dims.size(),
4,
phi::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W] "
"or [N, H, W, C], but got %u.",
input_dims.size()));
PADDLE_ENFORCE_GE(
groups,
1,
phi::errors::InvalidArgument("groups should be larger than 0."));
PADDLE_ENFORCE_EQ(data_format == "NCHW" || data_format == "NHWC",
true,
phi::errors::InvalidArgument(
"data_format must be one of "
"NCHW and NHWC. But recevied data_format: %s",
data_format));
const bool channel_last = (data_format == "NHWC");
if (!channel_last) {
PADDLE_ENFORCE_EQ(input_dims[1] % groups,
0,
phi::errors::InvalidArgument(
"The number of groups to divide channels in [%u] "
"should divide the number of channel [%u]",
groups,
input_dims[1]));
} else {
PADDLE_ENFORCE_EQ(input_dims[3] % groups,
0,
phi::errors::InvalidArgument(
"The number of groups to divide channels in [%u] "
"should divide the number of channel [%u]",
groups,
input_dims[3]));
}
auto output_dims = input_dims;
out->set_dtype(x.dtype());
out->set_dims(output_dims);
}
} // namespace phi
PD_REGISTER_INFER_META_FN(flatten, phi::FlattenInferMeta);
......
......@@ -435,4 +435,9 @@ void OneHotInferMeta(const MetaTensor& x, const Scalar& depth, MetaTensor* out);
void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out);
void ChannelShuffleInferMeta(const MetaTensor& x,
int groups,
const std::string& data_format,
MetaTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void ChannelShuffleGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
int groups,
const std::string& data_format,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void ChannelShuffleKernel(const Context& dev_ctx,
const DenseTensor& x,
int groups,
const std::string& data_format,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/channel_shuffle_grad_kernel.h"
#include "paddle/phi/kernels/impl/channel_shuffle_grad_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(channel_shuffle_grad,
CPU,
ALL_LAYOUT,
phi::ChannelShuffleGradKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/channel_shuffle_kernel.h"
#include "paddle/phi/kernels/impl/channel_shuffle_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(channel_shuffle,
CPU,
ALL_LAYOUT,
phi::ChannelShuffleKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/channel_shuffle_grad_kernel.h"
#include "paddle/phi/kernels/impl/channel_shuffle_grad_kernel_impl.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(channel_shuffle_grad,
GPU,
ALL_LAYOUT,
phi::ChannelShuffleGradKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/channel_shuffle_kernel.h"
#include "paddle/phi/kernels/impl/channel_shuffle_kernel_impl.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(channel_shuffle,
GPU,
ALL_LAYOUT,
phi::ChannelShuffleKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T, typename Context>
void ChannelShuffleGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
int groups,
const std::string& data_format,
DenseTensor* x_grad) {
auto* dout = &out_grad;
auto* dx = x_grad;
dev_ctx.template Alloc<T>(dx);
bool channel_last = (data_format == "NHWC");
auto do_dims = dout->dims();
auto dx_dims = dx->dims();
DenseTensor t(*dout);
if (!channel_last) {
t.Resize({do_dims[0], do_dims[1] / groups, groups, do_dims[2], do_dims[3]});
} else {
t.Resize({do_dims[0], do_dims[1], do_dims[2], do_dims[3] / groups, groups});
}
auto axis = !channel_last ? std::vector<int>{0, 2, 1, 3, 4}
: std::vector<int>{0, 1, 2, 4, 3};
DenseTensor o(*dx);
if (!channel_last) {
o.Resize({dx_dims[0], groups, dx_dims[1] / groups, dx_dims[2], dx_dims[3]});
} else {
o.Resize({dx_dims[0], dx_dims[1], dx_dims[2], groups, dx_dims[3] / groups});
}
phi::funcs::Transpose<Context, T, 5> trans;
trans(dev_ctx, t, &o, axis);
dx->Resize(dx_dims);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T, typename Context>
void ChannelShuffleKernel(const Context& dev_ctx,
const DenseTensor& x,
int groups,
const std::string& data_format,
DenseTensor* out) {
auto* in = &x;
dev_ctx.template Alloc<T>(out);
bool channel_last = (data_format == "NHWC");
auto in_dims = in->dims();
auto o_dims = out->dims();
DenseTensor t(*in);
if (!channel_last) {
t.Resize({in_dims[0], groups, in_dims[1] / groups, in_dims[2], in_dims[3]});
} else {
t.Resize({in_dims[0], in_dims[1], in_dims[2], groups, in_dims[3] / groups});
}
auto axis = !channel_last ? std::vector<int>{0, 2, 1, 3, 4}
: std::vector<int>{0, 1, 2, 4, 3};
DenseTensor o(*out);
if (!channel_last) {
o.Resize({in_dims[0], in_dims[1] / groups, groups, in_dims[2], in_dims[3]});
} else {
o.Resize({in_dims[0], in_dims[1], in_dims[2], in_dims[3] / groups, groups});
}
phi::funcs::Transpose<Context, T, 5> trans;
trans(dev_ctx, t, &o, axis);
out->Resize(o_dims);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature ChannelShuffleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("channel_shuffle_grad",
{"Out@GRAD"},
{"groups", "data_format"},
{"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(channel_shuffle_grad,
phi::ChannelShuffleGradOpArgumentMapping);
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.nn.functional as F
import paddle.fluid.core as core
import paddle.fluid as fluid
def channel_shuffle_np(x, groups, data_format="NCHW"):
if data_format == "NCHW":
n, c, h, w = x.shape
new_shape = (n, groups, c // groups, h, w)
npresult = np.reshape(x, new_shape)
npresult = npresult.transpose(0, 2, 1, 3, 4)
oshape = [n, c, h, w]
npresult = np.reshape(npresult, oshape)
return npresult
else:
n, h, w, c = x.shape
new_shape = (n, h, w, groups, c // groups)
npresult = np.reshape(x, new_shape)
npresult = npresult.transpose(0, 1, 2, 4, 3)
oshape = [n, h, w, c]
npresult = np.reshape(npresult, oshape)
return npresult
class TestChannelShuffleOp(OpTest):
def setUp(self):
self.op_type = "channel_shuffle"
self.init_data_format()
n, c, h, w = 2, 9, 4, 4
if self.format == "NCHW":
shape = [n, c, h, w]
if self.format == "NHWC":
shape = [n, h, w, c]
groups = 3
x = np.random.random(shape).astype("float64")
npresult = channel_shuffle_np(x, groups, self.format)
self.inputs = {'X': x}
self.outputs = {'Out': npresult}
self.attrs = {'groups': groups, "data_format": self.format}
def init_data_format(self):
self.format = "NCHW"
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestChannelLast(TestChannelShuffleOp):
def init_data_format(self):
self.format = "NHWC"
class TestChannelShuffleAPI(unittest.TestCase):
def setUp(self):
self.x_1_np = np.random.random([2, 9, 4, 4]).astype("float64")
self.x_2_np = np.random.random([2, 4, 4, 9]).astype("float64")
self.out_1_np = channel_shuffle_np(self.x_1_np, 3)
self.out_2_np = channel_shuffle_np(self.x_2_np, 3, "NHWC")
def test_static_graph_functional(self):
for use_cuda in ([False, True]
if core.is_compiled_with_cuda() else [False]):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
paddle.enable_static()
x_1 = paddle.fluid.data(
name="x", shape=[2, 9, 4, 4], dtype="float64")
x_2 = paddle.fluid.data(
name="x2", shape=[2, 4, 4, 9], dtype="float64")
out_1 = F.channel_shuffle(x_1, 3)
out_2 = F.channel_shuffle(x_2, 3, "NHWC")
exe = paddle.static.Executor(place=place)
res_1 = exe.run(fluid.default_main_program(),
feed={"x": self.x_1_np},
fetch_list=out_1,
use_prune=True)
res_2 = exe.run(fluid.default_main_program(),
feed={"x2": self.x_2_np},
fetch_list=out_2,
use_prune=True)
assert np.allclose(res_1, self.out_1_np)
assert np.allclose(res_2, self.out_2_np)
# same test between layer and functional in this op.
def test_static_graph_layer(self):
for use_cuda in ([False, True]
if core.is_compiled_with_cuda() else [False]):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
paddle.enable_static()
x_1 = paddle.fluid.data(
name="x", shape=[2, 9, 4, 4], dtype="float64")
x_2 = paddle.fluid.data(
name="x2", shape=[2, 4, 4, 9], dtype="float64")
# init instance
ps_1 = paddle.nn.ChannelShuffle(3)
ps_2 = paddle.nn.ChannelShuffle(3, "NHWC")
out_1 = ps_1(x_1)
out_2 = ps_2(x_2)
out_1_np = channel_shuffle_np(self.x_1_np, 3)
out_2_np = channel_shuffle_np(self.x_2_np, 3, "NHWC")
exe = paddle.static.Executor(place=place)
res_1 = exe.run(fluid.default_main_program(),
feed={"x": self.x_1_np},
fetch_list=out_1,
use_prune=True)
res_2 = exe.run(fluid.default_main_program(),
feed={"x2": self.x_2_np},
fetch_list=out_2,
use_prune=True)
assert np.allclose(res_1, out_1_np)
assert np.allclose(res_2, out_2_np)
def run_dygraph(self, groups, data_format):
n, c, h, w = 2, 9, 4, 4
if data_format == "NCHW":
shape = [n, c, h, w]
if data_format == "NHWC":
shape = [n, h, w, c]
x = np.random.random(shape).astype("float64")
npresult = channel_shuffle_np(x, groups, data_format)
for use_cuda in ([False, True]
if core.is_compiled_with_cuda() else [False]):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
paddle.disable_static(place=place)
channel_shuffle = paddle.nn.ChannelShuffle(
groups, data_format=data_format)
result = channel_shuffle(paddle.to_tensor(x))
self.assertTrue(np.allclose(result.numpy(), npresult))
result_functional = F.channel_shuffle(
paddle.to_tensor(x), 3, data_format)
self.assertTrue(np.allclose(result_functional.numpy(), npresult))
channel_shuffle_str = 'groups={}'.format(groups)
if data_format != 'NCHW':
channel_shuffle_str += ', data_format={}'.format(data_format)
self.assertEqual(channel_shuffle.extra_repr(), channel_shuffle_str)
def test_dygraph1(self):
self.run_dygraph(3, "NCHW")
def test_dygraph2(self):
self.run_dygraph(3, "NHWC")
class TestChannelShuffleError(unittest.TestCase):
def test_error_functional(self):
def error_input():
with paddle.fluid.dygraph.guard():
x = np.random.random([9, 4, 4]).astype("float64")
channel_shuffle = F.channel_shuffle(paddle.to_tensor(x), 3)
self.assertRaises(ValueError, error_input)
def error_groups_1():
with paddle.fluid.dygraph.guard():
x = np.random.random([2, 9, 4, 4]).astype("float64")
channel_shuffle = F.channel_shuffle(paddle.to_tensor(x), 3.33)
self.assertRaises(TypeError, error_groups_1)
def error_groups_2():
with paddle.fluid.dygraph.guard():
x = np.random.random([2, 9, 4, 4]).astype("float64")
channel_shuffle = F.channel_shuffle(paddle.to_tensor(x), -1)
self.assertRaises(ValueError, error_groups_2)
def error_data_format():
with paddle.fluid.dygraph.guard():
x = np.random.random([2, 9, 4, 4]).astype("float64")
channel_shuffle = F.channel_shuffle(
paddle.to_tensor(x), 3, "WOW")
self.assertRaises(ValueError, error_data_format)
def test_error_layer(self):
def error_input_layer():
with paddle.fluid.dygraph.guard():
x = np.random.random([9, 4, 4]).astype("float64")
cs = paddle.nn.ChannelShuffle(3)
cs(paddle.to_tensor(x))
self.assertRaises(ValueError, error_input_layer)
def error_groups_layer_1():
with paddle.fluid.dygraph.guard():
x = np.random.random([2, 9, 4, 4]).astype("float64")
cs = paddle.nn.ChannelShuffle(3.33)
self.assertRaises(TypeError, error_groups_layer_1)
def error_groups_layer_2():
with paddle.fluid.dygraph.guard():
x = np.random.random([2, 9, 4, 4]).astype("float64")
cs = paddle.nn.ChannelShuffle(-1)
self.assertRaises(ValueError, error_groups_layer_2)
def error_data_format_layer():
with paddle.fluid.dygraph.guard():
x = np.random.random([2, 9, 4, 4]).astype("float64")
cs = paddle.nn.ChannelShuffle(3, "MEOW")
self.assertRaises(ValueError, error_data_format_layer)
if __name__ == '__main__':
unittest.main()
......@@ -138,6 +138,7 @@ from .layer.transformer import Transformer # noqa: F401
from .layer.distance import PairwiseDistance # noqa: F401
from .layer.vision import PixelShuffle # noqa: F401
from .layer.vision import ChannelShuffle # noqa: F401
from .layer.container import LayerDict # noqa: F401
from .utils.spectral_norm_hook import spectral_norm
......@@ -300,6 +301,7 @@ __all__ = [ #noqa
'Swish',
'Mish',
'PixelShuffle',
'ChannelShuffle',
'ELU',
'ReLU6',
'LayerDict',
......
......@@ -114,6 +114,7 @@ from .pooling import max_unpool3d # noqa: F401
from .vision import affine_grid # noqa: F401
from .vision import grid_sample # noqa: F401
from .vision import pixel_shuffle # noqa: F401
from .vision import channel_shuffle # noqa: F401
from .input import one_hot # noqa: F401
from .input import embedding # noqa: F401
from ...fluid.layers import gather_tree # noqa: F401
......@@ -213,6 +214,7 @@ __all__ = [ #noqa
'grid_sample',
'local_response_norm',
'pixel_shuffle',
'channel_shuffle',
'embedding',
'gather_tree',
'one_hot',
......
......@@ -21,6 +21,7 @@ import numpy as np
from paddle import _C_ops
from ...device import is_compiled_with_rocm
from paddle import in_dynamic_mode
from paddle.framework import _non_static_mode
__all__ = []
......@@ -344,3 +345,71 @@ def pixel_shuffle(x, upscale_factor, data_format="NCHW", name=None):
attrs={"upscale_factor": upscale_factor,
"data_format": data_format})
return out
def channel_shuffle(x, groups, data_format="NCHW", name=None):
"""
This API implements channel shuffle operation.
See more details in :ref:`api_nn_vision_ChannelShuffle` .
Parameters:
x (Tensor): 4-D tensor, the data type should be float32 or float64.
groups (int): Number of groups to divide channels in.
data_format (str): The data format of the input and output data. An optional string of NCHW or NHWC. The default is NCHW. When it is NCHW, the data is stored in the order of [batch_size, input_channels, input_height, input_width].
name (str, optional): Name for the operation (optional, default is None). Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Out (Tensor): Rearranged tensor keeping the original tensor shape.
Examples:
.. code-block:: python
:name: channel_shuffle-example
import paddle
import paddle.nn.functional as F
x = paddle.arange(0, 0.6, 0.1, 'float32')
x = paddle.reshape(x, [1, 6, 1, 1])
# [[[[0. ]],
# [[0.10000000]],
# [[0.20000000]],
# [[0.30000001]],
# [[0.40000001]],
# [[0.50000000]]]]
y = F.channel_shuffle(x, 3)
# [[[[0. ]],
# [[0.20000000]],
# [[0.40000001]],
# [[0.10000000]],
# [[0.30000001]],
# [[0.50000000]]]]
"""
if len(x.shape) != 4:
raise ValueError(
"Input x should be 4D tensor, but received x with the shape of {}".
format(x.shape))
if not isinstance(groups, int):
raise TypeError("groups must be int type")
if groups <= 0:
raise ValueError("groups must be positive")
if data_format not in ["NCHW", "NHWC"]:
raise ValueError("Attr(data_format) should be 'NCHW' or 'NHWC'."
"But recevie Attr(data_format): {} ".format(
data_format))
if _non_static_mode():
return _C_ops.channel_shuffle(x, "groups", groups, "data_format",
data_format)
helper = LayerHelper("channel_shuffle", **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'channel_shuffle')
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="channel_shuffle",
inputs={"X": x},
outputs={"Out": out},
attrs={"groups": groups,
"data_format": data_format})
return out
......@@ -88,6 +88,7 @@ from .norm import SpectralNorm # noqa: F401
from .norm import LocalResponseNorm # noqa: F401
from .vision import PixelShuffle # noqa: F401
from .vision import ChannelShuffle # noqa: F401
from .distance import PairwiseDistance # noqa: F401
from .container import LayerDict # noqa: F401
......
......@@ -87,3 +87,76 @@ class PixelShuffle(Layer):
if self._name is not None:
main_str += ', name={}'.format(self._name)
return main_str
class ChannelShuffle(Layer):
"""
This operator divides channels in a tensor of shape [N, C, H, W] or [N, H, W, C] into g groups,
getting a tensor with the shape of [N, g, C/g, H, W] or [N, H, W, g, C/g], and transposes them
as [N, C/g, g, H, W] or [N, H, W, g, C/g], then rearranges them to original tensor shape. This
operation can improve the interaction between channels, using features efficiently. Please
refer to the paper: `ShuffleNet: An Extremely Efficient
Convolutional Neural Network for Mobile Devices <https://arxiv.org/abs/1707.01083>`_ .
by Zhang et. al (2017) for more details.
Parameters:
groups (int): Number of groups to divide channels in.
data_format (str): The data format of the input and output data. An optional string of NCHW or NHWC. The default is NCHW. When it is NCHW, the data is stored in the order of [batch_size, input_channels, input_height, input_width].
name (str, optional): Name for the operation (optional, default is None). Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Shape:
- **x**: 4-D tensor with shape of [N, C, H, W] or [N, H, W, C].
- **out**: 4-D tensor with shape and dtype same as x.
Examples:
.. code-block:: python
:name: ChannelShuffle-example
import paddle
import paddle.nn as nn
x = paddle.arange(0, 0.6, 0.1, 'float32')
x = paddle.reshape(x, [1, 6, 1, 1])
# [[[[0. ]],
# [[0.10000000]],
# [[0.20000000]],
# [[0.30000001]],
# [[0.40000001]],
# [[0.50000000]]]]
channel_shuffle = nn.ChannelShuffle(3)
y = channel_shuffle(x)
# [[[[0. ]],
# [[0.20000000]],
# [[0.40000001]],
# [[0.10000000]],
# [[0.30000001]],
# [[0.50000000]]]]
"""
def __init__(self, groups, data_format="NCHW", name=None):
super(ChannelShuffle, self).__init__()
if not isinstance(groups, int):
raise TypeError("groups must be int type")
if groups <= 0:
raise ValueError("groups must be positive")
if data_format not in ["NCHW", "NHWC"]:
raise ValueError("Data format should be 'NCHW' or 'NHWC'."
"But recevie data format: {}".format(data_format))
self._groups = groups
self._data_format = data_format
self._name = name
def forward(self, x):
return functional.channel_shuffle(x, self._groups, self._data_format,
self._name)
def extra_repr(self):
main_str = 'groups={}'.format(self._groups)
if self._data_format != 'NCHW':
main_str += ', data_format={}'.format(self._data_format)
if self._name is not None:
main_str += ', name={}'.format(self._name)
return main_str
......@@ -92,6 +92,7 @@ STATIC_MODE_TESTING_LIST = [
'test_case',
'test_cast_op',
'test_center_loss',
'test_channel_shuffle',
'test_cholesky_op',
'test_chunk_eval_op',
'test_chunk_op',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册