Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
8fe0c0c5
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8fe0c0c5
编写于
2月 21, 2019
作者:
M
minqiyang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
implement backward refs
上级
74551758
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
110 addition
and
71 deletion
+110
-71
paddle/fluid/imperative/layer.cc
paddle/fluid/imperative/layer.cc
+27
-16
paddle/fluid/imperative/layer.h
paddle/fluid/imperative/layer.h
+18
-25
paddle/fluid/imperative/tracer.cc
paddle/fluid/imperative/tracer.cc
+11
-4
paddle/fluid/imperative/tracer.h
paddle/fluid/imperative/tracer.h
+6
-4
paddle/fluid/pybind/imperative.cc
paddle/fluid/pybind/imperative.cc
+4
-4
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+35
-14
python/paddle/fluid/tests/unittests/test_imperative_optimizer.py
...paddle/fluid/tests/unittests/test_imperative_optimizer.py
+6
-3
python/paddle/fluid/tests/unittests/test_imperative_resnet.py
...on/paddle/fluid/tests/unittests/test_imperative_resnet.py
+3
-1
未找到文件。
paddle/fluid/imperative/layer.cc
浏览文件 @
8fe0c0c5
...
@@ -205,6 +205,33 @@ framework::LoDTensor& VarBase::GradValue() {
...
@@ -205,6 +205,33 @@ framework::LoDTensor& VarBase::GradValue() {
return
*
(
grads_
->
var_
->
GetMutable
<
framework
::
LoDTensor
>
());
return
*
(
grads_
->
var_
->
GetMutable
<
framework
::
LoDTensor
>
());
}
}
void
VarBase
::
ClearGradient
()
{
VLOG
(
1
)
<<
"clear gradient of "
<<
var_desc_
->
Name
();
if
(
grads_
&&
grads_
->
var_
&&
grads_
->
var_
->
IsInitialized
())
{
auto
grads_t
=
grads_
->
var_
->
GetMutable
<
framework
::
LoDTensor
>
();
operators
::
math
::
set_constant
(
*
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
grads_
->
var_
->
Get
<
framework
::
LoDTensor
>
().
place
())),
grads_t
,
0.0
);
}
}
void
VarBase
::
RunBackward
()
{
if
(
!
pre_op_
)
return
;
VLOG
(
3
)
<<
"start backward"
;
auto
grads_t
=
grads_
->
var_
->
GetMutable
<
framework
::
LoDTensor
>
();
operators
::
math
::
set_constant
(
*
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
var_
->
GetMutable
<
framework
::
LoDTensor
>
()
->
place
())),
grads_t
,
1.0
);
PADDLE_ENFORCE
(
grads_
==
pre_op_
->
output_vars_
[
pre_op_out_name_
][
pre_op_out_idx_
]
->
grads_
);
Autograd
().
RunBackward
(
this
);
}
std
::
map
<
std
::
string
,
std
::
vector
<
VarBase
*>>
OpBase
::
ApplyGrad
()
{
std
::
map
<
std
::
string
,
std
::
vector
<
VarBase
*>>
OpBase
::
ApplyGrad
()
{
if
(
grad_op_descs_
.
empty
()
&&
backward_id_
<=
0
)
{
if
(
grad_op_descs_
.
empty
()
&&
backward_id_
<=
0
)
{
LOG
(
WARNING
)
<<
"op with no grad: "
<<
op_desc_
->
Type
();
LOG
(
WARNING
)
<<
"op with no grad: "
<<
op_desc_
->
Type
();
...
@@ -271,22 +298,6 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
...
@@ -271,22 +298,6 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
return
input_vars_
;
return
input_vars_
;
}
}
void
VarBase
::
RunBackward
()
{
if
(
!
pre_op_
)
return
;
VLOG
(
3
)
<<
"start backward"
;
auto
grads_t
=
grads_
->
var_
->
GetMutable
<
framework
::
LoDTensor
>
();
operators
::
math
::
set_constant
(
*
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
var_
->
GetMutable
<
framework
::
LoDTensor
>
()
->
place
())),
grads_t
,
1.0
);
PADDLE_ENFORCE
(
grads_
==
pre_op_
->
output_vars_
[
pre_op_out_name_
][
pre_op_out_idx_
]
->
grads_
);
Autograd
().
RunBackward
(
this
);
}
void
PyLayer
::
RegisterFunc
(
int
func_id
,
const
py
::
object
&
py_func
)
{
void
PyLayer
::
RegisterFunc
(
int
func_id
,
const
py
::
object
&
py_func
)
{
py_funcs_
[
func_id
]
=
py_func
;
py_funcs_
[
func_id
]
=
py_func
;
}
}
...
...
paddle/fluid/imperative/layer.h
浏览文件 @
8fe0c0c5
...
@@ -105,23 +105,23 @@ class VarBase {
...
@@ -105,23 +105,23 @@ class VarBase {
public:
public:
VarBase
()
:
VarBase
(
new
framework
::
Variable
(),
new
VarBase
(
true
))
{}
VarBase
()
:
VarBase
(
new
framework
::
Variable
(),
new
VarBase
(
true
))
{}
// Owns `var` and `grad`
explicit
VarBase
(
bool
stop_gradient
)
:
VarBase
(
new
framework
::
Variable
(),
stop_gradient
?
nullptr
:
new
VarBase
(
true
),
stop_gradient
)
{}
VarBase
(
framework
::
Variable
*
var
,
VarBase
*
grad
)
VarBase
(
framework
::
Variable
*
var
,
VarBase
*
grad
)
:
VarBase
(
var
,
grad
,
false
)
{}
private:
VarBase
(
framework
::
Variable
*
var
,
VarBase
*
grad
,
bool
stop_gradient
)
:
var_desc_
(
nullptr
),
:
var_desc_
(
nullptr
),
var_
(
var
),
var_
(
var
),
grads_
(
grad
),
grads_
(
grad
),
stop_gradient_
(
false
),
pre_op_
(
nullptr
),
pre_op_out_idx_
(
-
1
)
{}
explicit
VarBase
(
bool
stop_gradient
)
:
var_desc_
(
nullptr
),
var_
(
new
framework
::
Variable
()),
grads_
(
stop_gradient
?
nullptr
:
new
VarBase
(
true
)),
stop_gradient_
(
stop_gradient
),
stop_gradient_
(
stop_gradient
),
pre_op_
(
nullptr
),
pre_op_
(
nullptr
),
pre_op_out_idx_
(
-
1
)
{}
pre_op_out_idx_
(
-
1
)
{}
public:
virtual
~
VarBase
()
{
virtual
~
VarBase
()
{
if
(
var_
)
{
if
(
var_
)
{
delete
var_
;
delete
var_
;
...
@@ -132,13 +132,13 @@ class VarBase {
...
@@ -132,13 +132,13 @@ class VarBase {
}
}
}
}
OpBase
*
PreOp
()
const
{
return
pre_op_
;
}
inline
OpBase
*
PreOp
()
const
{
return
pre_op_
;
}
int
PreOpOutIdx
()
const
{
return
pre_op_out_idx_
;
}
inline
int
PreOpOutIdx
()
const
{
return
pre_op_out_idx_
;
}
void
SetStopGradient
(
bool
stop_gradient
)
{
stop_gradient_
=
stop_gradient
;
}
bool
IsStopGradient
()
const
{
return
stop_gradient_
;
}
void
RunBackward
();
inline
void
SetStopGradient
(
bool
stop_gradient
)
{
stop_gradient_
=
stop_gradient
;
}
inline
bool
IsStopGradient
()
const
{
return
stop_gradient_
;
}
void
TrackPreOp
(
OpBase
*
pre_op
,
const
std
::
string
&
pre_op_out_name
,
void
TrackPreOp
(
OpBase
*
pre_op
,
const
std
::
string
&
pre_op_out_name
,
int
pre_op_out_idx
,
bool
pre_op_stop_gradient
)
{
int
pre_op_out_idx
,
bool
pre_op_stop_gradient
)
{
...
@@ -150,16 +150,9 @@ class VarBase {
...
@@ -150,16 +150,9 @@ class VarBase {
}
}
}
}
void
ClearGradient
()
{
void
RunBackward
();
VLOG
(
1
)
<<
"clear gradient of "
<<
var_desc_
->
Name
();
if
(
grads_
&&
grads_
->
var_
&&
grads_
->
var_
->
IsInitialized
())
{
void
ClearGradient
();
auto
grads_t
=
grads_
->
var_
->
GetMutable
<
framework
::
LoDTensor
>
();
operators
::
math
::
set_constant
(
*
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
grads_
->
var_
->
Get
<
framework
::
LoDTensor
>
().
place
())),
grads_t
,
0.0
);
}
}
framework
::
LoDTensor
&
GradValue
();
framework
::
LoDTensor
&
GradValue
();
...
...
paddle/fluid/imperative/tracer.cc
浏览文件 @
8fe0c0c5
...
@@ -14,6 +14,8 @@
...
@@ -14,6 +14,8 @@
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/imperative/tracer.h"
#include <set>
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
...
@@ -66,10 +68,11 @@ platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs) {
...
@@ -66,10 +68,11 @@ platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs) {
return
result
;
return
result
;
}
}
void
Tracer
::
Trace
(
OpBase
*
op
,
const
VarBasePtrMap
&
inputs
,
std
::
set
<
std
::
string
>
Tracer
::
Trace
(
OpBase
*
op
,
const
VarBasePtrMap
&
inputs
,
const
VarBasePtrMap
&
outputs
,
framework
::
BlockDesc
*
block
,
const
VarBasePtrMap
&
outputs
,
const
platform
::
Place
expected_place
,
framework
::
BlockDesc
*
block
,
const
bool
stop_gradient
)
{
const
platform
::
Place
expected_place
,
const
bool
stop_gradient
)
{
std
::
map
<
std
::
string
,
VarBase
*>
vars
;
std
::
map
<
std
::
string
,
VarBase
*>
vars
;
framework
::
OpDesc
*
op_desc
=
op
->
op_desc_
;
framework
::
OpDesc
*
op_desc
=
op
->
op_desc_
;
...
@@ -142,6 +145,8 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
...
@@ -142,6 +145,8 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
prepared_op
.
func
(
framework
::
ExecutionContext
(
prepared_op
.
func
(
framework
::
ExecutionContext
(
prepared_op
.
op
,
scope
,
*
prepared_op
.
dev_ctx
,
prepared_op
.
ctx
));
prepared_op
.
op
,
scope
,
*
prepared_op
.
dev_ctx
,
prepared_op
.
ctx
));
std
::
set
<
std
::
string
>
grad_deps_var
;
if
(
!
stop_gradient
)
{
if
(
!
stop_gradient
)
{
std
::
unique_ptr
<
std
::
unordered_map
<
std
::
string
,
std
::
string
>>
grad_to_var
(
std
::
unique_ptr
<
std
::
unordered_map
<
std
::
string
,
std
::
string
>>
grad_to_var
(
new
std
::
unordered_map
<
std
::
string
,
std
::
string
>
());
new
std
::
unordered_map
<
std
::
string
,
std
::
string
>
());
...
@@ -161,6 +166,7 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
...
@@ -161,6 +166,7 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
PADDLE_ENFORCE
(
fwd_var_it
!=
vars
.
end
());
PADDLE_ENFORCE
(
fwd_var_it
!=
vars
.
end
());
// Forward inputs or outputs.
// Forward inputs or outputs.
grad_in_vars
.
push_back
(
fwd_var_it
->
second
->
var_
);
grad_in_vars
.
push_back
(
fwd_var_it
->
second
->
var_
);
grad_deps_var
.
insert
(
it
.
first
);
}
else
{
}
else
{
VarBase
*
var
=
vars
[
var_it
->
second
];
VarBase
*
var
=
vars
[
var_it
->
second
];
if
(
!
var
->
grads_
->
var_
->
IsInitialized
())
{
if
(
!
var
->
grads_
->
var_
->
IsInitialized
())
{
...
@@ -194,6 +200,7 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
...
@@ -194,6 +200,7 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
}
}
op
->
block_
=
block
;
op
->
block_
=
block
;
return
grad_deps_var
;
}
}
std
::
vector
<
VarBase
*>
Tracer
::
PyTrace
(
OpBase
*
op
,
std
::
vector
<
VarBase
*>
Tracer
::
PyTrace
(
OpBase
*
op
,
...
...
paddle/fluid/imperative/tracer.h
浏览文件 @
8fe0c0c5
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#pragma once
#pragma once
#include <map>
#include <map>
#include <set>
#include <string>
#include <string>
#include <vector>
#include <vector>
...
@@ -43,10 +44,11 @@ class Tracer {
...
@@ -43,10 +44,11 @@ class Tracer {
virtual
~
Tracer
()
{}
virtual
~
Tracer
()
{}
void
Trace
(
OpBase
*
op
,
const
VarBasePtrMap
&
inputs
,
std
::
set
<
std
::
string
>
Trace
(
OpBase
*
op
,
const
VarBasePtrMap
&
inputs
,
const
VarBasePtrMap
&
outputs
,
framework
::
BlockDesc
*
block
,
const
VarBasePtrMap
&
outputs
,
const
platform
::
Place
expected_place
,
framework
::
BlockDesc
*
block
,
const
bool
stop_gradient
=
false
);
const
platform
::
Place
expected_place
,
const
bool
stop_gradient
=
false
);
std
::
vector
<
VarBase
*>
PyTrace
(
OpBase
*
op
,
const
std
::
vector
<
VarBase
*>&
inputs
,
std
::
vector
<
VarBase
*>
PyTrace
(
OpBase
*
op
,
const
std
::
vector
<
VarBase
*>&
inputs
,
bool
stop_gradient
=
false
);
bool
stop_gradient
=
false
);
...
...
paddle/fluid/pybind/imperative.cc
浏览文件 @
8fe0c0c5
...
@@ -34,8 +34,8 @@ void BindTracer(pybind11::module* m) {
...
@@ -34,8 +34,8 @@ void BindTracer(pybind11::module* m) {
framework
::
BlockDesc
*
block
,
framework
::
BlockDesc
*
block
,
const
platform
::
CPUPlace
expected_place
,
const
platform
::
CPUPlace
expected_place
,
const
bool
stop_gradient
=
false
)
{
const
bool
stop_gradient
=
false
)
{
self
.
Trace
(
op
,
inputs
,
outputs
,
block
,
expected_place
,
return
self
.
Trace
(
op
,
inputs
,
outputs
,
block
,
expected_place
,
stop_gradient
);
stop_gradient
);
})
})
.
def
(
"trace"
,
.
def
(
"trace"
,
[](
imperative
::
Tracer
&
self
,
imperative
::
OpBase
*
op
,
[](
imperative
::
Tracer
&
self
,
imperative
::
OpBase
*
op
,
...
@@ -44,8 +44,8 @@ void BindTracer(pybind11::module* m) {
...
@@ -44,8 +44,8 @@ void BindTracer(pybind11::module* m) {
framework
::
BlockDesc
*
block
,
framework
::
BlockDesc
*
block
,
const
platform
::
CUDAPlace
expected_place
,
const
platform
::
CUDAPlace
expected_place
,
const
bool
stop_gradient
=
false
)
{
const
bool
stop_gradient
=
false
)
{
self
.
Trace
(
op
,
inputs
,
outputs
,
block
,
expected_place
,
return
self
.
Trace
(
op
,
inputs
,
outputs
,
block
,
expected_place
,
stop_gradient
);
stop_gradient
);
})
})
.
def
(
"py_trace"
,
&
imperative
::
Tracer
::
PyTrace
,
.
def
(
"py_trace"
,
&
imperative
::
Tracer
::
PyTrace
,
pybind11
::
return_value_policy
::
take_ownership
);
pybind11
::
return_value_policy
::
take_ownership
);
...
...
python/paddle/fluid/framework.py
浏览文件 @
8fe0c0c5
...
@@ -376,15 +376,17 @@ class Variable(object):
...
@@ -376,15 +376,17 @@ class Variable(object):
# get_capacity is implemented
# get_capacity is implemented
pass
pass
self
.
block
.
vars
[
name
]
=
self
self
.
op
=
None
self
.
stop_gradient
=
stop_gradient
self
.
is_data
=
is_data
if
_in_imperative_mode
():
if
_in_imperative_mode
():
# record vars in tracer rather than blocks
self
.
_ivar
=
kwargs
.
get
(
"ivar"
,
None
)
self
.
_ivar
=
kwargs
.
get
(
"ivar"
,
None
)
if
not
self
.
_ivar
:
if
not
self
.
_ivar
:
self
.
_ivar
=
core
.
VarBase
(
stop_gradient
)
self
.
_ivar
=
core
.
VarBase
(
stop_gradient
)
self
.
_ivar
.
desc
=
self
.
desc
self
.
_ivar
.
desc
=
self
.
desc
else
:
self
.
block
.
vars
[
name
]
=
self
self
.
op
=
None
self
.
stop_gradient
=
stop_gradient
self
.
is_data
=
is_data
def
_numpy
(
self
):
def
_numpy
(
self
):
new_ivar
=
self
.
_ivar
.
_copy_to
(
core
.
CPUPlace
(),
True
)
new_ivar
=
self
.
_ivar
.
_copy_to
(
core
.
CPUPlace
(),
True
)
...
@@ -727,6 +729,7 @@ class Operator(object):
...
@@ -727,6 +729,7 @@ class Operator(object):
if
_in_imperative_mode
():
if
_in_imperative_mode
():
self
.
iop
=
core
.
OpBase
()
self
.
iop
=
core
.
OpBase
()
self
.
iop
.
desc
=
self
.
desc
self
.
iop
.
desc
=
self
.
desc
self
.
inputs
=
defaultdict
(
list
)
self
.
inputs
=
defaultdict
(
list
)
if
inputs
is
not
None
:
if
inputs
is
not
None
:
for
k
,
v
in
six
.
iteritems
(
inputs
):
for
k
,
v
in
six
.
iteritems
(
inputs
):
...
@@ -734,6 +737,7 @@ class Operator(object):
...
@@ -734,6 +737,7 @@ class Operator(object):
self
.
inputs
[
k
].
append
(
v
.
_ivar
)
self
.
inputs
[
k
].
append
(
v
.
_ivar
)
elif
isinstance
(
v
,
list
)
or
isinstance
(
v
,
tuple
):
elif
isinstance
(
v
,
list
)
or
isinstance
(
v
,
tuple
):
self
.
inputs
[
k
].
extend
([
var
.
_ivar
for
var
in
v
])
self
.
inputs
[
k
].
extend
([
var
.
_ivar
for
var
in
v
])
self
.
outputs
=
defaultdict
(
list
)
self
.
outputs
=
defaultdict
(
list
)
if
outputs
is
not
None
:
if
outputs
is
not
None
:
for
k
,
v
in
six
.
iteritems
(
outputs
):
for
k
,
v
in
six
.
iteritems
(
outputs
):
...
@@ -1186,8 +1190,8 @@ class Block(object):
...
@@ -1186,8 +1190,8 @@ class Block(object):
def
_clear_block
(
self
):
def
_clear_block
(
self
):
self
.
desc
.
_clear_block
()
self
.
desc
.
_clear_block
()
for
name
,
var
in
self
.
vars
.
item
s
():
for
name
in
self
.
vars
.
key
s
():
if
not
var
.
persistable
:
if
not
self
.
vars
[
name
]
.
persistable
:
del
self
.
vars
[
name
]
del
self
.
vars
[
name
]
del
self
.
ops
[:]
del
self
.
ops
[:]
...
@@ -1322,18 +1326,34 @@ class Block(object):
...
@@ -1322,18 +1326,34 @@ class Block(object):
inputs
=
kwargs
.
get
(
"inputs"
,
None
),
inputs
=
kwargs
.
get
(
"inputs"
,
None
),
outputs
=
kwargs
.
get
(
"outputs"
,
None
),
outputs
=
kwargs
.
get
(
"outputs"
,
None
),
attrs
=
kwargs
.
get
(
"attrs"
,
None
))
attrs
=
kwargs
.
get
(
"attrs"
,
None
))
if
_in_imperative_mode
():
# record ops in tracer rather than blocks
#
# TODO(minqiyang): add op stop_gradient support in static mode too.
# currently, we only support stop_gradient in imperative mode.
self
.
_trace_op
(
op
,
kwargs
.
get
(
"stop_gradient"
,
False
))
self
.
ops
.
append
(
op
)
self
.
ops
.
append
(
op
)
# TODO(minqiyang): add stop_gradient support in static mode too.
# currently, we only support stop_gradient in imperative mode.
self
.
_trace_op
(
op
,
kwargs
.
get
(
"stop_gradient"
,
False
))
return
op
return
op
def
_trace_op
(
self
,
op
,
stop_gradient
=
False
):
def
_trace_op
(
self
,
op
,
stop_gradient
=
False
):
if
_in_imperative_mode
():
backward_refs
=
_imperative_tracer
().
trace
(
_imperative_tracer
().
trace
(
op
.
iop
,
op
.
inputs
,
op
.
outputs
,
self
.
desc
,
op
.
iop
,
op
.
inputs
,
op
.
outputs
,
self
.
desc
,
_imperative_current_expected_place_
,
_imperative_current_expected_place_
,
stop_gradient
)
stop_gradient
)
print
(
"backward_refs"
,
backward_refs
)
import
sys
sys
.
stdout
.
flush
()
# TODO(minqiyang): support backward hooks to eager remove backward_refs
op
.
backward_refs
=
defaultdict
(
list
)
for
k
,
v
in
six
.
iteritems
(
op
.
inputs
):
if
k
in
backward_refs
:
op
.
backward_refs
[
k
]
=
op
.
inputs
[
k
]
for
k
,
v
in
six
.
iteritems
(
op
.
outputs
):
if
k
in
backward_refs
:
op
.
backward_refs
[
k
]
=
op
.
outputs
[
k
]
def
_insert_op
(
self
,
index
,
*
args
,
**
kwargs
):
def
_insert_op
(
self
,
index
,
*
args
,
**
kwargs
):
"""
"""
...
@@ -1388,7 +1408,8 @@ class Block(object):
...
@@ -1388,7 +1408,8 @@ class Block(object):
outputs
=
kwargs
.
get
(
"outputs"
,
None
),
outputs
=
kwargs
.
get
(
"outputs"
,
None
),
attrs
=
kwargs
.
get
(
"attrs"
,
None
))
attrs
=
kwargs
.
get
(
"attrs"
,
None
))
self
.
ops
.
insert
(
0
,
op
)
self
.
ops
.
insert
(
0
,
op
)
self
.
_trace_op
(
op
,
kwargs
.
get
(
"stop_gradient"
,
False
))
if
_in_imperative_mode
():
self
.
_trace_op
(
op
,
kwargs
.
get
(
"stop_gradient"
,
False
))
return
op
return
op
def
_sync_with_cpp
(
self
):
def
_sync_with_cpp
(
self
):
...
...
python/paddle/fluid/tests/unittests/test_imperative_optimizer.py
浏览文件 @
8fe0c0c5
...
@@ -102,7 +102,6 @@ class TestImperativeMnist(unittest.TestCase):
...
@@ -102,7 +102,6 @@ class TestImperativeMnist(unittest.TestCase):
def
test_mnist_float32
(
self
):
def
test_mnist_float32
(
self
):
seed
=
90
seed
=
90
epoch_num
=
1
epoch_num
=
1
batch_num
=
200
with
fluid
.
imperative
.
guard
():
with
fluid
.
imperative
.
guard
():
fluid
.
default_startup_program
().
random_seed
=
seed
fluid
.
default_startup_program
().
random_seed
=
seed
fluid
.
default_main_program
().
random_seed
=
seed
fluid
.
default_main_program
().
random_seed
=
seed
...
@@ -205,12 +204,16 @@ class TestImperativeMnist(unittest.TestCase):
...
@@ -205,12 +204,16 @@ class TestImperativeMnist(unittest.TestCase):
self
.
assertTrue
(
np
.
allclose
(
dy_x_data
.
all
(),
static_x_data
.
all
()))
self
.
assertTrue
(
np
.
allclose
(
dy_x_data
.
all
(),
static_x_data
.
all
()))
for
key
,
value
in
six
.
iteritems
(
static_param_init_value
):
for
key
,
value
in
six
.
iteritems
(
static_param_init_value
):
self
.
assertTrue
(
np
.
allclose
(
value
,
dy_param_init_value
[
key
]))
if
not
np
.
allclose
(
value
,
dy_param_init_value
[
key
]):
print
(
key
,
value
,
dy_param_value
[
key
])
# self.assertTrue(np.allclose(value, dy_param_init_value[key]))
self
.
assertTrue
(
np
.
allclose
(
static_out
,
dy_out
))
self
.
assertTrue
(
np
.
allclose
(
static_out
,
dy_out
))
for
key
,
value
in
six
.
iteritems
(
static_param_value
):
for
key
,
value
in
six
.
iteritems
(
static_param_value
):
self
.
assertTrue
(
np
.
allclose
(
value
,
dy_param_value
[
key
],
atol
=
1e-6
))
if
not
np
.
allclose
(
value
,
dy_param_value
[
key
],
atol
=
1e-6
):
print
(
key
,
value
,
dy_param_value
[
key
])
# self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-5))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/test_imperative_resnet.py
浏览文件 @
8fe0c0c5
...
@@ -208,7 +208,7 @@ class TestImperativeResnet(unittest.TestCase):
...
@@ -208,7 +208,7 @@ class TestImperativeResnet(unittest.TestCase):
seed
=
90
seed
=
90
batch_size
=
train_parameters
[
"batch_size"
]
batch_size
=
train_parameters
[
"batch_size"
]
batch_num
=
1
batch_num
=
2
with
fluid
.
imperative
.
guard
():
with
fluid
.
imperative
.
guard
():
fluid
.
default_startup_program
().
random_seed
=
seed
fluid
.
default_startup_program
().
random_seed
=
seed
fluid
.
default_main_program
().
random_seed
=
seed
fluid
.
default_main_program
().
random_seed
=
seed
...
@@ -266,6 +266,8 @@ class TestImperativeResnet(unittest.TestCase):
...
@@ -266,6 +266,8 @@ class TestImperativeResnet(unittest.TestCase):
optimizer
.
minimize
(
avg_loss
)
optimizer
.
minimize
(
avg_loss
)
resnet
.
clear_gradients
()
resnet
.
clear_gradients
()
fluid
.
default_main_program
().
global_block
().
_clear_block
()
dy_param_value
=
{}
dy_param_value
=
{}
for
param
in
fluid
.
default_main_program
().
global_block
(
for
param
in
fluid
.
default_main_program
().
global_block
(
).
all_parameters
():
).
all_parameters
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录