Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
4a11fdb4
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4a11fdb4
编写于
12月 29, 2017
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
follow comments
上级
8bd75900
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
35 addition
and
35 deletion
+35
-35
paddle/operators/cos_sim_op.cc
paddle/operators/cos_sim_op.cc
+5
-5
paddle/operators/cos_sim_op.cu
paddle/operators/cos_sim_op.cu
+6
-6
paddle/operators/cos_sim_op.h
paddle/operators/cos_sim_op.h
+24
-24
未找到文件。
paddle/operators/cos_sim_op.cc
浏览文件 @
4a11fdb4
...
...
@@ -155,11 +155,11 @@ struct CosSimDyFunctor<platform::CPUDeviceContext, T> {
const
T
*
y_norm
,
const
T
*
x
,
const
T
*
y
,
const
T
*
z
,
const
T
*
dz
,
const
size_t
rows
,
const
size_t
cols
,
T
*
dy
)
const
{
for
(
size_t
offset
=
0
;
offset
<
rows
;
++
offset
)
{
auto
xy_norm_prod
=
x_norm
[
offset
]
*
y_norm
[
0
];
auto
dz_data
=
dz
[
offset
];
auto
z_data
=
z
[
offset
];
auto
*
x_data
=
x
+
cols
*
offset
;
for
(
size_t
row_id
=
0
;
row_id
<
rows
;
++
row_id
)
{
auto
xy_norm_prod
=
x_norm
[
row_id
]
*
y_norm
[
0
];
auto
dz_data
=
dz
[
row_id
];
auto
z_data
=
z
[
row_id
];
auto
*
x_data
=
x
+
cols
*
row_id
;
auto
reciprocal_xy_norm_prod
=
1
/
xy_norm_prod
;
auto
y_norm_square
=
y_norm
[
0
]
*
y_norm
[
0
];
...
...
paddle/operators/cos_sim_op.cu
浏览文件 @
4a11fdb4
...
...
@@ -25,12 +25,12 @@ __global__ void CosSimDyKernel(const T* x_norm, const T* y_norm, const T* x,
const
size_t
rows
,
const
size_t
cols
,
T
*
dy
)
{
int
grid_size
=
blockDim
.
x
*
gridDim
.
x
;
T
y_norm_data
=
y_norm
[
0
];
for
(
int
offset
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
offset
<
rows
;
offset
+=
grid_size
)
{
T
xy_norm_prod
=
x_norm
[
offset
]
*
y_norm_data
;
T
dz_data
=
dz
[
offset
];
T
z_data
=
z
[
offset
];
const
T
*
x_data
=
x
+
cols
*
offset
;
for
(
int
row_id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
row_id
<
rows
;
row_id
+=
grid_size
)
{
T
xy_norm_prod
=
x_norm
[
row_id
]
*
y_norm_data
;
T
dz_data
=
dz
[
row_id
];
T
z_data
=
z
[
row_id
];
const
T
*
x_data
=
x
+
cols
*
row_id
;
T
reciprocal_xy_norm_prod
=
1
/
xy_norm_prod
;
T
y_norm_square
=
y_norm_data
*
y_norm_data
;
...
...
paddle/operators/cos_sim_op.h
浏览文件 @
4a11fdb4
...
...
@@ -32,11 +32,11 @@ struct CosSimFunctor {
z_
(
z
),
cols_
(
static_cast
<
size_t
>
(
cols
))
{}
inline
HOSTDEVICE
void
operator
()(
size_t
offset
)
const
{
auto
*
x
=
x_
+
cols_
*
offset
;
inline
HOSTDEVICE
void
operator
()(
size_t
row_id
)
const
{
auto
*
x
=
x_
+
cols_
*
row_id
;
T
xx
=
0
,
xy
=
0
,
yy
=
0
;
if
(
same_row
)
{
auto
*
y
=
y_
+
cols_
*
offset
;
auto
*
y
=
y_
+
cols_
*
row_id
;
T
tep_x
,
tep_y
;
for
(
size_t
i
=
0
;
i
<
cols_
;
++
i
)
{
tep_x
=
x
[
i
];
...
...
@@ -47,9 +47,9 @@ struct CosSimFunctor {
}
xx
=
sqrt
(
xx
);
yy
=
sqrt
(
yy
);
y_norm_
[
offset
]
=
yy
;
x_norm_
[
offset
]
=
xx
;
z_
[
offset
]
=
xy
/
(
xx
*
yy
);
y_norm_
[
row_id
]
=
yy
;
x_norm_
[
row_id
]
=
xx
;
z_
[
row_id
]
=
xy
/
(
xx
*
yy
);
}
else
{
// This can be wrote in a better way.
T
tep_x
,
tep_y
;
for
(
size_t
i
=
0
;
i
<
cols_
;
++
i
)
{
...
...
@@ -61,9 +61,9 @@ struct CosSimFunctor {
}
xx
=
sqrt
(
xx
);
yy
=
sqrt
(
yy
);
if
(
offset
==
0
)
y_norm_
[
0
]
=
yy
;
x_norm_
[
offset
]
=
xx
;
z_
[
offset
]
=
xy
/
(
xx
*
yy
);
if
(
row_id
==
0
)
y_norm_
[
0
]
=
yy
;
x_norm_
[
row_id
]
=
xx
;
z_
[
row_id
]
=
xy
/
(
xx
*
yy
);
}
}
...
...
@@ -125,15 +125,15 @@ struct CosSimGradFunctor {
dx_
(
dx
),
cols_
(
static_cast
<
size_t
>
(
cols
))
{}
inline
HOSTDEVICE
void
operator
()(
size_t
offset
)
const
{
auto
x_norm_square
=
x_norm_
[
offset
]
*
x_norm_
[
offset
];
auto
xy_norm_prod
=
x_norm_
[
offset
]
*
y_norm_
[
offset
];
auto
dz
=
dz_
[
offset
];
auto
z
=
z_
[
offset
];
inline
HOSTDEVICE
void
operator
()(
size_t
row_id
)
const
{
auto
x_norm_square
=
x_norm_
[
row_id
]
*
x_norm_
[
row_id
];
auto
xy_norm_prod
=
x_norm_
[
row_id
]
*
y_norm_
[
row_id
];
auto
dz
=
dz_
[
row_id
];
auto
z
=
z_
[
row_id
];
auto
*
dx
=
dx_
+
cols_
*
offset
;
auto
*
x
=
x_
+
cols_
*
offset
;
auto
*
y
=
y_
+
cols_
*
offset
;
auto
*
dx
=
dx_
+
cols_
*
row_id
;
auto
*
x
=
x_
+
cols_
*
row_id
;
auto
*
y
=
y_
+
cols_
*
row_id
;
auto
reciprocal_xy_norm_prod
=
1
/
xy_norm_prod
;
auto
reciprocal_x_norm_square
=
1
/
x_norm_square
;
...
...
@@ -166,14 +166,14 @@ struct CosSimDxFunctor {
dx_
(
dx
),
cols_
(
static_cast
<
size_t
>
(
cols
))
{}
inline
HOSTDEVICE
void
operator
()(
size_t
offset
)
const
{
auto
xy_norm_prod
=
x_norm_
[
offset
]
*
y_norm_
[
0
];
auto
dz
=
dz_
[
offset
];
auto
z
=
z_
[
offset
];
auto
*
x
=
x_
+
cols_
*
offset
;
inline
HOSTDEVICE
void
operator
()(
size_t
row_id
)
const
{
auto
xy_norm_prod
=
x_norm_
[
row_id
]
*
y_norm_
[
0
];
auto
dz
=
dz_
[
row_id
];
auto
z
=
z_
[
row_id
];
auto
*
x
=
x_
+
cols_
*
row_id
;
auto
reciprocal_xy_norm_prod
=
1
/
xy_norm_prod
;
auto
x_norm_square
=
x_norm_
[
offset
]
*
x_norm_
[
offset
];
auto
*
dx
=
dx_
+
cols_
*
offset
;
auto
x_norm_square
=
x_norm_
[
row_id
]
*
x_norm_
[
row_id
];
auto
*
dx
=
dx_
+
cols_
*
row_id
;
auto
reciprocal_x_norm_square
=
1
/
x_norm_square
;
for
(
size_t
i
=
0
;
i
<
cols_
;
++
i
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录