Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
caa6b596
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
caa6b596
编写于
12月 14, 2018
作者:
H
heqiaozhi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add hdfs_utils & helper & node doc
上级
37596000
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
238 addition
and
72 deletion
+238
-72
python/paddle/fluid/contrib/utils/hdfs_utils.py
python/paddle/fluid/contrib/utils/hdfs_utils.py
+116
-47
python/paddle/fluid/distributed/helper.py
python/paddle/fluid/distributed/helper.py
+29
-5
python/paddle/fluid/distributed/node.py
python/paddle/fluid/distributed/node.py
+93
-20
未找到文件。
python/paddle/fluid/contrib/utils/hdfs_utils.py
浏览文件 @
caa6b596
...
@@ -32,6 +32,28 @@ _logger.setLevel(logging.INFO)
...
@@ -32,6 +32,28 @@ _logger.setLevel(logging.INFO)
class
HDFSClient
(
object
):
class
HDFSClient
(
object
):
"""
A tool of HDFS
Args:
hadoop_home (string): hadoop_home
configs (dict): hadoop config, it is a dict, please contain
\
key "fs.default.name" and "hadoop.job.ugi"
Can be a float value
Examples:
hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = {
"fs.default.name": "hdfs://xxx.hadoop.com:54310",
"hadoop.job.ugi": "hello,hello123"
}
client = HDFSClient(hadoop_home, configs)
client.ls("/user/com/train-25")
files = client.lsr("/user/com/train-25/models")
"""
def
__init__
(
self
,
hadoop_home
,
configs
):
def
__init__
(
self
,
hadoop_home
,
configs
):
self
.
pre_commands
=
[]
self
.
pre_commands
=
[]
hadoop_bin
=
'%s/bin/hadoop'
%
hadoop_home
hadoop_bin
=
'%s/bin/hadoop'
%
hadoop_home
...
@@ -55,7 +77,10 @@ class HDFSClient(object):
...
@@ -55,7 +77,10 @@ class HDFSClient(object):
whole_commands
=
" "
.
join
(
whole_commands
)
whole_commands
=
" "
.
join
(
whole_commands
)
for
x
in
range
(
retry_times
+
1
):
for
x
in
range
(
retry_times
+
1
):
proc
=
subprocess
.
Popen
(
proc
=
subprocess
.
Popen
(
whole_commands
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
shell
=
True
)
whole_commands
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
shell
=
True
)
(
output
,
errors
)
=
proc
.
communicate
()
(
output
,
errors
)
=
proc
.
communicate
()
ret_code
,
ret_out
,
ret_err
=
proc
.
returncode
,
output
,
errors
ret_code
,
ret_out
,
ret_err
=
proc
.
returncode
,
output
,
errors
if
ret_code
:
if
ret_code
:
...
@@ -69,10 +94,12 @@ class HDFSClient(object):
...
@@ -69,10 +94,12 @@ class HDFSClient(object):
def
upload
(
self
,
hdfs_path
,
local_path
,
overwrite
=
False
,
retry_times
=
5
):
def
upload
(
self
,
hdfs_path
,
local_path
,
overwrite
=
False
,
retry_times
=
5
):
"""
"""
upload the local file to hdfs
upload the local file to hdfs
args:
Args:
local_file_path: the local file path
hdfs_path: hdfs path, target path
remote_file_path: default value(${OUTPUT_PATH}/${SYS_USER_ID}/${SYS_JOB_ID}/tmp)
local_path: local file path, source path
return:
overwrite: will overwrite the original file
retry_times: max times retry to upload
Returns:
True or False
True or False
"""
"""
assert
hdfs_path
is
not
None
assert
hdfs_path
is
not
None
...
@@ -115,10 +142,12 @@ class HDFSClient(object):
...
@@ -115,10 +142,12 @@ class HDFSClient(object):
def
download
(
self
,
hdfs_path
,
local_path
,
overwrite
=
False
,
unzip
=
False
):
def
download
(
self
,
hdfs_path
,
local_path
,
overwrite
=
False
,
unzip
=
False
):
"""
"""
download from hdfs
download from hdfs
args:
Args:
local_file_path: the local file path
hdfs_path: hdfs path, target path
remote_file_path: remote dir on hdfs
local_path: local file path, source path
return:
overwrite: will remove original file and overwrite it.
unzip: ignore this param
Returns
True or False
True or False
"""
"""
_logger
.
info
(
'Downloading %r to %r.'
,
hdfs_path
,
local_path
)
_logger
.
info
(
'Downloading %r to %r.'
,
hdfs_path
,
local_path
)
...
@@ -160,11 +189,11 @@ class HDFSClient(object):
...
@@ -160,11 +189,11 @@ class HDFSClient(object):
def
is_exist
(
self
,
hdfs_path
=
None
):
def
is_exist
(
self
,
hdfs_path
=
None
):
"""
"""
whether the remote hdfs path exists?
whether the remote hdfs path exists?
a
rgs:
A
rgs:
remote_file
_path: default value(${OUTPUT_PATH}/${SYS_USER_ID}/${SYS_JOB_ID}/tmp)
hdfs
_path: default value(${OUTPUT_PATH}/${SYS_USER_ID}/${SYS_JOB_ID}/tmp)
fs_name: The default values are the same as in the job configuration
fs_name: The default values are the same as in the job configuration
fs_ugi: The default values are the same as in the job configuration
fs_ugi: The default values are the same as in the job configuration
return
:
Returns
:
True or False
True or False
"""
"""
exist_cmd
=
[
'-test'
,
'-e'
,
hdfs_path
]
exist_cmd
=
[
'-test'
,
'-e'
,
hdfs_path
]
...
@@ -183,11 +212,11 @@ class HDFSClient(object):
...
@@ -183,11 +212,11 @@ class HDFSClient(object):
def
is_dir
(
self
,
hdfs_path
=
None
):
def
is_dir
(
self
,
hdfs_path
=
None
):
"""
"""
whether the remote hdfs path exists?
whether the remote hdfs path exists?
a
rgs:
A
rgs:
remote_file_path: default value(${OUTPUT_PATH}/${SYS_USER_ID}/${SYS_JOB_ID}/tmp)
remote_file_path: default value(${OUTPUT_PATH}/${SYS_USER_ID}/${SYS_JOB_ID}/tmp)
fs_name: The default values are the same as in the job configuration
fs_name: The default values are the same as in the job configuration
fs_ugi: The default values are the same as in the job configuration
fs_ugi: The default values are the same as in the job configuration
return
:
Returns
:
True or False
True or False
"""
"""
...
@@ -207,15 +236,17 @@ class HDFSClient(object):
...
@@ -207,15 +236,17 @@ class HDFSClient(object):
return
True
return
True
def
delete
(
self
,
hdfs_path
):
def
delete
(
self
,
hdfs_path
):
"""Remove a file or directory from HDFS.
"""
Remove a file or directory from HDFS.
:param hdfs_path: HDFS path.
:param recursive: Recursively delete files and directories. By default,
this method will raise an :class:`HdfsError` if trying to delete a
non-empty directory.
This function returns `True` if the deletion was successful and `False` if
Args:
no file or directory previously existed at `hdfs_path`.
param hdfs_path: HDFS path.
param recursive: Recursively delete files and directories. By default,
this method will raise an :class:`HdfsError` if trying to delete a
non-empty directory.
Returns:
This function returns `True` if the deletion was successful and `False` if
no file or directory previously existed at `hdfs_path`.
"""
"""
_logger
.
info
(
'Deleting %r.'
,
hdfs_path
)
_logger
.
info
(
'Deleting %r.'
,
hdfs_path
)
...
@@ -241,14 +272,17 @@ class HDFSClient(object):
...
@@ -241,14 +272,17 @@ class HDFSClient(object):
return
True
return
True
def
rename
(
self
,
hdfs_src_path
,
hdfs_dst_path
,
overwrite
=
False
):
def
rename
(
self
,
hdfs_src_path
,
hdfs_dst_path
,
overwrite
=
False
):
"""Move a file or folder.
"""
Rename a file or folder.
:param hdfs_src_path: Source path.
Args:
:param hdfs_dst_path: Destination path. If the path already exists and is
:param hdfs_src_path: Source path.
a directory, the source will be moved into it. If the path exists and is
:param hdfs_dst_path: Destination path. If the path already exists and is
a file, or if a parent destination directory is missing, this method will
a directory, the source will be moved into it. If the path exists and is
raise an :class:`HdfsError`.
a file, or if a parent destination directory is missing, this method will
raise an :class:`HdfsError`.
Returns:
This function returns `True` if the rename was successful and `False` if
rename was faild.
"""
"""
assert
hdfs_src_path
is
not
None
assert
hdfs_src_path
is
not
None
assert
hdfs_dst_path
is
not
None
assert
hdfs_dst_path
is
not
None
...
@@ -274,6 +308,11 @@ class HDFSClient(object):
...
@@ -274,6 +308,11 @@ class HDFSClient(object):
@
staticmethod
@
staticmethod
def
make_local_dirs
(
local_path
):
def
make_local_dirs
(
local_path
):
"""
create a directiory local, is same to mkdir
Args:
local_path: local path that wants to create a directiory.
"""
try
:
try
:
os
.
makedirs
(
local_path
)
os
.
makedirs
(
local_path
)
except
OSError
as
e
:
except
OSError
as
e
:
...
@@ -282,9 +321,11 @@ class HDFSClient(object):
...
@@ -282,9 +321,11 @@ class HDFSClient(object):
def
makedirs
(
self
,
hdfs_path
):
def
makedirs
(
self
,
hdfs_path
):
"""Create a remote directory, recursively if necessary.
"""Create a remote directory, recursively if necessary.
Args:
:param hdfs_path: Remote path. Intermediate directories will be created
:param hdfs_path: Remote path. Intermediate directories will be created
appropriately.
appropriately.
Returns:
True if make a directories was successful, False when make a directiries was failed.
"""
"""
_logger
.
info
(
'Creating directories to %r.'
,
hdfs_path
)
_logger
.
info
(
'Creating directories to %r.'
,
hdfs_path
)
assert
hdfs_path
is
not
None
assert
hdfs_path
is
not
None
...
@@ -304,6 +345,13 @@ class HDFSClient(object):
...
@@ -304,6 +345,13 @@ class HDFSClient(object):
return
True
return
True
def
ls
(
self
,
hdfs_path
):
def
ls
(
self
,
hdfs_path
):
"""
ls a hdfs_path.
Args:
:param hdfs_path: hdfs_path will be ls.
Returns:
This function returns a `list` that contaion all files in the hdfs_path.
"""
assert
hdfs_path
is
not
None
assert
hdfs_path
is
not
None
if
not
self
.
is_exist
(
hdfs_path
):
if
not
self
.
is_exist
(
hdfs_path
):
...
@@ -329,6 +377,14 @@ class HDFSClient(object):
...
@@ -329,6 +377,14 @@ class HDFSClient(object):
return
ret_lines
return
ret_lines
def
lsr
(
self
,
hdfs_path
,
only_file
=
True
,
sort
=
True
):
def
lsr
(
self
,
hdfs_path
,
only_file
=
True
,
sort
=
True
):
"""
ls a hdfs_path sort by time.
Args:
:param hdfs_path: hdfs_path will be ls.
Returns:
This function returns a `list` that contaion all files sorted by time in the hdfs_path.
"""
def
sort_by_time
(
v1
,
v2
):
def
sort_by_time
(
v1
,
v2
):
v1_time
=
datetime
.
strptime
(
v1
[
1
],
'%Y-%m-%d %H:%M'
)
v1_time
=
datetime
.
strptime
(
v1
[
1
],
'%Y-%m-%d %H:%M'
)
v2_time
=
datetime
.
strptime
(
v2
[
1
],
'%Y-%m-%d %H:%M'
)
v2_time
=
datetime
.
strptime
(
v2
[
1
],
'%Y-%m-%d %H:%M'
)
...
@@ -372,12 +428,15 @@ def multi_upload(client,
...
@@ -372,12 +428,15 @@ def multi_upload(client,
multi_processes
=
5
,
multi_processes
=
5
,
overwrite
=
False
):
overwrite
=
False
):
"""
"""
:param overwrite: will overwrite hdfs file or not
Upload file to hdfs.
:param multi_processes: the upload data process at the same time, default=5
Args:
:param client: instance of HDFSClient
:param overwrite: will overwrite hdfs file or not
:param hdfs_path: path on hdfs
:param multi_processes: the upload data process at the same time, default=5
:param local_path: path on local
:param client: instance of HDFSClient
:return:
:param hdfs_path: path on hdfs
:param local_path: path on local
Returns:
"""
"""
def
__subprocess_upload
(
datas
):
def
__subprocess_upload
(
datas
):
...
@@ -387,6 +446,13 @@ def multi_upload(client,
...
@@ -387,6 +446,13 @@ def multi_upload(client,
client
.
upload
(
hdfs_re_path
,
data
,
overwrite
,
retry_times
=
5
)
client
.
upload
(
hdfs_re_path
,
data
,
overwrite
,
retry_times
=
5
)
def
get_local_files
(
path
):
def
get_local_files
(
path
):
"""
Get all local files
Args:
path: local file path
Returns:
A list that contation all files in the path.
"""
rlist
=
[]
rlist
=
[]
if
not
os
.
path
.
isdir
(
path
):
if
not
os
.
path
.
isdir
(
path
):
...
@@ -431,14 +497,17 @@ def multi_download(client,
...
@@ -431,14 +497,17 @@ def multi_download(client,
multi_processes
=
5
):
multi_processes
=
5
):
"""
"""
multi_download
multi_download
:param client: instance of HDFSClient
Args:
:param hdfs_path: path on hdfs
:param client: instance of HDFSClient
:param local_path: path on local
:param hdfs_path: path on hdfs
:param trainer_id: current trainer id
:param local_path: path on local
:param trainers: all trainers number
:param trainer_id: current trainer id
:param file_cnt: all file number
:param trainers: all trainers number
:param multi_processes: the download data process at the same time, default=5
:param file_cnt: all file number
:return: None
:param multi_processes: the download data process at the same time, default=5
:return: None
Returns:
A list that be downloaded.
"""
"""
def
__subprocess_download
(
datas
):
def
__subprocess_download
(
datas
):
...
...
python/paddle/fluid/distributed/helper.py
浏览文件 @
caa6b596
...
@@ -15,13 +15,26 @@
...
@@ -15,13 +15,26 @@
from
mpi4py
import
MPI
from
mpi4py
import
MPI
import
ps_pb2
as
pslib
import
ps_pb2
as
pslib
class
FileSystem
(
object
):
class
FileSystem
(
object
):
def
__init__
(
self
,
fs_type
=
"afs"
,
"""
A file system that support async_executor hadoop client desc.
Args:
fs_type (string): fs_type, for example is "afs"
user (string): hadoop param
passwd (string): hadoop param
hadoop bin (string): hadoop param
Examples:
fs = FileSystm()
"""
def
__init__
(
self
,
fs_type
=
"afs"
,
uri
=
"afs://tianqi.afs.baidu.com:9902"
,
uri
=
"afs://tianqi.afs.baidu.com:9902"
,
user
=
None
,
user
=
None
,
passwd
=
None
,
passwd
=
None
,
hadoop_bin
=
""
,
hadoop_bin
=
""
):
afs_conf
=
None
):
assert
user
!=
None
assert
user
!=
None
assert
passwd
!=
None
assert
passwd
!=
None
assert
hadoop_bin
!=
None
assert
hadoop_bin
!=
None
...
@@ -38,9 +51,22 @@ class FileSystem(object):
...
@@ -38,9 +51,22 @@ class FileSystem(object):
#self.fs_client.afs_conf = afs_conf if not afs_conf else ""
#self.fs_client.afs_conf = afs_conf if not afs_conf else ""
def
get_desc
(
self
):
def
get_desc
(
self
):
"""
get hadoop desc.
"""
return
self
.
fs_client
return
self
.
fs_client
class
MPIHelper
(
object
):
class
MPIHelper
(
object
):
"""
MPIHelper is a wrapper of mpi4py, supprot get_rank get_size etc.
Args:
No params
Examples:
mh = MPIHelper()
mh.get_ip()
"""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
comm
=
MPI
.
COMM_WORLD
self
.
comm
=
MPI
.
COMM_WORLD
...
@@ -61,5 +87,3 @@ class MPIHelper(object):
...
@@ -61,5 +87,3 @@ class MPIHelper(object):
def
finalize
(
self
):
def
finalize
(
self
):
MPI
.
Finalize
()
MPI
.
Finalize
()
python/paddle/fluid/distributed/node.py
浏览文件 @
caa6b596
...
@@ -13,17 +13,34 @@
...
@@ -13,17 +13,34 @@
import
ps_pb2
as
pslib
import
ps_pb2
as
pslib
class
Server
(
object
):
class
Server
(
object
):
"""
A Server basic class.
"""
def
__init__
(
self
):
def
__init__
(
self
):
pass
pass
class
Worker
(
object
):
class
Worker
(
object
):
"""
A Worker basic class.
"""
def
__init__
(
self
):
def
__init__
(
self
):
pass
pass
class
DownpourServer
(
Server
):
class
DownpourServer
(
Server
):
"""
DownpourServer class is used to generate server program_desc
Args:
server: it is pslib.ServerParameter()
Examples:
server = DownpourServer()
"""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
server_
=
pslib
.
ServerParameter
()
self
.
server_
=
pslib
.
ServerParameter
()
self
.
server_
.
downpour_server_param
.
service_param
.
start_server_port
=
0
self
.
server_
.
downpour_server_param
.
service_param
.
start_server_port
=
0
...
@@ -33,8 +50,18 @@ class DownpourServer(Server):
...
@@ -33,8 +50,18 @@ class DownpourServer(Server):
self
.
server_
.
downpour_server_param
.
service_param
.
start_server_port
=
0
self
.
server_
.
downpour_server_param
.
service_param
.
start_server_port
=
0
self
.
server_
.
downpour_server_param
.
service_param
.
server_thread_num
=
12
self
.
server_
.
downpour_server_param
.
service_param
.
server_thread_num
=
12
def
add_sparse_table
(
self
,
table_id
,
learning_rate
,
def
add_sparse_table
(
self
,
table_id
,
learning_rate
,
slot_key_vars
,
slot_key_vars
,
slot_value_var
):
slot_value_var
):
"""
Args:
table_id(int): id of sparse params table
learning_rate(float): the learning rate used to update parameters.
\
Can be a float value
slot_key_vars(string): slot key id
slot_value_var(string): slot key value after embedding
Returns:
return None
"""
table
=
self
.
server_
.
downpour_server_param
.
downpour_table_param
.
add
()
table
=
self
.
server_
.
downpour_server_param
.
downpour_table_param
.
add
()
table
.
table_id
=
table_id
table
.
table_id
=
table_id
table
.
table_class
=
"DownpourSparseTable"
table
.
table_class
=
"DownpourSparseTable"
...
@@ -44,10 +71,10 @@ class DownpourServer(Server):
...
@@ -44,10 +71,10 @@ class DownpourServer(Server):
table
.
accessor
.
sparse_sgd_param
.
initial_g2sum
=
3
table
.
accessor
.
sparse_sgd_param
.
initial_g2sum
=
3
table
.
accessor
.
sparse_sgd_param
.
initial_range
=
1e-4
table
.
accessor
.
sparse_sgd_param
.
initial_range
=
1e-4
table
.
accessor
.
sparse_sgd_param
.
weight_bounds
.
extend
([
-
10
,
10
])
table
.
accessor
.
sparse_sgd_param
.
weight_bounds
.
extend
([
-
10
,
10
])
table
.
accessor
.
embedx_dim
=
8
table
.
accessor
.
embedx_dim
=
8
table
.
accessor
.
embedx_threshold
=
5
table
.
accessor
.
embedx_threshold
=
5
table
.
accessor
.
fea_dim
=
11
table
.
accessor
.
fea_dim
=
11
#table.accessor.fea_dim = abs(reduce(lambda x, y: x * y,
#table.accessor.fea_dim = abs(reduce(lambda x, y: x * y,
# slot_value_var[0].shape, 1))
# slot_value_var[0].shape, 1))
table
.
accessor
.
downpour_accessor_param
.
nonclk_coeff
=
0.1
table
.
accessor
.
downpour_accessor_param
.
nonclk_coeff
=
0.1
...
@@ -58,53 +85,99 @@ class DownpourServer(Server):
...
@@ -58,53 +85,99 @@ class DownpourServer(Server):
table
.
accessor
.
downpour_accessor_param
.
show_click_decay_rate
=
0.999
table
.
accessor
.
downpour_accessor_param
.
show_click_decay_rate
=
0.999
table
.
accessor
.
downpour_accessor_param
.
delete_threshold
=
0.8
table
.
accessor
.
downpour_accessor_param
.
delete_threshold
=
0.8
def
add_dense_table
(
self
,
table_id
,
learning_rate
,
def
add_dense_table
(
self
,
table_id
,
learning_rate
,
param_var
,
grad_var
):
param_var
,
grad_var
):
"""
Args:
table_id(int): id of sparse params table
learning_rate(float): the learning rate used to update parameters.
\
Can be a float value
param_var(list): all dense param. it is a list.
grad_var(list): all dense grad parm it is a list.
Returns:
return None
"""
table
=
self
.
server_
.
downpour_server_param
.
downpour_table_param
.
add
()
table
=
self
.
server_
.
downpour_server_param
.
downpour_table_param
.
add
()
table
.
table_id
=
table_id
table
.
table_id
=
table_id
table
.
table_class
=
"DownpourDenseTable"
table
.
table_class
=
"DownpourDenseTable"
table
.
type
=
pslib
.
PS_DENSE_TABLE
table
.
type
=
pslib
.
PS_DENSE_TABLE
table
.
accessor
.
accessor_class
=
"DownpourDenseValueAccessor"
table
.
accessor
.
accessor_class
=
"DownpourDenseValueAccessor"
table
.
accessor
.
dense_sgd_param
.
name
=
"adam"
table
.
accessor
.
dense_sgd_param
.
name
=
"adam"
table
.
accessor
.
dense_sgd_param
.
adam
.
learning_rate
=
learning_rate
table
.
accessor
.
dense_sgd_param
.
adam
.
learning_rate
=
learning_rate
table
.
accessor
.
dense_sgd_param
.
adam
.
avg_decay_rate
=
0.999993
table
.
accessor
.
dense_sgd_param
.
adam
.
avg_decay_rate
=
0.999993
table
.
accessor
.
dense_sgd_param
.
adam
.
ada_decay_rate
=
0.9999
table
.
accessor
.
dense_sgd_param
.
adam
.
ada_decay_rate
=
0.9999
table
.
accessor
.
dense_sgd_param
.
adam
.
ada_epsilon
=
1e-8
table
.
accessor
.
dense_sgd_param
.
adam
.
ada_epsilon
=
1e-8
table
.
accessor
.
dense_sgd_param
.
adam
.
mom_decay_rate
=
0.99
table
.
accessor
.
dense_sgd_param
.
adam
.
mom_decay_rate
=
0.99
table
.
accessor
.
dense_sgd_param
.
naive
.
learning_rate
=
0.0002
table
.
accessor
.
dense_sgd_param
.
naive
.
learning_rate
=
0.0002
fea_dim
=
0
fea_dim
=
0
for
param
in
filter
(
lambda
x
:
x
.
name
.
find
(
"embedding"
)
==
-
1
,
param_var
):
for
param
in
filter
(
lambda
x
:
x
.
name
.
find
(
"embedding"
)
==
-
1
,
param_var
):
fea_dim
+=
reduce
(
lambda
x
,
y
:
x
*
y
,
param
.
shape
,
1
)
fea_dim
+=
reduce
(
lambda
x
,
y
:
x
*
y
,
param
.
shape
,
1
)
table
.
accessor
.
fea_dim
=
fea_dim
table
.
accessor
.
fea_dim
=
fea_dim
def
get_desc
(
self
):
def
get_desc
(
self
):
"""
Return downpour server program_desc
"""
return
self
.
server_
return
self
.
server_
class
DownpourWorker
(
Worker
):
class
DownpourWorker
(
Worker
):
"""
DownpourWorker class is used to generate worker program_desc
Args:
window (int): push params frequency
worker: it is pslib.DownpourTrainerParameter
Examples:
worker = DownpourWorker(1)
"""
def
__init__
(
self
,
window
):
def
__init__
(
self
,
window
):
self
.
window
=
window
self
.
window
=
window
self
.
worker_
=
pslib
.
DownpourTrainerParameter
()
self
.
worker_
=
pslib
.
DownpourTrainerParameter
()
#self.worker_.pull_dense_per_batch = window
#self.worker_.pull_dense_per_batch = window
#self.worker_.push_dense_per_batch = window
#self.worker_.push_dense_per_batch = window
def
add_sparse_table
(
self
,
table_id
,
learning_rate
,
def
add_sparse_table
(
self
,
table_id
,
learning_rate
,
slot_key_vars
,
slot_key_vars
,
slot_value_vars
):
slot_value_vars
):
"""
Args:
table_id(int): id of sparse params table
learning_rate(float): the learning rate used to update parameters.
\
Can be a float value
slot_key_vars(string): slot key id
slot_value_var(string): slot key value after embedding
Returns:
return None
"""
table
=
self
.
worker_
.
sparse_table
.
add
()
table
=
self
.
worker_
.
sparse_table
.
add
()
table
.
table_id
=
table_id
table
.
table_id
=
table_id
table
.
slot_key
.
extend
(
table
.
slot_key
.
extend
([
var
.
name
for
var
in
slot_key_vars
])
[
var
.
name
for
var
in
slot_key_vars
])
table
.
slot_value
.
extend
([
var
.
name
for
var
in
slot_value_vars
])
table
.
slot_value
.
extend
(
[
var
.
name
for
var
in
slot_value_vars
])
table
.
slot_gradient
.
extend
(
table
.
slot_gradient
.
extend
(
[
var
.
name
+
"@GRAD"
for
var
in
slot_value_vars
])
[
var
.
name
+
"@GRAD"
for
var
in
slot_value_vars
])
def
add_dense_table
(
self
,
table_id
,
learning_rate
,
def
add_dense_table
(
self
,
table_id
,
learning_rate
,
param_vars
,
grad_vars
):
param_vars
,
grad_vars
):
"""
Args:
table_id(int): id of sparse params table
learning_rate(float): the learning rate used to update parameters.
\
Can be a float value
param_var(list): all dense param. it is a list.
grad_var(list): all dense grad parm it is a list.
Returns:
return None
"""
table
=
self
.
worker_
.
dense_table
.
add
()
table
=
self
.
worker_
.
dense_table
.
add
()
table
.
table_id
=
table_id
table
.
table_id
=
table_id
table
.
dense_variable_name
.
extend
(
filter
(
lambda
x
:
x
.
find
(
"embedding"
)
==
-
1
,
[
p
.
name
for
p
in
param_vars
]))
table
.
dense_variable_name
.
extend
(
table
.
dense_gradient_variable_name
.
extend
(
filter
(
lambda
x
:
x
.
find
(
"embedding"
)
==
-
1
,
[
g
.
name
for
g
in
grad_vars
]))
filter
(
lambda
x
:
x
.
find
(
"embedding"
)
==
-
1
,
[
p
.
name
for
p
in
param_vars
]))
table
.
dense_gradient_variable_name
.
extend
(
filter
(
lambda
x
:
x
.
find
(
"embedding"
)
==
-
1
,
[
g
.
name
for
g
in
grad_vars
]))
def
get_desc
(
self
):
def
get_desc
(
self
):
"""
Return downpour worker program_desc
"""
return
self
.
worker_
return
self
.
worker_
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录