Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
c7e739f5
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c7e739f5
编写于
12月 06, 2017
作者:
G
gongweibao
提交者:
GitHub
12月 06, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add LRN efficient GPU implement. (#5894)
Add LRN efficient GPU implement
上级
1d1555e2
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
289 addition
and
93 deletion
+289
-93
paddle/operators/lrn_op.cc
paddle/operators/lrn_op.cc
+100
-4
paddle/operators/lrn_op.cu
paddle/operators/lrn_op.cu
+158
-2
paddle/operators/lrn_op.h
paddle/operators/lrn_op.h
+30
-85
python/paddle/v2/fluid/tests/test_lrn_op.py
python/paddle/v2/fluid/tests/test_lrn_op.py
+1
-2
未找到文件。
paddle/operators/lrn_op.cc
浏览文件 @
c7e739f5
...
...
@@ -19,6 +19,103 @@ namespace operators {
using
framework
::
Tensor
;
template
<
typename
T
>
struct
LRNFunctor
<
platform
::
CPUPlace
,
T
>
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
*
out
,
framework
::
Tensor
*
mid
,
int
N
,
int
C
,
int
H
,
int
W
,
int
n
,
T
k
,
T
alpha
,
T
beta
)
{
auto
x_v
=
framework
::
EigenVector
<
T
>::
Flatten
(
input
);
const
int
start
=
-
(
n
-
1
)
/
2
;
const
int
end
=
start
+
n
;
auto
e_mid
=
framework
::
EigenTensor
<
T
,
4
>::
From
(
*
mid
);
e_mid
=
e_mid
.
constant
(
k
);
auto
e_x
=
framework
::
EigenTensor
<
T
,
4
>::
From
(
input
);
for
(
int
m
=
0
;
m
<
N
;
m
++
)
{
for
(
int
i
=
0
;
i
<
C
;
i
++
)
{
for
(
int
c
=
start
;
c
<=
end
;
c
++
)
{
int
ch
=
i
+
c
;
if
(
ch
>=
0
&&
ch
<
C
)
{
auto
s
=
e_mid
.
slice
(
Eigen
::
array
<
int
,
4
>
({{
m
,
i
,
0
,
0
}}),
Eigen
::
array
<
int
,
4
>
({{
1
,
1
,
H
,
W
}}));
auto
r
=
e_x
.
slice
(
Eigen
::
array
<
int
,
4
>
({{
m
,
ch
,
0
,
0
}}),
Eigen
::
array
<
int
,
4
>
({{
1
,
1
,
H
,
W
}}));
s
+=
alpha
*
r
.
square
();
}
}
}
}
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
out
);
out_e
=
x_v
*
e_mid
.
reshape
(
Eigen
::
DSizes
<
int
,
1
>
(
e_mid
.
size
())).
pow
(
-
beta
);
}
};
template
struct
LRNFunctor
<
platform
::
CPUPlace
,
float
>;
template
struct
LRNFunctor
<
platform
::
CPUPlace
,
double
>;
template
<
typename
T
>
struct
LRNGradFunctor
<
platform
::
CPUPlace
,
T
>
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
&
x
,
const
framework
::
Tensor
&
out
,
const
framework
::
Tensor
&
mid
,
framework
::
Tensor
*
x_g
,
const
framework
::
Tensor
&
out_g
,
int
N
,
int
C
,
int
H
,
int
W
,
int
n
,
T
alpha
,
T
beta
)
{
T
ratio
=
-
2
*
alpha
*
beta
;
auto
x_g_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
x_g
);
x_g_e
=
x_g_e
.
constant
(
0.0
);
auto
e_x
=
framework
::
EigenTensor
<
T
,
4
>::
From
(
x
);
auto
e_x_g
=
framework
::
EigenTensor
<
T
,
4
>::
From
(
*
x_g
);
auto
e_out
=
framework
::
EigenTensor
<
T
,
4
>::
From
(
out
);
auto
e_out_g
=
framework
::
EigenTensor
<
T
,
4
>::
From
(
out_g
);
auto
e_mid
=
framework
::
EigenTensor
<
T
,
4
>::
From
(
mid
);
const
int
start
=
-
(
n
-
1
)
/
2
;
const
int
end
=
start
+
n
;
for
(
int
m
=
0
;
m
<
N
;
m
++
)
{
for
(
int
i
=
0
;
i
<
C
;
i
++
)
{
auto
i_x
=
e_x
.
slice
(
Eigen
::
array
<
int
,
4
>
({{
m
,
i
,
0
,
0
}}),
Eigen
::
array
<
int
,
4
>
({{
1
,
1
,
H
,
W
}}));
auto
i_x_g
=
e_x_g
.
slice
(
Eigen
::
array
<
int
,
4
>
({{
m
,
i
,
0
,
0
}}),
Eigen
::
array
<
int
,
4
>
({{
1
,
1
,
H
,
W
}}));
auto
i_out_g
=
e_out_g
.
slice
(
Eigen
::
array
<
int
,
4
>
({{
m
,
i
,
0
,
0
}}),
Eigen
::
array
<
int
,
4
>
({{
1
,
1
,
H
,
W
}}));
auto
i_mid
=
e_mid
.
slice
(
Eigen
::
array
<
int
,
4
>
({{
m
,
i
,
0
,
0
}}),
Eigen
::
array
<
int
,
4
>
({{
1
,
1
,
H
,
W
}}));
i_x_g
=
i_mid
.
pow
(
-
beta
)
*
i_out_g
;
for
(
int
c
=
start
;
c
<=
end
;
c
++
)
{
int
ch
=
i
+
c
;
if
(
ch
<
0
||
ch
>=
C
)
{
continue
;
}
auto
c_out
=
e_out
.
slice
(
Eigen
::
array
<
int
,
4
>
({{
m
,
ch
,
0
,
0
}}),
Eigen
::
array
<
int
,
4
>
({{
1
,
1
,
H
,
W
}}));
auto
c_mid
=
e_mid
.
slice
(
Eigen
::
array
<
int
,
4
>
({{
m
,
ch
,
0
,
0
}}),
Eigen
::
array
<
int
,
4
>
({{
1
,
1
,
H
,
W
}}));
auto
c_out_g
=
e_out_g
.
slice
(
Eigen
::
array
<
int
,
4
>
({{
m
,
ch
,
0
,
0
}}),
Eigen
::
array
<
int
,
4
>
({{
1
,
1
,
H
,
W
}}));
i_x_g
+=
ratio
*
c_out_g
*
c_out
*
i_x
/
c_mid
;
}
}
}
}
};
template
struct
LRNGradFunctor
<
platform
::
CPUPlace
,
float
>;
template
struct
LRNGradFunctor
<
platform
::
CPUPlace
,
double
>;
class
LRNOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
...
...
@@ -83,8 +180,8 @@ class LRNOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment
(
R"DOC(
Local Response Normalization Operator.
This operator comes from the paper
"ImageNet Classification with Deep Convolutional Neural Networks"
.
This operator comes from the paper
:
<<ImageNet Classification with Deep Convolutional Neural Networks>>
.
The original formula is:
...
...
@@ -119,8 +216,7 @@ class LRNOpGrad : public framework::OperatorWithKernel {
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"MidOut"
)),
"Input(MidOut@GRAD) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"MidOut"
),
"Input(MidOut) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) should not be null"
);
...
...
paddle/operators/lrn_op.cu
浏览文件 @
c7e739f5
...
...
@@ -12,11 +12,167 @@
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/lrn_op.h"
namespace
ops
=
paddle
::
operators
;
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
__global__
void
KeCMRNormFillScale
(
int
img_size
,
const
T
*
in
,
T
*
mid
,
int
C
,
int
H
,
int
W
,
int
size
,
T
k
,
T
alpha
)
{
const
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
img_size
)
{
const
int
w
=
idx
%
W
;
const
int
h
=
(
idx
/
W
)
%
H
;
const
int
n
=
idx
/
W
/
H
;
const
int
offset
=
(
n
*
C
*
H
+
h
)
*
W
+
w
;
in
+=
offset
;
mid
+=
offset
;
const
int
step
=
H
*
W
;
const
int
pre_pad
=
(
size
-
1
)
/
2
;
const
int
post_pad
=
size
-
pre_pad
-
1
;
T
accum
=
0
;
int
index
=
0
;
while
(
index
<
C
+
post_pad
)
{
if
(
index
<
C
)
{
T
val
=
in
[
index
*
step
];
accum
+=
val
*
val
;
}
if
(
index
>=
size
)
{
T
val
=
in
[(
index
-
size
)
*
step
];
accum
-=
val
*
val
;
}
if
(
index
>=
post_pad
)
{
mid
[(
index
-
post_pad
)
*
step
]
=
k
+
accum
*
alpha
;
}
++
index
;
}
}
}
template
<
typename
T
>
__global__
void
KeCMRNormOutput
(
int
input_size
,
const
T
*
in
,
const
T
*
mid
,
T
negative_beta
,
T
*
out
)
{
const
int
index
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
index
<
input_size
)
{
out
[
index
]
=
in
[
index
]
*
pow
(
mid
[
index
],
negative_beta
);
}
}
template
<
typename
T
>
void
CrossMapNormal
(
const
framework
::
ExecutionContext
&
ctx
,
const
T
*
inputs
,
T
*
outputs
,
T
*
mid
,
int
N
,
int
C
,
int
H
,
int
W
,
int
n
,
T
k
,
T
alpha
,
T
beta
)
{
int
img_size
=
N
*
H
*
W
;
const
int
block_size
=
1024
;
int
grid_size
=
(
img_size
+
block_size
-
1
)
/
block_size
;
KeCMRNormFillScale
<
T
><<<
grid_size
,
block_size
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
img_size
,
inputs
,
mid
,
C
,
H
,
W
,
n
,
k
,
alpha
);
int
input_size
=
N
*
H
*
W
*
C
;
grid_size
=
(
input_size
+
block_size
-
1
)
/
block_size
;
KeCMRNormOutput
<
T
><<<
grid_size
,
block_size
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
input_size
,
inputs
,
mid
,
-
beta
,
outputs
);
}
template
<
typename
T
>
struct
LRNFunctor
<
platform
::
GPUPlace
,
T
>
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
*
out
,
framework
::
Tensor
*
mid
,
int
N
,
int
C
,
int
H
,
int
W
,
int
n
,
T
k
,
T
alpha
,
T
beta
)
{
CrossMapNormal
<
T
>
(
ctx
,
input
.
data
<
T
>
(),
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
mid
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
N
,
C
,
H
,
W
,
n
,
k
,
alpha
,
beta
);
}
};
template
struct
LRNFunctor
<
platform
::
GPUPlace
,
float
>;
template
struct
LRNFunctor
<
platform
::
GPUPlace
,
double
>;
template
<
typename
T
>
__global__
void
KeCMRNormDiff
(
int
img_size
,
const
T
*
x
,
const
T
*
out
,
const
T
*
mid
,
T
*
x_g
,
const
T
*
out_g
,
int
C
,
int
H
,
int
W
,
int
size
,
T
negative_beta
,
T
ratio
)
{
const
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
img_size
)
{
const
int
w
=
idx
%
W
;
const
int
h
=
(
idx
/
W
)
%
H
;
const
int
n
=
idx
/
W
/
H
;
const
int
offset
=
(
n
*
C
*
H
+
h
)
*
W
+
w
;
x
+=
offset
;
out
+=
offset
;
mid
+=
offset
;
out_g
+=
offset
;
x_g
+=
offset
;
const
int
step
=
H
*
W
;
const
int
pre_pad
=
size
-
(
size
+
1
)
/
2
;
const
int
post_pad
=
size
-
pre_pad
-
1
;
int
index
=
0
;
T
accum
=
0
;
// TODO(gongwb): optimize this with thread shared array.
while
(
index
<
C
+
post_pad
)
{
if
(
index
<
C
)
{
x_g
[
index
*
step
]
=
0.0
;
accum
+=
out_g
[
index
*
step
]
*
out
[
index
*
step
]
/
mid
[
index
*
step
];
}
if
(
index
>=
size
)
{
accum
-=
out_g
[(
index
-
size
)
*
step
]
*
out
[(
index
-
size
)
*
step
]
/
mid
[(
index
-
size
)
*
step
];
}
if
(
index
>=
post_pad
)
{
x_g
[(
index
-
post_pad
)
*
step
]
+=
out_g
[(
index
-
post_pad
)
*
step
]
*
pow
(
mid
[(
index
-
post_pad
)
*
step
],
negative_beta
)
-
ratio
*
x
[(
index
-
post_pad
)
*
step
]
*
accum
;
}
++
index
;
}
}
}
template
<
typename
T
>
void
CrossMapNormalGrad
(
const
framework
::
ExecutionContext
&
ctx
,
const
T
*
x
,
const
T
*
out
,
const
T
*
mid
,
T
*
x_g
,
const
T
*
out_g
,
int
N
,
int
C
,
int
H
,
int
W
,
int
n
,
T
alpha
,
T
beta
)
{
int
img_size
=
N
*
H
*
W
;
const
int
block_size
=
1024
;
int
grid_size
=
(
img_size
+
block_size
-
1
)
/
block_size
;
KeCMRNormDiff
<
T
><<<
grid_size
,
block_size
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
img_size
,
x
,
out
,
mid
,
x_g
,
out_g
,
C
,
H
,
W
,
n
,
-
beta
,
2.0
f
*
alpha
*
beta
);
}
template
<
typename
T
>
struct
LRNGradFunctor
<
platform
::
GPUPlace
,
T
>
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
&
x
,
const
framework
::
Tensor
&
out
,
const
framework
::
Tensor
&
mid
,
framework
::
Tensor
*
x_g
,
const
framework
::
Tensor
&
out_g
,
int
N
,
int
C
,
int
H
,
int
W
,
int
n
,
T
alpha
,
T
beta
)
{
CrossMapNormalGrad
<
T
>
(
ctx
,
x
.
data
<
T
>
(),
out
.
data
<
T
>
(),
mid
.
data
<
T
>
(),
x_g
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
out_g
.
data
<
T
>
(),
N
,
C
,
H
,
W
,
n
,
alpha
,
beta
);
}
};
template
struct
LRNGradFunctor
<
platform
::
GPUPlace
,
float
>;
template
struct
LRNGradFunctor
<
platform
::
GPUPlace
,
double
>;
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
lrn
,
ops
::
LRNKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
REGISTER_OP_GPU_KERNEL
(
lrn_grad
,
ops
::
LRNGradKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
paddle/operators/lrn_op.h
浏览文件 @
c7e739f5
...
...
@@ -21,6 +21,14 @@
namespace
paddle
{
namespace
operators
{
template
<
typename
place
,
typename
T
>
struct
LRNFunctor
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
*
out
,
framework
::
Tensor
*
mid
,
int
N
,
int
C
,
int
H
,
int
W
,
int
n
,
T
k
,
T
alpha
,
T
beta
);
};
template
<
typename
Place
,
typename
T
>
class
LRNKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
@@ -31,8 +39,8 @@ class LRNKernel : public framework::OpKernel<T> {
// f(x) represents outputs
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
// input
const
Tensor
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
x_dims
=
x
->
dims
();
const
Tensor
&
x
=
*
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
x_dims
=
x
.
dims
();
// NCHW
int
N
=
x_dims
[
0
];
...
...
@@ -57,38 +65,20 @@ class LRNKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE
(
beta
>=
0.0
,
"beta should >= 0.0"
);
PADDLE_ENFORCE
(
k
>=
0.0
,
"k should >= 0.0"
);
auto
x_v
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
x
);
const
int
start
=
-
(
n
-
1
)
/
2
;
const
int
end
=
start
+
n
;
auto
e_mid
=
framework
::
EigenTensor
<
T
,
4
>::
From
(
*
mid
);
e_mid
.
device
(
ctx
.
GetEigenDevice
<
Place
>
())
=
e_mid
.
constant
(
k
);
auto
e_x
=
framework
::
EigenTensor
<
T
,
4
>::
From
(
*
x
);
for
(
int
m
=
0
;
m
<
N
;
m
++
)
{
for
(
int
i
=
0
;
i
<
C
;
i
++
)
{
for
(
int
c
=
start
;
c
<=
end
;
c
++
)
{
int
ch
=
i
+
c
;
if
(
ch
>=
0
&&
ch
<
C
)
{
auto
s
=
e_mid
.
slice
(
Eigen
::
array
<
int
,
4
>
({{
m
,
i
,
0
,
0
}}),
Eigen
::
array
<
int
,
4
>
({{
1
,
1
,
H
,
W
}}));
auto
r
=
e_x
.
slice
(
Eigen
::
array
<
int
,
4
>
({{
m
,
ch
,
0
,
0
}}),
Eigen
::
array
<
int
,
4
>
({{
1
,
1
,
H
,
W
}}));
s
.
device
(
ctx
.
GetEigenDevice
<
Place
>
())
+=
alpha
*
r
.
square
();
}
}
}
}
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
out
);
out_e
.
device
(
ctx
.
GetEigenDevice
<
Place
>
())
=
x_v
*
e_mid
.
reshape
(
Eigen
::
DSizes
<
int
,
1
>
(
e_mid
.
size
())).
pow
(
-
beta
);
LRNFunctor
<
Place
,
T
>
f
;
f
(
ctx
,
x
,
out
,
mid
,
N
,
C
,
H
,
W
,
n
,
k
,
alpha
,
beta
);
}
};
template
<
typename
Place
,
typename
T
>
struct
LRNGradFunctor
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
&
x
,
const
framework
::
Tensor
&
out
,
const
framework
::
Tensor
&
mid
,
framework
::
Tensor
*
x_g
,
const
framework
::
Tensor
&
out_g
,
int
N
,
int
C
,
int
H
,
int
W
,
int
n
,
T
alpha
,
T
beta
);
};
/**
* \brief Backward calculation for normalization with across maps.
*
...
...
@@ -97,7 +87,7 @@ class LRNKernel : public framework::OpKernel<T> {
* The implementation of this Function is derived from the
* CrossMapNormalFunc implementation.
*
* InputGrad = OutputGrad *
denoms
^ (-beta)
* InputGrad = OutputGrad *
MidOut
^ (-beta)
* -- upper
* + > (OutputGrad * OutputValue * (-2 * alpha * beta) / MidOut) * InputValue
* -- lower
...
...
@@ -113,18 +103,15 @@ class LRNGradKernel : public framework::OpKernel<T> {
public:
using
Tensor
=
framework
::
Tensor
;
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
Tensor
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
Tensor
*
out
=
ctx
.
Input
<
Tensor
>
(
"Out"
);
const
Tensor
*
out_g
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
const
Tensor
*
mid
=
ctx
.
Input
<
Tensor
>
(
"MidOut"
);
const
Tensor
&
x
=
*
ctx
.
Input
<
Tensor
>
(
"X"
);
const
Tensor
&
out
=
*
ctx
.
Input
<
Tensor
>
(
"Out"
);
const
Tensor
&
out_g
=
*
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
const
Tensor
&
mid
=
*
ctx
.
Input
<
Tensor
>
(
"MidOut"
);
auto
x_g
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
x_g
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
x_g_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
x_g
);
x_g_e
.
device
(
ctx
.
GetEigenDevice
<
Place
>
())
=
x_g_e
.
constant
(
0.0
);
auto
x_dims
=
x
->
dims
();
auto
x_dims
=
x
.
dims
();
int
N
=
x_dims
[
0
];
int
C
=
x_dims
[
1
];
int
H
=
x_dims
[
2
];
...
...
@@ -133,51 +120,9 @@ class LRNGradKernel : public framework::OpKernel<T> {
int
n
=
ctx
.
Attr
<
int
>
(
"n"
);
T
alpha
=
ctx
.
Attr
<
T
>
(
"alpha"
);
T
beta
=
ctx
.
Attr
<
T
>
(
"beta"
);
T
ratio
=
-
2
*
alpha
*
beta
;
auto
e_x
=
framework
::
EigenTensor
<
T
,
4
>::
From
(
*
x
);
auto
e_x_g
=
framework
::
EigenTensor
<
T
,
4
>::
From
(
*
x_g
);
auto
e_out
=
framework
::
EigenTensor
<
T
,
4
>::
From
(
*
out
);
auto
e_out_g
=
framework
::
EigenTensor
<
T
,
4
>::
From
(
*
out_g
);
auto
e_mid
=
framework
::
EigenTensor
<
T
,
4
>::
From
(
*
mid
);
const
int
start
=
-
(
n
-
1
)
/
2
;
const
int
end
=
start
+
n
;
for
(
int
m
=
0
;
m
<
N
;
m
++
)
{
for
(
int
i
=
0
;
i
<
C
;
i
++
)
{
auto
i_x
=
e_x
.
slice
(
Eigen
::
array
<
int
,
4
>
({{
m
,
i
,
0
,
0
}}),
Eigen
::
array
<
int
,
4
>
({{
1
,
1
,
H
,
W
}}));
auto
i_x_g
=
e_x_g
.
slice
(
Eigen
::
array
<
int
,
4
>
({{
m
,
i
,
0
,
0
}}),
Eigen
::
array
<
int
,
4
>
({{
1
,
1
,
H
,
W
}}));
auto
i_out_g
=
e_out_g
.
slice
(
Eigen
::
array
<
int
,
4
>
({{
m
,
i
,
0
,
0
}}),
Eigen
::
array
<
int
,
4
>
({{
1
,
1
,
H
,
W
}}));
auto
i_mid
=
e_mid
.
slice
(
Eigen
::
array
<
int
,
4
>
({{
m
,
i
,
0
,
0
}}),
Eigen
::
array
<
int
,
4
>
({{
1
,
1
,
H
,
W
}}));
i_x_g
.
device
(
ctx
.
GetEigenDevice
<
Place
>
())
=
i_mid
.
pow
(
-
beta
)
*
i_out_g
;
for
(
int
c
=
start
;
c
<=
end
;
c
++
)
{
int
ch
=
i
+
c
;
if
(
ch
<
0
||
ch
>=
C
)
{
continue
;
}
auto
c_out
=
e_out
.
slice
(
Eigen
::
array
<
int
,
4
>
({{
m
,
ch
,
0
,
0
}}),
Eigen
::
array
<
int
,
4
>
({{
1
,
1
,
H
,
W
}}));
auto
c_mid
=
e_mid
.
slice
(
Eigen
::
array
<
int
,
4
>
({{
m
,
ch
,
0
,
0
}}),
Eigen
::
array
<
int
,
4
>
({{
1
,
1
,
H
,
W
}}));
auto
c_out_g
=
e_out_g
.
slice
(
Eigen
::
array
<
int
,
4
>
({{
m
,
ch
,
0
,
0
}}),
Eigen
::
array
<
int
,
4
>
({{
1
,
1
,
H
,
W
}}));
i_x_g
.
device
(
ctx
.
GetEigenDevice
<
Place
>
())
+=
ratio
*
c_out_g
*
c_out
*
i_x
/
c_mid
;
}
}
}
LRNGradFunctor
<
Place
,
T
>
f
;
f
(
ctx
,
x
,
out
,
mid
,
x_g
,
out_g
,
N
,
C
,
H
,
W
,
n
,
alpha
,
beta
);
}
};
...
...
python/paddle/v2/fluid/tests/test_lrn_op.py
浏览文件 @
c7e739f5
...
...
@@ -23,7 +23,7 @@ class TestLRNOp(OpTest):
start
=
-
(
self
.
n
-
1
)
/
2
end
=
start
+
self
.
n
mid
=
np
.
empty
((
self
.
N
,
self
.
C
,
self
.
H
,
self
.
W
)
,
dtype
=
float
)
mid
=
np
.
empty
((
self
.
N
,
self
.
C
,
self
.
H
,
self
.
W
)
).
astype
(
"float32"
)
mid
.
fill
(
self
.
k
)
for
m
in
range
(
0
,
self
.
N
):
for
i
in
range
(
0
,
self
.
C
):
...
...
@@ -74,5 +74,4 @@ class TestLRNOp(OpTest):
if
__name__
==
"__main__"
:
exit
(
0
)
# LRN grad implement wrong
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录