Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
93c16d96
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
93c16d96
编写于
12月 02, 2018
作者:
X
Xin Pan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
polish the autograd (need to verify correctness)
test=develop
上级
c3236f82
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
37 addition
and
61 deletion
+37
-61
paddle/fluid/imperative/layer.cc
paddle/fluid/imperative/layer.cc
+37
-61
未找到文件。
paddle/fluid/imperative/layer.cc
浏览文件 @
93c16d96
...
...
@@ -46,20 +46,16 @@ class Autograd {
void
RunBackward
(
VarBase
*
var
)
{
PADDLE_ENFORCE
(
var
->
pre_op_
->
op_desc_
);
// TODO(panyx0718): Only create vars that "require_grad"
std
::
vector
<
Variable
*>
op_grads
=
CreateOpGrads
(
var
->
pre_op_
->
output_vars_
->
size
());
op_grads
[
var
->
pre_op_out_idx_
]
=
var
->
grads_
;
// TODO(panyx0718): Only create for vars that "require_grad"
(
*
var
->
pre_op_
->
output_vars_
)[
var
->
pre_op_out_idx_
]
->
grads_
=
var
->
grads_
;
std
::
deque
<
std
::
pair
<
OpBase
*
,
std
::
vector
<
Variable
*>>
>
ready
;
ready
.
push_back
(
std
::
make_pair
(
var
->
pre_op_
,
op_grads
)
);
std
::
deque
<
OpBase
*
>
ready
;
ready
.
push_back
(
var
->
pre_op_
);
std
::
map
<
OpBase
*
,
int
>
dep_counts
=
ComputeDepCounts
(
var
->
pre_op_
);
std
::
map
<
OpBase
*
,
std
::
vector
<
Variable
*>>
visited
;
while
(
!
ready
.
empty
())
{
OpBase
*
ready_op
=
ready
.
front
().
first
;
std
::
vector
<
Variable
*>
ready_op_grads
=
ready
.
front
().
second
;
OpBase
*
ready_op
=
ready
.
front
();
ready
.
pop_front
();
std
::
vector
<
Variable
*>
input_grads
=
ready_op
->
ApplyGrad
(
scope_
);
...
...
@@ -67,29 +63,12 @@ class Autograd {
if
(
!
input_grads
[
i
])
continue
;
OpBase
*
pre_op
=
ready_op
->
pre_ops_
->
at
(
i
);
if
(
!
pre_op
)
continue
;
int
pre_op_out_idx
=
ready_op
->
pre_ops_out_idx_
->
at
(
i
);
dep_counts
[
pre_op
]
-=
1
;
PADDLE_ENFORCE
(
dep_counts
[
pre_op
]
>=
0
);
bool
pre_op_ready
=
dep_counts
[
pre_op
]
==
0
;
if
(
pre_op_ready
)
{
if
(
visited
.
find
(
pre_op
)
==
visited
.
end
())
{
PADDLE_ENFORCE
(
pre_op
->
output_vars_
->
size
()
==
1
);
visited
[
pre_op
]
=
{
input_grads
[
i
]};
}
else
{
std
::
vector
<
Variable
*>&
pre_op_grads
=
visited
[
pre_op
];
AccumGrads
(
pre_op_out_idx
,
input_grads
[
i
],
&
pre_op_grads
);
}
ready
.
push_back
(
std
::
make_pair
(
pre_op
,
visited
[
pre_op
]));
}
else
{
if
(
visited
.
find
(
pre_op
)
==
visited
.
end
())
{
// TODO(panyx0718): Only create vars that "require_grad"
visited
[
pre_op
]
=
CreateOpGrads
(
var
->
pre_op_
->
output_vars_
->
size
());
}
else
{
}
std
::
vector
<
Variable
*>&
pre_op_grads
=
visited
[
pre_op
];
AccumGrads
(
pre_op_out_idx
,
input_grads
[
i
],
&
pre_op_grads
);
ready
.
push_back
(
pre_op
);
}
}
}
...
...
@@ -184,27 +163,22 @@ void VarBase::ApplyGrad(framework::Scope* scope, Variable* grad) {
std
::
vector
<
Variable
*>
OpBase
::
ApplyGrad
(
framework
::
Scope
*
scope
)
{
VLOG
(
3
)
<<
"op grad "
<<
grad_op_desc_
->
Type
();
for
(
const
std
::
string
&
invar
:
grad_op_desc_
->
InputArgumentNames
())
{
block_
->
FindRecursiveOrCreateVar
(
invar
);
framework
::
Variable
*
var
=
scope
->
Var
(
invar
);
LOG
(
ERROR
)
<<
"op grad in var "
<<
invar
;
if
(
!
var
->
IsInitialized
())
{
framework
::
VarDesc
*
var_desc
=
block_
->
FindVar
(
invar
);
if
(
var_desc
->
GetType
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR
)
{
LOG
(
ERROR
)
<<
"grad op invar init "
<<
invar
;
var
->
GetMutable
<
framework
::
LoDTensor
>
();
}
else
{
LOG
(
ERROR
)
<<
"tracer doesn't support yet"
;
for
(
const
std
::
string
&
grad_invar
:
grad_op_desc_
->
InputArgumentNames
())
{
if
(
grad_to_var_
->
find
(
grad_invar
)
==
grad_to_var_
->
end
())
{
continue
;
}
LOG
(
ERROR
)
<<
"op grad in var "
<<
grad_invar
;
block_
->
FindRecursiveOrCreateVar
(
grad_invar
);
framework
::
Variable
*
var
=
scope
->
Var
(
grad_invar
);
const
std
::
string
&
invar
=
grad_to_var_
->
at
(
grad_invar
);
for
(
VarBase
*
varbase
:
*
output_vars_
)
{
if
(
varbase
->
var_desc_
->
Name
()
==
invar
)
{
var
->
GetMutable
<
framework
::
LoDTensor
>
()
->
ShareDataWith
(
varbase
->
grads_
->
Get
<
framework
::
LoDTensor
>
());
}
}
else
{
var
->
GetMutable
<
framework
::
LoDTensor
>
()
->
type
();
}
}
std
::
vector
<
Variable
*>
ret
;
for
(
size_t
i
=
0
;
i
<
input_vars_
->
size
();
++
i
)
{
ret
.
push_back
(
nullptr
);
}
for
(
const
std
::
string
&
outvar
:
grad_op_desc_
->
OutputArgumentNames
())
{
LOG
(
ERROR
)
<<
"grad outvar "
<<
outvar
;
block_
->
FindRecursiveOrCreateVar
(
outvar
);
...
...
@@ -225,23 +199,25 @@ std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) {
opbase
->
Run
(
*
scope
,
platform
::
CPUPlace
());
for
(
const
std
::
string
&
outvar
:
grad_op_desc_
->
OutputArgumentNames
())
{
if
(
grad_to_var_
->
find
(
outvar
)
!=
grad_to_var_
->
end
())
{
std
::
string
origin_var
=
(
*
grad_to_var_
)[
outvar
];
for
(
size_t
i
=
0
;
i
<
input_vars_
->
size
();
++
i
)
{
VarBase
*
origin_in_var
=
(
*
input_vars_
)[
i
];
if
(
origin_in_var
->
var_desc_
->
Name
()
==
origin_var
)
{
framework
::
Variable
*
var
=
scope
->
FindVar
(
outvar
);
LOG
(
ERROR
)
<<
"apply grad "
<<
outvar
<<
" with origin "
<<
origin_var
;
origin_in_var
->
ApplyGrad
(
scope
,
var
);
ret
[
i
]
=
var
;
// TODO(panyx0718): There might be 2 var with the same name. We
// currently assume the are the same Variable*. So it doesn't matter
// which one is used.
break
;
}
}
std
::
vector
<
Variable
*>
ret
;
for
(
size_t
i
=
0
;
i
<
input_vars_
->
size
();
++
i
)
{
bool
found
=
false
;
for
(
const
std
::
string
&
outvar
:
grad_op_desc_
->
OutputArgumentNames
())
{
Variable
*
var
=
scope
->
FindVar
(
outvar
);
VarBase
*
origin_var
=
(
*
input_vars_
)[
i
];
std
::
string
orig_var
=
grad_to_var_
->
at
(
outvar
);
PADDLE_ENFORCE
(
origin_var
->
var_desc_
->
Name
()
==
orig_var
);
LOG
(
ERROR
)
<<
"apply grad "
<<
outvar
<<
" with origin "
<<
orig_var
;
origin_var
->
ApplyGrad
(
scope
,
var
);
found
=
true
;
ret
.
push_back
(
var
);
// TODO(panyx0718): There might be another outvar with the same name.
// In that case, it doesn't matter the first one or the second one is
// used.
break
;
}
if
(
!
found
)
{
ret
.
push_back
(
nullptr
);
}
}
return
ret
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录