Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
35e6b5e1
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
35e6b5e1
编写于
12月 03, 2018
作者:
X
Xin Pan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
polish
test=develop
上级
b80fe826
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
14 addition
and
27 deletion
+14
-27
paddle/fluid/imperative/layer.cc
paddle/fluid/imperative/layer.cc
+8
-22
paddle/fluid/imperative/layer.h
paddle/fluid/imperative/layer.h
+1
-4
paddle/fluid/imperative/tracer.h
paddle/fluid/imperative/tracer.h
+1
-1
tools/print_signatures.py
tools/print_signatures.py
+4
-0
未找到文件。
paddle/fluid/imperative/layer.cc
浏览文件 @
35e6b5e1
...
...
@@ -75,16 +75,6 @@ class Autograd {
}
private:
void
AccumGrads
(
int
grad_idx
,
Variable
*
grad
,
std
::
vector
<
Variable
*>*
op_grads
)
{
if
(
!
(
*
op_grads
)[
grad_idx
])
{
// FIXME(panyx0718): This should be a deep copy.
(
*
op_grads
)[
grad_idx
]
=
grad
;
return
;
}
AddTo
(
grad
,
(
*
op_grads
)[
grad_idx
]);
}
std
::
map
<
OpBase
*
,
int
>
ComputeDepCounts
(
OpBase
*
op
)
{
std
::
map
<
OpBase
*
,
int
>
ret
;
...
...
@@ -108,14 +98,6 @@ class Autograd {
return
ret
;
}
std
::
vector
<
Variable
*>
CreateOpGrads
(
size_t
count
)
{
std
::
vector
<
Variable
*>
op_grads
;
for
(
size_t
i
=
0
;
i
<
count
;
++
i
)
{
op_grads
.
push_back
(
nullptr
);
}
return
op_grads
;
}
framework
::
Scope
*
scope_
;
};
...
...
@@ -133,7 +115,7 @@ framework::Variable* CreateVariable(const std::string& name,
varname
=
string
::
Sprintf
(
"%s@%d"
,
varname
,
id
);
}
LOG
(
ERROR
)
<<
"creating var "
<<
varname
;
VLOG
(
3
)
<<
"creating var "
<<
varname
;
framework
::
Variable
*
var
=
scope
->
Var
(
varname
);
framework
::
LoDTensor
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
...
...
@@ -165,22 +147,25 @@ std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) {
for
(
const
std
::
string
&
grad_invar
:
grad_op_desc_
->
InputArgumentNames
())
{
if
(
grad_to_var_
->
find
(
grad_invar
)
==
grad_to_var_
->
end
())
{
// grad op inputs can be forward inputs, so not in grad_to_var.
continue
;
}
LOG
(
ERROR
)
<<
"op grad in var "
<<
grad_invar
;
VLOG
(
3
)
<<
"op grad in var "
<<
grad_invar
;
block_
->
FindRecursiveOrCreateVar
(
grad_invar
);
framework
::
Variable
*
var
=
scope
->
Var
(
grad_invar
);
const
std
::
string
&
invar
=
grad_to_var_
->
at
(
grad_invar
);
for
(
VarBase
*
varbase
:
*
output_vars_
)
{
// Use the accumulated grads_ by sharing the input with grads_.
if
(
varbase
->
var_desc_
->
Name
()
==
invar
)
{
var
->
GetMutable
<
framework
::
LoDTensor
>
()
->
ShareDataWith
(
varbase
->
grads_
->
Get
<
framework
::
LoDTensor
>
());
break
;
}
}
}
for
(
const
std
::
string
&
outvar
:
grad_op_desc_
->
OutputArgumentNames
())
{
LOG
(
ERROR
)
<<
"grad outvar "
<<
outvar
;
VLOG
(
3
)
<<
"grad outvar "
<<
outvar
;
block_
->
FindRecursiveOrCreateVar
(
outvar
);
framework
::
Variable
*
var
=
scope
->
Var
(
outvar
);
if
(
!
var
->
IsInitialized
())
{
...
...
@@ -199,6 +184,7 @@ std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) {
opbase
->
Run
(
*
scope
,
platform
::
CPUPlace
());
// `ret` matches exactly with `input_vars_` of forward op.
std
::
vector
<
Variable
*>
ret
;
for
(
size_t
i
=
0
;
i
<
input_vars_
->
size
();
++
i
)
{
bool
found
=
false
;
...
...
@@ -207,7 +193,7 @@ std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) {
VarBase
*
origin_var
=
(
*
input_vars_
)[
i
];
std
::
string
orig_var
=
grad_to_var_
->
at
(
outvar
);
PADDLE_ENFORCE
(
origin_var
->
var_desc_
->
Name
()
==
orig_var
);
LOG
(
ERROR
)
<<
"apply grad "
<<
outvar
<<
" with origin "
<<
orig_var
;
VLOG
(
3
)
<<
"apply grad "
<<
outvar
<<
" with origin "
<<
orig_var
;
origin_var
->
ApplyGrad
(
scope
,
var
);
found
=
true
;
ret
.
push_back
(
var
);
...
...
paddle/fluid/imperative/layer.h
浏览文件 @
35e6b5e1
...
...
@@ -36,10 +36,7 @@ class VarBase {
var_
(
nullptr
),
grads_
(
nullptr
)
{}
virtual
~
VarBase
()
{
LOG
(
ERROR
)
<<
"deleting var"
;
LOG
(
ERROR
)
<<
"done deleting var"
;
}
virtual
~
VarBase
()
{}
void
ApplyGrad
(
framework
::
Scope
*
scope
,
framework
::
Variable
*
grad
);
...
...
paddle/fluid/imperative/tracer.h
浏览文件 @
35e6b5e1
...
...
@@ -55,7 +55,7 @@ class Tracer {
framework
::
BlockDesc
*
block
)
{
framework
::
Scope
*
scope
=
GetScope
(
block
);
framework
::
OpDesc
*
op_desc
=
op
->
op_desc_
;
LOG
(
ERROR
)
<<
"tracer tracing "
<<
op_desc
->
Type
();
VLOG
(
3
)
<<
"tracer tracing "
<<
op_desc
->
Type
();
op_desc
->
InferShape
(
*
block
);
op_desc
->
InferVarType
(
block
);
std
::
unique_ptr
<
framework
::
OperatorBase
>
op_base
=
...
...
tools/print_signatures.py
浏览文件 @
35e6b5e1
...
...
@@ -27,6 +27,8 @@ import pydoc
member_dict
=
collections
.
OrderedDict
()
experimental_namespace
=
{
"paddle.fluid.imperative"
}
def
visit_member
(
parent_name
,
member
):
cur_name
=
"."
.
join
([
parent_name
,
member
.
__name__
])
...
...
@@ -50,6 +52,8 @@ def visit_member(parent_name, member):
def
visit_all_module
(
mod
):
if
(
mod
.
__name__
in
experimental_namespace
):
return
for
member_name
in
(
name
for
name
in
(
mod
.
__all__
if
hasattr
(
mod
,
"__all__"
)
else
dir
(
mod
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录