Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
43a3af86
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
43a3af86
编写于
9月 27, 2018
作者:
C
chengduo
提交者:
GitHub
9月 27, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine sgd_op (#13626)
test=develop
上级
adae0a3b
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
21 addition
and
20 deletion
+21
-20
paddle/fluid/operators/sgd_op.cu
paddle/fluid/operators/sgd_op.cu
+21
-20
未找到文件。
paddle/fluid/operators/sgd_op.cu
浏览文件 @
43a3af86
...
...
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#
define EIGEN_USE_GPU
#
include <algorithm>
#include "paddle/fluid/operators/sgd_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
...
...
@@ -33,22 +33,21 @@ __global__ void SGDKernel(const T* g, const T* p, const T* learning_rate,
}
}
template
<
typename
T
,
int
block_size
>
template
<
typename
T
>
__global__
void
SparseSGDFunctorKernel
(
const
T
*
selected_rows
,
const
int64_t
*
rows
,
const
T
*
learning_rate
,
T
*
tensor_out
,
int64_t
row_numel
)
{
const
int
ty
=
blockIdx
.
y
;
int
tid
=
threadIdx
.
x
;
selected_rows
+=
ty
*
row_numel
;
tensor_out
+=
rows
[
ty
]
*
row_numel
;
for
(
int
index
=
tid
;
index
<
row_numel
;
index
+=
block_size
)
{
// Since index in rows of SelectedRows can be duplicate, we have to use
// Atomic Operation to avoid concurrent write error.
paddle
::
platform
::
CudaAtomicAdd
(
tensor_out
+
index
,
-
1.0
*
learning_rate
[
0
]
*
selected_rows
[
index
]);
int64_t
row_numel
,
int64_t
limit
)
{
for
(
int64_t
i
=
blockIdx
.
x
;
i
<
limit
;
i
+=
gridDim
.
x
)
{
const
T
*
selected_rows_ptr
=
selected_rows
+
i
*
row_numel
;
T
*
tensor_out_ptr
=
tensor_out
+
rows
[
i
]
*
row_numel
;
for
(
int64_t
index
=
threadIdx
.
x
;
index
<
row_numel
;
index
+=
blockDim
.
x
)
{
// Since index in rows of SelectedRows can be duplicate, we have to use
// Atomic Operation to avoid concurrent write error.
paddle
::
platform
::
CudaAtomicAdd
(
tensor_out_ptr
+
index
,
-
1.0
*
learning_rate
[
0
]
*
selected_rows_ptr
[
index
]);
}
}
}
}
// namespace
...
...
@@ -97,13 +96,15 @@ class SGDOpCUDAKernel : public framework::OpKernel<T> {
auto
*
in_data
=
in_value
.
data
<
T
>
();
auto
*
out_data
=
param_out
->
data
<
T
>
();
const
int
block_size
=
256
;
dim3
threads
(
block_size
,
1
);
dim3
grid
(
1
,
in_rows
.
size
());
SparseSGDFunctorKernel
<
T
,
256
><<<
grid
,
threads
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
const
int
kThreadsPerBlock
=
256
;
int
thread_x
=
kThreadsPerBlock
;
int
max_threads
=
ctx
.
cuda_device_context
().
GetMaxPhysicalThreadCount
();
int
max_blocks
=
std
::
max
(
max_threads
/
kThreadsPerBlock
,
1
);
SparseSGDFunctorKernel
<<<
max_blocks
,
thread_x
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
in_data
,
in_rows
.
CUDAData
(
ctx
.
GetPlace
()),
learning_rate
->
data
<
T
>
(),
out_data
,
in_row_numel
);
out_data
,
in_row_numel
,
in_rows
.
size
()
);
}
else
{
PADDLE_THROW
(
"Unsupported Variable Type of Grad"
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录