Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
de26ae41
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
de26ae41
编写于
12月 27, 2017
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add gpu code
上级
4f5e3d0d
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
59 addition
and
77 deletion
+59
-77
paddle/operators/cos_sim_op.cc
paddle/operators/cos_sim_op.cc
+17
-33
paddle/operators/cos_sim_op.cu
paddle/operators/cos_sim_op.cu
+33
-33
paddle/operators/cos_sim_op.h
paddle/operators/cos_sim_op.h
+9
-11
未找到文件。
paddle/operators/cos_sim_op.cc
浏览文件 @
de26ae41
...
...
@@ -151,42 +151,26 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
template
<
typename
T
>
struct
CosSimDyFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
CosSimDyFunctor
(
const
T
*
x_norm
,
const
T
*
y_norm
,
const
T
*
x
,
const
T
*
y
,
const
T
*
z
,
const
T
*
dz
,
T
*
dy
,
int
cols
)
:
x_norm_
(
x_norm
),
y_norm_
(
y_norm
),
x_
(
x
),
y_
(
y
),
z_
(
z
),
dz_
(
dz
),
dy_
(
dy
),
cols_
(
static_cast
<
size_t
>
(
cols
))
{}
inline
HOSTDEVICE
void
operator
()(
size_t
offset
)
const
{
auto
xy_norm_prod
=
x_norm_
[
offset
]
*
y_norm_
[
0
];
auto
dz
=
dz_
[
offset
];
auto
z
=
z_
[
offset
];
auto
*
x
=
x_
+
cols_
*
offset
;
inline
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
const
T
*
x_norm
,
const
T
*
y_norm
,
const
T
*
x
,
const
T
*
y
,
const
T
*
z
,
const
T
*
dz
,
const
size_t
rows
,
const
size_t
cols
,
T
*
dy
)
const
{
for
(
size_t
offset
=
0
;
offset
<
rows
;
++
offset
)
{
auto
xy_norm_prod
=
x_norm
[
offset
]
*
y_norm
[
0
];
auto
dz_data
=
dz
[
offset
];
auto
z_data
=
z
[
offset
];
auto
*
x_data
=
x
+
cols
*
offset
;
auto
reciprocal_xy_norm_prod
=
1
/
xy_norm_prod
;
auto
y_norm_square
=
y_norm_
[
0
]
*
y_norm_
[
0
];
auto
y_norm_square
=
y_norm
[
0
]
*
y_norm
[
0
];
auto
reciprocal_y_norm_square
=
1
/
y_norm_square
;
for
(
size_t
i
=
0
;
i
<
cols_
;
++
i
)
{
dy_
[
i
]
+=
dz
*
(
x
[
i
]
*
reciprocal_xy_norm_prod
-
z
*
y_
[
i
]
*
reciprocal_y_norm_square
);
for
(
size_t
i
=
0
;
i
<
cols
;
++
i
)
{
dy
[
i
]
+=
dz_data
*
(
x_data
[
i
]
*
reciprocal_xy_norm_prod
-
z_data
*
y
[
i
]
*
reciprocal_y_norm_square
);
}
}
}
const
T
*
x_norm_
;
const
T
*
y_norm_
;
const
T
*
x_
;
const
T
*
y_
;
const
T
*
z_
;
const
T
*
dz_
;
T
*
dy_
;
const
size_t
cols_
;
};
}
// namespace operators
}
// namespace paddle
...
...
paddle/operators/cos_sim_op.cu
浏览文件 @
de26ae41
...
...
@@ -20,45 +20,45 @@ namespace paddle {
namespace
operators
{
template
<
typename
T
>
struct
CosSimDyFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
CosSimDyFunctor
(
const
T
*
x_norm
,
const
T
*
y_norm
,
const
T
*
x
,
const
T
*
y
,
const
T
*
z
,
const
T
*
dz
,
T
*
dy
,
int
cols
)
:
x_norm_
(
x_norm
),
y_norm_
(
y_norm
),
x_
(
x
),
y_
(
y
),
z_
(
z
),
dz_
(
dz
),
dy_
(
dy
),
cols_
(
static_cast
<
size_t
>
(
cols
))
{}
inline
HOSTDEVICE
void
operator
()(
size_t
offset
)
const
{
auto
xy_norm_prod
=
x_norm_
[
offset
]
*
y_norm_
[
0
];
auto
dz
=
dz_
[
offset
];
auto
z
=
z_
[
offset
];
auto
*
x
=
x_
+
cols_
*
offset
;
auto
reciprocal_xy_norm_prod
=
1
/
xy_norm_prod
;
__global__
void
CosSimDyKernel
(
const
T
*
x_norm
,
const
T
*
y_norm
,
const
T
*
x
,
const
T
*
y
,
const
T
*
z
,
const
T
*
dz
,
const
size_t
rows
,
const
size_t
cols
,
T
*
dy
)
{
int
grid_size
=
blockDim
.
x
*
gridDim
.
x
;
T
y_norm_data
=
y_norm
[
0
];
for
(
int
offset
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
offset
<
rows
;
offset
+=
grid_size
)
{
T
xy_norm_prod
=
x_norm
[
offset
]
*
y_norm_data
;
T
dz_data
=
dz
[
offset
];
T
z_data
=
z
[
offset
];
const
T
*
x_data
=
x
+
cols
*
offset
;
T
reciprocal_xy_norm_prod
=
1
/
xy_norm_prod
;
auto
y_norm_square
=
y_norm_
[
0
]
*
y_norm_
[
0
];
auto
reciprocal_y_norm_square
=
1
/
y_norm_square
;
for
(
size_t
i
=
0
;
i
<
cols_
;
++
i
)
{
T
dy
=
dz
*
(
x
[
i
]
*
reciprocal_xy_norm_prod
-
z
*
y_
[
i
]
*
reciprocal_y_norm_square
);
// platform::CudaAtomicAdd(dy_ + i, dy);
dy_
[
i
]
+=
dy
;
T
y_norm_square
=
y_norm_data
*
y_norm_data
;
T
reciprocal_y_norm_square
=
1
/
y_norm_square
;
for
(
size_t
i
=
0
;
i
<
cols
;
++
i
)
{
T
dy_data
=
dz_data
*
(
x_data
[
i
]
*
reciprocal_xy_norm_prod
-
z_data
*
y
[
i
]
*
reciprocal_y_norm_square
);
platform
::
CudaAtomicAdd
(
dy
+
i
,
dy_data
);
}
}
}
const
T
*
x_norm_
;
const
T
*
y_norm_
;
const
T
*
x_
;
const
T
*
y_
;
const
T
*
z_
;
const
T
*
dz_
;
T
*
dy_
;
const
size_t
cols_
;
template
<
typename
T
>
struct
CosSimDyFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
inline
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
T
*
x_norm
,
const
T
*
y_norm
,
const
T
*
x
,
const
T
*
y
,
const
T
*
z
,
const
T
*
dz
,
const
size_t
rows
,
const
size_t
cols
,
T
*
dy
)
const
{
const
int
block_size
=
512
;
dim3
threads
(
block_size
,
1
);
dim3
grid
(
1
,
(
rows
+
block_size
-
1
)
/
block_size
);
CosSimDyKernel
<
T
><<<
grid
,
threads
,
0
,
ctx
.
stream
()
>>>
(
x_norm
,
y_norm
,
x
,
y
,
z
,
dz
,
rows
,
cols
,
dy
);
}
};
template
struct
CosSimDyFunctor
<
platform
::
CUDADeviceContext
,
float
>;
}
// namespace operators
}
// namespace paddle
...
...
paddle/operators/cos_sim_op.h
浏览文件 @
de26ae41
...
...
@@ -193,9 +193,10 @@ struct CosSimDxFunctor {
template
<
typename
DeviceContext
,
typename
T
>
struct
CosSimDyFunctor
{
CosSimDyFunctor
(
const
T
*
x_norm
,
const
T
*
y_norm
,
const
T
*
x
,
const
T
*
y
,
const
T
*
z
,
const
T
*
dz
,
T
*
dy
,
int
cols
);
inline
HOSTDEVICE
void
operator
()(
size_t
)
const
;
inline
void
operator
()(
const
DeviceContext
&
ctx
,
const
T
*
x_norm
,
const
T
*
y_norm
,
const
T
*
x
,
const
T
*
y
,
const
T
*
z
,
const
T
*
dz
,
const
size_t
rows
,
const
size_t
cols
,
T
*
dy
)
const
;
};
template
<
typename
DeviceContext
,
typename
T
>
...
...
@@ -255,14 +256,11 @@ class CosSimGradKernel : public framework::OpKernel<T> {
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
set_zero
(
dev_ctx
,
out_grad_y
,
static_cast
<
T
>
(
0
));
CosSimDyFunctor
<
DeviceContext
,
T
>
functor
(
in_x_norm
->
data
<
T
>
(),
in_y_norm
->
data
<
T
>
(),
in_x
->
data
<
T
>
(),
in_y
->
data
<
T
>
(),
in_z
->
data
<
T
>
(),
in_grad_z
->
data
<
T
>
(),
out_grad_y
->
data
<
T
>
(),
cols
);
platform
::
ForRange
<
DeviceContext
>
for_range
(
static_cast
<
const
DeviceContext
&>
(
context
.
device_context
()),
rows_x
);
for_range
(
functor
);
CosSimDyFunctor
<
DeviceContext
,
T
>
functor
;
functor
(
dev_ctx
,
in_x_norm
->
data
<
T
>
(),
in_y_norm
->
data
<
T
>
(),
in_x
->
data
<
T
>
(),
in_y
->
data
<
T
>
(),
in_z
->
data
<
T
>
(),
in_grad_z
->
data
<
T
>
(),
static_cast
<
size_t
>
(
rows_x
),
static_cast
<
size_t
>
(
cols
),
out_grad_y
->
data
<
T
>
());
}
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录