Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
16d4e137
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
16d4e137
编写于
12月 24, 2018
作者:
S
shippingwang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add ShuffleChannelOP
上级
7f73c16e
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
173 addition
and
77 deletion
+173
-77
paddle/fluid/operators/shuffle_channel_op.cc
paddle/fluid/operators/shuffle_channel_op.cc
+29
-31
paddle/fluid/operators/shuffle_channel_op.cu
paddle/fluid/operators/shuffle_channel_op.cu
+106
-6
paddle/fluid/operators/shuffle_channel_op.h
paddle/fluid/operators/shuffle_channel_op.h
+16
-20
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+7
-3
python/paddle/fluid/tests/unittests/test_layers.py
python/paddle/fluid/tests/unittests/test_layers.py
+1
-1
python/paddle/fluid/tests/unittests/test_shuffle_channel_op.py
...n/paddle/fluid/tests/unittests/test_shuffle_channel_op.py
+14
-16
未找到文件。
paddle/fluid/operators/shuffle_channel_op.cc
浏览文件 @
16d4e137
...
@@ -19,26 +19,27 @@ class ShuffleChannelOp : public framework::OperatorWithKernel {
...
@@ -19,26 +19,27 @@ class ShuffleChannelOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
-
>
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ShuffleChannelOp should not be null."
);
"Input(X) of ShuffleChannelOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
Has
In
put
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
Has
Out
put
(
"Out"
),
"Output(Out) of ShuffleChannelOp should not be null."
);
"Output(Out) of ShuffleChannelOp should not be null."
);
auto
input_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
input_dims
=
ctx
->
GetInputDim
(
"X"
);
PADDLE_ENFORCE
(
input_dims
.
size
()
==
4
,
"The layout of input is NCHW."
);
PADDLE_ENFORCE
(
input_dims
.
size
()
==
4
,
"The layout of input is NCHW."
);
// ENFORCE group
// ENFORCE group
auto
group
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>
>
(
"group"
);
// auto group = ctx->Attrs().Get<int
>("group");
ctx
->
SetOutputDim
(
"Out"
,
input_dims
);
ctx
->
SetOutputDim
(
"Out"
,
input_dims
);
}
}
/*
protected:
protected:
framework::OpKernelType GetExpectedKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx
.
GetPlace
());
ctx.device_context
());
}
}
*/
};
};
class
ShuffleChannelOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
ShuffleChannelOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
@@ -63,7 +64,7 @@ class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -63,7 +64,7 @@ class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker {
then, feed each group in the next layer with different subgroups.
then, feed each group in the next layer with different subgroups.
According to the paper, "Suppose a convolution layer with g groups
According to the paper, "Suppose a convolution layer with g groups
whose output has g
x
n channels, first reshape the output channel dimension into(g,n),
whose output has g
*
n channels, first reshape the output channel dimension into(g,n),
transposing and then flattening it back as the input of next layer. "
transposing and then flattening it back as the input of next layer. "
Shuffle channel operation makes it possible to build more powerful structures
Shuffle channel operation makes it possible to build more powerful structures
...
@@ -75,52 +76,49 @@ class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -75,52 +76,49 @@ class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker {
}
}
};
};
// Grad
class
ShuffleChannelGradOp
:
public
framework
::
OperatorWithKernel
{
class
ShuffleChannelOpGrad
:
public
framework
::
OperatorWithKernel
{
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@Grad) should not be null"
)
"Input(Out@Grad) should not be null"
)
;
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)),
"Output(X@Grad) should not be null"
);
"Output(X@Grad) should not be null"
);
auto
input_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
input_dims
=
ctx
->
GetInputDim
(
"X"
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
input_dims
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
input_dims
);
}
}
/*
protected:
protected:
framework::OpKernelType GetExpectedKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
return framework::OpKernelType(
framework::ToDataType(
framework::ToDataType(
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
))
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
->
type
()),
ctx.device_context());
ctx.device_context());
}
}
*/
};
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
// how to write gpu kernal
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
shufflechannel
,
ops
::
ShuffleChannelOp
,
REGISTER_OPERATOR
(
shuffle
_
channel
,
ops
::
ShuffleChannelOp
,
ops
::
ShuffleChannelOpMaker
,
ops
::
ShuffleChannelOpMaker
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
// paddle::framework::EmptyGradOpMaker);
// paddle::framework::EmptyGradOpMaker);
REGISTER_OPERATOR
(
shufflechannel_grad
,
ops
::
ShuffleChannelGradOp
);
REGISTER_OPERATOR
(
shuffle
_
channel_grad
,
ops
::
ShuffleChannelGradOp
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
shufflechannel
,
shuffle
_
channel
,
ops
::
ShuffleChannelOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
ShuffleChannelOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
ShuffleChannelOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
ops
::
ShuffleChannelOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
shufflechannel_grad
,
shuffle
_
channel_grad
,
ops
::
ShuffleChannelGradOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
ShuffleChannelGradOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
ShuffleChannelGradOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
ops
::
ShuffleChannelGradOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
double
>
);
paddle/fluid/operators/shuffle_channel_op.cu
浏览文件 @
16d4e137
...
@@ -10,15 +10,115 @@ See the License for the specific language governing permissions and
...
@@ -10,15 +10,115 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/shuffle_channel_op.h"
#include "paddle/fluid/operators/shuffle_channel_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_info.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
static
constexpr
int
kNumCUDAThreads
=
512
;
static
constexpr
int
kNumMaximumNumBlocks
=
4096
;
static
inline
int
NumBlocks
(
const
int
N
)
{
return
std
::
min
((
N
+
kNumCUDAThreads
-
1
)
/
kNumCUDAThreads
,
kNumMaximumNumBlocks
);
}
template
<
typename
T
>
__global__
void
ShuffleChannel
(
const
int
nthreads
,
const
int
feature_map_size
,
T
*
output
,
const
T
*
input
,
int
group_row
,
int
group_column
,
int
len
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
offset
=
blockDim
.
x
*
gridDim
.
x
;
for
(
size_t
ii
=
index
;
ii
<
nthreads
;
ii
+=
offset
)
{
const
int
n
=
index
/
group_row
/
group_column
/
len
;
const
int
i
=
(
index
/
group_column
/
len
)
%
group_row
;
const
int
j
=
index
/
len
%
group_column
;
const
int
k
=
index
-
(
n
*
feature_map_size
+
(
i
*
group_column
+
j
)
*
len
);
T
*
p_o
=
output
+
n
*
feature_map_size
+
(
j
*
group_row
+
i
)
*
len
;
p_o
[
k
]
=
input
[
index
];
}
}
template
<
typename
DeviceContext
,
typename
T
>
class
ShuffleChannelOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
input
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
output
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
int
group
=
ctx
.
Attr
<
int
>
(
"group"
);
auto
input_dims
=
input
->
dims
();
auto
num
=
input_dims
[
0
];
auto
channel
=
input_dims
[
1
];
auto
height
=
input_dims
[
2
];
auto
weight
=
input_dims
[
3
];
auto
feature_map_size
=
channel
*
height
*
weight
;
auto
sp_sz
=
height
*
weight
;
int
group_row
=
group
;
int
group_column
=
channel
/
group_row
;
// count is the product of NCHW same as numel()
int
count
=
num
*
group_column
*
group_row
*
sp_sz
;
int
blocks
=
NumBlocks
(
output
->
numel
());
int
threads
=
kNumCUDAThreads
;
const
T
*
input_data
=
input
->
data
<
T
>
();
T
*
output_data
=
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
ShuffleChannel
<
T
><<<
blocks
,
threads
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
count
,
feature_map_size
,
output_data
,
input_data
,
group_row
,
group_column
,
sp_sz
);
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
ShuffleChannelGradOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
input
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
int
group
=
ctx
.
Attr
<
int
>
(
"group"
);
auto
input_dims
=
input
->
dims
();
auto
num
=
input_dims
[
0
];
auto
channel
=
input_dims
[
1
];
auto
height
=
input_dims
[
2
];
auto
weight
=
input_dims
[
3
];
auto
feature_map_size
=
channel
*
height
*
weight
;
auto
sp_sz
=
height
*
weight
;
int
group_row
=
group
;
int
group_column
=
channel
/
group_row
;
auto
*
output_grad
=
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
input_grad
=
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
T
*
output_grad_data
=
output_grad
->
data
<
T
>
();
int
blocks
=
NumBlocks
(
output_grad
->
numel
());
int
threads
=
kNumCUDAThreads
;
int
count
=
num
*
group_column
*
group_row
*
sp_sz
;
ShuffleChannel
<
T
><<<
blocks
,
threads
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
count
,
feature_map_size
,
input_grad_data
,
output_grad_data
,
group_row
,
group_column
,
sp_sz
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
shufflechannel
,
shuffle
_
channel
,
ops
::
ShuffleChannelOp
Kernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
ops
::
ShuffleChannelOp
CUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
ShuffleChannelOp
Kernel
<
paddle
::
platform
::
CUDADeviceContext
,
ops
::
ShuffleChannelOpCUDA
Kernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
double
>
);
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
shufflechannel_grad
,
shuffle_channel_grad
,
ops
::
ShuffleChannelOpGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
ops
::
ShuffleChannelGradOpCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
ops
::
ShuffleChannelOpGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
ShuffleChannelGradOpCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
double
>
);
paddle/fluid/operators/shuffle_channel_op.h
浏览文件 @
16d4e137
...
@@ -21,10 +21,10 @@ namespace operators {
...
@@ -21,10 +21,10 @@ namespace operators {
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
ShuffleChannelOpKernel
:
public
framework
::
OpKernel
<
T
>
{
class
ShuffleChannelOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
c
ontext
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
c
tx
)
const
override
{
auto
*
input
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
input
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
output
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
output
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
group
=
ctx
.
Input
<
framework
::
Tensor
>
(
"group"
);
int
group
=
ctx
.
Attr
<
int
>
(
"group"
);
auto
input_dims
=
input
->
dims
();
auto
input_dims
=
input
->
dims
();
auto
num
=
input_dims
[
0
];
auto
num
=
input_dims
[
0
];
...
@@ -34,21 +34,19 @@ class ShuffleChannelOpKernel : public framework::OpKernel<T> {
...
@@ -34,21 +34,19 @@ class ShuffleChannelOpKernel : public framework::OpKernel<T> {
auto
feature_map_size
=
channel
*
height
*
weight
;
auto
feature_map_size
=
channel
*
height
*
weight
;
auto
sp_sz
=
height
*
weight
;
auto
sp_sz
=
height
*
weight
;
int
group_row
=
group
;
int
group_row
=
group
;
int
group_column
=
channel
s
/
group_row
;
int
group_column
=
channel
/
group_row
;
const
T
*
input_data
=
input
->
data
<
T
>
();
const
T
*
input_data
=
input
->
data
<
T
>
();
T
*
output_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
output_data
=
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
output_data_temp
=
output_data
+
n
*
feature_map_size
;
input_data_temp
=
input_data
+
n
*
feature_map_size
;
for
(
int
i
=
0
;
i
<
group_row
;
++
i
)
{
for
(
int
i
=
0
;
i
<
group_row
;
++
i
)
{
for
(
int
j
=
0
;
j
<
group_column
;
++
j
)
{
for
(
int
j
=
0
;
j
<
group_column
;
++
j
)
{
const
auto
*
p_i
=
input_data_temp
+
(
i
*
group_column
+
j
)
*
sp_sz
;
const
T
*
p_i
=
input_data
+
n
*
feature_map_size
+
auto
*
p_o
=
output_data_temp
+
(
j
*
group_row
+
i
)
*
sp_sz
;
(
i
*
group_column
+
j
)
*
sp_sz
;
memcpy
(
p_o
,
p_i
,
sizeof
(
Dtype
)
*
sp_sz
);
T
*
p_o
=
output_data
+
n
*
feature_map_size
+
(
j
*
group_row
+
i
)
*
sp_sz
;
memcpy
(
p_o
,
p_i
,
sizeof
(
int
)
*
sp_sz
);
}
}
}
}
}
}
...
@@ -61,7 +59,7 @@ class ShuffleChannelGradOpKernel : public framework::OpKernel<T> {
...
@@ -61,7 +59,7 @@ class ShuffleChannelGradOpKernel : public framework::OpKernel<T> {
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
input
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
input
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
group
=
ctx
.
Input
<
framework
::
Tensor
>
(
"group"
);
int
group
=
ctx
.
Attr
<
int
>
(
"group"
);
auto
input_dims
=
input
->
dims
();
auto
input_dims
=
input
->
dims
();
auto
num
=
input_dims
[
0
];
auto
num
=
input_dims
[
0
];
...
@@ -72,7 +70,7 @@ class ShuffleChannelGradOpKernel : public framework::OpKernel<T> {
...
@@ -72,7 +70,7 @@ class ShuffleChannelGradOpKernel : public framework::OpKernel<T> {
auto
sp_sz
=
height
*
weight
;
auto
sp_sz
=
height
*
weight
;
int
group_row
=
group
;
int
group_row
=
group
;
int
group_column
=
channel
s
/
group_row
;
int
group_column
=
channel
/
group_row
;
auto
*
output_grad
=
auto
*
output_grad
=
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
...
@@ -81,19 +79,17 @@ class ShuffleChannelGradOpKernel : public framework::OpKernel<T> {
...
@@ -81,19 +79,17 @@ class ShuffleChannelGradOpKernel : public framework::OpKernel<T> {
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
T
*
output_grad_data
=
output_grad
->
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
->
data
<
T
>
();
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
output_grad_temp
=
output_grad_data
+
n
*
feature_map_size
;
input_grad_temp
=
input_grad_data
+
n
*
feature_map_size
;
for
(
int
i
=
0
;
i
<
group_row
;
++
i
)
{
for
(
int
i
=
0
;
i
<
group_row
;
++
i
)
{
for
(
int
j
=
0
;
j
<
group_column
;
++
j
)
{
for
(
int
j
=
0
;
j
<
group_column
;
++
j
)
{
const
auto
*
p_i
=
output_grad_temp
+
(
i
*
group_column
+
j
)
*
sp_sz
;
const
T
*
p_i
=
output_grad_data
+
n
*
feature_map_size
+
auto
*
p_o
=
input_grad_temp
+
(
j
*
group_row
+
i
)
*
sp_sz
;
(
i
*
group_column
+
j
)
*
sp_sz
;
memcpy
(
p_o
,
p_i
,
sizeof
(
Dtype
)
*
sp_sz
);
T
*
p_o
=
input_grad_data
+
n
*
feature_map_size
+
(
j
*
group_row
+
i
)
*
sp_sz
;
memcpy
(
p_o
,
p_i
,
sizeof
(
int
)
*
sp_sz
);
}
}
}
}
}
}
return
;
}
}
};
};
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
16d4e137
...
@@ -173,7 +173,7 @@ __all__ = [
...
@@ -173,7 +173,7 @@ __all__ = [
'merge_selected_rows'
,
'merge_selected_rows'
,
'get_tensor_from_selected_rows'
,
'get_tensor_from_selected_rows'
,
'lstm'
,
'lstm'
,
'shufflechannel'
,
'shuffle
_
channel'
,
'psroi_pool'
,
'psroi_pool'
,
]
]
...
@@ -9334,18 +9334,21 @@ def shuffle_channel(x, group=1, name=None):
...
@@ -9334,18 +9334,21 @@ def shuffle_channel(x, group=1, name=None):
with multiple group convolutional layers.
with multiple group convolutional layers.
Args:
Args:
x: The input tensor variable.
x: The input tensor variable..
group: The num of group
Returns:
Returns:
Variable: channel shuffled tensor variable.
Variable: channel shuffled tensor variable.
Raises:
Raises:
ValueError: If group in not a int type variable.
ValueError: If group in not a
n
int type variable.
Examples:
Examples:
.. code-block:: python
.. code-block:: python
out = fluid.layers.shuffle_channel(x=group_conv,group=4)
"""
"""
helper
=
LayerHelper
(
"shuffle_channel"
,
**
locals
())
helper
=
LayerHelper
(
"shuffle_channel"
,
**
locals
())
...
@@ -9361,6 +9364,7 @@ def shuffle_channel(x, group=1, name=None):
...
@@ -9361,6 +9364,7 @@ def shuffle_channel(x, group=1, name=None):
inputs
=
{
"X"
:
x
},
inputs
=
{
"X"
:
x
},
outputs
=
{
"Out"
:
out
},
outputs
=
{
"Out"
:
out
},
attrs
=
{
"group"
:
group
})
attrs
=
{
"group"
:
group
})
return
out
@
templatedoc
()
@
templatedoc
()
...
...
python/paddle/fluid/tests/unittests/test_layers.py
浏览文件 @
16d4e137
...
@@ -1018,7 +1018,7 @@ class TestBook(unittest.TestCase):
...
@@ -1018,7 +1018,7 @@ class TestBook(unittest.TestCase):
def
test_shuffle_channel
(
self
):
def
test_shuffle_channel
(
self
):
program
=
Program
()
program
=
Program
()
with
program_guard
(
program
):
with
program_guard
(
program
):
x
=
layers
.
data
(
name
=
"x"
,
shape
=
[
1
0
,
32
,
16
,
16
],
dtype
=
"float32"
)
x
=
layers
.
data
(
name
=
"x"
,
shape
=
[
1
,
4
,
2
,
2
],
dtype
=
"float32"
)
group
=
layers
.
data
(
name
=
"group"
,
shape
=
[
1
],
dtype
=
"int32"
)
group
=
layers
.
data
(
name
=
"group"
,
shape
=
[
1
],
dtype
=
"int32"
)
out
=
layers
.
shuffle_channel
(
x
,
group
)
out
=
layers
.
shuffle_channel
(
x
,
group
)
self
.
assertIsNotNone
(
out
)
self
.
assertIsNotNone
(
out
)
...
...
python/paddle/fluid/tests/unittests/test_shuffle_channel_op.py
浏览文件 @
16d4e137
...
@@ -23,31 +23,29 @@ import paddle.fluid.core as core
...
@@ -23,31 +23,29 @@ import paddle.fluid.core as core
class
TestShuffleChannelOp
(
OpTest
):
class
TestShuffleChannelOp
(
OpTest
):
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'output'
)
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"shuffle_channel"
self
.
op_type
=
"shuffle_channel"
self
.
batch_size
=
10
self
.
batch_size
=
1
self
.
input_channels
=
16
self
.
input_channels
=
4
self
.
layer_h
=
32
self
.
layer_h
=
2
self
.
layer_w
=
32
self
.
layer_w
=
2
self
.
group
=
4
self
.
group
=
2
self
.
x
=
np
.
random
.
random
(
self
.
x
=
np
.
random
.
random
(
(
self
.
batch_size
,
self
.
input_channels
,
self
.
layer_h
,
self
,
(
self
.
batch_size
,
self
.
input_channels
,
self
.
layer_h
,
layer_w
)).
astype
(
'float32'
)
self
.
layer_w
)).
astype
(
'float32'
)
self
.
inputs
=
{
'X'
:
self
.
x
}
self
.
inputs
=
{
'X'
:
self
.
x
}
self
.
attrs
=
{
'group'
:
self
.
group
}
self
.
attrs
=
{
'group'
:
self
.
group
}
n
,
c
,
h
,
w
=
self
.
x
.
shape
n
,
c
,
h
,
w
=
self
.
x
.
shape
input_reshaped
=
np
.
reshape
(
self
.
x
,
input_reshaped
=
np
.
reshape
(
self
.
x
,
(
-
1
,
self
.
group
,
c
//
self
.
group
,
h
,
w
))
(
-
1
,
self
.
group
,
c
//
self
.
group
,
h
,
w
))
input_transposed
=
np
.
transpose
(
input_reshaped
,
(
0
,
2
,
1
,
3
,
4
))
input_transposed
=
np
.
transpose
(
input_reshaped
,
(
0
,
2
,
1
,
3
,
4
))
self
.
outputs
=
np
.
reshape
(
input_transposed
,
(
-
1
,
c
,
h
,
w
))
self
.
outputs
=
{
'Out'
:
np
.
reshape
(
input_transposed
,
(
-
1
,
c
,
h
,
w
))}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'Out'
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录