Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
bfbbe19f
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
694
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
bfbbe19f
编写于
4月 24, 2018
作者:
C
chengduo
提交者:
GitHub
4月 24, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #10150 from chengduoZH/fix_elementwise_gradient
Fix elementwise_gradient bug
上级
5ce57555
0f5d5b1f
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
105 addition
and
1 deletion
+105
-1
paddle/fluid/operators/elementwise_op_function.h
paddle/fluid/operators/elementwise_op_function.h
+2
-1
python/paddle/fluid/tests/unittests/test_elementwise_gradient_op.py
...dle/fluid/tests/unittests/test_elementwise_gradient_op.py
+103
-0
未找到文件。
paddle/fluid/operators/elementwise_op_function.h
浏览文件 @
bfbbe19f
...
@@ -356,8 +356,8 @@ __device__ T reduceSum(T val, int tid, int len) {
...
@@ -356,8 +356,8 @@ __device__ T reduceSum(T val, int tid, int len) {
// I use Warp-Level Parallelism and assume the Warp size
// I use Warp-Level Parallelism and assume the Warp size
// is 32 which may be different for different GPU,
// is 32 which may be different for different GPU,
// but most card's warp size is 32.
// but most card's warp size is 32.
__shared__
T
shm
[
32
];
const
int
warpSize
=
32
;
const
int
warpSize
=
32
;
__shared__
T
shm
[
warpSize
];
unsigned
mask
=
0u
;
unsigned
mask
=
0u
;
CREATE_SHFL_MASK
(
mask
,
tid
<
len
);
CREATE_SHFL_MASK
(
mask
,
tid
<
len
);
...
@@ -371,6 +371,7 @@ __device__ T reduceSum(T val, int tid, int len) {
...
@@ -371,6 +371,7 @@ __device__ T reduceSum(T val, int tid, int len) {
if
(
tid
%
warpSize
==
0
)
{
if
(
tid
%
warpSize
==
0
)
{
shm
[
tid
/
warpSize
]
=
val
;
shm
[
tid
/
warpSize
]
=
val
;
}
}
__syncthreads
();
CREATE_SHFL_MASK
(
mask
,
tid
<
warpSize
);
CREATE_SHFL_MASK
(
mask
,
tid
<
warpSize
);
...
...
python/paddle/fluid/tests/unittests/test_elementwise_gradient_op.py
0 → 100644
浏览文件 @
bfbbe19f
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
import
paddle.fluid.core
as
core
import
paddle.fluid
as
fluid
class
TestElementWiseAddOp
(
unittest
.
TestCase
):
def
__assert_close
(
self
,
tensor
,
np_array
,
msg
,
atol
=
1e-4
):
self
.
assertTrue
(
np
.
allclose
(
np
.
array
(
tensor
),
np_array
,
atol
=
atol
),
msg
)
def
check_forward_backward
(
self
):
def
test_with_place
(
place
):
out_grad
=
np
.
random
.
random_sample
(
self
.
x
.
shape
).
astype
(
np
.
float32
)
x_grad
=
out_grad
sum_axis
=
range
(
0
,
len
(
self
.
x
.
shape
))
del
sum_axis
[
self
.
axis
]
y_grad
=
np
.
sum
(
out_grad
,
axis
=
tuple
(
sum_axis
))
var_dict
=
locals
()
var_dict
[
'y'
]
=
self
.
y
var_dict
[
'x'
]
=
self
.
x
var_dict
[
'out'
]
=
self
.
out
var_dict
[
'y@GRAD'
]
=
y_grad
var_dict
[
'x@GRAD'
]
=
x_grad
var_dict
[
'out@GRAD'
]
=
out_grad
var_names
=
[
'x'
,
'y'
,
'out'
,
'y@GRAD'
,
'x@GRAD'
,
'out@GRAD'
]
ground_truth
=
{
name
:
var_dict
[
name
]
for
name
in
var_names
}
program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
program
):
block
=
program
.
global_block
()
for
name
in
ground_truth
:
block
.
create_var
(
name
=
name
,
dtype
=
'float32'
,
shape
=
ground_truth
[
name
].
shape
)
elementwise_add_op
=
block
.
append_op
(
type
=
"elementwise_add"
,
inputs
=
{
"X"
:
block
.
var
(
'x'
),
"Y"
:
block
.
var
(
'y'
),
},
outputs
=
{
"Out"
:
block
.
var
(
'out'
),
},
attrs
=
{
"axis"
:
self
.
axis
,
})
# generate backward op_desc
grad_op_desc_list
,
op_grad_to_var
=
core
.
get_grad_op_desc
(
elementwise_add_op
.
desc
,
set
(),
[])
grad_op_desc
=
grad_op_desc_list
[
0
]
new_op_desc
=
block
.
desc
.
append_op
()
new_op_desc
.
copy_from
(
grad_op_desc
)
for
var_name
in
grad_op_desc
.
output_arg_names
():
block
.
desc
.
var
(
var_name
.
encode
(
"ascii"
))
grad_op_desc
.
infer_var_type
(
block
.
desc
)
grad_op_desc
.
infer_shape
(
block
.
desc
)
for
arg
in
grad_op_desc
.
output_arg_names
():
grad_var
=
block
.
desc
.
find_var
(
arg
.
encode
(
"ascii"
))
grad_var
.
set_dtype
(
core
.
VarDesc
.
VarType
.
FP32
)
exe
=
fluid
.
Executor
(
place
)
out
=
exe
.
run
(
program
,
feed
=
{
name
:
var_dict
[
name
]
for
name
in
[
'x'
,
'y'
,
'out@GRAD'
]
},
fetch_list
=
[
'x@GRAD'
,
'y@GRAD'
])
self
.
__assert_close
(
x_grad
,
out
[
0
],
"x@GRAD"
)
self
.
__assert_close
(
y_grad
,
out
[
1
],
"y@GRAD"
,
atol
=
1.4
)
places
=
[
core
.
CPUPlace
()]
if
core
.
is_compiled_with_cuda
()
and
core
.
op_support_gpu
(
"elementwise_add"
):
places
.
append
(
core
.
CUDAPlace
(
0
))
for
place
in
places
:
test_with_place
(
place
)
def
test_check_forward_backward_with_scale_and_bias
(
self
):
np
.
random
.
seed
(
123
)
self
.
x
=
np
.
random
.
random
((
4
,
32
,
220
,
220
)).
astype
(
np
.
float32
)
self
.
y
=
np
.
random
.
random
((
32
)).
astype
(
np
.
float32
)
self
.
out
=
self
.
x
+
self
.
y
.
reshape
(
1
,
32
,
1
,
1
)
self
.
axis
=
1
self
.
check_forward_backward
()
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录