Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
1514eec6
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1514eec6
编写于
11月 30, 2021
作者:
Z
zhaocaibei123
提交者:
GitHub
11月 30, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
pscore global shuffle&default accessor config (#37626)
上级
2f4c089b
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
323 addition
and
95 deletion
+323
-95
paddle/fluid/framework/distributed_strategy.proto
paddle/fluid/framework/distributed_strategy.proto
+56
-25
python/paddle/distributed/fleet/base/distributed_strategy.py
python/paddle/distributed/fleet/base/distributed_strategy.py
+20
-4
python/paddle/distributed/fleet/base/fleet_base.py
python/paddle/distributed/fleet/base/fleet_base.py
+1
-1
python/paddle/distributed/fleet/runtime/the_one_ps.py
python/paddle/distributed/fleet/runtime/the_one_ps.py
+127
-57
python/paddle/fluid/dataset.py
python/paddle/fluid/dataset.py
+24
-6
python/paddle/fluid/tests/unittests/dist_fleet_ctr.py
python/paddle/fluid/tests/unittests/dist_fleet_ctr.py
+52
-1
python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py
python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py
+35
-0
python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py
.../fluid/tests/unittests/test_fleet_distributed_strategy.py
+8
-1
未找到文件。
paddle/fluid/framework/distributed_strategy.proto
浏览文件 @
1514eec6
...
...
@@ -181,7 +181,7 @@ enum TableType {
message
TableParameter
{
optional
uint64
table_id
=
1
;
optional
string
table_class
=
2
;
optional
uint64
shard_num
=
3
;
optional
uint64
shard_num
=
3
[
default
=
1000
]
;
optional
TableType
type
=
4
;
optional
TableAccessorParameter
accessor
=
5
;
}
...
...
@@ -190,42 +190,73 @@ message TableAccessorParameter {
optional
string
accessor_class
=
1
;
optional
SGDParameter
embed_sgd_param
=
2
;
optional
SGDParameter
embedx_sgd_param
=
3
;
optional
uint32
fea_dim
=
4
;
// for sparse table, this means field size of one
// value; for dense table, this means total value
// num
optional
uint32
embedx_dim
=
5
;
// embedx feature size
optional
uint32
embedx_threshold
=
6
;
// embedx feature create threshold
optional
uint32
fea_dim
=
4
[
default
=
11
];
// field size of one value
optional
uint32
embedx_dim
=
5
[
default
=
8
];
// embedx feature size
optional
uint32
embedx_threshold
=
6
[
default
=
10
];
// embedx feature create threshold
optional
CtrAccessorParameter
ctr_accessor_param
=
7
;
repeated
TableAccessorSaveParameter
table_accessor_save_param
=
8
;
}
// TODO(guanqun): add NaiveSGD/Adam...
message
SGDParameter
{
optional
string
name
=
1
;
optional
SGDRuleParameter
adagrad
=
2
;
optional
SparseNaiveSGDRuleParameter
naive
=
2
;
optional
SparseAdagradSGDRuleParameter
adagrad
=
3
;
optional
SparseAdamSGDParameter
adam
=
4
;
}
message
SGDRuleParameter
{
optional
double
learning_rate
=
1
;
optional
double
initial_g2sum
=
2
;
optional
double
initial_range
=
3
[
default
=
0
];
message
SparseNaiveSGDRuleParameter
{
// SparseNaiveSGDRule
optional
double
learning_rate
=
1
[
default
=
0.05
];
optional
double
initial_range
=
2
[
default
=
0.0001
];
repeated
float
weight_bounds
=
3
;
}
message
SparseAdagradSGDRuleParameter
{
// SparseAdaGradSGDRule|StdAdaGradSGDRule
optional
double
learning_rate
=
1
[
default
=
0.05
];
optional
double
initial_g2sum
=
2
[
default
=
3.0
];
optional
double
initial_range
=
3
[
default
=
0.0001
];
repeated
float
weight_bounds
=
4
;
}
message
SparseAdamSGDParameter
{
// SparseAdamSGDRule
optional
double
learning_rate
=
1
[
default
=
0.001
];
optional
double
initial_range
=
2
[
default
=
0.0001
];
optional
double
beta1_decay_rate
=
3
[
default
=
0.9
];
optional
double
beta2_decay_rate
=
4
[
default
=
0.999
];
optional
double
ada_epsilon
=
5
[
default
=
1e-08
];
repeated
float
weight_bounds
=
6
;
}
message
CtrAccessorParameter
{
optional
float
nonclk_coeff
=
1
;
// to calculate show_click_score
optional
float
click_coeff
=
2
;
// to calculate show_click_score
optional
float
base_threshold
=
3
;
// show_click_score > base_threshold, this feature can be saved
optional
float
delta_threshold
=
4
;
// delta_score > delta_threshold, this feature can be saved
optional
float
delta_keep_days
=
5
;
// unseen_day < delta_keep_days, this feature can be saved
optional
float
show_click_decay_rate
=
6
;
// show/click will update to
optional
float
nonclk_coeff
=
1
[
default
=
0.1
];
// to calculate show_click_score
optional
float
click_coeff
=
2
[
default
=
1
];
// to calculate show_click_score
optional
float
base_threshold
=
3
[
default
=
1.5
];
// show_click_score > base_threshold, this feature can be saved
optional
float
delta_threshold
=
4
[
default
=
0.25
];
// delta_score > delta_threshold, this feature can be saved
optional
float
delta_keep_days
=
5
[
default
=
16
];
// unseen_day < delta_keep_days, this feature can be saved
optional
float
show_click_decay_rate
=
6
[
default
=
0.98
];
// show/click will update to
// show/click *
// show_click_decay_rate after a day
optional
float
delete_threshold
=
7
;
// threshold to shrink a feasign
optional
float
delete_after_unseen_days
=
8
;
optional
int32
ssd_unseenday_threshold
=
9
;
optional
float
delete_threshold
=
7
[
default
=
0.8
];
// threshold to shrink a feasign
optional
float
delete_after_unseen_days
=
8
[
default
=
30
];
optional
int32
ssd_unseenday_threshold
=
9
[
default
=
1
];
}
message
TableAccessorSaveParameter
{
optional
uint32
param
=
1
;
optional
string
converter
=
2
;
optional
string
deconverter
=
3
;
}
message
FsClientParameter
{
...
...
python/paddle/distributed/fleet/base/distributed_strategy.py
浏览文件 @
1514eec6
...
...
@@ -470,21 +470,37 @@ class DistributedStrategy(object):
from
google.protobuf.descriptor
import
FieldDescriptor
table_param
=
self
.
strategy
.
downpour_table_param
def
set_table_config
(
msg
,
config_name
,
configs
):
def
set_table_config
(
msg
,
config_name
,
configs
,
index
=
0
):
for
field
in
msg
.
DESCRIPTOR
.
fields
:
name
=
config_name
+
"."
+
field
.
name
if
field
.
type
==
FieldDescriptor
.
TYPE_MESSAGE
:
print
(
"message:"
,
name
)
set_table_config
(
getattr
(
msg
,
field
.
name
),
name
,
configs
)
if
field
.
label
==
FieldDescriptor
.
LABEL_REPEATED
:
if
name
+
".num"
not
in
configs
:
continue
num
=
configs
[
name
+
".num"
]
print
(
"message num:"
,
name
,
num
)
for
i
in
range
(
num
):
data
=
getattr
(
msg
,
field
.
name
).
add
()
set_table_config
(
data
,
name
,
configs
,
i
)
else
:
set_table_config
(
getattr
(
msg
,
field
.
name
),
name
,
configs
)
else
:
print
(
"not message:"
,
name
)
if
name
not
in
configs
:
continue
if
field
.
label
==
FieldDescriptor
.
LABEL_REPEATED
:
getattr
(
msg
,
field
.
name
).
extend
(
configs
[
name
])
else
:
if
type
(
configs
[
name
])
==
list
:
setattr
(
msg
,
field
.
name
,
configs
[
name
][
index
])
else
:
setattr
(
msg
,
field
.
name
,
configs
[
name
])
if
not
configs
:
print
(
"table configs is empty"
)
else
:
set_table_config
(
table_param
,
"table_parameters"
,
configs
)
@
property
...
...
python/paddle/distributed/fleet/base/fleet_base.py
浏览文件 @
1514eec6
...
...
@@ -823,7 +823,7 @@ class Fleet(object):
self
.
_runtime_handle
.
_save_persistables
(
executor
,
dirname
,
main_program
,
mode
)
def
shrink
(
self
,
threshold
):
def
shrink
(
self
,
threshold
=
None
):
self
.
_runtime_handle
.
_shrink
(
threshold
)
def
distributed_optimizer
(
self
,
optimizer
,
strategy
=
None
):
...
...
python/paddle/distributed/fleet/runtime/the_one_ps.py
浏览文件 @
1514eec6
...
...
@@ -24,7 +24,6 @@ from paddle.fluid.parallel_executor import ParallelExecutor
from
paddle.fluid.framework
import
Variable
,
Parameter
from
.runtime_base
import
RuntimeBase
from
..base.private_helper_function
import
wait_server_ready
import
paddle.distributed.fleet
as
fleet
__all__
=
[]
...
...
@@ -53,6 +52,70 @@ def parse_table_class(varname, o_main_program):
return
"MemorySparseTable"
def
get_default_accessor_proto
(
accessor
,
varname
,
o_main_program
):
embedding_dim
=
0
for
var
in
o_main_program
.
list_vars
():
if
var
.
name
==
varname
:
print
(
"var:"
,
var
)
print
(
"var.shape:"
,
var
.
shape
)
embedding_dim
=
var
.
shape
[
1
]
print
(
"sparse dim:"
,
embedding_dim
)
break
accessor
.
accessor_class
=
"CtrCommonAccessor"
accessor
.
fea_dim
=
embedding_dim
+
2
accessor
.
embedx_dim
=
embedding_dim
-
1
accessor
.
embedx_threshold
=
0
ctr_accessor_param
=
accessor
.
ctr_accessor_param
ctr_accessor_param
.
nonclk_coeff
=
0.1
ctr_accessor_param
.
click_coeff
=
1.0
ctr_accessor_param
.
base_threshold
=
0
ctr_accessor_param
.
delta_threshold
=
0
ctr_accessor_param
.
delta_keep_days
=
16
ctr_accessor_param
.
show_click_decay_rate
=
1
ctr_accessor_param
.
delete_threshold
=
0
ctr_accessor_param
.
delete_after_unseen_days
=
30
ctr_accessor_param
.
ssd_unseenday_threshold
=
1
embed_sgd_param
=
accessor
.
embed_sgd_param
embed_sgd_param
.
name
=
"SparseAdaGradSGDRule"
embed_sgd_param
.
adagrad
.
learning_rate
=
0.05
embed_sgd_param
.
adagrad
.
initial_g2sum
=
3.0
embed_sgd_param
.
adagrad
.
initial_range
=
0.0001
embed_sgd_param
.
adagrad
.
weight_bounds
.
append
(
-
10.0
)
embed_sgd_param
.
adagrad
.
weight_bounds
.
append
(
10.0
)
embedx_sgd_param
=
accessor
.
embedx_sgd_param
embedx_sgd_param
.
name
=
"SparseAdaGradSGDRule"
embedx_sgd_param
.
adagrad
.
learning_rate
=
0.05
embedx_sgd_param
.
adagrad
.
initial_g2sum
=
3.0
embedx_sgd_param
.
adagrad
.
initial_range
=
0.0001
embedx_sgd_param
.
adagrad
.
weight_bounds
.
append
(
-
10.0
)
embedx_sgd_param
.
adagrad
.
weight_bounds
.
append
(
10.0
)
def
check_embedding_dim
(
accessor
,
varname
,
o_main_program
):
embedding_dim
=
0
for
var
in
o_main_program
.
list_vars
():
if
var
.
name
==
varname
:
print
(
"var:"
,
var
)
print
(
"var.shape:"
,
var
.
shape
)
embedding_dim
=
var
.
shape
[
1
]
print
(
"sparse dim:"
,
embedding_dim
)
break
fea_dim
=
accessor
.
fea_dim
if
fea_dim
!=
embedding_dim
+
2
:
raise
ValueError
(
"The fea_dim is wrong, it will be sparse_embedding_dim + 2: {}, but got {}"
.
format
(
embedding_dim
+
2
,
fea_dim
))
embedx_dim
=
accessor
.
embedx_dim
if
embedx_dim
!=
embedding_dim
-
1
:
raise
ValueError
(
"The embedx_dim is wrong, it will be sparse_embedding_dim - 1: {}, but got {}"
.
format
(
embedding_dim
-
1
,
embedx_dim
))
class
Accessor
:
def
__init__
(
self
):
self
.
accessor_class
=
""
...
...
@@ -344,6 +407,11 @@ class Table:
self
.
accessor_proto
=
None
def
to_string
(
self
,
indent
):
# if self.id == 1:
# proto_txt = ''
# with open('./sparse_table.prototxt') as f:
# proto_txt = f.read()
# return proto_txt
table_str
=
"{}downpour_table_param {{{}
\n
{}}}"
attrs
=
""
...
...
@@ -586,6 +654,8 @@ class TheOnePSRuntime(RuntimeBase):
return
kwargs
proto_txt
=
str
(
worker
)
+
"
\n
"
+
str
(
server
)
with
open
(
'proto_txt'
,
'w'
)
as
f
:
f
.
write
(
proto_txt
)
debug
=
bool
(
int
(
os
.
getenv
(
"PSERVER_DEBUG"
,
"0"
)))
...
...
@@ -846,55 +916,55 @@ class TheOnePSRuntime(RuntimeBase):
if
self
.
compiled_strategy
.
is_geo_mode
():
table
.
table_class
=
"SparseGeoTable"
else
:
import
copy
table_proto
=
copy
.
deepcopy
(
self
.
context
[
"user_defined_strategy"
].
sparse_table_configs
)
print
(
'table proto:'
,
table_proto
)
print
(
'table_class:'
,
table_proto
.
table_class
)
print
(
'shard_num:'
,
table_proto
.
shard_num
)
print
(
'table_proto.accessor:'
,
table_proto
.
accessor
)
print
(
'accessor.IsInitialized'
,
table_proto
.
accessor
.
IsInitialized
())
print
(
'accessor.ByteSize'
,
table_proto
.
accessor
.
ByteSize
())
if
table_proto
.
table_class
:
print
(
'table_proto.table_class is true'
)
table
.
table_class
=
table_proto
.
table_class
else
:
table
.
table_class
=
parse_table_class
(
common
.
table_name
,
self
.
origin_main_program
)
table_proto
=
self
.
context
[
"user_defined_strategy"
].
sparse_table_configs
if
table
.
table_class
!=
'MemorySparseTable'
:
table
.
table_class
=
'MemorySparseTable'
warnings
.
warn
(
"The PS mode must use MemorySparseTable."
)
if
table_proto
.
shard_num
:
print
(
'table_proto.shard_num is true'
)
table
.
shard_num
=
table_proto
.
shard_num
else
:
table
.
shard_num
=
1000
warnings
.
warn
(
"The shard_num of sparse table is not set, use default value 1000."
)
if
table_proto
.
accessor
.
ByteSize
()
==
0
:
print
(
'table_proto.accessor is false'
)
get_default_accessor_proto
(
table_proto
.
accessor
,
common
.
table_name
,
self
.
origin_main_program
)
warnings
.
warn
(
"The accessor of sparse table is not set, use default value."
)
check_embedding_dim
(
table_proto
.
accessor
,
common
.
table_name
,
self
.
origin_main_program
)
print
(
'accessor.ByteSize'
,
table_proto
.
accessor
.
ByteSize
())
from
google.protobuf
import
text_format
table
.
accessor_proto
=
text_format
.
MessageToString
(
table_proto
.
accessor
)
print
(
'table proto:'
,
table_proto
)
if
table
.
table_class
==
'MemorySparseTable'
and
table
.
accessor_proto
==
''
:
emb_dim
=
ctx
.
sections
()[
1
]
table
.
shard_num
=
1950
table
.
accessor_proto
=
'accessor_class: "CtrCommonAccessor"
\n
'
\
'embed_sgd_param {
\n
'
\
' name: "SparseAdaGradSGDRule"
\n
'
\
' adagrad {
\n
'
\
' learning_rate: 0.05
\n
'
\
' initial_g2sum: 3.0
\n
'
\
' initial_range: 0.0001
\n
'
\
' weight_bounds: -10.0
\n
'
\
' weight_bounds: 10.0
\n
'
\
' }
\n
'
\
'}
\n
'
\
'embedx_sgd_param {
\n
'
\
' name: "SparseAdaGradSGDRule"
\n
'
\
' adagrad {
\n
'
\
' learning_rate: 0.05
\n
'
\
' initial_g2sum: 3.0
\n
'
\
' initial_range: 0.0001
\n
'
\
' weight_bounds: -10.0
\n
'
\
' weight_bounds: 10.0
\n
'
\
' }
\n
'
\
'}
\n
'
\
'fea_dim: '
+
str
(
emb_dim
+
2
)
+
'
\n
'
\
'embedx_dim: '
+
str
(
emb_dim
-
1
)
+
'
\n
'
\
'embedx_threshold: 10
\n
'
\
'ctr_accessor_param {
\n
'
\
' nonclk_coeff: 0.1
\n
'
\
' click_coeff: 1.0
\n
'
\
' base_threshold: 1.5
\n
'
\
' delta_threshold: 0.25
\n
'
\
' delta_keep_days: 16.0
\n
'
\
' show_click_decay_rate: 0.98
\n
'
\
' delete_threshold: 0.8
\n
'
\
' delete_after_unseen_days: 30.0
\n
'
\
' ssd_unseenday_threshold: 1
\n
'
\
'}'
print
(
"the_one_ps table_proto:"
,
table
.
accessor_proto
)
else
:
table
.
type
=
"PS_DENSE_TABLE"
table
.
table_class
=
"CommonDenseTable"
...
...
@@ -916,7 +986,6 @@ class TheOnePSRuntime(RuntimeBase):
common
.
sync
=
"true"
else
:
common
.
sync
=
"false"
table
.
common
=
common
if
table
.
table_class
!=
'MemorySparseTable'
:
...
...
@@ -1108,8 +1177,6 @@ class TheOnePSRuntime(RuntimeBase):
TheOnePSRuntime
.
__exclude_vars
(
saved_varnames
),
main_program
.
list_vars
()))
self
.
_communicator
.
pull_dense
(
denses
)
import
paddle
for
var
in
remaining_vars
:
# if var.name not in recv_dense_varnames:
...
...
@@ -1209,9 +1276,8 @@ class TheOnePSRuntime(RuntimeBase):
split_dense_table
=
self
.
role_maker
.
_is_heter_parameter_server_mode
,
use_origin_program
=
True
)
print
(
"the one ps sparses:"
,
sparses
)
sparse_names
=
[]
for
id
,
name
in
sparses
.
items
():
sparse_names
.
extend
(
name
)
sparse_names
=
self
.
_save_sparse_params
(
executor
,
dirname
,
sparses
,
main_program
,
mode
)
print
(
"the one ps sparse names:"
,
sparse_names
)
denses
=
self
.
compiled_strategy
.
get_the_one_recv_context
(
...
...
@@ -1225,7 +1291,7 @@ class TheOnePSRuntime(RuntimeBase):
generate_vars
=
[
var
for
var
in
generate_vars
]
remaining_vars
=
list
(
filter
(
TheOnePSRuntime
.
__exclude_vars
(
generate_vars
+
sparse_names
),
TheOnePSRuntime
.
__exclude_vars
(
sparse_names
),
infer_program
.
list_vars
()))
print
(
"remain_vars:"
,
[
var
.
name
for
var
in
remaining_vars
])
for
var
in
remaining_vars
:
...
...
@@ -1235,9 +1301,6 @@ class TheOnePSRuntime(RuntimeBase):
os
.
path
.
join
(
model_path
,
var
.
name
),
use_binary_format
=
True
)
self
.
_ps_inference_save_persistables
(
executor
,
dirname
,
infer_program
,
mode
)
def
_save_inference_model
(
self
,
*
args
,
**
kwargs
):
self
.
_ps_inference_save_inference_model
(
*
args
,
**
kwargs
)
...
...
@@ -1314,8 +1377,15 @@ class TheOnePSRuntime(RuntimeBase):
self
.
_load_distributed_persistables
(
path
,
mode
)
else
:
self
.
_ps_inference_load_inference_model
(
path
,
mode
)
# self._load_distributed_persistables(path, mode=mode)
def
_shrink
(
self
,
threshold
):
def
_shrink
(
self
,
threshold
=
None
):
if
threshold
is
not
None
:
warnings
.
warn
(
"The param threshold is not used in MemorySparseTable, if you need to shrink, please set the config of accessor"
)
else
:
threshold
=
0
import
paddle.distributed.fleet
as
fleet
fleet
.
util
.
barrier
()
if
self
.
role_maker
.
_is_first_worker
():
...
...
python/paddle/fluid/dataset.py
浏览文件 @
1514eec6
...
...
@@ -862,7 +862,11 @@ class InMemoryDataset(DatasetBase):
thread_num(int): shuffle thread num. Default is 12.
"""
from
paddle.fluid.incubate.fleet.parameter_server.pslib
import
PSLib
if
fleet
is
not
None
:
if
not
isinstance
(
fleet
,
PSLib
):
fleet
.
barrier_worker
()
else
:
fleet
.
_role_maker
.
barrier_worker
()
if
self
.
trainer_num
==
-
1
:
self
.
trainer_num
=
fleet
.
worker_num
()
...
...
@@ -875,13 +879,22 @@ class InMemoryDataset(DatasetBase):
self
.
dataset
.
set_fleet_send_batch_size
(
self
.
fleet_send_batch_size
)
self
.
dataset
.
set_fleet_send_sleep_seconds
(
self
.
fleet_send_sleep_seconds
)
if
fleet
is
not
None
:
if
not
isinstance
(
fleet
,
PSLib
):
fleet
.
barrier_worker
()
else
:
fleet
.
_role_maker
.
barrier_worker
()
self
.
dataset
.
global_shuffle
(
thread_num
)
if
fleet
is
not
None
:
if
not
isinstance
(
fleet
,
PSLib
):
fleet
.
barrier_worker
()
else
:
fleet
.
_role_maker
.
barrier_worker
()
if
self
.
merge_by_lineid
:
self
.
dataset
.
merge_by_lineid
()
if
fleet
is
not
None
:
if
not
isinstance
(
fleet
,
PSLib
):
fleet
.
barrier_worker
()
else
:
fleet
.
_role_maker
.
barrier_worker
()
@
deprecated
(
...
...
@@ -1011,8 +1024,13 @@ class InMemoryDataset(DatasetBase):
import
numpy
as
np
local_data_size
=
self
.
dataset
.
get_shuffle_data_size
()
local_data_size
=
np
.
array
([
local_data_size
])
print
(
'global shuffle local_data_size: '
,
local_data_size
)
if
fleet
is
not
None
:
from
paddle.fluid.incubate.fleet.parameter_server.pslib
import
PSLib
global_data_size
=
local_data_size
*
0
if
not
isinstance
(
fleet
,
PSLib
):
global_data_size
=
fleet
.
util
.
all_reduce
(
local_data_size
)
else
:
fleet
.
_role_maker
.
all_reduce_worker
(
local_data_size
,
global_data_size
)
return
global_data_size
[
0
]
...
...
python/paddle/fluid/tests/unittests/dist_fleet_ctr.py
浏览文件 @
1514eec6
...
...
@@ -241,7 +241,7 @@ class TestDistCTR2x2(FleetDistRunnerBase):
self
.
check_model_right
(
model_dir
)
shutil
.
rmtree
(
model_dir
)
def
do_dataset_training
(
self
,
fleet
):
def
do_dataset_training
_queuedataset
(
self
,
fleet
):
train_file_list
=
ctr_dataset_reader
.
prepare_fake_data
()
exe
=
self
.
get_executor
()
...
...
@@ -288,5 +288,56 @@ class TestDistCTR2x2(FleetDistRunnerBase):
if
dirname
:
fleet
.
save_persistables
(
exe
,
dirname
=
dirname
)
def
do_dataset_training
(
self
,
fleet
):
train_file_list
=
ctr_dataset_reader
.
prepare_fake_data
()
exe
=
self
.
get_executor
()
exe
.
run
(
fluid
.
default_startup_program
())
fleet
.
init_worker
()
thread_num
=
2
batch_size
=
128
filelist
=
train_file_list
# config dataset
dataset
=
fluid
.
DatasetFactory
().
create_dataset
(
"InMemoryDataset"
)
dataset
.
set_use_var
(
self
.
feeds
)
dataset
.
set_batch_size
(
128
)
dataset
.
set_thread
(
2
)
dataset
.
set_filelist
(
filelist
)
dataset
.
set_pipe_command
(
'python ctr_dataset_reader.py'
)
dataset
.
load_into_memory
()
dataset
.
global_shuffle
(
fleet
,
12
)
##TODO: thread configure
shuffle_data_size
=
dataset
.
get_shuffle_data_size
(
fleet
)
local_data_size
=
dataset
.
get_shuffle_data_size
()
data_size_list
=
fleet
.
util
.
all_gather
(
local_data_size
)
print
(
'after global_shuffle data_size_list: '
,
data_size_list
)
print
(
'after global_shuffle data_size: '
,
shuffle_data_size
)
for
epoch_id
in
range
(
1
):
pass_start
=
time
.
time
()
exe
.
train_from_dataset
(
program
=
fluid
.
default_main_program
(),
dataset
=
dataset
,
fetch_list
=
[
self
.
avg_cost
],
fetch_info
=
[
"cost"
],
print_period
=
2
,
debug
=
int
(
os
.
getenv
(
"Debug"
,
"0"
)))
pass_time
=
time
.
time
()
-
pass_start
dataset
.
release_memory
()
if
os
.
getenv
(
"SAVE_MODEL"
)
==
"1"
:
model_dir
=
tempfile
.
mkdtemp
()
fleet
.
save_inference_model
(
exe
,
model_dir
,
[
feed
.
name
for
feed
in
self
.
feeds
],
self
.
avg_cost
)
self
.
check_model_right
(
model_dir
)
shutil
.
rmtree
(
model_dir
)
dirname
=
os
.
getenv
(
"SAVE_DIRNAME"
,
None
)
if
dirname
:
fleet
.
save_persistables
(
exe
,
dirname
=
dirname
)
if
__name__
==
"__main__"
:
runtime_main
(
TestDistCTR2x2
)
python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py
浏览文件 @
1514eec6
...
...
@@ -20,6 +20,41 @@ import tempfile
from
test_dist_fleet_base
import
TestFleetBase
class
TestDistMnistAsyncInMemoryDataset2x2
(
TestFleetBase
):
def
_setup_config
(
self
):
self
.
_mode
=
"async"
#self._reader = "pyreader"
self
.
_reader
=
"dataset"
def
check_with_place
(
self
,
model_file
,
delta
=
1e-3
,
check_error_log
=
False
,
need_envs
=
{}):
required_envs
=
{
"PATH"
:
os
.
getenv
(
"PATH"
,
""
),
"PYTHONPATH"
:
os
.
getenv
(
"PYTHONPATH"
,
""
),
"LD_LIBRARY_PATH"
:
os
.
getenv
(
"LD_LIBRARY_PATH"
,
""
),
"FLAGS_rpc_deadline"
:
"5000"
,
# 5sec to fail fast
"http_proxy"
:
""
,
"CPU_NUM"
:
"2"
,
"LOG_DIRNAME"
:
"/tmp"
,
"LOG_PREFIX"
:
self
.
__class__
.
__name__
,
}
required_envs
.
update
(
need_envs
)
if
check_error_log
:
required_envs
[
"GLOG_v"
]
=
"3"
required_envs
[
"GLOG_logtostderr"
]
=
"1"
tr0_losses
,
tr1_losses
=
self
.
_run_cluster
(
model_file
,
required_envs
)
def
test_dist_train
(
self
):
self
.
check_with_place
(
"dist_fleet_ctr.py"
,
delta
=
1e-5
,
check_error_log
=
False
)
class
TestDistMnistAsync2x2
(
TestFleetBase
):
def
_setup_config
(
self
):
self
.
_mode
=
"async"
...
...
python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py
浏览文件 @
1514eec6
...
...
@@ -259,11 +259,18 @@ class TestStrategyConfig(unittest.TestCase):
strategy
=
paddle
.
distributed
.
fleet
.
DistributedStrategy
()
configs
=
{
"table_parameters.accessor.embed_sgd_param.adagrad.learning_rate"
:
0.05
0.05
,
"table_parameters.accessor.table_accessor_save_param.num"
:
2
,
"table_parameters.accessor.table_accessor_save_param.param"
:
[
1
,
2
]
}
strategy
.
sparse_table_configs
=
configs
self
.
assertEqual
(
strategy
.
sparse_table_configs
.
accessor
.
embed_sgd_param
.
adagrad
.
learning_rate
,
0.05
)
self
.
assertEqual
(
strategy
.
sparse_table_configs
.
accessor
.
table_accessor_save_param
[
0
].
param
,
1
)
strategy
.
adam_d2sum
=
True
self
.
assertEqual
(
strategy
.
adam_d2sum
,
True
)
strategy
.
fs_client_param
=
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录