Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
bd773b9c
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
bd773b9c
编写于
11月 14, 2017
作者:
W
wanghaox
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify for maxoutop code review
上级
ab9c71d9
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
98 addition
and
99 deletion
+98
-99
paddle/operators/math/CMakeLists.txt
paddle/operators/math/CMakeLists.txt
+4
-2
paddle/operators/math/maxouting.cc
paddle/operators/math/maxouting.cc
+12
-13
paddle/operators/math/maxouting.cu
paddle/operators/math/maxouting.cu
+27
-34
paddle/operators/math/maxouting.h
paddle/operators/math/maxouting.h
+8
-14
paddle/operators/maxout_op.cc
paddle/operators/maxout_op.cc
+43
-20
paddle/operators/maxout_op.h
paddle/operators/maxout_op.h
+2
-5
python/paddle/v2/framework/tests/test_maxout_op.py
python/paddle/v2/framework/tests/test_maxout_op.py
+2
-11
未找到文件。
paddle/operators/math/CMakeLists.txt
浏览文件 @
bd773b9c
...
...
@@ -8,24 +8,26 @@ if(WITH_GPU)
nv_library
(
softmax SRCS softmax.cc softmax.cu DEPS operator
)
nv_library
(
cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator
)
nv_library
(
pooling SRCS pooling.cc pooling.cu DEPS device_context
)
nv_library
(
maxouting SRCS maxouting.cc maxouting.cu DEPS device_context
)
nv_library
(
sequence_pooling SRCS sequence_pooling.cc sequence_pooling.cu DEPS device_context math_function
)
nv_library
(
vol2col SRCS vol2col.cc vol2col.cu DEPS device_context
)
nv_library
(
context_project SRCS context_project.cc context_project.cu DEPS device_context
)
nv_library
(
sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context
)
nv_library
(
lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions
)
nv_library
(
gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function
)
nv_library
(
maxouting SRCS maxouting.cc maxouting.cu DEPS device_context
)
else
()
cc_library
(
math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator
)
cc_library
(
selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function
)
cc_library
(
softmax SRCS softmax.cc DEPS operator
)
cc_library
(
cross_entropy SRCS cross_entropy.cc DEPS operator
)
cc_library
(
pooling SRCS pooling.cc DEPS device_context
)
cc_library
(
maxouting SRCS maxouting.cc DEPS device_context
)
cc_library
(
sequence_pooling SRCS sequence_pooling.cc DEPS device_context math_function
)
cc_library
(
vol2col SRCS vol2col.cc DEPS device_context
)
cc_library
(
context_project SRCS context_project.cc DEPS device_context
)
cc_library
(
sequence2batch SRCS sequence2batch.cc DEPS device_context
)
cc_library
(
lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions
)
cc_library
(
gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function
)
cc_library
(
maxouting SRCS maxouting.cc DEPS device_context
)
endif
()
cc_test
(
math_function_test SRCS math_function_test.cc DEPS math_function tensor
)
...
...
paddle/operators/math/maxouting.cc
浏览文件 @
bd773b9c
...
...
@@ -20,25 +20,27 @@ namespace math {
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
* groups mustbe > 1
*/
template
<
typename
MaxOutProcess
,
typename
T
>
class
MaxOutFunctor
<
platform
::
CPUPlace
,
MaxOutProcess
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
output
,
int
groups
,
int
num_channels
,
MaxOutProcess
maxout_process
)
{
const
framework
::
Tensor
&
input
,
framework
::
Tensor
*
output
,
int
groups
,
MaxOutProcess
maxout_process
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
const
int
output_channels
=
num_channels
/
groups
;
const
int
output_channels
=
output
->
dims
()[
1
]
;
int
fea_size
=
input_height
*
input_width
;
// c_size mean output one batch size
int
c_size
=
fea_size
*
output_channels
;
const
T
*
input_data
=
input
.
data
<
T
>
();
T
*
output_data
=
output
.
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
output_data
=
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
int
new_bindex
=
c_size
*
i
;
...
...
@@ -50,7 +52,6 @@ class MaxOutFunctor<platform::CPUPlace, MaxOutProcess, T> {
maxout_process
.
compute
(
ele
,
input_data
[(
new_bindex
+
new_cindex
)
*
groups
+
ph
*
fea_size
+
f
]);
}
maxout_process
.
finalize
(
ele
,
(
static_cast
<
T
>
(
groups
)));
output_data
[(
new_bindex
+
new_cindex
+
f
)]
=
ele
;
}
}
...
...
@@ -68,11 +69,11 @@ public:
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
int
groups
,
int
num_channels
)
{
int
groups
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
const
int
output_channels
=
num_channels
/
groups
;
const
int
output_channels
=
output
.
dims
()[
1
]
;
int
fea_size
=
input_height
*
input_width
;
...
...
@@ -95,8 +96,6 @@ public:
if
(
input_data
[
input_idx
]
==
output_data
[
output_idx
])
{
input_grad_data
[
input_idx
]
+=
output_grad_data
[
output_idx
];
stop
=
true
;
}
else
{
input_grad_data
[
input_idx
]
=
0
;
}
}
}
...
...
@@ -108,9 +107,9 @@ public:
template
class
MaxOutGradFunctor
<
platform
::
CPUPlace
,
float
>;
template
class
MaxOutGradFunctor
<
platform
::
CPUPlace
,
double
>;
template
class
MaxOutFunctor
<
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
MaxOut
<
float
>,
float
>
;
math
::
MaxOut
<
float
>,
float
>
;
template
class
MaxOutFunctor
<
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
MaxOut
<
double
>,
double
>
;
math
::
MaxOut
<
double
>,
double
>
;
}
// namespace math
}
// namespace operators
...
...
paddle/operators/math/maxouting.cu
浏览文件 @
bd773b9c
...
...
@@ -24,21 +24,20 @@ __global__ void KernelMaxOut(const int nthreads, const T* input_data,
T
*
output_data
,
const
int
channels
,
const
int
input_height
,
const
int
input_width
,
int
groups
,
MaxOutProcess
maxout_process
)
{
int
size
=
input_height
*
input_width
*
channels
/
groups
;
int
featL
en
=
input_height
*
input_width
;
const
int
size
=
input_height
*
input_width
*
channels
/
groups
;
const
int
feat_l
en
=
input_height
*
input_width
;
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
batch_idx
=
index
/
size
;
int
i
=
index
%
size
;
int
channel_idx
=
i
/
featL
en
;
int
feat_idx
=
i
%
featL
en
;
int
batch_offset
=
index
%
size
;
int
channel_idx
=
batch_offset
/
feat_l
en
;
int
feat_idx
=
batch_offset
%
feat_l
en
;
int
data_idx
=
(
batch_idx
*
size
+
channel_idx
*
feat
L
en
)
*
groups
+
feat_idx
;
(
batch_idx
*
size
+
channel_idx
*
feat
_l
en
)
*
groups
+
feat_idx
;
T
ele
=
maxout_process
.
initial
();
for
(
int
g
=
0
;
g
<
groups
;
g
++
)
{
maxout_process
.
compute
(
ele
,
input_data
[
data_idx
+
g
*
feat
L
en
]);
for
(
int
g
=
0
;
g
<
groups
;
++
g
)
{
maxout_process
.
compute
(
ele
,
input_data
[
data_idx
+
g
*
feat
_l
en
]);
}
maxout_process
.
finalize
(
ele
,
(
static_cast
<
T
>
(
groups
)));
output_data
[
index
]
=
ele
;
}
}
...
...
@@ -47,21 +46,21 @@ __global__ void KernelMaxoutGrad(
const
int
nthreads
,
const
T
*
input_data
,
const
T
*
output_data
,
const
T
*
output_grad
,
T
*
input_grad
,
const
int
channels
,
const
int
input_height
,
const
int
input_width
,
int
groups
)
{
int
size
=
input_height
*
input_width
*
channels
/
groups
;
int
featL
en
=
input_height
*
input_width
;
const
int
size
=
input_height
*
input_width
*
channels
/
groups
;
const
int
feat_l
en
=
input_height
*
input_width
;
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
batch_idx
=
index
/
size
;
int
i
=
index
%
size
;
int
channel_idx
=
i
/
featL
en
;
int
feat_idx
=
i
%
featL
en
;
int
batch_offset
=
index
%
size
;
int
channel_idx
=
batch_offset
/
feat_l
en
;
int
feat_idx
=
batch_offset
%
feat_l
en
;
int
data_idx
=
(
batch_idx
*
size
+
channel_idx
*
feat
L
en
)
*
groups
+
feat_idx
;
(
batch_idx
*
size
+
channel_idx
*
feat
_l
en
)
*
groups
+
feat_idx
;
int
maxIndex
=
-
1
;
bool
stop
=
false
;
for
(
int
g
=
0
;
g
<
groups
&&
!
stop
;
g
++
)
{
if
(
input_data
[
data_idx
+
g
*
feat
L
en
]
==
output_data
[
index
])
{
maxIndex
=
data_idx
+
g
*
feat
L
en
;
if
(
input_data
[
data_idx
+
g
*
feat
_l
en
]
==
output_data
[
index
])
{
maxIndex
=
data_idx
+
g
*
feat
_l
en
;
stop
=
true
;
}
}
...
...
@@ -73,28 +72,25 @@ __global__ void KernelMaxoutGrad(
}
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template
<
typename
MaxOutProcess
,
typename
T
>
class
MaxOutFunctor
<
platform
::
GPUPlace
,
MaxOutProcess
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
output
,
int
groups
,
int
num_channels
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
*
output
,
int
groups
,
MaxOutProcess
maxout_process
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
const
int
output_channels
=
num_channels
/
groups
;
const
int
output_height
=
output
.
dims
()[
2
];
const
int
output_width
=
output
.
dims
()[
3
];
const
int
output_channels
=
output
->
dims
()[
1
]
;
const
int
output_height
=
output
->
dims
()[
2
];
const
int
output_width
=
output
->
dims
()[
3
];
const
T
*
input_data
=
input
.
data
<
T
>
();
T
*
output_data
=
output
.
mutable_data
<
T
>
(
context
.
GetPlace
());
int
nthreads
=
batch_size
*
output_channels
*
output_height
*
output_width
;
T
*
output_data
=
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
nthreads
=
output
->
numel
();
int
blocks
=
(
nthreads
+
1024
-
1
)
/
1024
;
dim3
threads
(
1024
,
1
);
dim3
grid
(
blocks
,
1
);
...
...
@@ -110,8 +106,6 @@ class MaxOutFunctor<platform::GPUPlace, MaxOutProcess, T> {
};
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template
<
typename
T
>
class
MaxOutGradFunctor
<
platform
::
GPUPlace
,
T
>
{
...
...
@@ -120,7 +114,7 @@ class MaxOutGradFunctor<platform::GPUPlace, T> {
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
int
groups
,
int
num_channels
)
{
int
groups
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_height
=
input
.
dims
()[
2
];
...
...
@@ -133,8 +127,7 @@ class MaxOutGradFunctor<platform::GPUPlace, T> {
const
T
*
output_data
=
output
.
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
.
data
<
T
>
();
T
*
input_grad_data
=
input_grad
.
mutable_data
<
T
>
(
context
.
GetPlace
());
int
nthreads
=
batch_size
*
output_channels
*
output_height
*
output_width
;
int
nthreads
=
output
.
numel
();
int
blocks
=
(
nthreads
+
1024
-
1
)
/
1024
;
dim3
threads
(
1024
,
1
);
dim3
grid
(
blocks
,
1
);
...
...
@@ -152,9 +145,9 @@ template class MaxOutGradFunctor<platform::GPUPlace, float>;
template
class
MaxOutGradFunctor
<
platform
::
GPUPlace
,
double
>;
template
class
MaxOutFunctor
<
platform
::
GPUPlace
,
paddle
::
operators
::
math
::
MaxOut
<
float
>,
float
>
;
math
::
MaxOut
<
float
>,
float
>
;
template
class
MaxOutFunctor
<
platform
::
GPUPlace
,
paddle
::
operators
::
math
::
MaxOut
<
double
>,
double
>
;
math
::
MaxOut
<
double
>,
double
>
;
}
// namespace math
}
// namespace operators
...
...
paddle/operators/math/maxouting.h
浏览文件 @
bd773b9c
...
...
@@ -22,26 +22,20 @@ namespace paddle {
namespace
operators
{
namespace
math
{
#define FLT_MAX \
__FLT_MAX__ // It might need to be placed in another file, but I'm still
// wondering where to put it.
__FLT_MAX__
/*
* \brief Extracting simple operations from
pooling
.
*
Both MaxPool and AvgPool need "initial", "compute" and "finaliz
e"
* \brief Extracting simple operations from
maxout
.
*
need "initial", "comput
e"
* operation.
* MaxPool initializes temp variable to the negative maximum to find the
* maximum value in the pooling field.
* AvgPool initializes temp variable to the zero to accumulate all values
* in pool pooling, and finally takes the average.
* MaxPoolGrad and AvgPoolGrad are gradient operations respectively.
*/
template
<
class
T
>
class
MaxOut
{
public:
DEVICE
inline
T
initial
()
{
return
static_cast
<
T
>
(
-
FLT_MAX
);
}
DEVICE
inline
void
compute
(
T
&
y
,
const
T
&
x
)
{
y
=
y
>
x
?
y
:
x
;
}
DEVICE
inline
void
finalize
(
T
&
y
,
const
T
&
group
)
{}
};
template
<
class
T
>
...
...
@@ -69,11 +63,12 @@ class MaxOutGrad {
* MaxPool2dGradFunctor, MaxPool3dGradFunctor.
*/
template
<
typename
Place
,
typename
MaxOutProcess
,
typename
T
>
class
MaxOutFunctor
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
output
,
int
groups
,
int
num_channels
,
MaxOutProcess
maxout_compute
);
const
framework
::
Tensor
&
input
,
framework
::
Tensor
*
output
,
int
groups
,
MaxOutProcess
maxout_compute
);
};
...
...
@@ -84,8 +79,7 @@ class MaxOutGradFunctor {
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
int
groups
,
int
num_channels
);
const
framework
::
Tensor
&
output_grad
,
int
groups
);
};
...
...
paddle/operators/maxout_op.cc
浏览文件 @
bd773b9c
...
...
@@ -19,17 +19,16 @@ namespace operators {
using
framework
::
Tensor
;
/********first define ProtoMaker类 ***************/
class
MaxOutOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
MaxOutOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"(Tensor) The input tensor of
pooling
operator. "
"(Tensor) The input tensor of
maxout
operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of feature."
);
AddOutput
(
"Out"
,
"(Tensor) The output tensor of
pooling
operator."
"(Tensor) The output tensor of
maxout
operator."
"The format of output tensor is also NCHW."
"Where N is batch size, C is "
"the number of channels, H and W is the height and "
...
...
@@ -38,23 +37,53 @@ class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr
<
int
>
(
"groups"
,
R"DOC(The group number of input layer.
)DOC"
)
.
SetDefault
(
2
);
AddAttr
<
int
>
(
"num_channels"
,
R"DOC(The channel number of input layer.
)DOC"
)
.
SetDefault
(
0
);
AddComment
(
R"DOC(A layer to do max out on conv layer output.
- Input: output of a conv layer.
)DOC"
);
AddComment
(
R"DOC(
- Input: NCHW.
- Output: feature map size same as input. Channel is (input channel) / groups.
So groups should be larger than 1, and the num of channels should be able
to devided by groups.
.. math::
y_{si+j} = \max_k x_{gsi + sk + j}
g = groups
s = input.size / num_channels
0 \le i < num_channels / groups
0 \le j < s
0 \le k < groups
Please refer to Paper:
- Maxout Networks: http://www.jmlr.org/proceedings/papers/v28/goodfellow13.pdf
- Multi-digit Number Recognition from Street View \
Imagery using Deep Convolutional Neural Networks: \
https://arxiv.org/pdf/1312.6082v4.pdf
The simple usage is:
.. code-block:: python
maxout = maxout_layer(input,
num_channels=128,
groups=4)
:param input: The input of this layer.
:type input: LayerOutput
:param num_channels: The channel number of input layer. If None will be set
automatically from previous output.
:type num_channels: int | None
:param groups: The group number of input layer.
:type groups: int
:param name: The name of this layer. It is optional.
:type name: None | basestring.
:param layer_attr: Extra Layer attribute.
:type layer_attr: ExtraLayerAttribute
:return: LayerOutput object.
:rtype: LayerOutput
)DOC"
);
}
};
/******************2nd **********************************/
class
MaxOutOp
:
public
framework
::
OperatorWithKernel
{
public:
...
...
@@ -67,20 +96,14 @@ class MaxOutOp : public framework::OperatorWithKernel {
"Output(Out) of maxoutOp should not be null."
);
auto
in_x_dims
=
ctx
->
GetInputDim
(
"X"
);
int
groups
=
ctx
->
Attrs
().
Get
<
int
>
(
"groups"
);
int
num_channels
=
ctx
->
Attrs
().
Get
<
int
>
(
"num_channels"
);
// check groups > 1
PADDLE_ENFORCE_GT
(
groups
,
1
,
"in maxoutop groups should be larger than 1"
);
// check num_channels%groups=0
PADDLE_ENFORCE_EQ
(
num_channels
%
groups
,
0
,
"the num of channels should be able"
"to devided by groups"
);
int
out_num_channels
=
num_channels
/
groups
;
std
::
vector
<
int64_t
>
output_shape
({
in_x_dims
[
0
],
out_num_channel
s
});
std
::
vector
<
int64_t
>
output_shape
({
in_x_dims
[
0
],
in_x_dims
[
1
]
/
group
s
});
output_shape
.
push_back
(
in_x_dims
[
2
]);
output_shape
.
push_back
(
in_x_dims
[
3
]);
...
...
paddle/operators/maxout_op.h
浏览文件 @
bd773b9c
...
...
@@ -14,7 +14,6 @@ limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/maxouting.h"
...
...
@@ -32,14 +31,13 @@ class MaxOutKernel : public framework::OpKernel<T> {
Tensor
*
out
=
context
.
Output
<
Tensor
>
(
"Out"
);
int
groups
=
context
.
template
Attr
<
int
>(
"groups"
);
int
num_channels
=
context
.
template
Attr
<
int
>(
"num_channels"
);
paddle
::
operators
::
math
::
MaxOutFunctor
<
Place
,
paddle
::
operators
::
math
::
MaxOut
<
T
>
,
T
>
maxout_forward
;
paddle
::
operators
::
math
::
MaxOut
<
T
>
maxout_process
;
maxout_forward
(
context
.
device_context
(),
*
in_x
,
*
out
,
groups
,
num_channel
s
,
maxout_forward
(
context
.
device_context
(),
*
in_x
,
out
,
group
s
,
maxout_process
);
}
};
...
...
@@ -55,7 +53,6 @@ class MaxOutGradKernel : public framework::OpKernel<T> {
Tensor
*
in_x_grad
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
int
groups
=
context
.
template
Attr
<
int
>(
"groups"
);
int
num_channels
=
context
.
template
Attr
<
int
>(
"num_channels"
);
...
...
@@ -68,7 +65,7 @@ class MaxOutGradKernel : public framework::OpKernel<T> {
paddle
::
operators
::
math
::
MaxOutGradFunctor
<
Place
,
T
>
maxout_backward
;
maxout_backward
(
context
.
device_context
(),
*
in_x
,
*
in_x_grad
,
*
out
,
*
out_grad
,
groups
,
num_channels
);
*
out_grad
,
groups
);
}
}
};
...
...
python/paddle/v2/framework/tests/test_maxout_op.py
浏览文件 @
bd773b9c
...
...
@@ -3,22 +3,13 @@ import numpy as np
from
op_test
import
OpTest
def
maxout_forward_naive_2sweetsky
(
input
,
groups
,
num_channels
):
s0
,
s1
,
s2
,
s3
=
input
.
shape
return
np
.
ndarray
([
s0
,
s1
/
groups
,
groups
,
s2
,
s3
],
\
buffer
=
input
,
dtype
=
input
.
dtype
).
max
(
axis
=
(
2
))
def
maxout_forward_naive
(
input
,
groups
,
num_channels
):
s0
,
s1
,
s2
,
s3
=
input
.
shape
return
np
.
ndarray
([
s0
,
s1
/
groups
,
groups
,
s2
,
s3
],
\
buffer
=
input
,
dtype
=
input
.
dtype
).
max
(
axis
=
(
2
))
class
TestMaxOut_Op
(
OpTest
):
class
TestMaxOutOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"maxout"
self
.
init_test_case
()
...
...
@@ -37,7 +28,7 @@ class TestMaxOut_Op(OpTest):
def
test_check_grad
(
self
):
print
self
.
inputs
print
self
.
outputs
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.5
)
self
.
check_grad
([
'X'
],
'Out'
)
def
init_test_case
(
self
):
self
.
MaxOut_forward_naive
=
maxout_forward_naive
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录