Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
d479ae17
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d479ae17
编写于
1月 12, 2021
作者:
C
Chengmo
提交者:
GitHub
1月 12, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
【Paddle.Fleet】Support local save sparse param (#30175)
* add save tensor support Co-authored-by:
N
seiriosPlus
<
tangwei12@baidu.com
>
上级
113810c5
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
135 addition
and
16 deletion
+135
-16
paddle/fluid/distributed/fleet.cc
paddle/fluid/distributed/fleet.cc
+10
-0
paddle/fluid/distributed/fleet.h
paddle/fluid/distributed/fleet.h
+4
-0
paddle/fluid/distributed/service/brpc_ps_client.cc
paddle/fluid/distributed/service/brpc_ps_client.cc
+79
-0
paddle/fluid/distributed/service/brpc_ps_client.h
paddle/fluid/distributed/service/brpc_ps_client.h
+7
-0
paddle/fluid/distributed/service/ps_client.h
paddle/fluid/distributed/service/ps_client.h
+5
-0
paddle/fluid/distributed/table/common_sparse_table.cc
paddle/fluid/distributed/table/common_sparse_table.cc
+3
-1
paddle/fluid/pybind/fleet_py.cc
paddle/fluid/pybind/fleet_py.cc
+1
-0
python/paddle/distributed/fleet/base/fleet_base.py
python/paddle/distributed/fleet/base/fleet_base.py
+1
-1
python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py
...buted/fleet/meta_optimizers/parameter_server_optimizer.py
+5
-5
python/paddle/distributed/fleet/runtime/the_one_ps.py
python/paddle/distributed/fleet/runtime/the_one_ps.py
+20
-9
未找到文件。
paddle/fluid/distributed/fleet.cc
浏览文件 @
d479ae17
...
@@ -459,6 +459,16 @@ void FleetWrapper::SaveModelOneTable(const uint64_t table_id,
...
@@ -459,6 +459,16 @@ void FleetWrapper::SaveModelOneTable(const uint64_t table_id,
}
}
}
}
void
FleetWrapper
::
RecvAndSaveTable
(
const
uint64_t
table_id
,
const
std
::
string
&
path
)
{
auto
*
communicator
=
Communicator
::
GetInstance
();
auto
ret
=
communicator
->
_worker_ptr
->
recv_and_save_table
(
table_id
,
path
);
if
(
ret
!=
0
)
{
LOG
(
ERROR
)
<<
"save model of table id: "
<<
table_id
<<
", to path: "
<<
path
<<
" failed"
;
}
}
void
FleetWrapper
::
PrintTableStat
(
const
uint64_t
table_id
)
{
void
FleetWrapper
::
PrintTableStat
(
const
uint64_t
table_id
)
{
auto
*
communicator
=
Communicator
::
GetInstance
();
auto
*
communicator
=
Communicator
::
GetInstance
();
auto
ret
=
communicator
->
_worker_ptr
->
print_table_stat
(
table_id
);
auto
ret
=
communicator
->
_worker_ptr
->
print_table_stat
(
table_id
);
...
...
paddle/fluid/distributed/fleet.h
浏览文件 @
d479ae17
...
@@ -198,6 +198,10 @@ class FleetWrapper {
...
@@ -198,6 +198,10 @@ class FleetWrapper {
// mode = 1, save delta feature, which means save diff
// mode = 1, save delta feature, which means save diff
void
SaveModelOneTable
(
const
uint64_t
table_id
,
const
std
::
string
&
path
,
void
SaveModelOneTable
(
const
uint64_t
table_id
,
const
std
::
string
&
path
,
const
int
mode
);
const
int
mode
);
// recv table from server and save it in LodTensor
void
RecvAndSaveTable
(
const
uint64_t
table_id
,
const
std
::
string
&
path
);
// clear all models, release their memory
// clear all models, release their memory
void
ClearModel
();
void
ClearModel
();
// clear one table
// clear one table
...
...
paddle/fluid/distributed/service/brpc_ps_client.cc
浏览文件 @
d479ae17
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include <algorithm>
#include <algorithm>
#include <memory>
#include <memory>
#include <sstream>
#include <string>
#include <string>
#include <vector>
#include <vector>
...
@@ -21,6 +22,7 @@
...
@@ -21,6 +22,7 @@
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/table/table.h"
#include "paddle/fluid/distributed/table/table.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/string/string_helper.h"
const
static
int
max_port
=
65535
;
const
static
int
max_port
=
65535
;
...
@@ -55,6 +57,16 @@ DEFINE_int32(pserver_connect_timeout_ms, 10000,
...
@@ -55,6 +57,16 @@ DEFINE_int32(pserver_connect_timeout_ms, 10000,
DEFINE_int32
(
pserver_sparse_merge_thread
,
1
,
"pserver sparse merge thread num"
);
DEFINE_int32
(
pserver_sparse_merge_thread
,
1
,
"pserver sparse merge thread num"
);
namespace
paddle
{
namespace
framework
{
class
Scope
;
class
Variable
;
}
// namespace framework
namespace
platform
{
class
DeviceContext
;
}
// namespace platform
}
// namespace paddle
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
...
@@ -903,5 +915,72 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient_partial(
...
@@ -903,5 +915,72 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient_partial(
return
fut
;
return
fut
;
}
}
int32_t
BrpcPsClient
::
recv_and_save_table
(
const
uint64_t
table_id
,
const
std
::
string
&
path
)
{
// get var information
std
::
string
var_name
=
""
;
int64_t
var_num
=
0
;
int64_t
var_shape
=
0
;
const
auto
&
worker_param
=
_config
.
worker_param
().
downpour_worker_param
();
for
(
size_t
i
=
0
;
i
<
worker_param
.
downpour_table_param_size
();
++
i
)
{
if
(
worker_param
.
downpour_table_param
(
i
).
table_id
()
==
table_id
)
{
var_name
=
worker_param
.
downpour_table_param
(
i
).
common
().
table_name
();
var_num
=
worker_param
.
downpour_table_param
(
i
).
accessor
().
fea_dim
();
var_shape
=
worker_param
.
downpour_table_param
(
i
).
accessor
().
embedx_dim
();
break
;
}
}
PADDLE_ENFORCE_NE
(
var_name
,
""
,
platform
::
errors
::
InvalidArgument
(
"Cannot find table id %d to save variables."
,
table_id
));
std
::
string
var_store
=
string
::
Sprintf
(
"%s"
,
path
);
MkDirRecursively
(
var_store
.
c_str
());
// pull sparse from server
std
::
vector
<
float
>
save_huge_vec
(
var_num
*
var_shape
);
std
::
vector
<
uint64_t
>
save_key
(
var_num
);
std
::
vector
<
float
*>
save_vec
;
for
(
size_t
i
=
0
;
i
<
save_key
.
size
();
++
i
)
{
save_key
[
i
]
=
i
;
save_vec
.
push_back
(
save_huge_vec
.
data
()
+
i
*
var_shape
);
}
auto
status
=
pull_sparse
((
float
**
)
save_vec
.
data
(),
table_id
,
save_key
.
data
(),
save_key
.
size
());
status
.
wait
();
// create lod tensor
std
::
shared_ptr
<
framework
::
Scope
>
scope
;
scope
.
reset
(
new
framework
::
Scope
());
auto
place
=
platform
::
CPUPlace
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
framework
::
Variable
*
var
=
scope
->
Var
(
var_name
);
framework
::
LoDTensor
*
var_tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
std
::
vector
<
int64_t
>
vec_dim
=
{
var_num
,
var_shape
};
var_tensor
->
Resize
(
framework
::
make_ddim
(
vec_dim
));
// copy and save
float
*
tensor_data
=
var_tensor
->
mutable_data
<
float
>
(
place
);
memcpy
(
tensor_data
,
save_huge_vec
.
data
(),
var_num
*
var_shape
*
sizeof
(
float
));
std
::
string
file_name
=
string
::
Sprintf
(
"%s/%s"
,
var_store
,
var_name
);
std
::
ofstream
fout
(
file_name
,
std
::
ios
::
binary
);
PADDLE_ENFORCE_EQ
(
static_cast
<
bool
>
(
fout
),
true
,
platform
::
errors
::
Unavailable
(
"Cannot open %s to save variables."
,
file_name
));
framework
::
SerializeToStream
(
fout
,
*
var_tensor
,
dev_ctx
);
fout
.
close
();
return
0
;
}
}
// namespace distributed
}
// namespace distributed
}
// namespace paddle
}
// namespace paddle
paddle/fluid/distributed/service/brpc_ps_client.h
浏览文件 @
d479ae17
...
@@ -22,6 +22,9 @@
...
@@ -22,6 +22,9 @@
#include "brpc/controller.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/service/ps_client.h"
#include "paddle/fluid/distributed/service/ps_client.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
...
@@ -148,6 +151,10 @@ class BrpcPsClient : public PSClient {
...
@@ -148,6 +151,10 @@ class BrpcPsClient : public PSClient {
virtual
std
::
future
<
int32_t
>
send_client2client_msg
(
virtual
std
::
future
<
int32_t
>
send_client2client_msg
(
int
msg_type
,
int
to_client_id
,
const
std
::
string
&
msg
)
override
;
int
msg_type
,
int
to_client_id
,
const
std
::
string
&
msg
)
override
;
// for local save sparse
virtual
int32_t
recv_and_save_table
(
const
uint64_t
table_id
,
const
std
::
string
&
path
);
private:
private:
virtual
int32_t
initialize
()
override
;
virtual
int32_t
initialize
()
override
;
...
...
paddle/fluid/distributed/service/ps_client.h
浏览文件 @
d479ae17
...
@@ -134,6 +134,11 @@ class PSClient {
...
@@ -134,6 +134,11 @@ class PSClient {
virtual
std
::
future
<
int32_t
>
push_global_step
(
int
table_id
,
virtual
std
::
future
<
int32_t
>
push_global_step
(
int
table_id
,
int64_t
*
total_send_data
,
int64_t
*
total_send_data
,
void
*
done
)
=
0
;
void
*
done
)
=
0
;
// recv table from server and save it in LodTensor
virtual
int32_t
recv_and_save_table
(
const
uint64_t
table_id
,
const
std
::
string
&
path
)
=
0
;
virtual
void
finalize_worker
()
=
0
;
virtual
void
finalize_worker
()
=
0
;
// client to client, 消息发送
// client to client, 消息发送
virtual
std
::
future
<
int32_t
>
send_client2client_msg
(
int
msg_type
,
virtual
std
::
future
<
int32_t
>
send_client2client_msg
(
int
msg_type
,
...
...
paddle/fluid/distributed/table/common_sparse_table.cc
浏览文件 @
d479ae17
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/string/string_helper.h"
#define PSERVER_SAVE_SUFFIX "_txt"
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
...
@@ -290,7 +291,8 @@ int32_t CommonSparseTable::save(const std::string& dirname,
...
@@ -290,7 +291,8 @@ int32_t CommonSparseTable::save(const std::string& dirname,
VLOG
(
0
)
<<
"sparse table save: "
<<
dirname
<<
" mode: "
<<
mode
;
VLOG
(
0
)
<<
"sparse table save: "
<<
dirname
<<
" mode: "
<<
mode
;
auto
varname
=
_config
.
common
().
table_name
();
auto
varname
=
_config
.
common
().
table_name
();
std
::
string
var_store
=
string
::
Sprintf
(
"%s/%s"
,
dirname
,
varname
);
std
::
string
var_store
=
string
::
Sprintf
(
"%s/%s%s"
,
dirname
,
varname
,
PSERVER_SAVE_SUFFIX
);
MkDirRecursively
(
var_store
.
c_str
());
MkDirRecursively
(
var_store
.
c_str
());
VLOG
(
3
)
<<
"save "
<<
varname
<<
" in dir: "
<<
var_store
<<
" begin"
;
VLOG
(
3
)
<<
"save "
<<
varname
<<
" in dir: "
<<
var_store
<<
" begin"
;
...
...
paddle/fluid/pybind/fleet_py.cc
浏览文件 @
d479ae17
...
@@ -58,6 +58,7 @@ void BindDistFleetWrapper(py::module* m) {
...
@@ -58,6 +58,7 @@ void BindDistFleetWrapper(py::module* m) {
.
def
(
"pull_dense_params"
,
&
FleetWrapper
::
PullDenseVarsSync
)
.
def
(
"pull_dense_params"
,
&
FleetWrapper
::
PullDenseVarsSync
)
.
def
(
"save_all_model"
,
&
FleetWrapper
::
SaveModel
)
.
def
(
"save_all_model"
,
&
FleetWrapper
::
SaveModel
)
.
def
(
"save_one_model"
,
&
FleetWrapper
::
SaveModelOneTable
)
.
def
(
"save_one_model"
,
&
FleetWrapper
::
SaveModelOneTable
)
.
def
(
"recv_and_save_model"
,
&
FleetWrapper
::
RecvAndSaveTable
)
.
def
(
"sparse_table_stat"
,
&
FleetWrapper
::
PrintTableStat
)
.
def
(
"sparse_table_stat"
,
&
FleetWrapper
::
PrintTableStat
)
.
def
(
"stop_server"
,
&
FleetWrapper
::
StopServer
)
.
def
(
"stop_server"
,
&
FleetWrapper
::
StopServer
)
.
def
(
"stop_worker"
,
&
FleetWrapper
::
FinalizeWorker
)
.
def
(
"stop_worker"
,
&
FleetWrapper
::
FinalizeWorker
)
...
...
python/paddle/distributed/fleet/base/fleet_base.py
浏览文件 @
d479ae17
...
@@ -545,7 +545,7 @@ class Fleet(object):
...
@@ -545,7 +545,7 @@ class Fleet(object):
executor
,
dirname
,
feeded_var_names
,
target_vars
,
main_program
,
executor
,
dirname
,
feeded_var_names
,
target_vars
,
main_program
,
export_for_deployment
)
export_for_deployment
)
def
save_persistables
(
self
,
executor
,
dirname
,
main_program
=
None
,
mode
=
1
):
def
save_persistables
(
self
,
executor
,
dirname
,
main_program
=
None
,
mode
=
0
):
"""
"""
saves all persistable tensors from :code:`main_program` to
saves all persistable tensors from :code:`main_program` to
...
...
python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py
浏览文件 @
d479ae17
...
@@ -64,12 +64,12 @@ class ParameterServerOptimizer(MetaOptimizerBase):
...
@@ -64,12 +64,12 @@ class ParameterServerOptimizer(MetaOptimizerBase):
_main
=
compiled_config
.
origin_main_program
.
clone
()
_main
=
compiled_config
.
origin_main_program
.
clone
()
_startup
=
compiled_config
.
origin_startup_program
.
clone
()
_startup
=
compiled_config
.
origin_startup_program
.
clone
()
from
paddle.fluid.incubate.fleet.parameter_server.ir.public
import
_add_lr_decay_table_pass
_add_lr_decay_table_pass
(
_main
,
compiled_config
,
self
.
user_defined_strategy
.
a_sync_configs
[
"lr_decay_steps"
])
if
not
compiled_config
.
is_geo_mode
():
if
not
compiled_config
.
is_geo_mode
():
from
paddle.fluid.incubate.fleet.parameter_server.ir.public
import
_add_lr_decay_table_pass
_add_lr_decay_table_pass
(
_main
,
compiled_config
,
self
.
user_defined_strategy
.
a_sync_configs
[
"lr_decay_steps"
])
# for main program
# for main program
_main
=
worker
.
delete_optimizer_pass
(
_main
,
compiled_config
)
_main
=
worker
.
delete_optimizer_pass
(
_main
,
compiled_config
)
_main
=
worker
.
distributed_ops_pass
(
_main
,
compiled_config
)
_main
=
worker
.
distributed_ops_pass
(
_main
,
compiled_config
)
...
...
python/paddle/distributed/fleet/runtime/the_one_ps.py
浏览文件 @
d479ae17
...
@@ -851,15 +851,26 @@ class TheOnePSRuntime(RuntimeBase):
...
@@ -851,15 +851,26 @@ class TheOnePSRuntime(RuntimeBase):
return
is_valid
return
is_valid
def
_save_sparse_params
(
self
,
executor
,
dirname
,
context
,
main_program
):
def
_save_sparse_params
(
self
,
executor
,
dirname
,
context
,
main_program
,
mode
):
from
paddle.fluid.incubate.fleet.parameter_server.ir.public
import
get_sparse_tablenames
distributed_varnames
=
get_sparse_tablenames
(
self
.
compiled_strategy
.
origin_main_program
,
True
)
values
=
[]
values
=
[]
for
id
,
names
in
context
.
items
():
for
id
,
names
in
context
.
items
():
if
names
not
in
distributed_varnames
:
# only save sparse param to local
self
.
_worker
.
recv_and_save_model
(
id
,
dirname
)
# save sparse & distributed param on server
self
.
_worker
.
save_one_model
(
id
,
dirname
,
mode
)
values
.
extend
(
names
)
values
.
extend
(
names
)
self
.
_worker
.
save_one_model
(
id
,
dirname
,
0
)
return
values
return
values
def
_save_distributed_persistables
(
self
,
executor
,
dirname
,
main_program
,
def
_save_distributed_persistables
(
self
,
mode
):
executor
,
dirname
,
main_program
,
mode
=
0
):
denses
=
self
.
compiled_strategy
.
get_the_one_recv_context
(
denses
=
self
.
compiled_strategy
.
get_the_one_recv_context
(
is_dense
=
True
,
is_dense
=
True
,
...
@@ -870,14 +881,14 @@ class TheOnePSRuntime(RuntimeBase):
...
@@ -870,14 +881,14 @@ class TheOnePSRuntime(RuntimeBase):
split_dense_table
=
self
.
role_maker
.
_is_heter_parameter_server_mode
,
split_dense_table
=
self
.
role_maker
.
_is_heter_parameter_server_mode
,
use_origin_program
=
True
)
use_origin_program
=
True
)
recv_sparse_varnames
=
self
.
_save_sparse_params
(
executor
,
dirname
,
sparse_varnames
=
self
.
_save_sparse_params
(
executor
,
dirname
,
sparses
,
sparses
,
main_program
)
main_program
,
mode
)
recv_dense_varnames
=
[]
recv_dense_varnames
=
[]
for
id
,
names
in
denses
.
items
():
for
id
,
names
in
denses
.
items
():
recv_dense_varnames
.
extend
(
names
)
recv_dense_varnames
.
extend
(
names
)
saved_varnames
=
recv_
sparse_varnames
saved_varnames
=
sparse_varnames
remaining_vars
=
list
(
remaining_vars
=
list
(
filter
(
filter
(
...
@@ -925,6 +936,7 @@ class TheOnePSRuntime(RuntimeBase):
...
@@ -925,6 +936,7 @@ class TheOnePSRuntime(RuntimeBase):
"in fleet.save_persistables() function, main_program must be as Program type, CompiledProgram is not allowed"
"in fleet.save_persistables() function, main_program must be as Program type, CompiledProgram is not allowed"
)
)
# Todo(MrChengmo): Save optimizer status
self
.
_save_distributed_persistables
(
executor
,
dirname
,
main_program
,
self
.
_save_distributed_persistables
(
executor
,
dirname
,
main_program
,
mode
)
mode
)
...
@@ -971,8 +983,7 @@ class TheOnePSRuntime(RuntimeBase):
...
@@ -971,8 +983,7 @@ class TheOnePSRuntime(RuntimeBase):
program
=
Program
.
parse_from_string
(
program_desc_str
)
program
=
Program
.
parse_from_string
(
program_desc_str
)
program
.
_copy_dist_param_info_from
(
fluid
.
default_main_program
())
program
.
_copy_dist_param_info_from
(
fluid
.
default_main_program
())
self
.
_ps_inference_save_persistables
(
self
.
_ps_inference_save_persistables
(
executor
,
dirname
,
program
)
executor
,
dirname
,
program
,
mode
=
0
)
def
_save_inference_model
(
self
,
*
args
,
**
kwargs
):
def
_save_inference_model
(
self
,
*
args
,
**
kwargs
):
self
.
_ps_inference_save_inference_model
(
*
args
,
**
kwargs
)
self
.
_ps_inference_save_inference_model
(
*
args
,
**
kwargs
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录