Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
00d5375e
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
00d5375e
编写于
9月 16, 2019
作者:
C
Chen Weihang
提交者:
GitHub
9月 16, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add prune_backward function to cover complicated test_program.clone situation (#19772)
上级
99c78b77
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
232 addition
and
19 deletion
+232
-19
paddle/fluid/framework/prune.cc
paddle/fluid/framework/prune.cc
+183
-6
paddle/fluid/framework/prune.h
paddle/fluid/framework/prune.h
+5
-0
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+3
-0
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+7
-13
python/paddle/fluid/tests/unittests/test_program_prune_backward.py
...ddle/fluid/tests/unittests/test_program_prune_backward.py
+34
-0
未找到文件。
paddle/fluid/framework/prune.cc
浏览文件 @
00d5375e
...
@@ -17,19 +17,40 @@ limitations under the License. */
...
@@ -17,19 +17,40 @@ limitations under the License. */
#include <glog/logging.h>
#include <glog/logging.h>
#include <algorithm>
#include <algorithm>
#include <memory>
#include <set>
#include <set>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/program_desc.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
const
char
kFeedOpType
[]
=
"feed"
;
const
char
kFeedOpType
[]
=
"feed"
;
const
char
kFetchOpType
[]
=
"fetch"
;
const
char
kFetchOpType
[]
=
"fetch"
;
bool
HasDependentVar
(
const
proto
::
OpDesc
&
op_desc
,
bool
HasDependentInputVar
(
const
std
::
set
<
std
::
string
>&
dependent_vars
)
{
const
proto
::
OpDesc
&
op_desc
,
const
std
::
unordered_set
<
std
::
string
>&
dependent_vars
)
{
for
(
auto
&
var
:
op_desc
.
inputs
())
{
for
(
auto
&
argu
:
var
.
arguments
())
{
if
(
dependent_vars
.
count
(
argu
)
!=
0
)
{
return
true
;
}
}
}
return
false
;
}
bool
HasDependentOutputVar
(
const
proto
::
OpDesc
&
op_desc
,
const
std
::
unordered_set
<
std
::
string
>&
dependent_vars
)
{
for
(
auto
&
var
:
op_desc
.
outputs
())
{
for
(
auto
&
var
:
op_desc
.
outputs
())
{
for
(
auto
&
argu
:
var
.
arguments
())
{
for
(
auto
&
argu
:
var
.
arguments
())
{
if
(
dependent_vars
.
count
(
argu
)
!=
0
)
{
if
(
dependent_vars
.
count
(
argu
)
!=
0
)
{
...
@@ -47,6 +68,14 @@ bool IsTarget(const proto::OpDesc& op_desc) {
...
@@ -47,6 +68,14 @@ bool IsTarget(const proto::OpDesc& op_desc) {
return
false
;
return
false
;
}
}
bool
HasTrueTarget
(
const
proto
::
OpDesc
&
op_desc
)
{
return
op_desc
.
has_is_target
()
&&
op_desc
.
is_target
();
}
bool
HasFalseTarget
(
const
proto
::
OpDesc
&
op_desc
)
{
return
op_desc
.
has_is_target
()
&&
!
op_desc
.
is_target
();
}
int
GetSubBlockIndex
(
const
proto
::
OpDesc
&
op_desc
)
{
int
GetSubBlockIndex
(
const
proto
::
OpDesc
&
op_desc
)
{
for
(
auto
&
attr
:
op_desc
.
attrs
())
{
for
(
auto
&
attr
:
op_desc
.
attrs
())
{
if
(
attr
.
type
()
==
proto
::
AttrType
::
BLOCK
)
{
if
(
attr
.
type
()
==
proto
::
AttrType
::
BLOCK
)
{
...
@@ -61,6 +90,24 @@ bool HasSubBlock(const proto::OpDesc& op_desc) {
...
@@ -61,6 +90,24 @@ bool HasSubBlock(const proto::OpDesc& op_desc) {
return
GetSubBlockIndex
(
op_desc
)
>
0
;
return
GetSubBlockIndex
(
op_desc
)
>
0
;
}
}
void
AppendOpInputVarNames
(
const
proto
::
OpDesc
&
op_desc
,
std
::
unordered_set
<
std
::
string
>*
vars_set
)
{
for
(
auto
&
var
:
op_desc
.
inputs
())
{
for
(
auto
&
arg
:
var
.
arguments
())
{
vars_set
->
emplace
(
arg
);
}
}
}
void
AppendOpOutputVarNames
(
const
proto
::
OpDesc
&
op_desc
,
std
::
unordered_set
<
std
::
string
>*
vars_set
)
{
for
(
auto
&
var
:
op_desc
.
outputs
())
{
for
(
auto
&
arg
:
var
.
arguments
())
{
vars_set
->
emplace
(
arg
);
}
}
}
// block_id is the idx of the current block in the input desc
// block_id is the idx of the current block in the input desc
// parent_block_id is the idx of the parent of the current block
// parent_block_id is the idx of the parent of the current block
// in the output desc, -1 means the current block is global block
// in the output desc, -1 means the current block is global block
...
@@ -68,7 +115,7 @@ bool HasSubBlock(const proto::OpDesc& op_desc) {
...
@@ -68,7 +115,7 @@ bool HasSubBlock(const proto::OpDesc& op_desc) {
// the child block to help pruning
// the child block to help pruning
void
prune_impl
(
const
proto
::
ProgramDesc
&
input
,
proto
::
ProgramDesc
*
output
,
void
prune_impl
(
const
proto
::
ProgramDesc
&
input
,
proto
::
ProgramDesc
*
output
,
int
block_id
,
int
parent_block_id
,
int
block_id
,
int
parent_block_id
,
std
::
set
<
std
::
string
>*
dependent_vars
,
std
::
unordered_
set
<
std
::
string
>*
dependent_vars
,
const
std
::
set
<
std
::
string
>
feed_var_names
)
{
const
std
::
set
<
std
::
string
>
feed_var_names
)
{
auto
&
block
=
input
.
blocks
(
block_id
);
auto
&
block
=
input
.
blocks
(
block_id
);
auto
&
ops
=
block
.
ops
();
auto
&
ops
=
block
.
ops
();
...
@@ -91,7 +138,7 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
...
@@ -91,7 +138,7 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
std
::
vector
<
bool
>
should_run
;
std
::
vector
<
bool
>
should_run
;
for
(
auto
op_iter
=
ops
.
rbegin
();
op_iter
!=
ops
.
rend
();
++
op_iter
)
{
for
(
auto
op_iter
=
ops
.
rbegin
();
op_iter
!=
ops
.
rend
();
++
op_iter
)
{
auto
&
op_desc
=
*
op_iter
;
auto
&
op_desc
=
*
op_iter
;
if
(
IsTarget
(
op_desc
)
||
HasDependentVar
(
op_desc
,
*
dependent_vars
))
{
if
(
IsTarget
(
op_desc
)
||
HasDependent
Output
Var
(
op_desc
,
*
dependent_vars
))
{
// insert its input to the dependency graph
// insert its input to the dependency graph
for
(
auto
&
var
:
op_desc
.
inputs
())
{
for
(
auto
&
var
:
op_desc
.
inputs
())
{
for
(
auto
&
argu
:
var
.
arguments
())
{
for
(
auto
&
argu
:
var
.
arguments
())
{
...
@@ -127,7 +174,7 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
...
@@ -127,7 +174,7 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
*
op
=
input
.
blocks
(
block_id
).
ops
(
i
);
*
op
=
input
.
blocks
(
block_id
).
ops
(
i
);
if
(
HasSubBlock
(
*
op
))
{
if
(
HasSubBlock
(
*
op
))
{
// create sub_block_dependent_vars here to help prune the sub block
// create sub_block_dependent_vars here to help prune the sub block
std
::
set
<
std
::
string
>
sub_block_dependent_vars
;
std
::
unordered_
set
<
std
::
string
>
sub_block_dependent_vars
;
for
(
auto
&
var
:
op
->
inputs
())
{
for
(
auto
&
var
:
op
->
inputs
())
{
for
(
auto
&
argu
:
var
.
arguments
())
{
for
(
auto
&
argu
:
var
.
arguments
())
{
if
(
feed_var_names
.
count
(
argu
)
==
0
)
{
if
(
feed_var_names
.
count
(
argu
)
==
0
)
{
...
@@ -188,9 +235,139 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
...
@@ -188,9 +235,139 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
void
Prune
(
const
proto
::
ProgramDesc
&
input
,
void
Prune
(
const
proto
::
ProgramDesc
&
input
,
const
std
::
set
<
std
::
string
>&
feed_var_names
,
const
std
::
set
<
std
::
string
>&
feed_var_names
,
proto
::
ProgramDesc
*
output
)
{
proto
::
ProgramDesc
*
output
)
{
std
::
set
<
std
::
string
>
dependent_vars
;
std
::
unordered_
set
<
std
::
string
>
dependent_vars
;
output
->
clear_blocks
();
output
->
clear_blocks
();
prune_impl
(
input
,
output
,
0
,
-
1
,
&
dependent_vars
,
feed_var_names
);
prune_impl
(
input
,
output
,
0
,
-
1
,
&
dependent_vars
,
feed_var_names
);
}
}
void
CloneWholeBlock
(
proto
::
ProgramDesc
*
input
,
proto
::
ProgramDesc
*
output
,
int
block_id
,
int
parent_block_id
)
{
auto
*
block_field
=
output
->
mutable_blocks
();
*
block_field
->
Add
()
=
input
->
blocks
(
block_id
);
int
output_block_id
=
output
->
blocks_size
()
-
1
;
auto
*
output_block
=
output
->
mutable_blocks
(
output_block_id
);
output_block
->
set_idx
(
output_block_id
);
output_block
->
set_parent_idx
(
parent_block_id
);
}
void
PruneBackwardImpl
(
proto
::
ProgramDesc
*
input
,
proto
::
ProgramDesc
*
output
,
int
block_id
,
int
parent_block_id
)
{
// Step 1. Copy the current input block to output
CloneWholeBlock
(
input
,
output
,
block_id
,
parent_block_id
);
int
output_block_id
=
output
->
blocks_size
()
-
1
;
auto
*
output_block
=
output
->
mutable_blocks
(
output_block_id
);
// Step 2. Mark forward ops on main branch
auto
*
ops
=
input
->
mutable_blocks
(
block_id
)
->
mutable_ops
();
std
::
unordered_set
<
std
::
string
>
op_input_vars
;
std
::
unordered_set
<
std
::
string
>
op_output_vars
;
for
(
auto
op_iter
=
ops
->
rbegin
();
op_iter
!=
ops
->
rend
();
++
op_iter
)
{
auto
&
op_desc
=
*
op_iter
;
if
(
HasTrueTarget
(
op_desc
)
||
HasDependentOutputVar
(
op_desc
,
op_input_vars
))
{
op_desc
.
set_is_target
(
true
);
AppendOpInputVarNames
(
op_desc
,
&
op_input_vars
);
AppendOpOutputVarNames
(
op_desc
,
&
op_output_vars
);
}
}
// Step 3. Mark backward & optimize ops on main branch
std
::
unordered_set
<
std
::
string
>
gradop_input_vars
;
std
::
unordered_set
<
std
::
string
>
gradop_output_vars
;
for
(
auto
op_iter
=
ops
->
begin
();
op_iter
!=
ops
->
end
();
++
op_iter
)
{
auto
&
op_desc
=
*
op_iter
;
if
(
HasFalseTarget
(
op_desc
)
||
HasDependentInputVar
(
op_desc
,
gradop_output_vars
))
{
op_desc
.
set_is_target
(
false
);
AppendOpInputVarNames
(
op_desc
,
&
gradop_input_vars
);
AppendOpOutputVarNames
(
op_desc
,
&
gradop_output_vars
);
}
}
// Step 4. Mark ops need to be reserved on sub-branch
for
(
auto
op_iter
=
ops
->
rbegin
();
op_iter
!=
ops
->
rend
();
++
op_iter
)
{
auto
&
op_desc
=
*
op_iter
;
if
(
!
op_desc
.
has_is_target
())
{
if
(
HasDependentOutputVar
(
op_desc
,
gradop_input_vars
))
{
op_desc
.
set_is_target
(
false
);
AppendOpInputVarNames
(
op_desc
,
&
gradop_input_vars
);
}
else
{
op_desc
.
set_is_target
(
true
);
AppendOpInputVarNames
(
op_desc
,
&
op_input_vars
);
AppendOpOutputVarNames
(
op_desc
,
&
op_output_vars
);
}
}
}
// Step 5. Copy the forward ops to new ProgramDesc
// Note: The proto::ProgramDesc doesn't have interface
// to remove op and var
auto
*
op_field
=
output_block
->
mutable_ops
();
op_field
->
Clear
();
for
(
auto
op_iter
=
ops
->
begin
();
op_iter
!=
ops
->
end
();
++
op_iter
)
{
if
(
IsTarget
(
*
op_iter
))
{
auto
*
op
=
op_field
->
Add
();
*
op
=
*
op_iter
;
if
(
HasSubBlock
(
*
op
))
{
CloneWholeBlock
(
input
,
output
,
GetSubBlockIndex
(
*
op
),
output_block_id
);
}
}
}
// Step 6. Copy the forward vars to new ProgramDesc
// construct all var's map before clear
auto
*
var_field
=
output_block
->
mutable_vars
();
std
::
unordered_map
<
std
::
string
,
proto
::
VarDesc
>
var_map
;
for
(
const
auto
&
var
:
*
var_field
)
{
var_map
[
var
.
name
()]
=
var
;
}
std
::
unordered_set
<
std
::
string
>
var_names
;
var_names
.
insert
(
op_input_vars
.
begin
(),
op_input_vars
.
end
());
var_names
.
insert
(
op_output_vars
.
begin
(),
op_output_vars
.
end
());
var_field
->
Clear
();
for
(
const
auto
&
name
:
var_names
)
{
*
var_field
->
Add
()
=
var_map
[
name
];
}
}
std
::
unique_ptr
<
framework
::
ProgramDesc
>
PruneBackward
(
const
framework
::
ProgramDesc
&
origin
)
{
// Copy original ProgramDesc, origin can't be change
framework
::
ProgramDesc
origin_clone
(
origin
);
// Step 1. Update loss op's role & set loss op to be target
// The loss op's op_role is (kForward | kLoss)
// The input ProgramDesc should have loss operator.
auto
ops
=
origin_clone
.
Block
(
0
).
AllOps
();
bool
has_loss_op
=
false
;
for
(
auto
op
:
ops
)
{
int
op_role
=
boost
::
get
<
int
>
(
op
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()));
if
(
op_role
==
(
static_cast
<
int
>
(
OpRole
::
kForward
)
|
static_cast
<
int
>
(
OpRole
::
kLoss
)))
{
op
->
SetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
static_cast
<
int
>
(
OpRole
::
kForward
));
op
->
SetIsTarget
(
true
);
has_loss_op
=
true
;
}
else
if
(
op_role
==
(
static_cast
<
int
>
(
OpRole
::
kBackward
)
|
static_cast
<
int
>
(
OpRole
::
kLoss
)))
{
op
->
SetIsTarget
(
false
);
break
;
}
}
PADDLE_ENFORCE_EQ
(
has_loss_op
,
true
,
"The Program need to be pruned its backward part"
"should have loss operator."
);
// Step 2. Prune backward
proto
::
ProgramDesc
pruned_desc
;
pruned_desc
.
clear_blocks
();
PruneBackwardImpl
(
origin_clone
.
Proto
(),
&
pruned_desc
,
0
,
-
1
);
// Step 3. Contruct new framework::ProgramDesc
return
std
::
unique_ptr
<
framework
::
ProgramDesc
>
(
new
framework
::
ProgramDesc
(
pruned_desc
));
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/prune.h
浏览文件 @
00d5375e
...
@@ -14,9 +14,11 @@ limitations under the License. */
...
@@ -14,9 +14,11 @@ limitations under the License. */
#pragma once
#pragma once
#include <memory>
#include <set>
#include <set>
#include <string>
#include <string>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -26,5 +28,8 @@ void Prune(const proto::ProgramDesc& input,
...
@@ -26,5 +28,8 @@ void Prune(const proto::ProgramDesc& input,
const
std
::
set
<
std
::
string
>&
feed_var_names
,
const
std
::
set
<
std
::
string
>&
feed_var_names
,
proto
::
ProgramDesc
*
output
);
proto
::
ProgramDesc
*
output
);
std
::
unique_ptr
<
framework
::
ProgramDesc
>
PruneBackward
(
const
framework
::
ProgramDesc
&
origin
);
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/pybind/pybind.cc
浏览文件 @
00d5375e
...
@@ -761,6 +761,9 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -761,6 +761,9 @@ All parameter, weight, gradient are variables in Paddle.
Prune
(
*
prog_with_targets
.
Proto
(),
feeded_var_names
,
&
pruned_desc
);
Prune
(
*
prog_with_targets
.
Proto
(),
feeded_var_names
,
&
pruned_desc
);
return
new
ProgramDesc
(
pruned_desc
);
return
new
ProgramDesc
(
pruned_desc
);
});
});
m
.
def
(
"prune_backward"
,
[](
const
framework
::
ProgramDesc
&
program
)
{
return
PruneBackward
(
program
);
});
m
.
def
(
"empty_var_name"
,
m
.
def
(
"empty_var_name"
,
[]()
{
return
std
::
string
(
framework
::
kEmptyVarName
);
});
[]()
{
return
std
::
string
(
framework
::
kEmptyVarName
);
});
m
.
def
(
"grad_var_suffix"
,
m
.
def
(
"grad_var_suffix"
,
...
...
python/paddle/fluid/framework.py
浏览文件 @
00d5375e
...
@@ -3235,9 +3235,13 @@ class Program(object):
...
@@ -3235,9 +3235,13 @@ class Program(object):
"""
"""
if
for_test
:
if
for_test
:
if
self
.
_appending_grad_times
>
0
:
if
self
.
_appending_grad_times
>
0
:
loss_op
=
self
.
_find_loss_op
()
forward_prog
=
Program
()
assert
loss_op
is
not
None
,
"The optimized network should have loss operator."
forward_prog
.
desc
=
core
.
prune_backward
(
self
.
desc
)
forward_prog
=
self
.
_prune
([],
loss_op
)
forward_prog
.
blocks
=
[
Block
(
forward_prog
,
i
)
for
i
in
six
.
moves
.
range
(
forward_prog
.
desc
.
num_blocks
())
]
forward_prog
.
_sync_with_cpp
()
p
=
forward_prog
.
_inference_optimize
(
prune_read_op
=
False
)
p
=
forward_prog
.
_inference_optimize
(
prune_read_op
=
False
)
else
:
else
:
p
=
self
.
_inference_optimize
(
prune_read_op
=
False
)
p
=
self
.
_inference_optimize
(
prune_read_op
=
False
)
...
@@ -3637,16 +3641,6 @@ class Program(object):
...
@@ -3637,16 +3641,6 @@ class Program(object):
for
each_var
in
list
(
each_block
.
vars
.
values
()):
for
each_var
in
list
(
each_block
.
vars
.
values
()):
yield
each_var
yield
each_var
def
_find_loss_op
(
self
):
loss_op
=
None
op_role_key
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
()
forward_loss
=
int
(
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
)
|
int
(
core
.
op_proto_and_checker_maker
.
OpRole
.
Loss
)
for
op
in
self
.
global_block
().
ops
:
if
int
(
op
.
all_attrs
()[
op_role_key
])
==
forward_loss
:
loss_op
=
op
return
loss_op
class
Parameter
(
Variable
):
class
Parameter
(
Variable
):
"""
"""
...
...
python/paddle/fluid/tests/unittests/test_program_prune_backward.py
浏览文件 @
00d5375e
...
@@ -52,6 +52,25 @@ def lstm_net(use_feed):
...
@@ -52,6 +52,25 @@ def lstm_net(use_feed):
return
avg_cost
return
avg_cost
def
simple_fc_net_with_accuracy
(
use_feed
):
img
=
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
784
],
dtype
=
'float32'
)
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
hidden
=
img
for
_
in
range
(
4
):
hidden
=
fluid
.
layers
.
fc
(
hidden
,
size
=
200
,
act
=
'relu'
,
bias_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
value
=
1.0
)))
prediction
=
fluid
.
layers
.
fc
(
hidden
,
size
=
10
,
act
=
'softmax'
)
loss
=
fluid
.
layers
.
cross_entropy
(
input
=
prediction
,
label
=
label
)
loss
=
fluid
.
layers
.
mean
(
loss
)
accuracy_out
=
fluid
.
layers
.
accuracy
(
input
=
prediction
,
label
=
label
,
k
=
5
)
return
loss
class
TestProgramPruneBackward
(
unittest
.
TestCase
):
class
TestProgramPruneBackward
(
unittest
.
TestCase
):
def
program_compare
(
self
,
program_a
,
program_b
):
def
program_compare
(
self
,
program_a
,
program_b
):
assert
isinstance
(
assert
isinstance
(
...
@@ -109,6 +128,21 @@ class TestProgramPruneBackward(unittest.TestCase):
...
@@ -109,6 +128,21 @@ class TestProgramPruneBackward(unittest.TestCase):
"label"
:
label
},
"label"
:
label
},
optimizer
=
optimizer
)
optimizer
=
optimizer
)
def
test_simple_fc_net_with_accuracy
(
self
):
def
optimizer
():
optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.001
,
regularization
=
fluid
.
regularizer
.
L2Decay
(
1e-4
))
return
optimizer
with
self
.
program_scope_guard
():
img
,
label
=
init_data
()
self
.
check_prune_correctness
(
method
=
simple_fc_net_with_accuracy
,
feed_dict
=
{
"image"
:
img
,
"label"
:
label
},
optimizer
=
optimizer
)
def
test_batchnorm_fc
(
self
):
def
test_batchnorm_fc
(
self
):
def
optimizer
():
def
optimizer
():
optimizer
=
fluid
.
optimizer
.
SGD
(
optimizer
=
fluid
.
optimizer
.
SGD
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录