Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
5fea8cd4
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5fea8cd4
编写于
12月 14, 2018
作者:
M
minqiyang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add sorted_result parameter to SelectedRows Functor
test=develop
上级
da796dfe
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
24 addition
and
20 deletion
+24
-20
paddle/fluid/operators/math/selected_rows_functor.cc
paddle/fluid/operators/math/selected_rows_functor.cc
+10
-7
paddle/fluid/operators/math/selected_rows_functor.cu
paddle/fluid/operators/math/selected_rows_functor.cu
+2
-1
paddle/fluid/operators/math/selected_rows_functor.h
paddle/fluid/operators/math/selected_rows_functor.h
+6
-10
paddle/fluid/operators/optimizers/adam_op.h
paddle/fluid/operators/optimizers/adam_op.h
+6
-2
未找到文件。
paddle/fluid/operators/math/selected_rows_functor.cc
浏览文件 @
5fea8cd4
...
@@ -253,23 +253,26 @@ elementwise_add_to(const DeviceContext& ctx, BlasT<DeviceContext, T>* blas,
...
@@ -253,23 +253,26 @@ elementwise_add_to(const DeviceContext& ctx, BlasT<DeviceContext, T>* blas,
template
<
typename
T
>
template
<
typename
T
>
struct
MergeAdd
<
platform
::
CPUDeviceContext
,
T
>
{
struct
MergeAdd
<
platform
::
CPUDeviceContext
,
T
>
{
framework
::
SelectedRows
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
framework
::
SelectedRows
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
SelectedRows
&
input
)
{
const
framework
::
SelectedRows
&
input
,
const
bool
sorted_result
=
false
)
{
framework
::
SelectedRows
out
;
framework
::
SelectedRows
out
;
(
*
this
)(
context
,
input
,
&
out
);
(
*
this
)(
context
,
input
,
&
out
,
sorted_result
);
return
out
;
return
out
;
}
}
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
SelectedRows
&
input
,
const
framework
::
SelectedRows
&
input
,
framework
::
SelectedRows
*
output
)
{
framework
::
SelectedRows
*
output
,
const
bool
sorted_result
=
false
)
{
std
::
vector
<
const
framework
::
SelectedRows
*>
inputs
;
std
::
vector
<
const
framework
::
SelectedRows
*>
inputs
;
inputs
.
push_back
(
&
input
);
inputs
.
push_back
(
&
input
);
(
*
this
)(
context
,
inputs
,
output
);
(
*
this
)(
context
,
inputs
,
output
,
sorted_result
);
}
}
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
std
::
vector
<
const
framework
::
SelectedRows
*>&
inputs
,
const
std
::
vector
<
const
framework
::
SelectedRows
*>&
inputs
,
framework
::
SelectedRows
*
output
)
{
framework
::
SelectedRows
*
output
,
const
bool
sorted_result
=
false
)
{
if
(
inputs
.
size
()
==
0
)
{
if
(
inputs
.
size
()
==
0
)
{
VLOG
(
3
)
<<
"no input! return"
;
VLOG
(
3
)
<<
"no input! return"
;
return
;
return
;
...
@@ -302,8 +305,8 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
...
@@ -302,8 +305,8 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
}
}
std
::
vector
<
int64_t
>
merge_rows
(
merged_row_set
.
begin
(),
std
::
vector
<
int64_t
>
merge_rows
(
merged_row_set
.
begin
(),
merged_row_set
.
end
());
merged_row_set
.
end
());
if
(
sorted_result
_
)
{
if
(
sorted_result
)
{
std
::
sort
(
merge_rows
);
std
::
sort
(
merge_rows
.
begin
(),
merge_rows
.
end
()
);
}
}
std
::
unordered_map
<
int64_t
,
size_t
>
rows_to_id
;
std
::
unordered_map
<
int64_t
,
size_t
>
rows_to_id
;
for
(
size_t
i
=
0
;
i
<
merge_rows
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
merge_rows
.
size
();
++
i
)
{
...
...
paddle/fluid/operators/math/selected_rows_functor.cu
浏览文件 @
5fea8cd4
...
@@ -266,7 +266,8 @@ __global__ void MergeAddKernel(const T* input, const int64_t* input_rows,
...
@@ -266,7 +266,8 @@ __global__ void MergeAddKernel(const T* input, const int64_t* input_rows,
template
<
typename
T
>
template
<
typename
T
>
struct
MergeAdd
<
platform
::
CUDADeviceContext
,
T
>
{
struct
MergeAdd
<
platform
::
CUDADeviceContext
,
T
>
{
framework
::
SelectedRows
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
framework
::
SelectedRows
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
SelectedRows
&
input
)
{
const
framework
::
SelectedRows
&
input
,
const
bool
sorted_result
=
false
)
{
framework
::
SelectedRows
out
;
framework
::
SelectedRows
out
;
(
*
this
)(
context
,
input
,
&
out
);
(
*
this
)(
context
,
input
,
&
out
);
return
out
;
return
out
;
...
...
paddle/fluid/operators/math/selected_rows_functor.h
浏览文件 @
5fea8cd4
...
@@ -78,23 +78,19 @@ namespace scatter {
...
@@ -78,23 +78,19 @@ namespace scatter {
// functors for manuplating SelectedRows data
// functors for manuplating SelectedRows data
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
struct
MergeAdd
{
struct
MergeAdd
{
MergeAdd
()
:
sorted_result_
(
false
)
{}
explicit
MergeAdd
(
bool
sorted_result
)
:
sorted_result_
(
sorted_result
)
{}
// unary functor, merge by adding duplicated rows in
// unary functor, merge by adding duplicated rows in
// the input SelectedRows object.
// the input SelectedRows object.
framework
::
SelectedRows
operator
()(
const
DeviceContext
&
context
,
framework
::
SelectedRows
operator
()(
const
DeviceContext
&
context
,
const
framework
::
SelectedRows
&
input
);
const
framework
::
SelectedRows
&
input
,
const
bool
sorted_result
=
false
);
void
operator
()(
const
DeviceContext
&
context
,
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
SelectedRows
&
input
,
const
framework
::
SelectedRows
&
input
,
framework
::
SelectedRows
*
output
);
framework
::
SelectedRows
*
output
,
const
bool
sorted_result
=
false
);
void
operator
()(
const
DeviceContext
&
context
,
void
operator
()(
const
DeviceContext
&
context
,
const
std
::
vector
<
const
framework
::
SelectedRows
*>&
inputs
,
const
std
::
vector
<
const
framework
::
SelectedRows
*>&
inputs
,
framework
::
SelectedRows
*
output
);
framework
::
SelectedRows
*
output
,
const
bool
sorted_result
=
false
);
private:
bool
sorted_result_
;
};
};
enum
class
ScatterOps
{
ASSIGN
,
ADD
,
SUB
,
SUBBY
,
MUL
,
DIV
,
DIVBY
};
enum
class
ScatterOps
{
ASSIGN
,
ADD
,
SUB
,
SUBBY
,
MUL
,
DIV
,
DIVBY
};
...
...
paddle/fluid/operators/optimizers/adam_op.h
浏览文件 @
5fea8cd4
...
@@ -157,6 +157,9 @@ struct AdamFunctor<T, CPUAdam> {
...
@@ -157,6 +157,9 @@ struct AdamFunctor<T, CPUAdam> {
}
}
};
};
template
<
typename
T
,
typename
Flavour
>
struct
SparseAdamFunctor
;
template
<
typename
T
>
template
<
typename
T
>
struct
SparseAdamFunctor
<
T
,
GPUAdam
>
{
struct
SparseAdamFunctor
<
T
,
GPUAdam
>
{
T
beta1_
;
T
beta1_
;
...
@@ -283,6 +286,7 @@ struct SparseAdamFunctor<T, CPUAdam> {
...
@@ -283,6 +286,7 @@ struct SparseAdamFunctor<T, CPUAdam> {
// Calculation
// Calculation
if
(
i
==
*
(
rows_
+
j
))
{
if
(
i
==
*
(
rows_
+
j
))
{
T
g
=
grad_
[
j
*
row_numel_
];
mom1
=
beta1_
*
mom1
+
(
1
-
beta1_
)
*
g
;
mom1
=
beta1_
*
mom1
+
(
1
-
beta1_
)
*
g
;
mom2
=
beta2_
*
mom2
+
(
1
-
beta2_
)
*
g
*
g
;
mom2
=
beta2_
*
mom2
+
(
1
-
beta2_
)
*
g
*
g
;
++
j
;
++
j
;
...
@@ -388,12 +392,12 @@ class AdamOpKernel : public framework::OpKernel<T> {
...
@@ -388,12 +392,12 @@ class AdamOpKernel : public framework::OpKernel<T> {
}
else
{
}
else
{
// merge duplicated rows if any.
// merge duplicated rows if any.
// The rows of grad_merge have been sorted inside MergeAdd functor
// The rows of grad_merge have been sorted inside MergeAdd functor
scatter
::
MergeAdd
<
DeviceContext
,
T
>
merge_func
(
true
)
;
scatter
::
MergeAdd
<
DeviceContext
,
T
>
merge_func
;
auto
*
grad_merge_var
=
const_cast
<
framework
::
Scope
&>
(
ctx
.
scope
())
auto
*
grad_merge_var
=
const_cast
<
framework
::
Scope
&>
(
ctx
.
scope
())
.
Var
()
.
Var
()
->
GetMutable
<
framework
::
SelectedRows
>
();
->
GetMutable
<
framework
::
SelectedRows
>
();
merge_func
(
ctx
.
template
device_context
<
DeviceContext
>(),
grad
,
merge_func
(
ctx
.
template
device_context
<
DeviceContext
>(),
grad
,
grad_merge_var
);
grad_merge_var
,
true
);
grad_merge_ptr
=
grad_merge_var
;
grad_merge_ptr
=
grad_merge_var
;
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录