Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
00d5375e
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
00d5375e
编写于
9月 16, 2019
作者:
C
Chen Weihang
提交者:
GitHub
9月 16, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add prune_backward function to cover complicated test_program.clone situation (#19772)
上级
99c78b77
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
232 addition
and
19 deletion
+232
-19
paddle/fluid/framework/prune.cc
paddle/fluid/framework/prune.cc
+183
-6
paddle/fluid/framework/prune.h
paddle/fluid/framework/prune.h
+5
-0
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+3
-0
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+7
-13
python/paddle/fluid/tests/unittests/test_program_prune_backward.py
...ddle/fluid/tests/unittests/test_program_prune_backward.py
+34
-0
未找到文件。
paddle/fluid/framework/prune.cc
浏览文件 @
00d5375e
...
@@ -17,19 +17,40 @@ limitations under the License. */
...
@@ -17,19 +17,40 @@ limitations under the License. */
#include <glog/logging.h>
#include <glog/logging.h>
#include <algorithm>
#include <algorithm>
#include <memory>
#include <set>
#include <set>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/program_desc.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
const
char
kFeedOpType
[]
=
"feed"
;
const
char
kFeedOpType
[]
=
"feed"
;
const
char
kFetchOpType
[]
=
"fetch"
;
const
char
kFetchOpType
[]
=
"fetch"
;
bool
HasDependentVar
(
const
proto
::
OpDesc
&
op_desc
,
bool
HasDependentInputVar
(
const
std
::
set
<
std
::
string
>&
dependent_vars
)
{
const
proto
::
OpDesc
&
op_desc
,
const
std
::
unordered_set
<
std
::
string
>&
dependent_vars
)
{
for
(
auto
&
var
:
op_desc
.
inputs
())
{
for
(
auto
&
argu
:
var
.
arguments
())
{
if
(
dependent_vars
.
count
(
argu
)
!=
0
)
{
return
true
;
}
}
}
return
false
;
}
bool
HasDependentOutputVar
(
const
proto
::
OpDesc
&
op_desc
,
const
std
::
unordered_set
<
std
::
string
>&
dependent_vars
)
{
for
(
auto
&
var
:
op_desc
.
outputs
())
{
for
(
auto
&
var
:
op_desc
.
outputs
())
{
for
(
auto
&
argu
:
var
.
arguments
())
{
for
(
auto
&
argu
:
var
.
arguments
())
{
if
(
dependent_vars
.
count
(
argu
)
!=
0
)
{
if
(
dependent_vars
.
count
(
argu
)
!=
0
)
{
...
@@ -47,6 +68,14 @@ bool IsTarget(const proto::OpDesc& op_desc) {
...
@@ -47,6 +68,14 @@ bool IsTarget(const proto::OpDesc& op_desc) {
return
false
;
return
false
;
}
}
bool
HasTrueTarget
(
const
proto
::
OpDesc
&
op_desc
)
{
return
op_desc
.
has_is_target
()
&&
op_desc
.
is_target
();
}
bool
HasFalseTarget
(
const
proto
::
OpDesc
&
op_desc
)
{
return
op_desc
.
has_is_target
()
&&
!
op_desc
.
is_target
();
}
int
GetSubBlockIndex
(
const
proto
::
OpDesc
&
op_desc
)
{
int
GetSubBlockIndex
(
const
proto
::
OpDesc
&
op_desc
)
{
for
(
auto
&
attr
:
op_desc
.
attrs
())
{
for
(
auto
&
attr
:
op_desc
.
attrs
())
{
if
(
attr
.
type
()
==
proto
::
AttrType
::
BLOCK
)
{
if
(
attr
.
type
()
==
proto
::
AttrType
::
BLOCK
)
{
...
@@ -61,6 +90,24 @@ bool HasSubBlock(const proto::OpDesc& op_desc) {
...
@@ -61,6 +90,24 @@ bool HasSubBlock(const proto::OpDesc& op_desc) {
return
GetSubBlockIndex
(
op_desc
)
>
0
;
return
GetSubBlockIndex
(
op_desc
)
>
0
;
}
}
void
AppendOpInputVarNames
(
const
proto
::
OpDesc
&
op_desc
,
std
::
unordered_set
<
std
::
string
>*
vars_set
)
{
for
(
auto
&
var
:
op_desc
.
inputs
())
{
for
(
auto
&
arg
:
var
.
arguments
())
{
vars_set
->
emplace
(
arg
);
}
}
}
void
AppendOpOutputVarNames
(
const
proto
::
OpDesc
&
op_desc
,
std
::
unordered_set
<
std
::
string
>*
vars_set
)
{
for
(
auto
&
var
:
op_desc
.
outputs
())
{
for
(
auto
&
arg
:
var
.
arguments
())
{
vars_set
->
emplace
(
arg
);
}
}
}
// block_id is the idx of the current block in the input desc
// block_id is the idx of the current block in the input desc
// parent_block_id is the idx of the parent of the current block
// parent_block_id is the idx of the parent of the current block
// in the output desc, -1 means the current block is global block
// in the output desc, -1 means the current block is global block
...
@@ -68,7 +115,7 @@ bool HasSubBlock(const proto::OpDesc& op_desc) {
...
@@ -68,7 +115,7 @@ bool HasSubBlock(const proto::OpDesc& op_desc) {
// the child block to help pruning
// the child block to help pruning
void
prune_impl
(
const
proto
::
ProgramDesc
&
input
,
proto
::
ProgramDesc
*
output
,
void
prune_impl
(
const
proto
::
ProgramDesc
&
input
,
proto
::
ProgramDesc
*
output
,
int
block_id
,
int
parent_block_id
,
int
block_id
,
int
parent_block_id
,
std
::
set
<
std
::
string
>*
dependent_vars
,
std
::
unordered_
set
<
std
::
string
>*
dependent_vars
,
const
std
::
set
<
std
::
string
>
feed_var_names
)
{
const
std
::
set
<
std
::
string
>
feed_var_names
)
{
auto
&
block
=
input
.
blocks
(
block_id
);
auto
&
block
=
input
.
blocks
(
block_id
);
auto
&
ops
=
block
.
ops
();
auto
&
ops
=
block
.
ops
();
...
@@ -91,7 +138,7 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
...
@@ -91,7 +138,7 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
std
::
vector
<
bool
>
should_run
;
std
::
vector
<
bool
>
should_run
;
for
(
auto
op_iter
=
ops
.
rbegin
();
op_iter
!=
ops
.
rend
();
++
op_iter
)
{
for
(
auto
op_iter
=
ops
.
rbegin
();
op_iter
!=
ops
.
rend
();
++
op_iter
)
{
auto
&
op_desc
=
*
op_iter
;
auto
&
op_desc
=
*
op_iter
;
if
(
IsTarget
(
op_desc
)
||
HasDependentVar
(
op_desc
,
*
dependent_vars
))
{
if
(
IsTarget
(
op_desc
)
||
HasDependent
Output
Var
(
op_desc
,
*
dependent_vars
))
{
// insert its input to the dependency graph
// insert its input to the dependency graph
for
(
auto
&
var
:
op_desc
.
inputs
())
{
for
(
auto
&
var
:
op_desc
.
inputs
())
{
for
(
auto
&
argu
:
var
.
arguments
())
{
for
(
auto
&
argu
:
var
.
arguments
())
{
...
@@ -127,7 +174,7 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
...
@@ -127,7 +174,7 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
*
op
=
input
.
blocks
(
block_id
).
ops
(
i
);
*
op
=
input
.
blocks
(
block_id
).
ops
(
i
);
if
(
HasSubBlock
(
*
op
))
{
if
(
HasSubBlock
(
*
op
))
{
// create sub_block_dependent_vars here to help prune the sub block
// create sub_block_dependent_vars here to help prune the sub block
std
::
set
<
std
::
string
>
sub_block_dependent_vars
;
std
::
unordered_
set
<
std
::
string
>
sub_block_dependent_vars
;
for
(
auto
&
var
:
op
->
inputs
())
{
for
(
auto
&
var
:
op
->
inputs
())
{
for
(
auto
&
argu
:
var
.
arguments
())
{
for
(
auto
&
argu
:
var
.
arguments
())
{
if
(
feed_var_names
.
count
(
argu
)
==
0
)
{
if
(
feed_var_names
.
count
(
argu
)
==
0
)
{
...
@@ -188,9 +235,139 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
...
@@ -188,9 +235,139 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
void
Prune
(
const
proto
::
ProgramDesc
&
input
,
void
Prune
(
const
proto
::
ProgramDesc
&
input
,
const
std
::
set
<
std
::
string
>&
feed_var_names
,
const
std
::
set
<
std
::
string
>&
feed_var_names
,
proto
::
ProgramDesc
*
output
)
{
proto
::
ProgramDesc
*
output
)
{
std
::
set
<
std
::
string
>
dependent_vars
;
std
::
unordered_
set
<
std
::
string
>
dependent_vars
;
output
->
clear_blocks
();
output
->
clear_blocks
();
prune_impl
(
input
,
output
,
0
,
-
1
,
&
dependent_vars
,
feed_var_names
);
prune_impl
(
input
,
output
,
0
,
-
1
,
&
dependent_vars
,
feed_var_names
);
}
}
void
CloneWholeBlock
(
proto
::
ProgramDesc
*
input
,
proto
::
ProgramDesc
*
output
,
int
block_id
,
int
parent_block_id
)
{
auto
*
block_field
=
output
->
mutable_blocks
();
*
block_field
->
Add
()
=
input
->
blocks
(
block_id
);
int
output_block_id
=
output
->
blocks_size
()
-
1
;
auto
*
output_block
=
output
->
mutable_blocks
(
output_block_id
);
output_block
->
set_idx
(
output_block_id
);
output_block
->
set_parent_idx
(
parent_block_id
);
}
void
PruneBackwardImpl
(
proto
::
ProgramDesc
*
input
,
proto
::
ProgramDesc
*
output
,
int
block_id
,
int
parent_block_id
)
{
// Step 1. Copy the current input block to output
CloneWholeBlock
(
input
,
output
,
block_id
,
parent_block_id
);
int
output_block_id
=
output
->
blocks_size
()
-
1
;
auto
*
output_block
=
output
->
mutable_blocks
(
output_block_id
);
// Step 2. Mark forward ops on main branch
auto
*
ops
=
input
->
mutable_blocks
(
block_id
)
->
mutable_ops
();
std
::
unordered_set
<
std
::
string
>
op_input_vars
;
std
::
unordered_set
<
std
::
string
>
op_output_vars
;
for
(
auto
op_iter
=
ops
->
rbegin
();
op_iter
!=
ops
->
rend
();
++
op_iter
)
{
auto
&
op_desc
=
*
op_iter
;
if
(
HasTrueTarget
(
op_desc
)
||
HasDependentOutputVar
(
op_desc
,
op_input_vars
))
{
op_desc
.
set_is_target
(
true
);
AppendOpInputVarNames
(
op_desc
,
&
op_input_vars
);
AppendOpOutputVarNames
(
op_desc
,
&
op_output_vars
);
}
}
// Step 3. Mark backward & optimize ops on main branch
std
::
unordered_set
<
std
::
string
>
gradop_input_vars
;
std
::
unordered_set
<
std
::
string
>
gradop_output_vars
;
for
(
auto
op_iter
=
ops
->
begin
();
op_iter
!=
ops
->
end
();
++
op_iter
)
{
auto
&
op_desc
=
*
op_iter
;
if
(
HasFalseTarget
(
op_desc
)
||
HasDependentInputVar
(
op_desc
,
gradop_output_vars
))
{
op_desc
.
set_is_target
(
false
);
AppendOpInputVarNames
(
op_desc
,
&
gradop_input_vars
);
AppendOpOutputVarNames
(
op_desc
,
&
gradop_output_vars
);
}
}
// Step 4. Mark ops need to be reserved on sub-branch
for
(
auto
op_iter
=
ops
->
rbegin
();
op_iter
!=
ops
->
rend
();
++
op_iter
)
{
auto
&
op_desc
=
*
op_iter
;
if
(
!
op_desc
.
has_is_target
())
{
if
(
HasDependentOutputVar
(
op_desc
,
gradop_input_vars
))
{
op_desc
.
set_is_target
(
false
);
AppendOpInputVarNames
(
op_desc
,
&
gradop_input_vars
);
}
else
{
op_desc
.
set_is_target
(
true
);
AppendOpInputVarNames
(
op_desc
,
&
op_input_vars
);
AppendOpOutputVarNames
(
op_desc
,
&
op_output_vars
);
}
}
}
// Step 5. Copy the forward ops to new ProgramDesc
// Note: The proto::ProgramDesc doesn't have interface
// to remove op and var
auto
*
op_field
=
output_block
->
mutable_ops
();
op_field
->
Clear
();
for
(
auto
op_iter
=
ops
->
begin
();
op_iter
!=
ops
->
end
();
++
op_iter
)
{
if
(
IsTarget
(
*
op_iter
))
{
auto
*
op
=
op_field
->
Add
();
*
op
=
*
op_iter
;
if
(
HasSubBlock
(
*
op
))
{
CloneWholeBlock
(
input
,
output
,
GetSubBlockIndex
(
*
op
),
output_block_id
);
}
}
}
// Step 6. Copy the forward vars to new ProgramDesc
// construct all var's map before clear
auto
*
var_field
=
output_block
->
mutable_vars
();
std
::
unordered_map
<
std
::
string
,
proto
::
VarDesc
>
var_map
;
for
(
const
auto
&
var
:
*
var_field
)
{
var_map
[
var
.
name
()]
=
var
;
}
std
::
unordered_set
<
std
::
string
>
var_names
;
var_names
.
insert
(
op_input_vars
.
begin
(),
op_input_vars
.
end
());
var_names
.
insert
(
op_output_vars
.
begin
(),
op_output_vars
.
end
());
var_field
->
Clear
();
for
(
const
auto
&
name
:
var_names
)
{
*
var_field
->
Add
()
=
var_map
[
name
];
}
}
std
::
unique_ptr
<
framework
::
ProgramDesc
>
PruneBackward
(
const
framework
::
ProgramDesc
&
origin
)
{
// Copy original ProgramDesc, origin can't be change
framework
::
ProgramDesc
origin_clone
(
origin
);
// Step 1. Update loss op's role & set loss op to be target
// The loss op's op_role is (kForward | kLoss)
// The input ProgramDesc should have loss operator.
auto
ops
=
origin_clone
.
Block
(
0
).
AllOps
();
bool
has_loss_op
=
false
;
for
(
auto
op
:
ops
)
{
int
op_role
=
boost
::
get
<
int
>
(
op
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()));
if
(
op_role
==
(
static_cast
<
int
>
(
OpRole
::
kForward
)
|
static_cast
<
int
>
(
OpRole
::
kLoss
)))
{
op
->
SetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
static_cast
<
int
>
(
OpRole
::
kForward
));
op
->
SetIsTarget
(
true
);
has_loss_op
=
true
;
}
else
if
(
op_role
==
(
static_cast
<
int
>
(
OpRole
::
kBackward
)
|
static_cast
<
int
>
(
OpRole
::
kLoss
)))
{
op
->
SetIsTarget
(
false
);
break
;
}
}
PADDLE_ENFORCE_EQ
(
has_loss_op
,
true
,
"The Program need to be pruned its backward part"
"should have loss operator."
);
// Step 2. Prune backward
proto
::
ProgramDesc
pruned_desc
;
pruned_desc
.
clear_blocks
();
PruneBackwardImpl
(
origin_clone
.
Proto
(),
&
pruned_desc
,
0
,
-
1
);
// Step 3. Contruct new framework::ProgramDesc
return
std
::
unique_ptr
<
framework
::
ProgramDesc
>
(
new
framework
::
ProgramDesc
(
pruned_desc
));
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/prune.h
浏览文件 @
00d5375e
...
@@ -14,9 +14,11 @@ limitations under the License. */
...
@@ -14,9 +14,11 @@ limitations under the License. */
#pragma once
#pragma once
#include <memory>
#include <set>
#include <set>
#include <string>
#include <string>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -26,5 +28,8 @@ void Prune(const proto::ProgramDesc& input,
...
@@ -26,5 +28,8 @@ void Prune(const proto::ProgramDesc& input,
const
std
::
set
<
std
::
string
>&
feed_var_names
,
const
std
::
set
<
std
::
string
>&
feed_var_names
,
proto
::
ProgramDesc
*
output
);
proto
::
ProgramDesc
*
output
);
std
::
unique_ptr
<
framework
::
ProgramDesc
>
PruneBackward
(
const
framework
::
ProgramDesc
&
origin
);
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/pybind/pybind.cc
浏览文件 @
00d5375e
...
@@ -761,6 +761,9 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -761,6 +761,9 @@ All parameter, weight, gradient are variables in Paddle.
Prune
(
*
prog_with_targets
.
Proto
(),
feeded_var_names
,
&
pruned_desc
);
Prune
(
*
prog_with_targets
.
Proto
(),
feeded_var_names
,
&
pruned_desc
);
return
new
ProgramDesc
(
pruned_desc
);
return
new
ProgramDesc
(
pruned_desc
);
});
});
m
.
def
(
"prune_backward"
,
[](
const
framework
::
ProgramDesc
&
program
)
{
return
PruneBackward
(
program
);
});
m
.
def
(
"empty_var_name"
,
m
.
def
(
"empty_var_name"
,
[]()
{
return
std
::
string
(
framework
::
kEmptyVarName
);
});
[]()
{
return
std
::
string
(
framework
::
kEmptyVarName
);
});
m
.
def
(
"grad_var_suffix"
,
m
.
def
(
"grad_var_suffix"
,
...
...
python/paddle/fluid/framework.py
浏览文件 @
00d5375e
...
@@ -3235,9 +3235,13 @@ class Program(object):
...
@@ -3235,9 +3235,13 @@ class Program(object):
"""
"""
if
for_test
:
if
for_test
:
if
self
.
_appending_grad_times
>
0
:
if
self
.
_appending_grad_times
>
0
:
loss_op
=
self
.
_find_loss_op
()
forward_prog
=
Program
()
assert
loss_op
is
not
None
,
"The optimized network should have loss operator."
forward_prog
.
desc
=
core
.
prune_backward
(
self
.
desc
)
forward_prog
=
self
.
_prune
([],
loss_op
)
forward_prog
.
blocks
=
[
Block
(
forward_prog
,
i
)
for
i
in
six
.
moves
.
range
(
forward_prog
.
desc
.
num_blocks
())
]
forward_prog
.
_sync_with_cpp
()
p
=
forward_prog
.
_inference_optimize
(
prune_read_op
=
False
)
p
=
forward_prog
.
_inference_optimize
(
prune_read_op
=
False
)
else
:
else
:
p
=
self
.
_inference_optimize
(
prune_read_op
=
False
)
p
=
self
.
_inference_optimize
(
prune_read_op
=
False
)
...
@@ -3637,16 +3641,6 @@ class Program(object):
...
@@ -3637,16 +3641,6 @@ class Program(object):
for
each_var
in
list
(
each_block
.
vars
.
values
()):
for
each_var
in
list
(
each_block
.
vars
.
values
()):
yield
each_var
yield
each_var
def
_find_loss_op
(
self
):
loss_op
=
None
op_role_key
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
()
forward_loss
=
int
(
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
)
|
int
(
core
.
op_proto_and_checker_maker
.
OpRole
.
Loss
)
for
op
in
self
.
global_block
().
ops
:
if
int
(
op
.
all_attrs
()[
op_role_key
])
==
forward_loss
:
loss_op
=
op
return
loss_op
class
Parameter
(
Variable
):
class
Parameter
(
Variable
):
"""
"""
...
...
python/paddle/fluid/tests/unittests/test_program_prune_backward.py
浏览文件 @
00d5375e
...
@@ -52,6 +52,25 @@ def lstm_net(use_feed):
...
@@ -52,6 +52,25 @@ def lstm_net(use_feed):
return
avg_cost
return
avg_cost
def
simple_fc_net_with_accuracy
(
use_feed
):
img
=
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
784
],
dtype
=
'float32'
)
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
hidden
=
img
for
_
in
range
(
4
):
hidden
=
fluid
.
layers
.
fc
(
hidden
,
size
=
200
,
act
=
'relu'
,
bias_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
value
=
1.0
)))
prediction
=
fluid
.
layers
.
fc
(
hidden
,
size
=
10
,
act
=
'softmax'
)
loss
=
fluid
.
layers
.
cross_entropy
(
input
=
prediction
,
label
=
label
)
loss
=
fluid
.
layers
.
mean
(
loss
)
accuracy_out
=
fluid
.
layers
.
accuracy
(
input
=
prediction
,
label
=
label
,
k
=
5
)
return
loss
class
TestProgramPruneBackward
(
unittest
.
TestCase
):
class
TestProgramPruneBackward
(
unittest
.
TestCase
):
def
program_compare
(
self
,
program_a
,
program_b
):
def
program_compare
(
self
,
program_a
,
program_b
):
assert
isinstance
(
assert
isinstance
(
...
@@ -109,6 +128,21 @@ class TestProgramPruneBackward(unittest.TestCase):
...
@@ -109,6 +128,21 @@ class TestProgramPruneBackward(unittest.TestCase):
"label"
:
label
},
"label"
:
label
},
optimizer
=
optimizer
)
optimizer
=
optimizer
)
def
test_simple_fc_net_with_accuracy
(
self
):
def
optimizer
():
optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.001
,
regularization
=
fluid
.
regularizer
.
L2Decay
(
1e-4
))
return
optimizer
with
self
.
program_scope_guard
():
img
,
label
=
init_data
()
self
.
check_prune_correctness
(
method
=
simple_fc_net_with_accuracy
,
feed_dict
=
{
"image"
:
img
,
"label"
:
label
},
optimizer
=
optimizer
)
def
test_batchnorm_fc
(
self
):
def
test_batchnorm_fc
(
self
):
def
optimizer
():
def
optimizer
():
optimizer
=
fluid
.
optimizer
.
SGD
(
optimizer
=
fluid
.
optimizer
.
SGD
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录