Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
97cb5479
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
97cb5479
编写于
5月 09, 2018
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
change PE strategy
上级
303277f0
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
168 addition
and
45 deletion
+168
-45
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+64
-9
paddle/fluid/framework/details/multi_devices_graph_builder.h
paddle/fluid/framework/details/multi_devices_graph_builder.h
+9
-2
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+6
-5
paddle/fluid/framework/parallel_executor.h
paddle/fluid/framework/parallel_executor.h
+2
-1
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+4
-2
python/paddle/fluid/parallel_executor.py
python/paddle/fluid/parallel_executor.py
+7
-2
python/paddle/fluid/tests/unittests/test_parallel_executor.py
...on/paddle/fluid/tests/unittests/test_parallel_executor.py
+76
-24
未找到文件。
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
97cb5479
...
...
@@ -37,20 +37,26 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
platform
::
NCCLContextMap
*
nccl_ctxs
,
bool
use_default_grad_scale
)
platform
::
NCCLContextMap
*
nccl_ctxs
,
bool
use_default_grad_scale
,
bool
balance_parameter_opt_between_cards
)
:
loss_var_name_
(
loss_var_name
),
places_
(
places
),
local_scopes_
(
local_scopes
),
nccl_ctxs_
(
nccl_ctxs
)
{
nccl_ctxs_
(
nccl_ctxs
),
balance_parameter_opt_between_cards_
(
balance_parameter_opt_between_cards
)
{
#else
MultiDevSSAGraphBuilder
::
MultiDevSSAGraphBuilder
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
bool
use_default_grad_scale
)
const
std
::
vector
<
Scope
*>
&
local_scopes
,
bool
use_default_grad_scale
,
bool
balance_parameter_opt_between_cards
)
:
loss_var_name_
(
loss_var_name
),
places_
(
places
),
local_scopes_
(
local_scopes
)
{
local_scopes_
(
local_scopes
),
balance_parameter_opt_between_cards_
(
balance_parameter_opt_between_cards
)
{
#endif
for
(
auto
&
p
:
params
)
{
grad_names_
.
insert
(
GradVarName
(
p
));
...
...
@@ -124,6 +130,12 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
// Find "send" op first for split is in front of send.
OpDesc
*
send_op
=
GetSendOpDesc
(
program
);
size_t
cur_device_id
=
0
;
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
var_name_on_devices
;
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
bcast_var_name_set
;
var_name_on_devices
.
resize
(
places_
.
size
());
bcast_var_name_set
.
resize
(
places_
.
size
());
bool
is_forwarding
=
true
;
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
if
(
op
->
Type
()
==
"send"
)
{
...
...
@@ -139,17 +151,33 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
}
is_forwarding
=
false
;
}
else
{
CreateComputationalOps
(
&
result
,
*
op
,
places_
.
size
());
int
op_dev_id
=
GetOpDeviceID
(
var_name_on_devices
,
*
op
);
if
(
op_dev_id
==
-
1
)
{
// var on all device
CreateComputationalOps
(
&
result
,
*
op
,
places_
.
size
());
}
else
{
CreateComputationalOp
(
&
result
,
*
op
,
op_dev_id
);
for
(
auto
&
var_name
:
op
->
OutputArgumentNames
())
{
var_name_on_devices
[
op_dev_id
].
emplace
(
var_name
);
}
}
if
(
!
is_forwarding
&&
places_
.
size
()
>
1
)
{
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
for
(
auto
&
og
:
op
->
OutputArgumentNames
())
{
if
(
IsParameterGradientOnce
(
og
,
&
og_has_been_broadcast
))
{
if
(
IsSparseGradient
(
var_types
,
og
))
{
CreateReduceOp
(
&
result
,
og
,
0
);
CreateBroadcastOp
(
&
result
,
og
,
0
);
if
(
balance_parameter_opt_between_cards_
)
{
CreateReduceOp
(
&
result
,
og
,
cur_device_id
);
var_name_on_devices
[
cur_device_id
].
emplace
(
og
);
bcast_var_name_set
[
cur_device_id
].
emplace
(
og
.
substr
(
0
,
og
.
size
()
-
strlen
(
kGradVarSuffix
)));
cur_device_id
=
(
cur_device_id
+
1
)
%
places_
.
size
();
}
else
{
InsertNCCLAllReduceOp
(
&
result
,
og
);
if
(
IsSparseGradient
(
var_types
,
og
))
{
CreateReduceOp
(
&
result
,
og
,
0
);
CreateBroadcastOp
(
&
result
,
og
,
0
);
}
else
{
InsertNCCLAllReduceOp
(
&
result
,
og
);
}
}
}
}
...
...
@@ -157,6 +185,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
}
}
// Insert BCast Ops
for
(
size_t
dev_id
=
0
;
dev_id
<
bcast_var_name_set
.
size
();
++
dev_id
)
{
auto
&
to_bcast_set
=
bcast_var_name_set
[
dev_id
];
for
(
auto
&
bcast_name
:
to_bcast_set
)
{
CreateBroadcastOp
(
&
result
,
bcast_name
,
dev_id
);
}
}
/*
Dependency graph has been constructed. However, there are still data
harzaeds need to be handled.
...
...
@@ -265,6 +300,26 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
return
is_pg_once
;
}
int
MultiDevSSAGraphBuilder
::
GetOpDeviceID
(
const
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
&
var_name_on_devices
,
const
OpDesc
&
op
)
const
{
if
(
!
balance_parameter_opt_between_cards_
)
{
return
-
1
;
}
int
var_dev_id
=
-
1
;
for
(
auto
&
var_name
:
op
.
InputArgumentNames
())
{
if
(
var_dev_id
!=
-
1
)
break
;
for
(
size_t
i
=
0
;
i
<
var_name_on_devices
.
size
();
++
i
)
{
if
(
var_name_on_devices
[
i
].
count
(
var_name
))
{
var_dev_id
=
static_cast
<
int
>
(
i
);
break
;
}
}
}
return
var_dev_id
;
}
void
MultiDevSSAGraphBuilder
::
CreateScaleLossGradOp
(
SSAGraph
*
result
)
const
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
// Insert ScaleCost OpHandle
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.h
浏览文件 @
97cb5479
...
...
@@ -36,13 +36,15 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
platform
::
NCCLContextMap
*
nccl_ctxs
,
bool
use_default_grad_scale
);
bool
use_default_grad_scale
,
bool
balance_parameter_opt_between_cards
);
#else
MultiDevSSAGraphBuilder
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
bool
use_default_grad_scale
);
bool
use_default_grad_scale
,
bool
balance_parameter_opt_between_cards
);
#endif
std
::
unique_ptr
<
SSAGraph
>
Build
(
const
ProgramDesc
&
program
)
const
override
;
...
...
@@ -60,6 +62,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
#ifdef PADDLE_WITH_CUDA
platform
::
NCCLContextMap
*
nccl_ctxs_
;
#endif
bool
balance_parameter_opt_between_cards_
;
bool
use_default_grad_scale_
;
bool
IsScaleLossOp
(
const
OpDesc
&
op
)
const
;
...
...
@@ -84,6 +87,10 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const
std
::
string
&
og
,
std
::
unordered_set
<
std
::
string
>
*
og_has_been_broadcast
)
const
;
int
GetOpDeviceID
(
const
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
&
var_name_on_devices
,
const
OpDesc
&
op
)
const
;
void
InsertNCCLAllReduceOp
(
SSAGraph
*
result
,
const
std
::
string
&
og
)
const
;
void
CreateBroadcastOp
(
SSAGraph
*
result
,
const
std
::
string
&
p_name
,
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
97cb5479
...
...
@@ -58,7 +58,7 @@ ParallelExecutor::ParallelExecutor(
const
std
::
unordered_set
<
std
::
string
>
&
bcast_vars
,
const
ProgramDesc
&
main_program
,
const
std
::
string
&
loss_var_name
,
Scope
*
scope
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
bool
allow_op_delay
,
bool
use_default_grad_scale
)
bool
use_default_grad_scale
,
bool
balance_parameter_opt_between_cards
)
:
member_
(
new
ParallelExecutorPrivate
(
places
))
{
member_
->
global_scope_
=
scope
;
...
...
@@ -93,11 +93,12 @@ ParallelExecutor::ParallelExecutor(
#ifdef PADDLE_WITH_CUDA
details
::
MultiDevSSAGraphBuilder
builder
(
member_
->
places_
,
loss_var_name
,
params
,
member_
->
local_scopes_
,
member_
->
nccl_ctxs_
.
get
(),
use_default_grad_scale
);
member_
->
nccl_ctxs_
.
get
(),
use_default_grad_scale
,
balance_parameter_opt_between_cards
);
#else
details
::
MultiDevSSAGraphBuilder
builder
(
member_
->
places_
,
loss_var_name
,
params
,
member_
->
local_scopes_
,
use_default_grad_scale
);
details
::
MultiDevSSAGraphBuilder
builder
(
member_
->
places_
,
loss_var_name
,
params
,
member_
->
local_scopes_
,
use_default_grad_scale
,
balance_parameter_opt_between_cards
);
#endif
auto
graph
=
builder
.
Build
(
main_program
);
...
...
paddle/fluid/framework/parallel_executor.h
浏览文件 @
97cb5479
...
...
@@ -40,7 +40,8 @@ class ParallelExecutor {
const
ProgramDesc
&
main_program
,
const
std
::
string
&
loss_var_name
,
Scope
*
scope
,
const
std
::
vector
<
Scope
*>&
local_scopes
,
bool
allow_op_delay
,
bool
use_default_grad_scale
);
bool
allow_op_delay
,
bool
use_default_grad_scale
,
bool
balance_parameter_opt_between_cards
);
~
ParallelExecutor
();
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
97cb5479
...
...
@@ -502,11 +502,13 @@ All parameter, weight, gradient are variables in Paddle.
const
std
::
unordered_set
<
std
::
string
>
&
bcast_vars
,
const
ProgramDesc
&
main_program
,
const
std
::
string
&
loss_var_name
,
Scope
*
scope
,
std
::
vector
<
Scope
*>
&
local_scopes
,
bool
allow_op_delay
,
bool
use_default_grad_scale
)
{
bool
allow_op_delay
,
bool
use_default_grad_scale
,
bool
balance_parameter_opt_between_cards
)
{
new
(
&
self
)
ParallelExecutor
(
num_threads
,
use_event
,
places
,
params
,
bcast_vars
,
main_program
,
loss_var_name
,
scope
,
local_scopes
,
allow_op_delay
,
use_default_grad_scale
);
allow_op_delay
,
use_default_grad_scale
,
balance_parameter_opt_between_cards
);
})
.
def
(
"bcast_params"
,
&
ParallelExecutor
::
BCastParamsToGPUs
)
// NOTE: even we return a vec<Scope*>* to Python use reference policy.
...
...
python/paddle/fluid/parallel_executor.py
浏览文件 @
97cb5479
...
...
@@ -30,7 +30,8 @@ class ParallelExecutor(object):
num_threads
=
None
,
allow_op_delay
=
False
,
share_vars_from
=
None
,
use_default_grad_scale
=
True
):
use_default_grad_scale
=
True
,
balance_parameter_opt_between_cards
=
False
):
"""
ParallelExecutor can run program in parallel.
...
...
@@ -51,6 +52,9 @@ class ParallelExecutor(object):
gradients of each device and scaled gradients would be
aggregated. Otherwise, a customized scale value should be fed
to the network.
balance_parameter_opt_between_cards(bool, default True): Whether
updating different gradients on different cards. Currently, it
is not recommended.
Returns:
A ParallelExecutor object.
...
...
@@ -129,7 +133,8 @@ class ParallelExecutor(object):
scope
,
local_scopes
,
allow_op_delay
,
use_default_grad_scale
)
use_default_grad_scale
,
balance_parameter_opt_between_cards
)
self
.
scope
=
scope
...
...
python/paddle/fluid/tests/unittests/test_parallel_executor.py
浏览文件 @
97cb5479
...
...
@@ -205,7 +205,8 @@ class TestParallelExecutorBase(unittest.TestCase):
allow_op_delay
=
False
,
feed_dict
=
None
,
seed
=
None
,
use_parallel_executor
=
True
):
use_parallel_executor
=
True
,
balance_parameter_opt_between_cards
=
False
):
def
run_executor
(
exe
,
feed
,
fetch_list
,
program
=
None
):
if
isinstance
(
exe
,
fluid
.
ParallelExecutor
):
res
=
exe
.
run
(
fetch_list
=
fetch_list
,
feed
=
feed
)
...
...
@@ -234,7 +235,11 @@ class TestParallelExecutorBase(unittest.TestCase):
if
use_parallel_executor
:
exe
=
fluid
.
ParallelExecutor
(
True
,
loss_name
=
loss
.
name
,
allow_op_delay
=
allow_op_delay
)
True
,
loss_name
=
loss
.
name
,
allow_op_delay
=
allow_op_delay
,
balance_parameter_opt_between_cards
=
balance_parameter_opt_between_cards
)
else
:
exe
=
fluid
.
Executor
(
place
=
place
)
...
...
@@ -280,20 +285,27 @@ class TestMNIST(TestParallelExecutorBase):
fluid
.
recordio_writer
.
convert_reader_to_recordio_file
(
'./mnist.recordio'
,
reader
,
feeder
)
def
check_simple_fc_convergence
(
self
):
def
check_simple_fc_convergence
(
self
,
balance_parameter_opt_between_cards
):
self
.
check_network_convergence
(
simple_fc_net
)
self
.
check_network_convergence
(
simple_fc_net
,
allow_op_delay
=
True
)
img
=
np
.
zeros
(
shape
=
[
32
,
784
],
dtype
=
'float32'
)
label
=
np
.
ones
(
shape
=
[
32
,
1
],
dtype
=
'int64'
)
self
.
check_network_convergence
(
simple_fc_net
,
feed_dict
=
{
"image"
:
img
,
"label"
:
label
})
simple_fc_net
,
feed_dict
=
{
"image"
:
img
,
"label"
:
label
},
balance_parameter_opt_between_cards
=
balance_parameter_opt_between_cards
)
def
test_simple_fc
(
self
):
self
.
check_simple_fc_convergence
()
self
.
check_simple_fc_convergence
(
False
)
def
test_simple_fc_with_new_strategy
(
self
):
self
.
check_simple_fc_convergence
(
True
)
def
check_simple_fc_parallel_accuracy
(
self
):
def
check_simple_fc_parallel_accuracy
(
self
,
balance_parameter_opt_between_cards
):
img
=
np
.
zeros
(
shape
=
[
32
,
784
],
dtype
=
'float32'
)
label
=
np
.
ones
(
shape
=
[
32
,
1
],
dtype
=
'int64'
)
single_first_loss
,
single_last_loss
=
self
.
check_network_convergence
(
...
...
@@ -307,7 +319,9 @@ class TestMNIST(TestParallelExecutorBase):
seed
=
1000
,
feed_dict
=
{
"image"
:
img
,
"label"
:
label
},
use_parallel_executor
=
True
)
use_parallel_executor
=
True
,
balance_parameter_opt_between_cards
=
balance_parameter_opt_between_cards
)
for
p_f
in
parallel_first_loss
:
self
.
assertAlmostEquals
(
p_f
,
single_first_loss
[
0
],
delta
=
1e-6
)
...
...
@@ -315,18 +329,28 @@ class TestMNIST(TestParallelExecutorBase):
self
.
assertAlmostEquals
(
p_l
,
single_last_loss
[
0
],
delta
=
1e-6
)
def
test_simple_fc_parallel_accuracy
(
self
):
self
.
check_simple_fc_parallel_accuracy
()
self
.
check_simple_fc_parallel_accuracy
(
False
)
def
check_batchnorm_fc_convergence
(
self
):
def
test_simple_fc_parallel_accuracy_with_new_strategy
(
self
):
self
.
check_simple_fc_parallel_accuracy
(
True
)
def
check_batchnorm_fc_convergence
(
self
,
balance_parameter_opt_between_cards
):
self
.
check_network_convergence
(
fc_with_batchnorm
)
img
=
np
.
zeros
(
shape
=
[
32
,
784
],
dtype
=
'float32'
)
label
=
np
.
ones
(
shape
=
[
32
,
1
],
dtype
=
'int64'
)
self
.
check_network_convergence
(
fc_with_batchnorm
,
feed_dict
=
{
"image"
:
img
,
"label"
:
label
})
fc_with_batchnorm
,
feed_dict
=
{
"image"
:
img
,
"label"
:
label
},
balance_parameter_opt_between_cards
=
balance_parameter_opt_between_cards
)
def
test_batchnorm_fc
(
self
):
self
.
check_batchnorm_fc_convergence
()
self
.
check_batchnorm_fc_convergence
(
False
)
def
test_batchnorm_fc_with_new_strategy
(
self
):
self
.
check_batchnorm_fc_convergence
(
True
)
class
TestResnet
(
TestParallelExecutorBase
):
...
...
@@ -348,17 +372,22 @@ class TestResnet(TestParallelExecutorBase):
# fluid.recordio_writer.convert_reader_to_recordio_file(
# "./flowers.recordio", reader, feeder, compressor=fluid.core.RecordIOWriter.Compressor.NoCompress)
def
check_resnet_convergence
(
self
):
def
check_resnet_convergence
(
self
,
balance_parameter_opt_between_cards
):
import
functools
batch_size
=
2
self
.
check_network_convergence
(
functools
.
partial
(
SE_ResNeXt50Small
,
batch_size
=
batch_size
),
iter
=
20
,
batch_size
=
batch_size
)
batch_size
=
batch_size
,
balance_parameter_opt_between_cards
=
balance_parameter_opt_between_cards
)
def
test_resnet
(
self
):
self
.
check_resnet_convergence
()
self
.
check_resnet_convergence
(
False
)
def
test_resnet_with_new_strategy
(
self
):
self
.
check_resnet_convergence
(
True
)
class
ModelHyperParams
(
object
):
...
...
@@ -519,7 +548,7 @@ class TestTransformer(TestParallelExecutorBase):
class
ParallelExecutorTestingDuringTraining
(
unittest
.
TestCase
):
def
check_network_convergence
(
self
):
def
check_network_convergence
(
self
,
balance_parameter_opt_between_cards
):
main
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main
,
startup
):
...
...
@@ -539,12 +568,18 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase):
feed_dict
=
{
'image'
:
image
,
'label'
:
label
}
train_exe
=
fluid
.
ParallelExecutor
(
use_cuda
=
True
,
loss_name
=
loss
.
name
,
main_program
=
main
)
use_cuda
=
True
,
loss_name
=
loss
.
name
,
main_program
=
main
,
balance_parameter_opt_between_cards
=
balance_parameter_opt_between_cards
)
test_exe
=
fluid
.
ParallelExecutor
(
use_cuda
=
True
,
main_program
=
test_program
,
share_vars_from
=
train_exe
)
share_vars_from
=
train_exe
,
balance_parameter_opt_between_cards
=
balance_parameter_opt_between_cards
)
for
i
in
xrange
(
5
):
test_loss
,
=
test_exe
.
run
([
loss
.
name
],
feed
=
feed_dict
)
...
...
@@ -558,8 +593,11 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase):
"Train loss: "
+
str
(
train_loss
)
+
"
\n
Test loss:"
+
str
(
test_loss
))
def
test_parallel
(
self
):
self
.
check_network_convergence
()
def
test_parallel_testing
(
self
):
self
.
check_network_convergence
(
False
)
def
test_parallel_testing_with_new_strategy
(
self
):
self
.
check_network_convergence
(
True
)
import
paddle.dataset.conll05
as
conll05
...
...
@@ -579,7 +617,7 @@ embedding_name = 'emb'
def
db_lstm
(
word
,
predicate
,
ctx_n2
,
ctx_n1
,
ctx_0
,
ctx_p1
,
ctx_p2
,
mark
,
is_sparse
,
**
ignored
):
is_sparse
,
balance_parameter_opt_between_cards
,
**
ignored
):
# 8 features
predicate_embedding
=
fluid
.
layers
.
embedding
(
input
=
predicate
,
...
...
@@ -648,7 +686,9 @@ def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark,
class
TestCRFModel
(
unittest
.
TestCase
):
def
check_network_convergence
(
self
,
is_sparse
):
def
check_network_convergence
(
self
,
is_sparse
,
balance_parameter_opt_between_cards
=
False
):
main
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main
,
startup
):
...
...
@@ -696,7 +736,11 @@ class TestCRFModel(unittest.TestCase):
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup
)
pe
=
fluid
.
ParallelExecutor
(
use_cuda
=
True
,
loss_name
=
avg_cost
.
name
)
pe
=
fluid
.
ParallelExecutor
(
use_cuda
=
True
,
loss_name
=
avg_cost
.
name
,
balance_parameter_opt_between_cards
=
balance_parameter_opt_between_cards
)
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
...
...
@@ -718,6 +762,14 @@ class TestCRFModel(unittest.TestCase):
def
test_update_dense_parameter
(
self
):
self
.
check_network_convergence
(
is_sparse
=
False
)
def
test_update_sparse_parameter_with_new_strategy
(
self
):
self
.
check_network_convergence
(
is_sparse
=
False
,
balance_parameter_opt_between_cards
=
True
)
def
test_update_dense_parameter_with_new_strategy
(
self
):
self
.
check_network_convergence
(
is_sparse
=
False
,
balance_parameter_opt_between_cards
=
True
)
# test fetch all the variables of global_block
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录