Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
4d0d0eca
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4d0d0eca
编写于
3月 21, 2022
作者:
Y
yaoxuefeng
提交者:
GitHub
3月 21, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
mod base (#40702)
上级
382e460b
变更
36
隐藏空白更改
内联
并排
Showing
36 changed file
with
499 addition
and
40 deletion
+499
-40
paddle/fluid/distributed/ps/service/brpc_ps_client.cc
paddle/fluid/distributed/ps/service/brpc_ps_client.cc
+61
-0
paddle/fluid/distributed/ps/service/brpc_ps_client.h
paddle/fluid/distributed/ps/service/brpc_ps_client.h
+9
-0
paddle/fluid/distributed/ps/service/brpc_ps_server.h
paddle/fluid/distributed/ps/service/brpc_ps_server.h
+1
-1
paddle/fluid/distributed/ps/service/graph_brpc_server.h
paddle/fluid/distributed/ps/service/graph_brpc_server.h
+1
-1
paddle/fluid/distributed/ps/service/ps_client.h
paddle/fluid/distributed/ps/service/ps_client.h
+46
-0
paddle/fluid/distributed/ps/service/ps_local_client.cc
paddle/fluid/distributed/ps/service/ps_local_client.cc
+73
-0
paddle/fluid/distributed/ps/service/ps_local_client.h
paddle/fluid/distributed/ps/service/ps_local_client.h
+8
-0
paddle/fluid/distributed/ps/service/ps_local_server.h
paddle/fluid/distributed/ps/service/ps_local_server.h
+0
-1
paddle/fluid/distributed/ps/service/server.cc
paddle/fluid/distributed/ps/service/server.cc
+0
-2
paddle/fluid/distributed/ps/service/server.h
paddle/fluid/distributed/ps/service/server.h
+0
-15
paddle/fluid/distributed/ps/table/accessor.h
paddle/fluid/distributed/ps/table/accessor.h
+14
-0
paddle/fluid/distributed/ps/table/common_dense_table.cc
paddle/fluid/distributed/ps/table/common_dense_table.cc
+15
-0
paddle/fluid/distributed/ps/table/common_dense_table.h
paddle/fluid/distributed/ps/table/common_dense_table.h
+2
-0
paddle/fluid/distributed/ps/table/common_graph_table.h
paddle/fluid/distributed/ps/table/common_graph_table.h
+3
-0
paddle/fluid/distributed/ps/table/common_sparse_table.cc
paddle/fluid/distributed/ps/table/common_sparse_table.cc
+26
-0
paddle/fluid/distributed/ps/table/common_sparse_table.h
paddle/fluid/distributed/ps/table/common_sparse_table.h
+3
-0
paddle/fluid/distributed/ps/table/common_table.h
paddle/fluid/distributed/ps/table/common_table.h
+3
-0
paddle/fluid/distributed/ps/table/ctr_accessor.cc
paddle/fluid/distributed/ps/table/ctr_accessor.cc
+10
-0
paddle/fluid/distributed/ps/table/ctr_accessor.h
paddle/fluid/distributed/ps/table/ctr_accessor.h
+1
-0
paddle/fluid/distributed/ps/table/ctr_double_accessor.cc
paddle/fluid/distributed/ps/table/ctr_double_accessor.cc
+10
-0
paddle/fluid/distributed/ps/table/ctr_double_accessor.h
paddle/fluid/distributed/ps/table/ctr_double_accessor.h
+1
-0
paddle/fluid/distributed/ps/table/depends/sparse_utils.h
paddle/fluid/distributed/ps/table/depends/sparse_utils.h
+1
-1
paddle/fluid/distributed/ps/table/downpour_ctr_accessor.cc
paddle/fluid/distributed/ps/table/downpour_ctr_accessor.cc
+10
-0
paddle/fluid/distributed/ps/table/downpour_ctr_accessor.h
paddle/fluid/distributed/ps/table/downpour_ctr_accessor.h
+1
-0
paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h
paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h
+2
-0
paddle/fluid/distributed/ps/table/memory_sparse_table.cc
paddle/fluid/distributed/ps/table/memory_sparse_table.cc
+20
-0
paddle/fluid/distributed/ps/table/memory_sparse_table.h
paddle/fluid/distributed/ps/table/memory_sparse_table.h
+3
-0
paddle/fluid/distributed/ps/table/ssd_sparse_table.cc
paddle/fluid/distributed/ps/table/ssd_sparse_table.cc
+15
-0
paddle/fluid/distributed/ps/table/ssd_sparse_table.h
paddle/fluid/distributed/ps/table/ssd_sparse_table.h
+3
-0
paddle/fluid/distributed/ps/table/table.h
paddle/fluid/distributed/ps/table/table.h
+26
-0
paddle/fluid/distributed/ps/table/tensor_accessor.cc
paddle/fluid/distributed/ps/table/tensor_accessor.cc
+10
-0
paddle/fluid/distributed/ps/table/tensor_accessor.h
paddle/fluid/distributed/ps/table/tensor_accessor.h
+1
-0
paddle/fluid/distributed/ps/table/tensor_table.h
paddle/fluid/distributed/ps/table/tensor_table.h
+2
-0
paddle/fluid/distributed/ps/wrapper/fleet.cc
paddle/fluid/distributed/ps/wrapper/fleet.cc
+26
-0
paddle/fluid/distributed/ps/wrapper/fleet.h
paddle/fluid/distributed/ps/wrapper/fleet.h
+8
-1
paddle/fluid/distributed/ps/wrapper/ps_wrapper.h
paddle/fluid/distributed/ps/wrapper/ps_wrapper.h
+84
-18
未找到文件。
paddle/fluid/distributed/ps/service/brpc_ps_client.cc
浏览文件 @
4d0d0eca
...
...
@@ -414,6 +414,16 @@ std::future<int32_t> BrpcPsClient::load(uint32_t table_id,
return
send_cmd
(
table_id
,
PS_LOAD_ONE_TABLE
,
{
epoch
,
mode
});
}
std
::
future
<
int32_t
>
BrpcPsClient
::
Load
(
const
LoadSaveContext
&
load_context
)
{
if
(
load_context
.
table_id
<
0
)
{
return
send_cmd
(
-
1
,
PS_LOAD_ALL_TABLE
,
{
load_context
.
epoch
,
load_context
.
mode
});
}
else
{
return
send_cmd
(
load_context
.
table_id
,
PS_LOAD_ONE_TABLE
,
{
load_context
.
epoch
,
load_context
.
mode
});
}
}
std
::
future
<
int32_t
>
BrpcPsClient
::
save
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
{
VLOG
(
1
)
<<
"BrpcPsClient::save path "
<<
epoch
;
...
...
@@ -427,6 +437,19 @@ std::future<int32_t> BrpcPsClient::save(uint32_t table_id,
return
send_save_cmd
(
table_id
,
PS_SAVE_ONE_TABLE
,
{
epoch
,
mode
});
}
std
::
future
<
int32_t
>
BrpcPsClient
::
Save
(
const
LoadSaveContext
&
save_context
)
{
if
(
save_context
.
table_id
<
0
)
{
VLOG
(
1
)
<<
"BrpcPsClient::save path "
<<
save_context
.
epoch
;
return
send_save_cmd
(
-
1
,
PS_SAVE_ALL_TABLE
,
{
save_context
.
epoch
,
save_context
.
mode
});
}
else
{
VLOG
(
1
)
<<
"BrpcPsClient::save one table path "
<<
save_context
.
epoch
<<
" table_id "
<<
save_context
.
table_id
;
return
send_save_cmd
(
save_context
.
table_id
,
PS_SAVE_ONE_TABLE
,
{
save_context
.
epoch
,
save_context
.
mode
});
}
}
std
::
future
<
int32_t
>
BrpcPsClient
::
clear
()
{
return
send_cmd
(
-
1
,
PS_CLEAR_ALL_TABLE
,
{});
}
...
...
@@ -505,6 +528,44 @@ std::future<int32_t> BrpcPsClient::barrier(size_t table_id,
return
send_cmd
(
table_id
,
PS_BARRIER
,
{
std
::
to_string
(
barrier_type
)});
}
std
::
future
<
int32_t
>
BrpcPsClient
::
Pull
(
RequestContext
&
pull_context
)
{
if
(
pull_context
.
value_type
==
Dense
)
{
// pull dense
Region
*
dense_region
=
reinterpret_cast
<
Region
*>
(
pull_context
.
dense_values
);
pull_dense
(
dense_region
,
pull_context
.
num
,
pull_context
.
table
);
}
else
{
// pull sparse
uint64_t
*
keys
=
reinterpret_cast
<
uint64_t
*>
(
pull_context
.
keys
);
float
**
select_values
=
reinterpret_cast
<
float
**>
(
pull_context
.
sparse_values
);
size_t
table_id
=
pull_context
.
table
;
size_t
num
=
pull_context
.
num
;
bool
is_training
=
pull_context
.
is_training
;
if
(
pull_context
.
training_mode
==
Geo
)
{
// for geo
pull_sparse_param
(
select_values
,
table_id
,
keys
,
num
,
is_training
);
}
else
if
(
pull_context
.
training_mode
==
Async
)
{
// for async
pull_sparse
(
select_values
,
table_id
,
keys
,
num
,
is_training
);
}
}
}
std
::
future
<
int32_t
>
BrpcPsClient
::
Push
(
RequestContext
&
push_context
)
{
if
(
push_context
.
value_type
==
Dense
)
{
// push dense
const
Region
*
dense_region
=
push_context
.
push_context
.
push_dense_values
;
push_dense
(
dense_region
,
push_context
.
num
,
push_context
.
table
);
}
else
{
// push sparse
size_t
table_id
=
push_context
.
table
;
size_t
num
=
push_context
.
num
;
bool
is_training
=
push_context
.
is_training
;
if
(
push_context
.
training_mode
==
Geo
)
{
// for geo
// TODO(zhaocaibei)
}
else
if
(
push_context
.
training_mode
==
Async
)
{
// for async
const
uint64_t
*
keys
=
push_context
.
push_context
.
keys
;
const
float
**
update_values
=
push_context
.
push_context
.
push_values
;
push_sparse
(
table_id
,
keys
,
update_values
,
num
);
}
}
}
std
::
future
<
int32_t
>
BrpcPsClient
::
pull_geo_param
(
size_t
table_id
,
std
::
vector
<
float
>
*
values
,
std
::
vector
<
uint64_t
>
*
keys
,
...
...
paddle/fluid/distributed/ps/service/brpc_ps_client.h
浏览文件 @
4d0d0eca
...
...
@@ -163,12 +163,17 @@ class BrpcPsClient : public PSClient {
std
::
future
<
int32_t
>
load
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
override
;
std
::
future
<
int32_t
>
Load
(
const
LoadSaveContext
&
load_context
)
override
;
std
::
future
<
int32_t
>
save
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
override
;
std
::
future
<
int32_t
>
save
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
override
;
virtual
std
::
future
<
int32_t
>
Save
(
const
LoadSaveContext
&
save_context
)
override
;
std
::
future
<
int32_t
>
clear
()
override
;
std
::
future
<
int32_t
>
clear
(
uint32_t
table_id
)
override
;
...
...
@@ -199,6 +204,10 @@ class BrpcPsClient : public PSClient {
const
uint64_t
*
keys
,
size_t
num
,
bool
is_training
);
virtual
std
::
future
<
int32_t
>
Pull
(
RequestContext
&
pull_context
)
override
;
virtual
std
::
future
<
int32_t
>
Push
(
RequestContext
&
push_context
)
override
;
virtual
std
::
future
<
int32_t
>
print_table_stat
(
uint32_t
table_id
);
virtual
std
::
future
<
int32_t
>
barrier
(
size_t
table_id
,
uint32_t
barrier_type
);
...
...
paddle/fluid/distributed/ps/service/brpc_ps_server.h
浏览文件 @
4d0d0eca
...
...
@@ -51,7 +51,7 @@ class BrpcPsServer : public PSServer {
_server
.
Join
();
return
0
;
}
virtual
int32_t
port
();
int32_t
port
();
private:
virtual
int32_t
initialize
();
...
...
paddle/fluid/distributed/ps/service/graph_brpc_server.h
浏览文件 @
4d0d0eca
...
...
@@ -43,7 +43,7 @@ class GraphBrpcServer : public PSServer {
_server
.
Join
();
return
0
;
}
virtual
int32_t
port
();
int32_t
port
();
std
::
condition_variable
*
export_cv
()
{
return
&
cv_
;
}
...
...
paddle/fluid/distributed/ps/service/ps_client.h
浏览文件 @
4d0d0eca
...
...
@@ -26,6 +26,7 @@
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/distributed/ps/table/table.h"
#include "paddle/fluid/platform/timer.h"
namespace
paddle
{
...
...
@@ -59,6 +60,41 @@ class PSClientClosure : public google::protobuf::Closure {
std
::
vector
<
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>>
_promises
;
};
struct
LoadSaveContext
{
int
table_id
;
std
::
string
epoch
;
std
::
string
mode
;
};
enum
TrainingMode
{
Async
=
0
,
Sync
=
1
,
Geo
=
3
};
enum
TrainingPhase
{
Init
=
0
,
Train
=
1
,
Save
=
2
};
// enum ValueType {
// Sparse = 0,
// Dense = 1
// };
struct
PushContext
{
const
uint64_t
*
keys
;
const
float
**
push_values
;
const
Region
*
push_dense_values
;
};
struct
RequestContext
{
int
table
;
TrainingMode
training_mode
;
// 1 for async, 2 for geo, 3 for sync
TrainingPhase
training_phase
;
// 1 for init, 2 for train
ValueType
value_type
;
// 1 for sparse, 2 for dense
void
*
keys
;
void
**
sparse_values
;
// for sparse values
Region
*
dense_values
;
// for dense values
PushContext
push_context
;
size_t
num
;
bool
is_training
;
void
*
callback
;
};
class
PSClient
{
public:
PSClient
()
{}
...
...
@@ -86,6 +122,9 @@ class PSClient {
// 指定table数据load
virtual
std
::
future
<
int32_t
>
load
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
=
0
;
// context配置load选项
virtual
std
::
future
<
int32_t
>
Load
(
const
LoadSaveContext
&
load_context
)
=
0
;
// 全量table数据save value_accessor根据mode,可能有不同的save条件
virtual
std
::
future
<
int32_t
>
save
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
=
0
;
...
...
@@ -93,6 +132,8 @@ class PSClient {
virtual
std
::
future
<
int32_t
>
save
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
=
0
;
virtual
std
::
future
<
int32_t
>
Save
(
const
LoadSaveContext
&
save_context
)
=
0
;
// 清空table数据
virtual
std
::
future
<
int32_t
>
clear
()
=
0
;
virtual
std
::
future
<
int32_t
>
clear
(
uint32_t
table_id
)
=
0
;
...
...
@@ -107,6 +148,8 @@ class PSClient {
virtual
std
::
future
<
int32_t
>
pull_dense
(
Region
*
regions
,
size_t
region_num
,
size_t
table_id
)
=
0
;
// 保留
virtual
std
::
future
<
int32_t
>
Push
(
RequestContext
&
push_context
)
=
0
;
// firstly push dense param for parameter server
// this is neccessary because dense weight initialized in trainer on cold
// start
...
...
@@ -117,6 +160,9 @@ class PSClient {
virtual
std
::
future
<
int32_t
>
push_dense
(
const
Region
*
regions
,
size_t
region_num
,
size_t
table_id
)
=
0
;
virtual
std
::
future
<
int32_t
>
Pull
(
RequestContext
&
pull_context
)
=
0
;
// 使用keys进行pull请求,结果填充values
// keys和values的个数均为num个,每个value占用select_size空间
// future结束前keys和values缓冲区不能再次使用
...
...
paddle/fluid/distributed/ps/service/ps_local_client.cc
浏览文件 @
4d0d0eca
...
...
@@ -56,6 +56,19 @@ int32_t PsLocalClient::initialize() {
return
done
();
}
std
::
future
<
int32_t
>
PsLocalClient
::
Load
(
const
LoadSaveContext
&
load_context
)
{
if
(
load_context
.
table_id
<
0
)
{
for
(
auto
&
it
:
_table_map
)
{
load
(
it
.
first
,
load_context
.
epoch
,
load_context
.
mode
);
}
return
done
();
}
else
{
auto
*
table_ptr
=
table
(
load_context
.
table_id
);
table_ptr
->
load
(
load_context
.
epoch
,
load_context
.
mode
);
return
done
();
}
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
save
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
{
// TODO
...
...
@@ -74,6 +87,21 @@ int32_t PsLocalClient::initialize() {
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
Save
(
const
LoadSaveContext
&
save_context
)
{
if
(
save_context
.
table_id
<
0
)
{
for
(
auto
&
it
:
_table_map
)
{
save
(
it
.
first
,
save_context
.
epoch
,
save_context
.
mode
);
}
return
done
();
}
else
{
auto
*
table_ptr
=
table
(
save_context
.
table_id
);
table_ptr
->
flush
();
table_ptr
->
save
(
save_context
.
epoch
,
save_context
.
mode
);
return
done
();
}
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
clear
()
{
// TODO
return
done
();
...
...
@@ -93,6 +121,51 @@ int32_t PsLocalClient::initialize() {
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
Pull
(
RequestContext
&
pull_context
)
{
if
(
pull_context
.
value_type
==
Dense
)
{
// pull dense
Region
*
dense_region
=
reinterpret_cast
<
Region
*>
(
pull_context
.
dense_values
);
pull_dense
(
dense_region
,
pull_context
.
num
,
pull_context
.
table
);
}
else
{
// pull sparse
uint64_t
*
keys
=
reinterpret_cast
<
uint64_t
*>
(
pull_context
.
keys
);
char
**
select_values
=
reinterpret_cast
<
char
**>
(
pull_context
.
sparse_values
);
size_t
table_id
=
pull_context
.
table
;
size_t
num
=
pull_context
.
num
;
pull_sparse_ptr
(
select_values
,
table_id
,
keys
,
num
);
}
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
Push
(
RequestContext
&
push_context
)
{
if
(
push_context
.
value_type
==
Dense
)
{
// push dense
if
(
push_context
.
training_phase
==
Init
)
{
const
Region
*
regions
=
push_context
.
push_context
.
push_dense_values
;
size_t
region_num
=
push_context
.
num
;
push_dense_param
(
regions
,
region_num
,
push_context
.
table
);
}
else
{
if
(
push_context
.
training_mode
==
Geo
)
{
// geo
float
*
total_send_data
=
reinterpret_cast
<
float
*>
(
push_context
.
dense_values
);
size_t
total_send_data_size
=
push_context
.
num
;
push_dense_raw_gradient
(
push_context
.
table
,
total_send_data
,
total_send_data_size
,
push_context
.
callback
);
}
else
{
// async and sync
const
Region
*
regions
=
push_context
.
push_context
.
push_dense_values
;
size_t
region_num
=
push_context
.
num
;
push_dense
(
regions
,
region_num
,
push_context
.
table
);
}
}
}
else
{
// push sparse
if
(
push_context
.
training_mode
==
Async
)
{
const
uint64_t
*
keys
=
push_context
.
push_context
.
keys
;
const
float
**
update_values
=
push_context
.
push_context
.
push_values
;
size_t
table_id
=
push_context
.
table
;
size_t
num
=
push_context
.
num
;
push_sparse
(
table_id
,
keys
,
update_values
,
num
);
}
else
{
// TODO
}
}
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
pull_dense
(
Region
*
regions
,
size_t
region_num
,
size_t
table_id
)
{
...
...
paddle/fluid/distributed/ps/service/ps_local_client.h
浏览文件 @
4d0d0eca
...
...
@@ -39,12 +39,16 @@ class PsLocalClient : public PSClient {
virtual
::
std
::
future
<
int32_t
>
load
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
override
;
virtual
std
::
future
<
int32_t
>
Load
(
const
LoadSaveContext
&
load_context
)
override
;
virtual
::
std
::
future
<
int32_t
>
save
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
override
;
virtual
::
std
::
future
<
int32_t
>
save
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
override
;
virtual
std
::
future
<
int32_t
>
Save
(
const
LoadSaveContext
&
save_context
)
override
;
virtual
::
std
::
future
<
int32_t
>
clear
()
override
;
virtual
::
std
::
future
<
int32_t
>
clear
(
uint32_t
table_id
)
override
;
...
...
@@ -55,6 +59,10 @@ class PsLocalClient : public PSClient {
virtual
::
std
::
future
<
int32_t
>
pull_dense
(
Region
*
regions
,
size_t
region_num
,
size_t
table_id
);
virtual
::
std
::
future
<
int32_t
>
Pull
(
RequestContext
&
pull_context
)
override
;
virtual
::
std
::
future
<
int32_t
>
Push
(
RequestContext
&
push_context
)
override
;
virtual
::
std
::
future
<
int32_t
>
push_dense
(
const
Region
*
regions
,
size_t
region_num
,
size_t
table_id
);
...
...
paddle/fluid/distributed/ps/service/ps_local_server.h
浏览文件 @
4d0d0eca
...
...
@@ -28,7 +28,6 @@ class PsLocalServer : public PSServer {
virtual
uint64_t
start
()
{
return
0
;
}
virtual
uint64_t
start
(
const
std
::
string
&
ip
,
uint32_t
port
)
{
return
0
;
}
virtual
int32_t
stop
()
{
return
0
;
}
virtual
int32_t
port
()
{
return
0
;
}
virtual
int32_t
configure
(
const
PSParameter
&
config
,
PSEnvironment
&
env
,
size_t
server_rank
,
const
std
::
vector
<
framework
::
ProgramDesc
>
&
server_sub_program
=
{})
{
...
...
paddle/fluid/distributed/ps/service/server.cc
浏览文件 @
4d0d0eca
...
...
@@ -67,8 +67,6 @@ int32_t PSServer::configure(
_config
=
config
.
server_param
();
_rank
=
server_rank
;
_environment
=
&
env
;
_shuffled_ins
=
paddle
::
framework
::
MakeChannel
<
std
::
pair
<
uint64_t
,
std
::
string
>>
();
size_t
shard_num
=
env
.
get_ps_servers
().
size
();
const
auto
&
downpour_param
=
_config
.
downpour_server_param
();
...
...
paddle/fluid/distributed/ps/service/server.h
浏览文件 @
4d0d0eca
...
...
@@ -69,11 +69,6 @@ class PSServer {
const
PSParameter
&
config
,
PSEnvironment
&
env
,
size_t
server_rank
,
const
std
::
vector
<
framework
::
ProgramDesc
>
&
server_sub_program
=
{});
// return server_ip
virtual
std
::
string
ip
()
{
return
butil
::
my_ip_cstr
();
}
// return server_port
virtual
int32_t
port
()
=
0
;
virtual
uint64_t
start
(
const
std
::
string
&
ip
,
uint32_t
port
)
=
0
;
virtual
int32_t
stop
()
=
0
;
...
...
@@ -94,15 +89,6 @@ class PSServer {
return
&
_table_map
;
}
typedef
std
::
function
<
int32_t
(
int
,
int
,
const
std
::
string
&
)
>
MsgHandlerFunc
;
virtual
int
registe_pserver2pserver_msg_handler
(
int
msg_type
,
MsgHandlerFunc
handler
)
{
_msg_handler_map
[
msg_type
]
=
handler
;
return
0
;
}
paddle
::
framework
::
Channel
<
std
::
pair
<
uint64_t
,
std
::
string
>>
_shuffled_ins
;
protected:
virtual
int32_t
initialize
()
=
0
;
...
...
@@ -111,7 +97,6 @@ class PSServer {
ServerParameter
_config
;
PSEnvironment
*
_environment
;
std
::
unordered_map
<
uint32_t
,
std
::
shared_ptr
<
Table
>>
_table_map
;
std
::
unordered_map
<
int32_t
,
MsgHandlerFunc
>
_msg_handler_map
;
protected:
std
::
shared_ptr
<
framework
::
Scope
>
scope_
;
...
...
paddle/fluid/distributed/ps/table/accessor.h
浏览文件 @
4d0d0eca
...
...
@@ -45,6 +45,17 @@ struct DataConverter {
std
::
string
deconverter
;
};
struct
AccessorInfo
{
size_t
dim
;
size_t
size
;
size_t
select_size
;
size_t
select_dim
;
size_t
update_size
;
size_t
update_dim
;
size_t
mf_size
;
size_t
fea_dim
;
};
class
ValueAccessor
{
public:
ValueAccessor
()
{}
...
...
@@ -68,6 +79,8 @@ class ValueAccessor {
}
virtual
int
initialize
()
=
0
;
virtual
void
GetTableInfo
(
AccessorInfo
&
info
)
=
0
;
// value维度
virtual
size_t
dim
()
=
0
;
// value各个维度的size
...
...
@@ -163,6 +176,7 @@ class ValueAccessor {
TableAccessorParameter
_config
;
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
struct
DataConverter
>>
_data_coverter_map
;
AccessorInfo
_accessor_info
;
};
REGISTER_PSCORE_REGISTERER
(
ValueAccessor
);
}
// namespace distributed
...
...
paddle/fluid/distributed/ps/table/common_dense_table.cc
浏览文件 @
4d0d0eca
...
...
@@ -128,6 +128,21 @@ int32_t CommonDenseTable::set_global_lr(float* lr) {
return
0
;
}
int32_t
CommonDenseTable
::
Pull
(
TableContext
&
context
)
{
CHECK
(
context
.
value_type
==
Dense
);
float
*
pull_values
=
context
.
pull_context
.
values
;
return
pull_dense
(
pull_values
,
context
.
num
);
}
int32_t
CommonDenseTable
::
Push
(
TableContext
&
context
)
{
CHECK
(
context
.
value_type
==
Dense
);
if
(
context
.
pull_context
.
values
!=
nullptr
)
{
const
float
*
values
=
context
.
push_context
.
values
;
return
push_dense
(
values
,
context
.
num
);
}
return
0
;
}
int32_t
CommonDenseTable
::
pull_dense
(
float
*
pull_values
,
size_t
num
)
{
std
::
copy
(
values_
[
param_idx_
].
begin
(),
values_
[
param_idx_
].
end
(),
pull_values
);
...
...
paddle/fluid/distributed/ps/table/common_dense_table.h
浏览文件 @
4d0d0eca
...
...
@@ -40,6 +40,8 @@ class CommonDenseTable : public DenseTable {
const
std
::
string
&
name
);
virtual
int32_t
initialize_value
();
virtual
int32_t
initialize_optimizer
();
virtual
int32_t
Pull
(
TableContext
&
context
);
virtual
int32_t
Push
(
TableContext
&
context
);
int32_t
pull_dense
(
float
*
pull_values
,
size_t
num
)
override
;
int32_t
push_dense_param
(
const
float
*
values
,
size_t
num
)
override
;
int32_t
push_dense
(
const
float
*
values
,
size_t
num
)
override
;
...
...
paddle/fluid/distributed/ps/table/common_graph_table.h
浏览文件 @
4d0d0eca
...
...
@@ -454,6 +454,9 @@ class GraphTable : public SparseTable {
int32_t
get_server_index_by_id
(
int64_t
id
);
Node
*
find_node
(
int64_t
id
);
virtual
int32_t
Pull
(
TableContext
&
context
)
{
return
0
;
}
virtual
int32_t
Push
(
TableContext
&
context
)
{
return
0
;
}
virtual
int32_t
pull_sparse
(
float
*
values
,
const
PullSparseValue
&
pull_value
)
{
return
0
;
...
...
paddle/fluid/distributed/ps/table/common_sparse_table.cc
浏览文件 @
4d0d0eca
...
...
@@ -355,6 +355,32 @@ int32_t CommonSparseTable::pour() {
return
0
;
}
int32_t
CommonSparseTable
::
Pull
(
TableContext
&
context
)
{
CHECK
(
context
.
value_type
==
Sparse
);
if
(
context
.
use_ptr
)
{
char
**
pull_values
=
context
.
pull_context
.
ptr_values
;
const
uint64_t
*
keys
=
context
.
pull_context
.
keys
;
return
pull_sparse_ptr
(
pull_values
,
keys
,
context
.
num
);
}
else
{
float
*
pull_values
=
context
.
pull_context
.
values
;
const
PullSparseValue
&
pull_value
=
context
.
pull_context
.
pull_value
;
return
pull_sparse
(
pull_values
,
pull_value
);
}
}
int32_t
CommonSparseTable
::
Push
(
TableContext
&
context
)
{
CHECK
(
context
.
value_type
==
Sparse
);
if
(
context
.
pull_context
.
values
!=
nullptr
)
{
const
float
*
values
=
context
.
push_context
.
values
;
const
uint64_t
*
keys
=
context
.
push_context
.
keys
;
return
push_sparse
(
keys
,
values
,
context
.
num
);
}
else
{
const
float
**
values
=
context
.
push_context
.
ptr_values
;
const
uint64_t
*
keys
=
context
.
push_context
.
keys
;
return
push_sparse
(
keys
,
values
,
context
.
num
);
}
}
int32_t
CommonSparseTable
::
pull_sparse
(
float
*
pull_values
,
const
PullSparseValue
&
pull_value
)
{
auto
shard_num
=
task_pool_size_
;
...
...
paddle/fluid/distributed/ps/table/common_sparse_table.h
浏览文件 @
4d0d0eca
...
...
@@ -121,6 +121,9 @@ class CommonSparseTable : public SparseTable {
virtual
int32_t
push_dense
(
const
float
*
values
,
size_t
num
)
{
return
0
;
}
// unused method end
virtual
int32_t
Pull
(
TableContext
&
context
);
virtual
int32_t
Push
(
TableContext
&
context
);
virtual
int32_t
initialize
();
virtual
int32_t
initialize_shard
()
{
return
0
;
}
virtual
int32_t
initialize_value
();
...
...
paddle/fluid/distributed/ps/table/common_table.h
浏览文件 @
4d0d0eca
...
...
@@ -119,6 +119,9 @@ class BarrierTable : public Table {
virtual
void
*
get_shard
(
size_t
shard_idx
)
{
return
0
;
}
virtual
int32_t
Pull
(
TableContext
&
context
)
{
return
0
;
}
virtual
int32_t
Push
(
TableContext
&
context
)
{
return
0
;
}
int32_t
pull_dense
(
float
*
values
,
size_t
num
)
override
{
return
0
;
}
int32_t
push_dense
(
const
float
*
values
,
size_t
num
)
override
{
return
0
;
}
...
...
paddle/fluid/distributed/ps/table/ctr_accessor.cc
浏览文件 @
4d0d0eca
...
...
@@ -38,6 +38,16 @@ int CtrCommonAccessor::initialize() {
return
0
;
}
void
CtrCommonAccessor
::
GetTableInfo
(
AccessorInfo
&
info
)
{
info
.
dim
=
dim
();
info
.
size
=
size
();
info
.
select_dim
=
select_dim
();
info
.
select_size
=
select_size
();
info
.
update_dim
=
update_dim
();
info
.
update_size
=
update_size
();
info
.
fea_dim
=
fea_dim
();
}
size_t
CtrCommonAccessor
::
dim
()
{
return
common_feature_value
.
dim
();
}
size_t
CtrCommonAccessor
::
dim_size
(
size_t
dim
)
{
...
...
paddle/fluid/distributed/ps/table/ctr_accessor.h
浏览文件 @
4d0d0eca
...
...
@@ -126,6 +126,7 @@ class CtrCommonAccessor : public ValueAccessor {
virtual
int
initialize
();
virtual
~
CtrCommonAccessor
()
{}
virtual
void
GetTableInfo
(
AccessorInfo
&
info
);
// value维度
virtual
size_t
dim
();
// value各个维度的size
...
...
paddle/fluid/distributed/ps/table/ctr_double_accessor.cc
浏览文件 @
4d0d0eca
...
...
@@ -37,6 +37,16 @@ int DownpourCtrDoubleAccessor::initialize() {
return
0
;
}
void
DownpourCtrDoubleAccessor
::
GetTableInfo
(
AccessorInfo
&
info
)
{
info
.
dim
=
dim
();
info
.
size
=
size
();
info
.
select_dim
=
select_dim
();
info
.
select_size
=
select_size
();
info
.
update_dim
=
update_dim
();
info
.
update_size
=
update_size
();
info
.
fea_dim
=
fea_dim
();
}
size_t
DownpourCtrDoubleAccessor
::
dim
()
{
auto
embedx_dim
=
_config
.
embedx_dim
();
return
DownpourCtrDoubleFeatureValue
::
dim
(
embedx_dim
);
...
...
paddle/fluid/distributed/ps/table/ctr_double_accessor.h
浏览文件 @
4d0d0eca
...
...
@@ -168,6 +168,7 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
DownpourCtrDoubleAccessor
()
{}
virtual
~
DownpourCtrDoubleAccessor
()
{}
virtual
int
initialize
();
virtual
void
GetTableInfo
(
AccessorInfo
&
info
);
// value维度
virtual
size_t
dim
();
// value各个维度的size
...
...
paddle/fluid/distributed/ps/table/depends/sparse_utils.h
浏览文件 @
4d0d0eca
...
...
@@ -58,7 +58,7 @@ struct PullSparseValue {
std
::
vector
<
int
>*
offset_shard
)
const
{
offset_shard
->
reserve
(
numel_
/
shard_num
+
1
);
for
(
int
x
=
0
;
x
<
numel_
;
++
x
)
{
if
(
feasigns_
[
x
]
%
shard_num
==
shard_id
)
{
if
(
int
(
feasigns_
[
x
]
%
shard_num
)
==
shard_id
)
{
offset_shard
->
push_back
(
x
);
}
}
...
...
paddle/fluid/distributed/ps/table/downpour_ctr_accessor.cc
浏览文件 @
4d0d0eca
...
...
@@ -37,6 +37,16 @@ int DownpourCtrAccessor::initialize() {
return
0
;
}
void
DownpourCtrAccessor
::
GetTableInfo
(
AccessorInfo
&
info
)
{
info
.
dim
=
dim
();
info
.
size
=
size
();
info
.
select_dim
=
select_dim
();
info
.
select_size
=
select_size
();
info
.
update_dim
=
update_dim
();
info
.
update_size
=
update_size
();
info
.
fea_dim
=
fea_dim
();
}
size_t
DownpourCtrAccessor
::
dim
()
{
auto
embedx_dim
=
_config
.
embedx_dim
();
return
DownpourCtrFeatureValue
::
dim
(
embedx_dim
);
...
...
paddle/fluid/distributed/ps/table/downpour_ctr_accessor.h
浏览文件 @
4d0d0eca
...
...
@@ -160,6 +160,7 @@ class DownpourCtrAccessor : public ValueAccessor {
virtual
~
DownpourCtrAccessor
()
{}
virtual
int
initialize
();
virtual
void
GetTableInfo
(
AccessorInfo
&
info
);
// value维度
virtual
size_t
dim
();
// value各个维度的size
...
...
paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h
浏览文件 @
4d0d0eca
...
...
@@ -48,6 +48,8 @@ class MemorySparseGeoTable : public SparseTable {
virtual
int32_t
save
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
{
return
0
;
}
virtual
int32_t
Pull
(
TableContext
&
context
)
{
return
0
;
}
virtual
int32_t
Push
(
TableContext
&
context
)
{
return
0
;
}
virtual
int32_t
flush
()
{
return
0
;
}
virtual
int32_t
shrink
(
const
std
::
string
&
param
)
{
return
0
;
}
virtual
void
clear
()
{
return
;
}
...
...
paddle/fluid/distributed/ps/table/memory_sparse_table.cc
浏览文件 @
4d0d0eca
...
...
@@ -390,6 +390,26 @@ std::pair<int64_t, int64_t> MemorySparseTable::print_table_stat() {
return
{
feasign_size
,
mf_size
};
}
int32_t
MemorySparseTable
::
Pull
(
TableContext
&
context
)
{
CHECK
(
context
.
value_type
==
Sparse
);
if
(
context
.
use_ptr
)
{
char
**
pull_values
=
context
.
pull_context
.
ptr_values
;
const
uint64_t
*
keys
=
context
.
pull_context
.
keys
;
return
pull_sparse_ptr
(
pull_values
,
keys
,
context
.
num
);
}
else
{
float
*
pull_values
=
context
.
pull_context
.
values
;
const
PullSparseValue
&
pull_value
=
context
.
pull_context
.
pull_value
;
return
pull_sparse
(
pull_values
,
pull_value
);
}
}
int32_t
MemorySparseTable
::
Push
(
TableContext
&
context
)
{
CHECK
(
context
.
value_type
==
Sparse
);
const
uint64_t
*
keys
=
context
.
push_context
.
keys
;
return
push_sparse
(
keys
,
context
.
push_context
.
ptr_values
,
context
.
num
);
}
int32_t
MemorySparseTable
::
pull_sparse
(
float
*
pull_values
,
const
PullSparseValue
&
pull_value
)
{
CostTimer
timer
(
"pserver_sparse_select_all"
);
...
...
paddle/fluid/distributed/ps/table/memory_sparse_table.h
浏览文件 @
4d0d0eca
...
...
@@ -48,6 +48,9 @@ class MemorySparseTable : public SparseTable {
virtual
int32_t
push_dense
(
const
float
*
values
,
size_t
num
)
{
return
0
;
}
// unused method end
virtual
int32_t
Pull
(
TableContext
&
context
);
virtual
int32_t
Push
(
TableContext
&
context
);
virtual
int32_t
initialize
();
virtual
int32_t
initialize_shard
()
{
return
0
;
}
virtual
int32_t
initialize_value
();
...
...
paddle/fluid/distributed/ps/table/ssd_sparse_table.cc
浏览文件 @
4d0d0eca
...
...
@@ -61,6 +61,21 @@ int32_t SSDSparseTable::initialize() {
return
0
;
}
int32_t
SSDSparseTable
::
Pull
(
TableContext
&
context
)
{
CHECK
(
context
.
value_type
==
Sparse
);
if
(
context
.
use_ptr
)
{
char
**
pull_values
=
context
.
pull_context
.
ptr_values
;
const
uint64_t
*
keys
=
context
.
pull_context
.
keys
;
return
pull_sparse_ptr
(
pull_values
,
keys
,
context
.
num
);
}
else
{
float
*
pull_values
=
context
.
pull_context
.
values
;
const
PullSparseValue
&
pull_value
=
context
.
pull_context
.
pull_value
;
return
pull_sparse
(
pull_values
,
pull_value
);
}
}
int32_t
SSDSparseTable
::
Push
(
TableContext
&
context
)
{
return
0
;
}
int32_t
SSDSparseTable
::
pull_sparse
(
float
*
pull_values
,
const
PullSparseValue
&
pull_value
)
{
auto
shard_num
=
task_pool_size_
;
...
...
paddle/fluid/distributed/ps/table/ssd_sparse_table.h
浏览文件 @
4d0d0eca
...
...
@@ -42,6 +42,9 @@ class SSDSparseTable : public CommonSparseTable {
// exchange data
virtual
int32_t
update_table
();
virtual
int32_t
Pull
(
TableContext
&
context
);
virtual
int32_t
Push
(
TableContext
&
context
);
virtual
int32_t
pull_sparse
(
float
*
values
,
const
PullSparseValue
&
pull_value
);
virtual
int32_t
pull_sparse_ptr
(
char
**
pull_values
,
const
uint64_t
*
keys
,
...
...
paddle/fluid/distributed/ps/table/table.h
浏览文件 @
4d0d0eca
...
...
@@ -32,6 +32,30 @@
namespace
paddle
{
namespace
distributed
{
enum
ValueType
{
Sparse
=
0
,
Dense
=
1
};
struct
PullContext
{
const
uint64_t
*
keys
;
const
PullSparseValue
pull_value
;
float
*
values
;
char
**
ptr_values
;
};
struct
TablePushContext
{
const
uint64_t
*
keys
;
const
float
*
values
;
const
float
**
ptr_values
;
};
struct
TableContext
{
ValueType
value_type
;
PullContext
pull_context
;
TablePushContext
push_context
;
size_t
num
;
bool
use_ptr
;
};
class
Table
{
public:
Table
()
{}
...
...
@@ -39,6 +63,8 @@ class Table {
virtual
int32_t
initialize
(
const
TableParameter
&
config
,
const
FsClientParameter
&
fs_config
);
virtual
int32_t
Pull
(
TableContext
&
context
)
=
0
;
virtual
int32_t
Push
(
TableContext
&
context
)
=
0
;
virtual
int32_t
pull_dense
(
float
*
values
,
size_t
num
)
=
0
;
virtual
int32_t
push_dense
(
const
float
*
values
,
size_t
num
)
=
0
;
// for push global_step
...
...
paddle/fluid/distributed/ps/table/tensor_accessor.cc
浏览文件 @
4d0d0eca
...
...
@@ -20,6 +20,16 @@ namespace distributed {
int
CommMergeAccessor
::
initialize
()
{
return
0
;
}
void
CommMergeAccessor
::
GetTableInfo
(
AccessorInfo
&
info
)
{
info
.
dim
=
dim
();
info
.
size
=
size
();
info
.
select_dim
=
select_dim
();
info
.
select_size
=
select_size
();
info
.
update_dim
=
update_dim
();
info
.
update_size
=
update_size
();
info
.
fea_dim
=
fea_dim
();
}
// value 维度
size_t
CommMergeAccessor
::
dim
()
{
return
0
;
}
...
...
paddle/fluid/distributed/ps/table/tensor_accessor.h
浏览文件 @
4d0d0eca
...
...
@@ -30,6 +30,7 @@ class CommMergeAccessor : public ValueAccessor {
CommMergeAccessor
()
{}
virtual
~
CommMergeAccessor
()
{}
virtual
int
initialize
();
virtual
void
GetTableInfo
(
AccessorInfo
&
info
);
// value维度
virtual
size_t
dim
();
// value各个维度的size
...
...
paddle/fluid/distributed/ps/table/tensor_table.h
浏览文件 @
4d0d0eca
...
...
@@ -48,6 +48,8 @@ class TensorTable : public Table {
TensorTable
()
{}
virtual
~
TensorTable
()
{}
virtual
int32_t
Pull
(
TableContext
&
context
)
{
return
0
;
}
virtual
int32_t
Push
(
TableContext
&
context
)
{
return
0
;
}
int32_t
pull_dense
(
float
*
values
,
size_t
num
)
override
{
return
0
;
}
int32_t
push_dense
(
const
float
*
values
,
size_t
num
)
override
{
return
0
;
}
...
...
paddle/fluid/distributed/ps/wrapper/fleet.cc
浏览文件 @
4d0d0eca
...
...
@@ -30,6 +30,32 @@ bool FleetWrapper::is_initialized_ = false;
std
::
shared_ptr
<
paddle
::
distributed
::
PSCore
>
FleetWrapper
::
pserver_ptr_
=
NULL
;
void
FleetWrapper
::
Stop
()
{
StopServer
();
}
void
FleetWrapper
::
Load
(
WrapperContext
&
context
)
{
auto
table_id
=
context
.
table_id
;
if
(
table_id
>=
0
&&
context
.
meta
!=
""
)
{
LoadSparseOnServer
(
context
.
path
,
context
.
meta
,
context
.
table_id
);
return
;
}
if
(
table_id
<
0
)
{
// laod all
LoadModel
(
context
.
path
,
context
.
mode
);
}
else
{
// load one table
LoadModelOneTable
(
table_id
,
context
.
path
,
context
.
mode
);
}
return
;
}
void
FleetWrapper
::
Save
(
WrapperContext
&
context
)
{
auto
table_id
=
context
.
table_id
;
if
(
table_id
<
0
)
{
SaveModel
(
context
.
path
,
context
.
mode
);
}
else
{
SaveModelOneTable
(
table_id
,
context
.
path
,
context
.
mode
);
}
return
;
}
void
FleetWrapper
::
SetClient2ClientConfig
(
int
request_timeout_ms
,
int
connect_timeout_ms
,
int
max_retry
)
{
...
...
paddle/fluid/distributed/ps/wrapper/fleet.h
浏览文件 @
4d0d0eca
...
...
@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h"
#include "paddle/fluid/distributed/ps/service/ps_service/service.h"
#include "paddle/fluid/distributed/ps/wrapper/ps_wrapper.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/framework/io/shell.h"
...
...
@@ -54,7 +55,7 @@ using framework::Variable;
using
RpcCtxMap
=
std
::
unordered_map
<
std
::
string
,
CommContext
>
;
class
FleetWrapper
{
class
FleetWrapper
:
public
PSWrapper
{
public:
virtual
~
FleetWrapper
()
{}
FleetWrapper
()
{
...
...
@@ -68,7 +69,13 @@ class FleetWrapper {
// pserver request max retry
client2client_max_retry_
=
3
;
}
virtual
int32_t
Initialize
(
InitContext
&
context
)
{
return
0
;
}
virtual
void
Stop
()
override
;
virtual
void
Load
(
WrapperContext
&
context
)
override
;
virtual
void
Save
(
WrapperContext
&
context
)
override
;
// set client to client communication config
void
SetClient2ClientConfig
(
int
request_timeout_ms
,
int
connect_timeout_ms
,
int
max_retry
);
...
...
paddle/fluid/distributed/ps/wrapper/ps_wrapper.h
浏览文件 @
4d0d0eca
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_FLUID_DISTRIBUTED_PS_WRAPPER_PS_WRAPPER_H_
#define PADDLE_FLUID_DISTRIBUTED_PS_WRAPPER_PS_WRAPPER_H_
#endif // PADDLE_FLUID_DISTRIBUTED_PS_WRAPPER_PS_WRAPPER_H_
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <atomic>
#include <ctime>
#include <map>
#include <memory>
#include <random>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h"
#include "paddle/fluid/distributed/ps/service/ps_service/service.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/framework/io/shell.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace
paddle
{
namespace
framework
{
class
Scope
;
class
SelectedRows
;
class
Variable
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
distributed
{
class
PSCore
;
using
framework
::
LoDTensor
;
using
framework
::
Scope
;
using
phi
::
SelectedRows
;
using
framework
::
Variable
;
using
RpcCtxMap
=
std
::
unordered_map
<
std
::
string
,
CommContext
>
;
struct
WrapperContext
{
uint32_t
table_id
;
const
std
::
string
path
;
const
int
mode
;
const
std
::
string
meta
;
};
struct
InitContext
{
const
std
::
vector
<
int
>
dev_ids
;
// for gpu
};
class
PSWrapper
{
public:
virtual
~
PSWrapper
()
{}
PSWrapper
()
{}
// init server
virtual
int32_t
Initialize
(
InitContext
&
context
)
=
0
;
virtual
void
Stop
()
=
0
;
virtual
void
Load
(
WrapperContext
&
context
)
=
0
;
virtual
void
Save
(
WrapperContext
&
context
)
=
0
;
};
}
// end namespace distributed
}
// end namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录