Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
a29b4227
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a29b4227
编写于
9月 20, 2018
作者:
S
sneaxiy
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix sparse gradient clip
上级
b6f61faf
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
35 addition
and
17 deletion
+35
-17
paddle/fluid/operators/clip_op.h
paddle/fluid/operators/clip_op.h
+32
-11
paddle/fluid/operators/math/selected_rows_functor.cu
paddle/fluid/operators/math/selected_rows_functor.cu
+3
-6
未找到文件。
paddle/fluid/operators/clip_op.h
浏览文件 @
a29b4227
...
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/transform.h"
namespace
paddle
{
...
...
@@ -61,14 +62,32 @@ class ClipKernel : public framework::OpKernel<T> {
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
max
=
context
.
Attr
<
T
>
(
"max"
);
auto
min
=
context
.
Attr
<
T
>
(
"min"
);
auto
*
x
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
out
=
context
.
Output
<
Tensor
>
(
"Out"
);
auto
*
x_var
=
context
.
InputVar
(
"X"
);
if
(
x_var
->
IsType
<
framework
::
LoDTensor
>
())
{
auto
*
x
=
context
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
*
out
=
context
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
T
*
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
const
T
*
x_data
=
x
->
data
<
T
>
();
int64_t
numel
=
x
->
numel
();
Transform
<
DeviceContext
>
trans
;
trans
(
context
.
template
device_context
<
DeviceContext
>(),
x_data
,
x_data
+
numel
,
out_data
,
ClipFunctor
<
T
>
(
min
,
max
));
}
else
if
(
x_var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
*
x
=
context
.
Input
<
framework
::
SelectedRows
>
(
"X"
);
auto
*
out
=
context
.
Output
<
framework
::
SelectedRows
>
(
"Out"
);
PADDLE_ENFORCE_NE
(
x
,
out
,
"Inplace clip is not allowed when x is SelectedRows"
);
math
::
scatter
::
MergeAdd
<
DeviceContext
,
T
>
merge_func
;
merge_func
(
context
.
template
device_context
<
DeviceContext
>(),
*
x
,
out
);
auto
*
out_tensor
=
out
->
mutable_value
();
auto
*
out_data
=
out_tensor
->
data
<
T
>
();
int64_t
numel
=
out_tensor
->
numel
();
Transform
<
DeviceContext
>
trans
;
trans
(
context
.
template
device_context
<
DeviceContext
>(),
out_data
,
out_data
+
numel
,
out_data
,
ClipFunctor
<
T
>
(
min
,
max
));
}
else
{
PADDLE_THROW
(
"ClipOp only supports LoDTensor and SelectedRows"
);
}
}
};
...
...
@@ -78,10 +97,12 @@ class ClipGradKernel : public framework::OpKernel<T> {
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
max
=
context
.
Attr
<
T
>
(
"max"
);
auto
min
=
context
.
Attr
<
T
>
(
"min"
);
auto
*
d_out
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_x
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
d_out
=
context
.
Input
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_x
=
context
.
Output
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"X"
));
if
(
d_x
!=
nullptr
)
{
auto
*
x
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
x
=
context
.
Input
<
framework
::
LoD
Tensor
>
(
"X"
);
int64_t
numel
=
d_out
->
numel
();
auto
*
d_x_data
=
d_x
->
mutable_data
<
T
>
(
context
.
GetPlace
());
const
T
*
d_out_data
=
d_out
->
data
<
T
>
();
...
...
paddle/fluid/operators/math/selected_rows_functor.cu
浏览文件 @
a29b4227
...
...
@@ -236,7 +236,7 @@ template <typename T, int block_size>
__global__
void
MergeAddKernel
(
const
T
*
input
,
const
int64_t
*
input_rows
,
T
*
out
,
const
int64_t
*
out_rows
,
size_t
out_rows_size
,
int64_t
row_numel
)
{
const
int
ty
=
blockIdx
.
y
;
const
int
ty
=
blockIdx
.
x
;
int
tid
=
threadIdx
.
x
;
__shared__
size_t
out_idx
;
...
...
@@ -291,12 +291,9 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
const
int
block_size
=
256
;
dim3
threads
(
block_size
,
1
);
dim3
grid1
(
1
,
input_rows
.
size
()
);
dim3
grid1
(
input_rows
.
size
(),
1
);
MergeAddKernel
<
T
,
256
><<<
grid1
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
MergeAddKernel
<
T
,
256
><<<
grid1
,
threads
,
0
,
context
.
stream
()
>>>
(
input_data
,
input_rows
.
CUDAData
(
context
.
GetPlace
()),
out_data
,
out
.
mutable_rows
()
->
CUDAMutableData
(
context
.
GetPlace
()),
out
.
rows
().
size
(),
input_width
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录