Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
54793e35
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
54793e35
编写于
10月 14, 2017
作者:
D
Dong Zhihong
浏览文件
操作
浏览文件
下载
差异文件
Merge remote-tracking branch 'origin/develop' into fix/scope
上级
5eed0134
ec783d6b
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
59 addition
and
25 deletion
+59
-25
paddle/framework/backward.cc
paddle/framework/backward.cc
+44
-13
paddle/framework/backward.h
paddle/framework/backward.h
+12
-2
paddle/framework/block_desc.h
paddle/framework/block_desc.h
+3
-10
未找到文件。
paddle/framework/backward.cc
浏览文件 @
54793e35
...
...
@@ -273,18 +273,40 @@ static bool AllGradInSet(const std::vector<std::string>& names,
return
true
;
}
static
void
CreateGradVarInBlock
(
BlockDescBind
*
block_desc
,
size_t
grad_op_start_index
)
{
static
void
CreateGradVarInBlock
(
std
::
unordered_map
<
std
::
string
,
GradVarInfo
>*
grad_var_record
,
BlockDescBind
*
block_desc
,
size_t
grad_op_start_index
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
param_name_map
)
{
auto
ops
=
block_desc
->
AllOps
();
for
(
size_t
op_index
=
grad_op_start_index
;
op_index
<
ops
.
size
();
++
op_index
)
{
for
(
const
auto
&
output
:
ops
[
op_index
]
->
Outputs
())
{
for
(
const
auto
&
real_output
:
output
.
second
)
{
if
(
!
block_desc
->
HasVar
(
real_output
))
{
block_desc
->
Var
(
real_output
);
}
}
}
// <<<<<<< HEAD
// for (const auto& output : ops[op_index]->Outputs()) {
// for (const auto& real_output : output.second) {
// if (!block_desc->HasVar(real_output)) {
// block_desc->Var(real_output);
// }
// }
// }
// =======
ForEachVarName
(
ops
[
op_index
]
->
Outputs
(),
[
&
](
const
std
::
string
&
grad_var_name
)
{
if
(
block_desc
->
HasVar
(
grad_var_name
))
{
return
false
;
}
block_desc
->
Var
(
grad_var_name
);
auto
it
=
param_name_map
.
find
(
grad_var_name
);
if
(
it
==
param_name_map
.
end
())
{
return
false
;
}
auto
param_var_name
=
it
->
second
;
auto
&
grad_record
=
(
*
grad_var_record
)[
param_var_name
];
grad_record
.
name_
=
grad_var_name
;
grad_record
.
block_idx_
=
block_desc
->
ID
();
grad_record
.
op_idx_
=
static_cast
<
int
>
(
op_index
);
return
false
;
/* not break */
});
// >>>>>>> origin/develop
}
}
...
...
@@ -400,8 +422,9 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
return
backward_descs
;
}
void
AppendBackward
(
ProgramDescBind
&
program_desc
,
const
VarDescBind
&
target
,
const
std
::
unordered_set
<
std
::
string
>&
no_grad_vars
)
{
std
::
unordered_map
<
std
::
string
/*fwd_var_name*/
,
GradVarInfo
/*grad_var_info*/
>
AppendBackward
(
ProgramDescBind
&
program_desc
,
const
VarDescBind
&
target
,
const
std
::
unordered_set
<
std
::
string
>&
no_grad_vars
)
{
std
::
unordered_set
<
std
::
string
>
no_grad_var_names
;
no_grad_var_names
.
reserve
(
no_grad_vars
.
size
()
+
1
);
no_grad_var_names
.
insert
(
std
::
string
(
kEmptyVarName
)
+
kGradVarSuffix
);
...
...
@@ -423,20 +446,28 @@ void AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target,
all_ops
.
push_back
(
std
::
move
(
fill_one_op
));
size_t
forward_op_num
=
all_ops
.
size
();
size_t
forward_block_num
=
program_desc
.
Size
();
// Insert backward operators
std
::
unordered_map
<
std
::
string
,
std
::
string
>
grad_to_var
;
auto
backward_op_descs
=
MakeBlockBackward
(
program_desc
,
root_block_idx
,
&
no_grad_var_names
,
&
grad_to_var
);
std
::
unordered_map
<
std
::
string
,
GradVarInfo
>
retv
;
// Create Variable
for
(
auto
&
ptr
:
backward_op_descs
)
{
all_ops
.
push_back
(
std
::
move
(
ptr
));
}
root_block
->
Var
(
fill_one_op_out
);
// create grad_var for all blocks in this program
CreateGradVarInBlock
(
root_block
,
forward_op_num
);
CreateGradVarInBlock
(
&
retv
,
root_block
,
forward_op_num
,
grad_to_var
);
for
(
size_t
block_index
=
forward_block_num
;
block_index
<
program_desc
.
Size
();
++
block_index
)
{
CreateGradVarInBlock
(
program_desc
.
Block
(
block_index
),
0
);
CreateGradVarInBlock
(
&
retv
,
program_desc
.
Block
(
block_index
),
0
,
grad_to_var
);
}
return
retv
;
}
}
// namespace framework
...
...
paddle/framework/backward.h
浏览文件 @
54793e35
...
...
@@ -14,7 +14,10 @@
#pragma once
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "paddle/framework/operator.h"
#include "paddle/framework/program_desc.h"
...
...
@@ -27,10 +30,17 @@ extern std::unique_ptr<OperatorBase> Backward(
const
OperatorBase
&
forwardOp
,
const
std
::
unordered_set
<
std
::
string
>&
no_grad_vars
);
struct
GradVarInfo
{
std
::
string
name_
;
int
block_idx_
;
int
op_idx_
;
};
// TODO(jiayi): Add target as parameter and generate backward op
// according to target.
void
AppendBackward
(
ProgramDescBind
&
program_desc
,
const
VarDescBind
&
target
,
const
std
::
unordered_set
<
std
::
string
>&
no_grad_vars
);
std
::
unordered_map
<
std
::
string
/*fwd_var_name*/
,
GradVarInfo
/*grad_var_info*/
>
AppendBackward
(
ProgramDescBind
&
program_desc
,
const
VarDescBind
&
target
,
const
std
::
unordered_set
<
std
::
string
>&
no_grad_vars
);
}
// namespace framework
}
// namespace paddle
paddle/framework/block_desc.h
浏览文件 @
54793e35
...
...
@@ -33,15 +33,6 @@ class ProgramDescBind;
class
BlockDescBind
{
public:
friend
std
::
vector
<
std
::
unique_ptr
<
OpDescBind
>>
MakeBlockBackward
(
ProgramDescBind
&
program_desc
,
int
block_idx
,
std
::
unordered_set
<
std
::
string
>
*
no_grad_vars
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>
*
grad_to_var
);
friend
void
AppendBackward
(
ProgramDescBind
&
program_desc
,
const
VarDescBind
&
target
,
const
std
::
unordered_set
<
std
::
string
>
&
no_grad_vars
);
BlockDescBind
(
ProgramDescBind
*
prog
,
BlockDesc
*
desc
)
:
prog_
(
prog
),
desc_
(
desc
),
need_update_
(
false
)
{}
...
...
@@ -69,7 +60,9 @@ class BlockDescBind {
BlockDesc
*
Proto
();
private:
// FIXME(yuyang18): backward will access private data of BlockDesc.
// Mark it public temporary. We can fix it later.
public:
ProgramDescBind
*
prog_
;
// not_own
BlockDesc
*
desc_
;
// not_own
bool
need_update_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录