Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
c33ddc74
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c33ddc74
编写于
9月 01, 2017
作者:
Y
yangyaming
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix some bugs, add more unittests.
上级
e9cc3282
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
79 addition
and
20 deletion
+79
-20
paddle/operators/squared_l2_distance_op.cc
paddle/operators/squared_l2_distance_op.cc
+5
-3
paddle/operators/squared_l2_distance_op.h
paddle/operators/squared_l2_distance_op.h
+11
-8
python/paddle/v2/framework/tests/test_squared_l2_distance_op.py
.../paddle/v2/framework/tests/test_squared_l2_distance_op.py
+63
-9
未找到文件。
paddle/operators/squared_l2_distance_op.cc
浏览文件 @
c33ddc74
...
...
@@ -49,7 +49,9 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
"First dimension of target must be equal to input "
"or to 1."
);
ctx
.
Output
<
Tensor
>
(
"sub_result"
)
->
Resize
(
x_dims
);
ctx
.
Output
<
Tensor
>
(
"sub_result"
)
->
Resize
({
static_cast
<
int
>
(
x_dims
[
0
]),
static_cast
<
int
>
(
framework
::
product
(
x_dims
)
/
x_dims
[
0
])});
ctx
.
Output
<
Tensor
>
(
"Out"
)
->
Resize
({
x_dims
[
0
],
1
});
}
};
...
...
@@ -97,8 +99,8 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel {
"must be 1."
);
auto
*
x_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
y_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
if
(
x_grad
!=
nullptr
)
x_grad
->
Resize
(
x_dims
);
if
(
y_grad
!=
nullptr
)
y_grad
->
Resize
(
y_dims
);
if
(
x_grad
)
x_grad
->
Resize
(
x_dims
);
if
(
y_grad
)
y_grad
->
Resize
(
y_dims
);
}
};
...
...
paddle/operators/squared_l2_distance_op.h
浏览文件 @
c33ddc74
...
...
@@ -53,14 +53,16 @@ class SquaredL2DistanceKernel : public framework::OpKernel {
auto
y_dims
=
y
.
dimensions
();
// buffer the substraction result
if
(
y_dims
[
0
]
==
1
&&
x_dims
[
0
]
>
y_dims
[
0
])
{
auto
y_broadcast_dims
=
y_dims
;
y_broadcast_dims
[
0
]
=
x_dims
[
0
];
sub_result
.
device
(
place
)
=
x
-
y
.
broadcast
(
y_broadcast_dims
);
sub_result
.
device
(
place
)
=
x
-
y
.
broadcast
(
Eigen
::
array
<
int
,
2
>
({
static_cast
<
int
>
(
x_dims
[
0
]),
1
})
);
}
else
{
sub_result
.
device
(
place
)
=
x
-
y
;
}
z
.
device
(
place
)
=
sub_result
.
pow
(
2
).
sum
(
Eigen
::
array
<
int
,
1
>
({
1
}));
auto
sub_res_pow2
=
sub_result
*
sub_result
;
z
.
device
(
place
)
=
sub_res_pow2
.
sum
(
Eigen
::
array
<
int
,
1
>
({
1
}))
.
reshape
(
Eigen
::
array
<
int
,
2
>
({
static_cast
<
int
>
(
x_dims
[
0
]),
1
}));
}
};
...
...
@@ -86,7 +88,7 @@ class SquaredL2DistanceGradKernel : public framework::OpKernel {
// propagate back to input
auto
eigen_place
=
context
.
GetEigenDevice
<
Place
>
();
if
(
x_g
!=
nullptr
)
{
if
(
x_g
)
{
x_g
->
mutable_data
<
T
>
(
context
.
GetPlace
());
// eigen matrix
auto
x_grad
=
...
...
@@ -95,7 +97,7 @@ class SquaredL2DistanceGradKernel : public framework::OpKernel {
x_grad
.
device
(
eigen_place
)
=
grad_mat
;
}
if
(
y_g
!=
nullptr
)
{
if
(
y_g
)
{
y_g
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
y_grad
=
EigenMatrix
<
T
>::
From
(
*
y_g
,
framework
::
make_ddim
({
y_dims
[
0
],
cols
}));
...
...
@@ -107,8 +109,9 @@ class SquaredL2DistanceGradKernel : public framework::OpKernel {
if
(
sub_result
.
dimensions
()[
0
]
==
y_dims
[
0
])
{
y_grad
.
device
(
eigen_place
)
=
-
1
*
grad_mat
;
}
else
{
auto
col_sum_res
=
-
1
*
(
grad_mat
.
sum
(
Eigen
::
array
<
int
,
1
>
({
0
})));
y_grad
.
device
(
eigen_place
)
=
-
1
*
(
grad_mat
.
sum
(
Eigen
::
array
<
int
,
2
>
({
0
})
));
col_sum_res
.
reshape
(
Eigen
::
array
<
int
,
2
>
({
1
,
cols
}
));
}
}
}
...
...
python/paddle/v2/framework/tests/test_squared_l2_distance_op.py
浏览文件 @
c33ddc74
...
...
@@ -4,30 +4,84 @@ from gradient_checker import GradientChecker, create_op
import
numpy
as
np
class
TestSquaredL2DistanceOp
(
unittest
.
TestCase
):
class
TestSquaredL2DistanceOp
_f0
(
unittest
.
TestCase
):
__metaclass__
=
OpTestMeta
def
setUp
(
self
):
self
.
type
=
'squared_l2_distance'
self
.
inputs
=
{
'X'
:
np
.
random
.
uniform
(
0.1
,
1.
,
(
2
,
3
)).
astype
(
'float32'
),
'Y'
:
np
.
random
.
uniform
(
0.1
,
1.
,
(
2
,
3
)).
astype
(
'float32'
)
'X'
:
np
.
random
.
uniform
(
0.1
,
1.
,
(
32
,
64
)).
astype
(
'float32'
),
'Y'
:
np
.
random
.
uniform
(
0.1
,
1.
,
(
32
,
64
)).
astype
(
'float32'
)
}
sub
R
es
=
self
.
inputs
[
'X'
]
-
self
.
inputs
[
'Y'
]
output
=
sub
Res
*
subR
es
sub
_r
es
=
self
.
inputs
[
'X'
]
-
self
.
inputs
[
'Y'
]
output
=
sub
_res
*
sub_r
es
self
.
outputs
=
{
'sub_result'
:
subRes
,
'sub_result'
:
sub_res
,
'Out'
:
np
.
expand_dims
(
output
.
sum
(
1
),
1
)
}
class
TestSquaredL2DistanceOp_f1
(
unittest
.
TestCase
):
__metaclass__
=
OpTestMeta
def
setUp
(
self
):
self
.
type
=
'squared_l2_distance'
self
.
inputs
=
{
'X'
:
np
.
random
.
uniform
(
0.1
,
1.
,
(
32
,
64
)).
astype
(
'float32'
),
'Y'
:
np
.
random
.
uniform
(
0.1
,
1.
,
(
1
,
64
)).
astype
(
'float32'
)
}
sub_res
=
self
.
inputs
[
'X'
]
-
self
.
inputs
[
'Y'
]
output
=
sub_res
*
sub_res
self
.
outputs
=
{
'sub_result'
:
sub_res
,
'Out'
:
np
.
expand_dims
(
output
.
sum
(
1
),
1
)
}
class
TestSquaredL2DistanceOp_f2
(
unittest
.
TestCase
):
__metaclass__
=
OpTestMeta
def
setUp
(
self
):
self
.
type
=
'squared_l2_distance'
self
.
inputs
=
{
'X'
:
np
.
random
.
uniform
(
0.1
,
1.
,
(
32
,
64
,
128
)).
astype
(
'float32'
),
'Y'
:
np
.
random
.
uniform
(
0.1
,
1.
,
(
1
,
64
,
128
)).
astype
(
'float32'
)
}
sub_res
=
self
.
inputs
[
'X'
]
-
self
.
inputs
[
'Y'
]
sub_res
=
sub_res
.
reshape
((
32
,
64
*
128
))
output
=
sub_res
*
sub_res
self
.
outputs
=
{
'sub_result'
:
sub_res
,
'Out'
:
np
.
expand_dims
(
output
.
sum
(
1
),
1
)
}
class
TestSquaredL2DistanceGradOp
(
GradientChecker
):
def
test_squared_l2_distance
(
self
):
def
test_squared_l2_distance_b0
(
self
):
op
=
create_op
(
"squared_l2_distance"
)
inputs
=
{
'X'
:
np
.
random
.
uniform
(
0.1
,
.
6
,
(
2
,
3
)).
astype
(
'float32'
),
'Y'
:
np
.
random
.
uniform
(
0.1
,
.
6
,
(
2
,
3
)).
astype
(
'float32'
)
}
self
.
compare_grad
(
op
,
inputs
)
self
.
check_grad
(
op
,
inputs
,
set
([
"X"
,
"Y"
]),
"Out"
)
def
test_squared_l2_distance_b1
(
self
):
op
=
create_op
(
"squared_l2_distance"
)
inputs
=
{
'X'
:
np
.
random
.
uniform
(
0.1
,
.
6
,
(
2
,
3
)).
astype
(
'float32'
),
'Y'
:
np
.
random
.
uniform
(
0.1
,
.
6
,
(
1
,
3
)).
astype
(
'float32'
)
}
self
.
compare_grad
(
op
,
inputs
)
self
.
check_grad
(
op
,
inputs
,
set
([
"X"
,
"Y"
]),
"Out"
)
def
test_squared_l2_distance_b2
(
self
):
op
=
create_op
(
"squared_l2_distance"
)
inputs
=
{
'X'
:
np
.
random
.
uniform
(
0.1
,
1.
,
(
2
,
3
)).
astype
(
'float32'
),
'Y'
:
np
.
random
.
uniform
(
0.1
,
1.
,
(
2
,
3
)).
astype
(
'float32'
)
'X'
:
np
.
random
.
uniform
(
0.1
,
.
6
,
(
2
,
3
,
4
)).
astype
(
'float32'
),
'Y'
:
np
.
random
.
uniform
(
0.1
,
.
6
,
(
1
,
3
,
4
)).
astype
(
'float32'
)
}
self
.
compare_grad
(
op
,
inputs
)
self
.
check_grad
(
op
,
inputs
,
set
([
"X"
,
"Y"
]),
"Out"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录