Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
4d0d0eca
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4d0d0eca
编写于
3月 21, 2022
作者:
Y
yaoxuefeng
提交者:
GitHub
3月 21, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
mod base (#40702)
上级
382e460b
变更
36
显示空白变更内容
内联
并排
Showing
36 changed file
with
499 addition
and
40 deletion
+499
-40
paddle/fluid/distributed/ps/service/brpc_ps_client.cc
paddle/fluid/distributed/ps/service/brpc_ps_client.cc
+61
-0
paddle/fluid/distributed/ps/service/brpc_ps_client.h
paddle/fluid/distributed/ps/service/brpc_ps_client.h
+9
-0
paddle/fluid/distributed/ps/service/brpc_ps_server.h
paddle/fluid/distributed/ps/service/brpc_ps_server.h
+1
-1
paddle/fluid/distributed/ps/service/graph_brpc_server.h
paddle/fluid/distributed/ps/service/graph_brpc_server.h
+1
-1
paddle/fluid/distributed/ps/service/ps_client.h
paddle/fluid/distributed/ps/service/ps_client.h
+46
-0
paddle/fluid/distributed/ps/service/ps_local_client.cc
paddle/fluid/distributed/ps/service/ps_local_client.cc
+73
-0
paddle/fluid/distributed/ps/service/ps_local_client.h
paddle/fluid/distributed/ps/service/ps_local_client.h
+8
-0
paddle/fluid/distributed/ps/service/ps_local_server.h
paddle/fluid/distributed/ps/service/ps_local_server.h
+0
-1
paddle/fluid/distributed/ps/service/server.cc
paddle/fluid/distributed/ps/service/server.cc
+0
-2
paddle/fluid/distributed/ps/service/server.h
paddle/fluid/distributed/ps/service/server.h
+0
-15
paddle/fluid/distributed/ps/table/accessor.h
paddle/fluid/distributed/ps/table/accessor.h
+14
-0
paddle/fluid/distributed/ps/table/common_dense_table.cc
paddle/fluid/distributed/ps/table/common_dense_table.cc
+15
-0
paddle/fluid/distributed/ps/table/common_dense_table.h
paddle/fluid/distributed/ps/table/common_dense_table.h
+2
-0
paddle/fluid/distributed/ps/table/common_graph_table.h
paddle/fluid/distributed/ps/table/common_graph_table.h
+3
-0
paddle/fluid/distributed/ps/table/common_sparse_table.cc
paddle/fluid/distributed/ps/table/common_sparse_table.cc
+26
-0
paddle/fluid/distributed/ps/table/common_sparse_table.h
paddle/fluid/distributed/ps/table/common_sparse_table.h
+3
-0
paddle/fluid/distributed/ps/table/common_table.h
paddle/fluid/distributed/ps/table/common_table.h
+3
-0
paddle/fluid/distributed/ps/table/ctr_accessor.cc
paddle/fluid/distributed/ps/table/ctr_accessor.cc
+10
-0
paddle/fluid/distributed/ps/table/ctr_accessor.h
paddle/fluid/distributed/ps/table/ctr_accessor.h
+1
-0
paddle/fluid/distributed/ps/table/ctr_double_accessor.cc
paddle/fluid/distributed/ps/table/ctr_double_accessor.cc
+10
-0
paddle/fluid/distributed/ps/table/ctr_double_accessor.h
paddle/fluid/distributed/ps/table/ctr_double_accessor.h
+1
-0
paddle/fluid/distributed/ps/table/depends/sparse_utils.h
paddle/fluid/distributed/ps/table/depends/sparse_utils.h
+1
-1
paddle/fluid/distributed/ps/table/downpour_ctr_accessor.cc
paddle/fluid/distributed/ps/table/downpour_ctr_accessor.cc
+10
-0
paddle/fluid/distributed/ps/table/downpour_ctr_accessor.h
paddle/fluid/distributed/ps/table/downpour_ctr_accessor.h
+1
-0
paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h
paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h
+2
-0
paddle/fluid/distributed/ps/table/memory_sparse_table.cc
paddle/fluid/distributed/ps/table/memory_sparse_table.cc
+20
-0
paddle/fluid/distributed/ps/table/memory_sparse_table.h
paddle/fluid/distributed/ps/table/memory_sparse_table.h
+3
-0
paddle/fluid/distributed/ps/table/ssd_sparse_table.cc
paddle/fluid/distributed/ps/table/ssd_sparse_table.cc
+15
-0
paddle/fluid/distributed/ps/table/ssd_sparse_table.h
paddle/fluid/distributed/ps/table/ssd_sparse_table.h
+3
-0
paddle/fluid/distributed/ps/table/table.h
paddle/fluid/distributed/ps/table/table.h
+26
-0
paddle/fluid/distributed/ps/table/tensor_accessor.cc
paddle/fluid/distributed/ps/table/tensor_accessor.cc
+10
-0
paddle/fluid/distributed/ps/table/tensor_accessor.h
paddle/fluid/distributed/ps/table/tensor_accessor.h
+1
-0
paddle/fluid/distributed/ps/table/tensor_table.h
paddle/fluid/distributed/ps/table/tensor_table.h
+2
-0
paddle/fluid/distributed/ps/wrapper/fleet.cc
paddle/fluid/distributed/ps/wrapper/fleet.cc
+26
-0
paddle/fluid/distributed/ps/wrapper/fleet.h
paddle/fluid/distributed/ps/wrapper/fleet.h
+8
-1
paddle/fluid/distributed/ps/wrapper/ps_wrapper.h
paddle/fluid/distributed/ps/wrapper/ps_wrapper.h
+84
-18
未找到文件。
paddle/fluid/distributed/ps/service/brpc_ps_client.cc
浏览文件 @
4d0d0eca
...
@@ -414,6 +414,16 @@ std::future<int32_t> BrpcPsClient::load(uint32_t table_id,
...
@@ -414,6 +414,16 @@ std::future<int32_t> BrpcPsClient::load(uint32_t table_id,
return
send_cmd
(
table_id
,
PS_LOAD_ONE_TABLE
,
{
epoch
,
mode
});
return
send_cmd
(
table_id
,
PS_LOAD_ONE_TABLE
,
{
epoch
,
mode
});
}
}
std
::
future
<
int32_t
>
BrpcPsClient
::
Load
(
const
LoadSaveContext
&
load_context
)
{
if
(
load_context
.
table_id
<
0
)
{
return
send_cmd
(
-
1
,
PS_LOAD_ALL_TABLE
,
{
load_context
.
epoch
,
load_context
.
mode
});
}
else
{
return
send_cmd
(
load_context
.
table_id
,
PS_LOAD_ONE_TABLE
,
{
load_context
.
epoch
,
load_context
.
mode
});
}
}
std
::
future
<
int32_t
>
BrpcPsClient
::
save
(
const
std
::
string
&
epoch
,
std
::
future
<
int32_t
>
BrpcPsClient
::
save
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
{
const
std
::
string
&
mode
)
{
VLOG
(
1
)
<<
"BrpcPsClient::save path "
<<
epoch
;
VLOG
(
1
)
<<
"BrpcPsClient::save path "
<<
epoch
;
...
@@ -427,6 +437,19 @@ std::future<int32_t> BrpcPsClient::save(uint32_t table_id,
...
@@ -427,6 +437,19 @@ std::future<int32_t> BrpcPsClient::save(uint32_t table_id,
return
send_save_cmd
(
table_id
,
PS_SAVE_ONE_TABLE
,
{
epoch
,
mode
});
return
send_save_cmd
(
table_id
,
PS_SAVE_ONE_TABLE
,
{
epoch
,
mode
});
}
}
std
::
future
<
int32_t
>
BrpcPsClient
::
Save
(
const
LoadSaveContext
&
save_context
)
{
if
(
save_context
.
table_id
<
0
)
{
VLOG
(
1
)
<<
"BrpcPsClient::save path "
<<
save_context
.
epoch
;
return
send_save_cmd
(
-
1
,
PS_SAVE_ALL_TABLE
,
{
save_context
.
epoch
,
save_context
.
mode
});
}
else
{
VLOG
(
1
)
<<
"BrpcPsClient::save one table path "
<<
save_context
.
epoch
<<
" table_id "
<<
save_context
.
table_id
;
return
send_save_cmd
(
save_context
.
table_id
,
PS_SAVE_ONE_TABLE
,
{
save_context
.
epoch
,
save_context
.
mode
});
}
}
std
::
future
<
int32_t
>
BrpcPsClient
::
clear
()
{
std
::
future
<
int32_t
>
BrpcPsClient
::
clear
()
{
return
send_cmd
(
-
1
,
PS_CLEAR_ALL_TABLE
,
{});
return
send_cmd
(
-
1
,
PS_CLEAR_ALL_TABLE
,
{});
}
}
...
@@ -505,6 +528,44 @@ std::future<int32_t> BrpcPsClient::barrier(size_t table_id,
...
@@ -505,6 +528,44 @@ std::future<int32_t> BrpcPsClient::barrier(size_t table_id,
return
send_cmd
(
table_id
,
PS_BARRIER
,
{
std
::
to_string
(
barrier_type
)});
return
send_cmd
(
table_id
,
PS_BARRIER
,
{
std
::
to_string
(
barrier_type
)});
}
}
std
::
future
<
int32_t
>
BrpcPsClient
::
Pull
(
RequestContext
&
pull_context
)
{
if
(
pull_context
.
value_type
==
Dense
)
{
// pull dense
Region
*
dense_region
=
reinterpret_cast
<
Region
*>
(
pull_context
.
dense_values
);
pull_dense
(
dense_region
,
pull_context
.
num
,
pull_context
.
table
);
}
else
{
// pull sparse
uint64_t
*
keys
=
reinterpret_cast
<
uint64_t
*>
(
pull_context
.
keys
);
float
**
select_values
=
reinterpret_cast
<
float
**>
(
pull_context
.
sparse_values
);
size_t
table_id
=
pull_context
.
table
;
size_t
num
=
pull_context
.
num
;
bool
is_training
=
pull_context
.
is_training
;
if
(
pull_context
.
training_mode
==
Geo
)
{
// for geo
pull_sparse_param
(
select_values
,
table_id
,
keys
,
num
,
is_training
);
}
else
if
(
pull_context
.
training_mode
==
Async
)
{
// for async
pull_sparse
(
select_values
,
table_id
,
keys
,
num
,
is_training
);
}
}
}
std
::
future
<
int32_t
>
BrpcPsClient
::
Push
(
RequestContext
&
push_context
)
{
if
(
push_context
.
value_type
==
Dense
)
{
// push dense
const
Region
*
dense_region
=
push_context
.
push_context
.
push_dense_values
;
push_dense
(
dense_region
,
push_context
.
num
,
push_context
.
table
);
}
else
{
// push sparse
size_t
table_id
=
push_context
.
table
;
size_t
num
=
push_context
.
num
;
bool
is_training
=
push_context
.
is_training
;
if
(
push_context
.
training_mode
==
Geo
)
{
// for geo
// TODO(zhaocaibei)
}
else
if
(
push_context
.
training_mode
==
Async
)
{
// for async
const
uint64_t
*
keys
=
push_context
.
push_context
.
keys
;
const
float
**
update_values
=
push_context
.
push_context
.
push_values
;
push_sparse
(
table_id
,
keys
,
update_values
,
num
);
}
}
}
std
::
future
<
int32_t
>
BrpcPsClient
::
pull_geo_param
(
size_t
table_id
,
std
::
future
<
int32_t
>
BrpcPsClient
::
pull_geo_param
(
size_t
table_id
,
std
::
vector
<
float
>
*
values
,
std
::
vector
<
float
>
*
values
,
std
::
vector
<
uint64_t
>
*
keys
,
std
::
vector
<
uint64_t
>
*
keys
,
...
...
paddle/fluid/distributed/ps/service/brpc_ps_client.h
浏览文件 @
4d0d0eca
...
@@ -163,12 +163,17 @@ class BrpcPsClient : public PSClient {
...
@@ -163,12 +163,17 @@ class BrpcPsClient : public PSClient {
std
::
future
<
int32_t
>
load
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
std
::
future
<
int32_t
>
load
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
override
;
const
std
::
string
&
mode
)
override
;
std
::
future
<
int32_t
>
Load
(
const
LoadSaveContext
&
load_context
)
override
;
std
::
future
<
int32_t
>
save
(
const
std
::
string
&
epoch
,
std
::
future
<
int32_t
>
save
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
override
;
const
std
::
string
&
mode
)
override
;
std
::
future
<
int32_t
>
save
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
std
::
future
<
int32_t
>
save
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
override
;
const
std
::
string
&
mode
)
override
;
virtual
std
::
future
<
int32_t
>
Save
(
const
LoadSaveContext
&
save_context
)
override
;
std
::
future
<
int32_t
>
clear
()
override
;
std
::
future
<
int32_t
>
clear
()
override
;
std
::
future
<
int32_t
>
clear
(
uint32_t
table_id
)
override
;
std
::
future
<
int32_t
>
clear
(
uint32_t
table_id
)
override
;
...
@@ -199,6 +204,10 @@ class BrpcPsClient : public PSClient {
...
@@ -199,6 +204,10 @@ class BrpcPsClient : public PSClient {
const
uint64_t
*
keys
,
const
uint64_t
*
keys
,
size_t
num
,
bool
is_training
);
size_t
num
,
bool
is_training
);
virtual
std
::
future
<
int32_t
>
Pull
(
RequestContext
&
pull_context
)
override
;
virtual
std
::
future
<
int32_t
>
Push
(
RequestContext
&
push_context
)
override
;
virtual
std
::
future
<
int32_t
>
print_table_stat
(
uint32_t
table_id
);
virtual
std
::
future
<
int32_t
>
print_table_stat
(
uint32_t
table_id
);
virtual
std
::
future
<
int32_t
>
barrier
(
size_t
table_id
,
uint32_t
barrier_type
);
virtual
std
::
future
<
int32_t
>
barrier
(
size_t
table_id
,
uint32_t
barrier_type
);
...
...
paddle/fluid/distributed/ps/service/brpc_ps_server.h
浏览文件 @
4d0d0eca
...
@@ -51,7 +51,7 @@ class BrpcPsServer : public PSServer {
...
@@ -51,7 +51,7 @@ class BrpcPsServer : public PSServer {
_server
.
Join
();
_server
.
Join
();
return
0
;
return
0
;
}
}
virtual
int32_t
port
();
int32_t
port
();
private:
private:
virtual
int32_t
initialize
();
virtual
int32_t
initialize
();
...
...
paddle/fluid/distributed/ps/service/graph_brpc_server.h
浏览文件 @
4d0d0eca
...
@@ -43,7 +43,7 @@ class GraphBrpcServer : public PSServer {
...
@@ -43,7 +43,7 @@ class GraphBrpcServer : public PSServer {
_server
.
Join
();
_server
.
Join
();
return
0
;
return
0
;
}
}
virtual
int32_t
port
();
int32_t
port
();
std
::
condition_variable
*
export_cv
()
{
return
&
cv_
;
}
std
::
condition_variable
*
export_cv
()
{
return
&
cv_
;
}
...
...
paddle/fluid/distributed/ps/service/ps_client.h
浏览文件 @
4d0d0eca
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/distributed/ps/table/table.h"
#include "paddle/fluid/platform/timer.h"
#include "paddle/fluid/platform/timer.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -59,6 +60,41 @@ class PSClientClosure : public google::protobuf::Closure {
...
@@ -59,6 +60,41 @@ class PSClientClosure : public google::protobuf::Closure {
std
::
vector
<
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>>
_promises
;
std
::
vector
<
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>>
_promises
;
};
};
struct
LoadSaveContext
{
int
table_id
;
std
::
string
epoch
;
std
::
string
mode
;
};
enum
TrainingMode
{
Async
=
0
,
Sync
=
1
,
Geo
=
3
};
enum
TrainingPhase
{
Init
=
0
,
Train
=
1
,
Save
=
2
};
// enum ValueType {
// Sparse = 0,
// Dense = 1
// };
struct
PushContext
{
const
uint64_t
*
keys
;
const
float
**
push_values
;
const
Region
*
push_dense_values
;
};
struct
RequestContext
{
int
table
;
TrainingMode
training_mode
;
// 1 for async, 2 for geo, 3 for sync
TrainingPhase
training_phase
;
// 1 for init, 2 for train
ValueType
value_type
;
// 1 for sparse, 2 for dense
void
*
keys
;
void
**
sparse_values
;
// for sparse values
Region
*
dense_values
;
// for dense values
PushContext
push_context
;
size_t
num
;
bool
is_training
;
void
*
callback
;
};
class
PSClient
{
class
PSClient
{
public:
public:
PSClient
()
{}
PSClient
()
{}
...
@@ -86,6 +122,9 @@ class PSClient {
...
@@ -86,6 +122,9 @@ class PSClient {
// 指定table数据load
// 指定table数据load
virtual
std
::
future
<
int32_t
>
load
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
virtual
std
::
future
<
int32_t
>
load
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
=
0
;
const
std
::
string
&
mode
)
=
0
;
// context配置load选项
virtual
std
::
future
<
int32_t
>
Load
(
const
LoadSaveContext
&
load_context
)
=
0
;
// 全量table数据save value_accessor根据mode,可能有不同的save条件
// 全量table数据save value_accessor根据mode,可能有不同的save条件
virtual
std
::
future
<
int32_t
>
save
(
const
std
::
string
&
epoch
,
virtual
std
::
future
<
int32_t
>
save
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
=
0
;
const
std
::
string
&
mode
)
=
0
;
...
@@ -93,6 +132,8 @@ class PSClient {
...
@@ -93,6 +132,8 @@ class PSClient {
virtual
std
::
future
<
int32_t
>
save
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
virtual
std
::
future
<
int32_t
>
save
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
=
0
;
const
std
::
string
&
mode
)
=
0
;
virtual
std
::
future
<
int32_t
>
Save
(
const
LoadSaveContext
&
save_context
)
=
0
;
// 清空table数据
// 清空table数据
virtual
std
::
future
<
int32_t
>
clear
()
=
0
;
virtual
std
::
future
<
int32_t
>
clear
()
=
0
;
virtual
std
::
future
<
int32_t
>
clear
(
uint32_t
table_id
)
=
0
;
virtual
std
::
future
<
int32_t
>
clear
(
uint32_t
table_id
)
=
0
;
...
@@ -107,6 +148,8 @@ class PSClient {
...
@@ -107,6 +148,8 @@ class PSClient {
virtual
std
::
future
<
int32_t
>
pull_dense
(
Region
*
regions
,
size_t
region_num
,
virtual
std
::
future
<
int32_t
>
pull_dense
(
Region
*
regions
,
size_t
region_num
,
size_t
table_id
)
=
0
;
// 保留
size_t
table_id
)
=
0
;
// 保留
virtual
std
::
future
<
int32_t
>
Push
(
RequestContext
&
push_context
)
=
0
;
// firstly push dense param for parameter server
// firstly push dense param for parameter server
// this is neccessary because dense weight initialized in trainer on cold
// this is neccessary because dense weight initialized in trainer on cold
// start
// start
...
@@ -117,6 +160,9 @@ class PSClient {
...
@@ -117,6 +160,9 @@ class PSClient {
virtual
std
::
future
<
int32_t
>
push_dense
(
const
Region
*
regions
,
virtual
std
::
future
<
int32_t
>
push_dense
(
const
Region
*
regions
,
size_t
region_num
,
size_t
region_num
,
size_t
table_id
)
=
0
;
size_t
table_id
)
=
0
;
virtual
std
::
future
<
int32_t
>
Pull
(
RequestContext
&
pull_context
)
=
0
;
// 使用keys进行pull请求,结果填充values
// 使用keys进行pull请求,结果填充values
// keys和values的个数均为num个,每个value占用select_size空间
// keys和values的个数均为num个,每个value占用select_size空间
// future结束前keys和values缓冲区不能再次使用
// future结束前keys和values缓冲区不能再次使用
...
...
paddle/fluid/distributed/ps/service/ps_local_client.cc
浏览文件 @
4d0d0eca
...
@@ -56,6 +56,19 @@ int32_t PsLocalClient::initialize() {
...
@@ -56,6 +56,19 @@ int32_t PsLocalClient::initialize() {
return
done
();
return
done
();
}
}
std
::
future
<
int32_t
>
PsLocalClient
::
Load
(
const
LoadSaveContext
&
load_context
)
{
if
(
load_context
.
table_id
<
0
)
{
for
(
auto
&
it
:
_table_map
)
{
load
(
it
.
first
,
load_context
.
epoch
,
load_context
.
mode
);
}
return
done
();
}
else
{
auto
*
table_ptr
=
table
(
load_context
.
table_id
);
table_ptr
->
load
(
load_context
.
epoch
,
load_context
.
mode
);
return
done
();
}
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
save
(
const
std
::
string
&
epoch
,
::
std
::
future
<
int32_t
>
PsLocalClient
::
save
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
{
const
std
::
string
&
mode
)
{
// TODO
// TODO
...
@@ -74,6 +87,21 @@ int32_t PsLocalClient::initialize() {
...
@@ -74,6 +87,21 @@ int32_t PsLocalClient::initialize() {
return
done
();
return
done
();
}
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
Save
(
const
LoadSaveContext
&
save_context
)
{
if
(
save_context
.
table_id
<
0
)
{
for
(
auto
&
it
:
_table_map
)
{
save
(
it
.
first
,
save_context
.
epoch
,
save_context
.
mode
);
}
return
done
();
}
else
{
auto
*
table_ptr
=
table
(
save_context
.
table_id
);
table_ptr
->
flush
();
table_ptr
->
save
(
save_context
.
epoch
,
save_context
.
mode
);
return
done
();
}
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
clear
()
{
::
std
::
future
<
int32_t
>
PsLocalClient
::
clear
()
{
// TODO
// TODO
return
done
();
return
done
();
...
@@ -93,6 +121,51 @@ int32_t PsLocalClient::initialize() {
...
@@ -93,6 +121,51 @@ int32_t PsLocalClient::initialize() {
return
done
();
return
done
();
}
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
Pull
(
RequestContext
&
pull_context
)
{
if
(
pull_context
.
value_type
==
Dense
)
{
// pull dense
Region
*
dense_region
=
reinterpret_cast
<
Region
*>
(
pull_context
.
dense_values
);
pull_dense
(
dense_region
,
pull_context
.
num
,
pull_context
.
table
);
}
else
{
// pull sparse
uint64_t
*
keys
=
reinterpret_cast
<
uint64_t
*>
(
pull_context
.
keys
);
char
**
select_values
=
reinterpret_cast
<
char
**>
(
pull_context
.
sparse_values
);
size_t
table_id
=
pull_context
.
table
;
size_t
num
=
pull_context
.
num
;
pull_sparse_ptr
(
select_values
,
table_id
,
keys
,
num
);
}
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
Push
(
RequestContext
&
push_context
)
{
if
(
push_context
.
value_type
==
Dense
)
{
// push dense
if
(
push_context
.
training_phase
==
Init
)
{
const
Region
*
regions
=
push_context
.
push_context
.
push_dense_values
;
size_t
region_num
=
push_context
.
num
;
push_dense_param
(
regions
,
region_num
,
push_context
.
table
);
}
else
{
if
(
push_context
.
training_mode
==
Geo
)
{
// geo
float
*
total_send_data
=
reinterpret_cast
<
float
*>
(
push_context
.
dense_values
);
size_t
total_send_data_size
=
push_context
.
num
;
push_dense_raw_gradient
(
push_context
.
table
,
total_send_data
,
total_send_data_size
,
push_context
.
callback
);
}
else
{
// async and sync
const
Region
*
regions
=
push_context
.
push_context
.
push_dense_values
;
size_t
region_num
=
push_context
.
num
;
push_dense
(
regions
,
region_num
,
push_context
.
table
);
}
}
}
else
{
// push sparse
if
(
push_context
.
training_mode
==
Async
)
{
const
uint64_t
*
keys
=
push_context
.
push_context
.
keys
;
const
float
**
update_values
=
push_context
.
push_context
.
push_values
;
size_t
table_id
=
push_context
.
table
;
size_t
num
=
push_context
.
num
;
push_sparse
(
table_id
,
keys
,
update_values
,
num
);
}
else
{
// TODO
}
}
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
pull_dense
(
Region
*
regions
,
::
std
::
future
<
int32_t
>
PsLocalClient
::
pull_dense
(
Region
*
regions
,
size_t
region_num
,
size_t
region_num
,
size_t
table_id
)
{
size_t
table_id
)
{
...
...
paddle/fluid/distributed/ps/service/ps_local_client.h
浏览文件 @
4d0d0eca
...
@@ -39,12 +39,16 @@ class PsLocalClient : public PSClient {
...
@@ -39,12 +39,16 @@ class PsLocalClient : public PSClient {
virtual
::
std
::
future
<
int32_t
>
load
(
uint32_t
table_id
,
virtual
::
std
::
future
<
int32_t
>
load
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
override
;
const
std
::
string
&
mode
)
override
;
virtual
std
::
future
<
int32_t
>
Load
(
const
LoadSaveContext
&
load_context
)
override
;
virtual
::
std
::
future
<
int32_t
>
save
(
const
std
::
string
&
epoch
,
virtual
::
std
::
future
<
int32_t
>
save
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
override
;
const
std
::
string
&
mode
)
override
;
virtual
::
std
::
future
<
int32_t
>
save
(
uint32_t
table_id
,
virtual
::
std
::
future
<
int32_t
>
save
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
override
;
const
std
::
string
&
mode
)
override
;
virtual
std
::
future
<
int32_t
>
Save
(
const
LoadSaveContext
&
save_context
)
override
;
virtual
::
std
::
future
<
int32_t
>
clear
()
override
;
virtual
::
std
::
future
<
int32_t
>
clear
()
override
;
virtual
::
std
::
future
<
int32_t
>
clear
(
uint32_t
table_id
)
override
;
virtual
::
std
::
future
<
int32_t
>
clear
(
uint32_t
table_id
)
override
;
...
@@ -55,6 +59,10 @@ class PsLocalClient : public PSClient {
...
@@ -55,6 +59,10 @@ class PsLocalClient : public PSClient {
virtual
::
std
::
future
<
int32_t
>
pull_dense
(
Region
*
regions
,
size_t
region_num
,
virtual
::
std
::
future
<
int32_t
>
pull_dense
(
Region
*
regions
,
size_t
region_num
,
size_t
table_id
);
size_t
table_id
);
virtual
::
std
::
future
<
int32_t
>
Pull
(
RequestContext
&
pull_context
)
override
;
virtual
::
std
::
future
<
int32_t
>
Push
(
RequestContext
&
push_context
)
override
;
virtual
::
std
::
future
<
int32_t
>
push_dense
(
const
Region
*
regions
,
virtual
::
std
::
future
<
int32_t
>
push_dense
(
const
Region
*
regions
,
size_t
region_num
,
size_t
table_id
);
size_t
region_num
,
size_t
table_id
);
...
...
paddle/fluid/distributed/ps/service/ps_local_server.h
浏览文件 @
4d0d0eca
...
@@ -28,7 +28,6 @@ class PsLocalServer : public PSServer {
...
@@ -28,7 +28,6 @@ class PsLocalServer : public PSServer {
virtual
uint64_t
start
()
{
return
0
;
}
virtual
uint64_t
start
()
{
return
0
;
}
virtual
uint64_t
start
(
const
std
::
string
&
ip
,
uint32_t
port
)
{
return
0
;
}
virtual
uint64_t
start
(
const
std
::
string
&
ip
,
uint32_t
port
)
{
return
0
;
}
virtual
int32_t
stop
()
{
return
0
;
}
virtual
int32_t
stop
()
{
return
0
;
}
virtual
int32_t
port
()
{
return
0
;
}
virtual
int32_t
configure
(
virtual
int32_t
configure
(
const
PSParameter
&
config
,
PSEnvironment
&
env
,
size_t
server_rank
,
const
PSParameter
&
config
,
PSEnvironment
&
env
,
size_t
server_rank
,
const
std
::
vector
<
framework
::
ProgramDesc
>
&
server_sub_program
=
{})
{
const
std
::
vector
<
framework
::
ProgramDesc
>
&
server_sub_program
=
{})
{
...
...
paddle/fluid/distributed/ps/service/server.cc
浏览文件 @
4d0d0eca
...
@@ -67,8 +67,6 @@ int32_t PSServer::configure(
...
@@ -67,8 +67,6 @@ int32_t PSServer::configure(
_config
=
config
.
server_param
();
_config
=
config
.
server_param
();
_rank
=
server_rank
;
_rank
=
server_rank
;
_environment
=
&
env
;
_environment
=
&
env
;
_shuffled_ins
=
paddle
::
framework
::
MakeChannel
<
std
::
pair
<
uint64_t
,
std
::
string
>>
();
size_t
shard_num
=
env
.
get_ps_servers
().
size
();
size_t
shard_num
=
env
.
get_ps_servers
().
size
();
const
auto
&
downpour_param
=
_config
.
downpour_server_param
();
const
auto
&
downpour_param
=
_config
.
downpour_server_param
();
...
...
paddle/fluid/distributed/ps/service/server.h
浏览文件 @
4d0d0eca
...
@@ -69,11 +69,6 @@ class PSServer {
...
@@ -69,11 +69,6 @@ class PSServer {
const
PSParameter
&
config
,
PSEnvironment
&
env
,
size_t
server_rank
,
const
PSParameter
&
config
,
PSEnvironment
&
env
,
size_t
server_rank
,
const
std
::
vector
<
framework
::
ProgramDesc
>
&
server_sub_program
=
{});
const
std
::
vector
<
framework
::
ProgramDesc
>
&
server_sub_program
=
{});
// return server_ip
virtual
std
::
string
ip
()
{
return
butil
::
my_ip_cstr
();
}
// return server_port
virtual
int32_t
port
()
=
0
;
virtual
uint64_t
start
(
const
std
::
string
&
ip
,
uint32_t
port
)
=
0
;
virtual
uint64_t
start
(
const
std
::
string
&
ip
,
uint32_t
port
)
=
0
;
virtual
int32_t
stop
()
=
0
;
virtual
int32_t
stop
()
=
0
;
...
@@ -94,15 +89,6 @@ class PSServer {
...
@@ -94,15 +89,6 @@ class PSServer {
return
&
_table_map
;
return
&
_table_map
;
}
}
typedef
std
::
function
<
int32_t
(
int
,
int
,
const
std
::
string
&
)
>
MsgHandlerFunc
;
virtual
int
registe_pserver2pserver_msg_handler
(
int
msg_type
,
MsgHandlerFunc
handler
)
{
_msg_handler_map
[
msg_type
]
=
handler
;
return
0
;
}
paddle
::
framework
::
Channel
<
std
::
pair
<
uint64_t
,
std
::
string
>>
_shuffled_ins
;
protected:
protected:
virtual
int32_t
initialize
()
=
0
;
virtual
int32_t
initialize
()
=
0
;
...
@@ -111,7 +97,6 @@ class PSServer {
...
@@ -111,7 +97,6 @@ class PSServer {
ServerParameter
_config
;
ServerParameter
_config
;
PSEnvironment
*
_environment
;
PSEnvironment
*
_environment
;
std
::
unordered_map
<
uint32_t
,
std
::
shared_ptr
<
Table
>>
_table_map
;
std
::
unordered_map
<
uint32_t
,
std
::
shared_ptr
<
Table
>>
_table_map
;
std
::
unordered_map
<
int32_t
,
MsgHandlerFunc
>
_msg_handler_map
;
protected:
protected:
std
::
shared_ptr
<
framework
::
Scope
>
scope_
;
std
::
shared_ptr
<
framework
::
Scope
>
scope_
;
...
...
paddle/fluid/distributed/ps/table/accessor.h
浏览文件 @
4d0d0eca
...
@@ -45,6 +45,17 @@ struct DataConverter {
...
@@ -45,6 +45,17 @@ struct DataConverter {
std
::
string
deconverter
;
std
::
string
deconverter
;
};
};
struct
AccessorInfo
{
size_t
dim
;
size_t
size
;
size_t
select_size
;
size_t
select_dim
;
size_t
update_size
;
size_t
update_dim
;
size_t
mf_size
;
size_t
fea_dim
;
};
class
ValueAccessor
{
class
ValueAccessor
{
public:
public:
ValueAccessor
()
{}
ValueAccessor
()
{}
...
@@ -68,6 +79,8 @@ class ValueAccessor {
...
@@ -68,6 +79,8 @@ class ValueAccessor {
}
}
virtual
int
initialize
()
=
0
;
virtual
int
initialize
()
=
0
;
virtual
void
GetTableInfo
(
AccessorInfo
&
info
)
=
0
;
// value维度
// value维度
virtual
size_t
dim
()
=
0
;
virtual
size_t
dim
()
=
0
;
// value各个维度的size
// value各个维度的size
...
@@ -163,6 +176,7 @@ class ValueAccessor {
...
@@ -163,6 +176,7 @@ class ValueAccessor {
TableAccessorParameter
_config
;
TableAccessorParameter
_config
;
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
struct
DataConverter
>>
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
struct
DataConverter
>>
_data_coverter_map
;
_data_coverter_map
;
AccessorInfo
_accessor_info
;
};
};
REGISTER_PSCORE_REGISTERER
(
ValueAccessor
);
REGISTER_PSCORE_REGISTERER
(
ValueAccessor
);
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/ps/table/common_dense_table.cc
浏览文件 @
4d0d0eca
...
@@ -128,6 +128,21 @@ int32_t CommonDenseTable::set_global_lr(float* lr) {
...
@@ -128,6 +128,21 @@ int32_t CommonDenseTable::set_global_lr(float* lr) {
return
0
;
return
0
;
}
}
int32_t
CommonDenseTable
::
Pull
(
TableContext
&
context
)
{
CHECK
(
context
.
value_type
==
Dense
);
float
*
pull_values
=
context
.
pull_context
.
values
;
return
pull_dense
(
pull_values
,
context
.
num
);
}
int32_t
CommonDenseTable
::
Push
(
TableContext
&
context
)
{
CHECK
(
context
.
value_type
==
Dense
);
if
(
context
.
pull_context
.
values
!=
nullptr
)
{
const
float
*
values
=
context
.
push_context
.
values
;
return
push_dense
(
values
,
context
.
num
);
}
return
0
;
}
int32_t
CommonDenseTable
::
pull_dense
(
float
*
pull_values
,
size_t
num
)
{
int32_t
CommonDenseTable
::
pull_dense
(
float
*
pull_values
,
size_t
num
)
{
std
::
copy
(
values_
[
param_idx_
].
begin
(),
values_
[
param_idx_
].
end
(),
std
::
copy
(
values_
[
param_idx_
].
begin
(),
values_
[
param_idx_
].
end
(),
pull_values
);
pull_values
);
...
...
paddle/fluid/distributed/ps/table/common_dense_table.h
浏览文件 @
4d0d0eca
...
@@ -40,6 +40,8 @@ class CommonDenseTable : public DenseTable {
...
@@ -40,6 +40,8 @@ class CommonDenseTable : public DenseTable {
const
std
::
string
&
name
);
const
std
::
string
&
name
);
virtual
int32_t
initialize_value
();
virtual
int32_t
initialize_value
();
virtual
int32_t
initialize_optimizer
();
virtual
int32_t
initialize_optimizer
();
virtual
int32_t
Pull
(
TableContext
&
context
);
virtual
int32_t
Push
(
TableContext
&
context
);
int32_t
pull_dense
(
float
*
pull_values
,
size_t
num
)
override
;
int32_t
pull_dense
(
float
*
pull_values
,
size_t
num
)
override
;
int32_t
push_dense_param
(
const
float
*
values
,
size_t
num
)
override
;
int32_t
push_dense_param
(
const
float
*
values
,
size_t
num
)
override
;
int32_t
push_dense
(
const
float
*
values
,
size_t
num
)
override
;
int32_t
push_dense
(
const
float
*
values
,
size_t
num
)
override
;
...
...
paddle/fluid/distributed/ps/table/common_graph_table.h
浏览文件 @
4d0d0eca
...
@@ -454,6 +454,9 @@ class GraphTable : public SparseTable {
...
@@ -454,6 +454,9 @@ class GraphTable : public SparseTable {
int32_t
get_server_index_by_id
(
int64_t
id
);
int32_t
get_server_index_by_id
(
int64_t
id
);
Node
*
find_node
(
int64_t
id
);
Node
*
find_node
(
int64_t
id
);
virtual
int32_t
Pull
(
TableContext
&
context
)
{
return
0
;
}
virtual
int32_t
Push
(
TableContext
&
context
)
{
return
0
;
}
virtual
int32_t
pull_sparse
(
float
*
values
,
virtual
int32_t
pull_sparse
(
float
*
values
,
const
PullSparseValue
&
pull_value
)
{
const
PullSparseValue
&
pull_value
)
{
return
0
;
return
0
;
...
...
paddle/fluid/distributed/ps/table/common_sparse_table.cc
浏览文件 @
4d0d0eca
...
@@ -355,6 +355,32 @@ int32_t CommonSparseTable::pour() {
...
@@ -355,6 +355,32 @@ int32_t CommonSparseTable::pour() {
return
0
;
return
0
;
}
}
int32_t
CommonSparseTable
::
Pull
(
TableContext
&
context
)
{
CHECK
(
context
.
value_type
==
Sparse
);
if
(
context
.
use_ptr
)
{
char
**
pull_values
=
context
.
pull_context
.
ptr_values
;
const
uint64_t
*
keys
=
context
.
pull_context
.
keys
;
return
pull_sparse_ptr
(
pull_values
,
keys
,
context
.
num
);
}
else
{
float
*
pull_values
=
context
.
pull_context
.
values
;
const
PullSparseValue
&
pull_value
=
context
.
pull_context
.
pull_value
;
return
pull_sparse
(
pull_values
,
pull_value
);
}
}
int32_t
CommonSparseTable
::
Push
(
TableContext
&
context
)
{
CHECK
(
context
.
value_type
==
Sparse
);
if
(
context
.
pull_context
.
values
!=
nullptr
)
{
const
float
*
values
=
context
.
push_context
.
values
;
const
uint64_t
*
keys
=
context
.
push_context
.
keys
;
return
push_sparse
(
keys
,
values
,
context
.
num
);
}
else
{
const
float
**
values
=
context
.
push_context
.
ptr_values
;
const
uint64_t
*
keys
=
context
.
push_context
.
keys
;
return
push_sparse
(
keys
,
values
,
context
.
num
);
}
}
int32_t
CommonSparseTable
::
pull_sparse
(
float
*
pull_values
,
int32_t
CommonSparseTable
::
pull_sparse
(
float
*
pull_values
,
const
PullSparseValue
&
pull_value
)
{
const
PullSparseValue
&
pull_value
)
{
auto
shard_num
=
task_pool_size_
;
auto
shard_num
=
task_pool_size_
;
...
...
paddle/fluid/distributed/ps/table/common_sparse_table.h
浏览文件 @
4d0d0eca
...
@@ -121,6 +121,9 @@ class CommonSparseTable : public SparseTable {
...
@@ -121,6 +121,9 @@ class CommonSparseTable : public SparseTable {
virtual
int32_t
push_dense
(
const
float
*
values
,
size_t
num
)
{
return
0
;
}
virtual
int32_t
push_dense
(
const
float
*
values
,
size_t
num
)
{
return
0
;
}
// unused method end
// unused method end
virtual
int32_t
Pull
(
TableContext
&
context
);
virtual
int32_t
Push
(
TableContext
&
context
);
virtual
int32_t
initialize
();
virtual
int32_t
initialize
();
virtual
int32_t
initialize_shard
()
{
return
0
;
}
virtual
int32_t
initialize_shard
()
{
return
0
;
}
virtual
int32_t
initialize_value
();
virtual
int32_t
initialize_value
();
...
...
paddle/fluid/distributed/ps/table/common_table.h
浏览文件 @
4d0d0eca
...
@@ -119,6 +119,9 @@ class BarrierTable : public Table {
...
@@ -119,6 +119,9 @@ class BarrierTable : public Table {
virtual
void
*
get_shard
(
size_t
shard_idx
)
{
return
0
;
}
virtual
void
*
get_shard
(
size_t
shard_idx
)
{
return
0
;
}
virtual
int32_t
Pull
(
TableContext
&
context
)
{
return
0
;
}
virtual
int32_t
Push
(
TableContext
&
context
)
{
return
0
;
}
int32_t
pull_dense
(
float
*
values
,
size_t
num
)
override
{
return
0
;
}
int32_t
pull_dense
(
float
*
values
,
size_t
num
)
override
{
return
0
;
}
int32_t
push_dense
(
const
float
*
values
,
size_t
num
)
override
{
return
0
;
}
int32_t
push_dense
(
const
float
*
values
,
size_t
num
)
override
{
return
0
;
}
...
...
paddle/fluid/distributed/ps/table/ctr_accessor.cc
浏览文件 @
4d0d0eca
...
@@ -38,6 +38,16 @@ int CtrCommonAccessor::initialize() {
...
@@ -38,6 +38,16 @@ int CtrCommonAccessor::initialize() {
return
0
;
return
0
;
}
}
void
CtrCommonAccessor
::
GetTableInfo
(
AccessorInfo
&
info
)
{
info
.
dim
=
dim
();
info
.
size
=
size
();
info
.
select_dim
=
select_dim
();
info
.
select_size
=
select_size
();
info
.
update_dim
=
update_dim
();
info
.
update_size
=
update_size
();
info
.
fea_dim
=
fea_dim
();
}
size_t
CtrCommonAccessor
::
dim
()
{
return
common_feature_value
.
dim
();
}
size_t
CtrCommonAccessor
::
dim
()
{
return
common_feature_value
.
dim
();
}
size_t
CtrCommonAccessor
::
dim_size
(
size_t
dim
)
{
size_t
CtrCommonAccessor
::
dim_size
(
size_t
dim
)
{
...
...
paddle/fluid/distributed/ps/table/ctr_accessor.h
浏览文件 @
4d0d0eca
...
@@ -126,6 +126,7 @@ class CtrCommonAccessor : public ValueAccessor {
...
@@ -126,6 +126,7 @@ class CtrCommonAccessor : public ValueAccessor {
virtual
int
initialize
();
virtual
int
initialize
();
virtual
~
CtrCommonAccessor
()
{}
virtual
~
CtrCommonAccessor
()
{}
virtual
void
GetTableInfo
(
AccessorInfo
&
info
);
// value维度
// value维度
virtual
size_t
dim
();
virtual
size_t
dim
();
// value各个维度的size
// value各个维度的size
...
...
paddle/fluid/distributed/ps/table/ctr_double_accessor.cc
浏览文件 @
4d0d0eca
...
@@ -37,6 +37,16 @@ int DownpourCtrDoubleAccessor::initialize() {
...
@@ -37,6 +37,16 @@ int DownpourCtrDoubleAccessor::initialize() {
return
0
;
return
0
;
}
}
void
DownpourCtrDoubleAccessor
::
GetTableInfo
(
AccessorInfo
&
info
)
{
info
.
dim
=
dim
();
info
.
size
=
size
();
info
.
select_dim
=
select_dim
();
info
.
select_size
=
select_size
();
info
.
update_dim
=
update_dim
();
info
.
update_size
=
update_size
();
info
.
fea_dim
=
fea_dim
();
}
size_t
DownpourCtrDoubleAccessor
::
dim
()
{
size_t
DownpourCtrDoubleAccessor
::
dim
()
{
auto
embedx_dim
=
_config
.
embedx_dim
();
auto
embedx_dim
=
_config
.
embedx_dim
();
return
DownpourCtrDoubleFeatureValue
::
dim
(
embedx_dim
);
return
DownpourCtrDoubleFeatureValue
::
dim
(
embedx_dim
);
...
...
paddle/fluid/distributed/ps/table/ctr_double_accessor.h
浏览文件 @
4d0d0eca
...
@@ -168,6 +168,7 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
...
@@ -168,6 +168,7 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
DownpourCtrDoubleAccessor
()
{}
DownpourCtrDoubleAccessor
()
{}
virtual
~
DownpourCtrDoubleAccessor
()
{}
virtual
~
DownpourCtrDoubleAccessor
()
{}
virtual
int
initialize
();
virtual
int
initialize
();
virtual
void
GetTableInfo
(
AccessorInfo
&
info
);
// value维度
// value维度
virtual
size_t
dim
();
virtual
size_t
dim
();
// value各个维度的size
// value各个维度的size
...
...
paddle/fluid/distributed/ps/table/depends/sparse_utils.h
浏览文件 @
4d0d0eca
...
@@ -58,7 +58,7 @@ struct PullSparseValue {
...
@@ -58,7 +58,7 @@ struct PullSparseValue {
std
::
vector
<
int
>*
offset_shard
)
const
{
std
::
vector
<
int
>*
offset_shard
)
const
{
offset_shard
->
reserve
(
numel_
/
shard_num
+
1
);
offset_shard
->
reserve
(
numel_
/
shard_num
+
1
);
for
(
int
x
=
0
;
x
<
numel_
;
++
x
)
{
for
(
int
x
=
0
;
x
<
numel_
;
++
x
)
{
if
(
feasigns_
[
x
]
%
shard_num
==
shard_id
)
{
if
(
int
(
feasigns_
[
x
]
%
shard_num
)
==
shard_id
)
{
offset_shard
->
push_back
(
x
);
offset_shard
->
push_back
(
x
);
}
}
}
}
...
...
paddle/fluid/distributed/ps/table/downpour_ctr_accessor.cc
浏览文件 @
4d0d0eca
...
@@ -37,6 +37,16 @@ int DownpourCtrAccessor::initialize() {
...
@@ -37,6 +37,16 @@ int DownpourCtrAccessor::initialize() {
return
0
;
return
0
;
}
}
void
DownpourCtrAccessor
::
GetTableInfo
(
AccessorInfo
&
info
)
{
info
.
dim
=
dim
();
info
.
size
=
size
();
info
.
select_dim
=
select_dim
();
info
.
select_size
=
select_size
();
info
.
update_dim
=
update_dim
();
info
.
update_size
=
update_size
();
info
.
fea_dim
=
fea_dim
();
}
size_t
DownpourCtrAccessor
::
dim
()
{
size_t
DownpourCtrAccessor
::
dim
()
{
auto
embedx_dim
=
_config
.
embedx_dim
();
auto
embedx_dim
=
_config
.
embedx_dim
();
return
DownpourCtrFeatureValue
::
dim
(
embedx_dim
);
return
DownpourCtrFeatureValue
::
dim
(
embedx_dim
);
...
...
paddle/fluid/distributed/ps/table/downpour_ctr_accessor.h
浏览文件 @
4d0d0eca
...
@@ -160,6 +160,7 @@ class DownpourCtrAccessor : public ValueAccessor {
...
@@ -160,6 +160,7 @@ class DownpourCtrAccessor : public ValueAccessor {
virtual
~
DownpourCtrAccessor
()
{}
virtual
~
DownpourCtrAccessor
()
{}
virtual
int
initialize
();
virtual
int
initialize
();
virtual
void
GetTableInfo
(
AccessorInfo
&
info
);
// value维度
// value维度
virtual
size_t
dim
();
virtual
size_t
dim
();
// value各个维度的size
// value各个维度的size
...
...
paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h
浏览文件 @
4d0d0eca
...
@@ -48,6 +48,8 @@ class MemorySparseGeoTable : public SparseTable {
...
@@ -48,6 +48,8 @@ class MemorySparseGeoTable : public SparseTable {
virtual
int32_t
save
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
{
virtual
int32_t
save
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
{
return
0
;
return
0
;
}
}
virtual
int32_t
Pull
(
TableContext
&
context
)
{
return
0
;
}
virtual
int32_t
Push
(
TableContext
&
context
)
{
return
0
;
}
virtual
int32_t
flush
()
{
return
0
;
}
virtual
int32_t
flush
()
{
return
0
;
}
virtual
int32_t
shrink
(
const
std
::
string
&
param
)
{
return
0
;
}
virtual
int32_t
shrink
(
const
std
::
string
&
param
)
{
return
0
;
}
virtual
void
clear
()
{
return
;
}
virtual
void
clear
()
{
return
;
}
...
...
paddle/fluid/distributed/ps/table/memory_sparse_table.cc
浏览文件 @
4d0d0eca
...
@@ -390,6 +390,26 @@ std::pair<int64_t, int64_t> MemorySparseTable::print_table_stat() {
...
@@ -390,6 +390,26 @@ std::pair<int64_t, int64_t> MemorySparseTable::print_table_stat() {
return
{
feasign_size
,
mf_size
};
return
{
feasign_size
,
mf_size
};
}
}
int32_t
MemorySparseTable
::
Pull
(
TableContext
&
context
)
{
CHECK
(
context
.
value_type
==
Sparse
);
if
(
context
.
use_ptr
)
{
char
**
pull_values
=
context
.
pull_context
.
ptr_values
;
const
uint64_t
*
keys
=
context
.
pull_context
.
keys
;
return
pull_sparse_ptr
(
pull_values
,
keys
,
context
.
num
);
}
else
{
float
*
pull_values
=
context
.
pull_context
.
values
;
const
PullSparseValue
&
pull_value
=
context
.
pull_context
.
pull_value
;
return
pull_sparse
(
pull_values
,
pull_value
);
}
}
int32_t
MemorySparseTable
::
Push
(
TableContext
&
context
)
{
CHECK
(
context
.
value_type
==
Sparse
);
const
uint64_t
*
keys
=
context
.
push_context
.
keys
;
return
push_sparse
(
keys
,
context
.
push_context
.
ptr_values
,
context
.
num
);
}
int32_t
MemorySparseTable
::
pull_sparse
(
float
*
pull_values
,
int32_t
MemorySparseTable
::
pull_sparse
(
float
*
pull_values
,
const
PullSparseValue
&
pull_value
)
{
const
PullSparseValue
&
pull_value
)
{
CostTimer
timer
(
"pserver_sparse_select_all"
);
CostTimer
timer
(
"pserver_sparse_select_all"
);
...
...
paddle/fluid/distributed/ps/table/memory_sparse_table.h
浏览文件 @
4d0d0eca
...
@@ -48,6 +48,9 @@ class MemorySparseTable : public SparseTable {
...
@@ -48,6 +48,9 @@ class MemorySparseTable : public SparseTable {
virtual
int32_t
push_dense
(
const
float
*
values
,
size_t
num
)
{
return
0
;
}
virtual
int32_t
push_dense
(
const
float
*
values
,
size_t
num
)
{
return
0
;
}
// unused method end
// unused method end
virtual
int32_t
Pull
(
TableContext
&
context
);
virtual
int32_t
Push
(
TableContext
&
context
);
virtual
int32_t
initialize
();
virtual
int32_t
initialize
();
virtual
int32_t
initialize_shard
()
{
return
0
;
}
virtual
int32_t
initialize_shard
()
{
return
0
;
}
virtual
int32_t
initialize_value
();
virtual
int32_t
initialize_value
();
...
...
paddle/fluid/distributed/ps/table/ssd_sparse_table.cc
浏览文件 @
4d0d0eca
...
@@ -61,6 +61,21 @@ int32_t SSDSparseTable::initialize() {
...
@@ -61,6 +61,21 @@ int32_t SSDSparseTable::initialize() {
return
0
;
return
0
;
}
}
int32_t
SSDSparseTable
::
Pull
(
TableContext
&
context
)
{
CHECK
(
context
.
value_type
==
Sparse
);
if
(
context
.
use_ptr
)
{
char
**
pull_values
=
context
.
pull_context
.
ptr_values
;
const
uint64_t
*
keys
=
context
.
pull_context
.
keys
;
return
pull_sparse_ptr
(
pull_values
,
keys
,
context
.
num
);
}
else
{
float
*
pull_values
=
context
.
pull_context
.
values
;
const
PullSparseValue
&
pull_value
=
context
.
pull_context
.
pull_value
;
return
pull_sparse
(
pull_values
,
pull_value
);
}
}
int32_t
SSDSparseTable
::
Push
(
TableContext
&
context
)
{
return
0
;
}
int32_t
SSDSparseTable
::
pull_sparse
(
float
*
pull_values
,
int32_t
SSDSparseTable
::
pull_sparse
(
float
*
pull_values
,
const
PullSparseValue
&
pull_value
)
{
const
PullSparseValue
&
pull_value
)
{
auto
shard_num
=
task_pool_size_
;
auto
shard_num
=
task_pool_size_
;
...
...
paddle/fluid/distributed/ps/table/ssd_sparse_table.h
浏览文件 @
4d0d0eca
...
@@ -42,6 +42,9 @@ class SSDSparseTable : public CommonSparseTable {
...
@@ -42,6 +42,9 @@ class SSDSparseTable : public CommonSparseTable {
// exchange data
// exchange data
virtual
int32_t
update_table
();
virtual
int32_t
update_table
();
virtual
int32_t
Pull
(
TableContext
&
context
);
virtual
int32_t
Push
(
TableContext
&
context
);
virtual
int32_t
pull_sparse
(
float
*
values
,
const
PullSparseValue
&
pull_value
);
virtual
int32_t
pull_sparse
(
float
*
values
,
const
PullSparseValue
&
pull_value
);
virtual
int32_t
pull_sparse_ptr
(
char
**
pull_values
,
const
uint64_t
*
keys
,
virtual
int32_t
pull_sparse_ptr
(
char
**
pull_values
,
const
uint64_t
*
keys
,
...
...
paddle/fluid/distributed/ps/table/table.h
浏览文件 @
4d0d0eca
...
@@ -32,6 +32,30 @@
...
@@ -32,6 +32,30 @@
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
enum
ValueType
{
Sparse
=
0
,
Dense
=
1
};
struct
PullContext
{
const
uint64_t
*
keys
;
const
PullSparseValue
pull_value
;
float
*
values
;
char
**
ptr_values
;
};
struct
TablePushContext
{
const
uint64_t
*
keys
;
const
float
*
values
;
const
float
**
ptr_values
;
};
struct
TableContext
{
ValueType
value_type
;
PullContext
pull_context
;
TablePushContext
push_context
;
size_t
num
;
bool
use_ptr
;
};
class
Table
{
class
Table
{
public:
public:
Table
()
{}
Table
()
{}
...
@@ -39,6 +63,8 @@ class Table {
...
@@ -39,6 +63,8 @@ class Table {
virtual
int32_t
initialize
(
const
TableParameter
&
config
,
virtual
int32_t
initialize
(
const
TableParameter
&
config
,
const
FsClientParameter
&
fs_config
);
const
FsClientParameter
&
fs_config
);
virtual
int32_t
Pull
(
TableContext
&
context
)
=
0
;
virtual
int32_t
Push
(
TableContext
&
context
)
=
0
;
virtual
int32_t
pull_dense
(
float
*
values
,
size_t
num
)
=
0
;
virtual
int32_t
pull_dense
(
float
*
values
,
size_t
num
)
=
0
;
virtual
int32_t
push_dense
(
const
float
*
values
,
size_t
num
)
=
0
;
virtual
int32_t
push_dense
(
const
float
*
values
,
size_t
num
)
=
0
;
// for push global_step
// for push global_step
...
...
paddle/fluid/distributed/ps/table/tensor_accessor.cc
浏览文件 @
4d0d0eca
...
@@ -20,6 +20,16 @@ namespace distributed {
...
@@ -20,6 +20,16 @@ namespace distributed {
int
CommMergeAccessor
::
initialize
()
{
return
0
;
}
int
CommMergeAccessor
::
initialize
()
{
return
0
;
}
void
CommMergeAccessor
::
GetTableInfo
(
AccessorInfo
&
info
)
{
info
.
dim
=
dim
();
info
.
size
=
size
();
info
.
select_dim
=
select_dim
();
info
.
select_size
=
select_size
();
info
.
update_dim
=
update_dim
();
info
.
update_size
=
update_size
();
info
.
fea_dim
=
fea_dim
();
}
// value 维度
// value 维度
size_t
CommMergeAccessor
::
dim
()
{
return
0
;
}
size_t
CommMergeAccessor
::
dim
()
{
return
0
;
}
...
...
paddle/fluid/distributed/ps/table/tensor_accessor.h
浏览文件 @
4d0d0eca
...
@@ -30,6 +30,7 @@ class CommMergeAccessor : public ValueAccessor {
...
@@ -30,6 +30,7 @@ class CommMergeAccessor : public ValueAccessor {
CommMergeAccessor
()
{}
CommMergeAccessor
()
{}
virtual
~
CommMergeAccessor
()
{}
virtual
~
CommMergeAccessor
()
{}
virtual
int
initialize
();
virtual
int
initialize
();
virtual
void
GetTableInfo
(
AccessorInfo
&
info
);
// value维度
// value维度
virtual
size_t
dim
();
virtual
size_t
dim
();
// value各个维度的size
// value各个维度的size
...
...
paddle/fluid/distributed/ps/table/tensor_table.h
浏览文件 @
4d0d0eca
...
@@ -48,6 +48,8 @@ class TensorTable : public Table {
...
@@ -48,6 +48,8 @@ class TensorTable : public Table {
TensorTable
()
{}
TensorTable
()
{}
virtual
~
TensorTable
()
{}
virtual
~
TensorTable
()
{}
virtual
int32_t
Pull
(
TableContext
&
context
)
{
return
0
;
}
virtual
int32_t
Push
(
TableContext
&
context
)
{
return
0
;
}
int32_t
pull_dense
(
float
*
values
,
size_t
num
)
override
{
return
0
;
}
int32_t
pull_dense
(
float
*
values
,
size_t
num
)
override
{
return
0
;
}
int32_t
push_dense
(
const
float
*
values
,
size_t
num
)
override
{
return
0
;
}
int32_t
push_dense
(
const
float
*
values
,
size_t
num
)
override
{
return
0
;
}
...
...
paddle/fluid/distributed/ps/wrapper/fleet.cc
浏览文件 @
4d0d0eca
...
@@ -30,6 +30,32 @@ bool FleetWrapper::is_initialized_ = false;
...
@@ -30,6 +30,32 @@ bool FleetWrapper::is_initialized_ = false;
std
::
shared_ptr
<
paddle
::
distributed
::
PSCore
>
FleetWrapper
::
pserver_ptr_
=
NULL
;
std
::
shared_ptr
<
paddle
::
distributed
::
PSCore
>
FleetWrapper
::
pserver_ptr_
=
NULL
;
void
FleetWrapper
::
Stop
()
{
StopServer
();
}
void
FleetWrapper
::
Load
(
WrapperContext
&
context
)
{
auto
table_id
=
context
.
table_id
;
if
(
table_id
>=
0
&&
context
.
meta
!=
""
)
{
LoadSparseOnServer
(
context
.
path
,
context
.
meta
,
context
.
table_id
);
return
;
}
if
(
table_id
<
0
)
{
// laod all
LoadModel
(
context
.
path
,
context
.
mode
);
}
else
{
// load one table
LoadModelOneTable
(
table_id
,
context
.
path
,
context
.
mode
);
}
return
;
}
void
FleetWrapper
::
Save
(
WrapperContext
&
context
)
{
auto
table_id
=
context
.
table_id
;
if
(
table_id
<
0
)
{
SaveModel
(
context
.
path
,
context
.
mode
);
}
else
{
SaveModelOneTable
(
table_id
,
context
.
path
,
context
.
mode
);
}
return
;
}
void
FleetWrapper
::
SetClient2ClientConfig
(
int
request_timeout_ms
,
void
FleetWrapper
::
SetClient2ClientConfig
(
int
request_timeout_ms
,
int
connect_timeout_ms
,
int
connect_timeout_ms
,
int
max_retry
)
{
int
max_retry
)
{
...
...
paddle/fluid/distributed/ps/wrapper/fleet.h
浏览文件 @
4d0d0eca
...
@@ -25,6 +25,7 @@ limitations under the License. */
...
@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h"
#include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h"
#include "paddle/fluid/distributed/ps/service/ps_service/service.h"
#include "paddle/fluid/distributed/ps/service/ps_service/service.h"
#include "paddle/fluid/distributed/ps/wrapper/ps_wrapper.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/framework/io/shell.h"
#include "paddle/fluid/framework/io/shell.h"
...
@@ -54,7 +55,7 @@ using framework::Variable;
...
@@ -54,7 +55,7 @@ using framework::Variable;
using
RpcCtxMap
=
std
::
unordered_map
<
std
::
string
,
CommContext
>
;
using
RpcCtxMap
=
std
::
unordered_map
<
std
::
string
,
CommContext
>
;
class
FleetWrapper
{
class
FleetWrapper
:
public
PSWrapper
{
public:
public:
virtual
~
FleetWrapper
()
{}
virtual
~
FleetWrapper
()
{}
FleetWrapper
()
{
FleetWrapper
()
{
...
@@ -68,7 +69,13 @@ class FleetWrapper {
...
@@ -68,7 +69,13 @@ class FleetWrapper {
// pserver request max retry
// pserver request max retry
client2client_max_retry_
=
3
;
client2client_max_retry_
=
3
;
}
}
virtual
int32_t
Initialize
(
InitContext
&
context
)
{
return
0
;
}
virtual
void
Stop
()
override
;
virtual
void
Load
(
WrapperContext
&
context
)
override
;
virtual
void
Save
(
WrapperContext
&
context
)
override
;
// set client to client communication config
// set client to client communication config
void
SetClient2ClientConfig
(
int
request_timeout_ms
,
int
connect_timeout_ms
,
void
SetClient2ClientConfig
(
int
request_timeout_ms
,
int
connect_timeout_ms
,
int
max_retry
);
int
max_retry
);
...
...
paddle/fluid/distributed/ps/wrapper/ps_wrapper.h
浏览文件 @
4d0d0eca
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
// limitations under the License.
limitations under the License. */
#ifndef PADDLE_FLUID_DISTRIBUTED_PS_WRAPPER_PS_WRAPPER_H_
#pragma once
#define PADDLE_FLUID_DISTRIBUTED_PS_WRAPPER_PS_WRAPPER_H_
#include <atomic>
#endif // PADDLE_FLUID_DISTRIBUTED_PS_WRAPPER_PS_WRAPPER_H_
#include <ctime>
#include <map>
#include <memory>
#include <random>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h"
#include "paddle/fluid/distributed/ps/service/ps_service/service.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/framework/io/shell.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace
paddle
{
namespace
framework
{
class
Scope
;
class
SelectedRows
;
class
Variable
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
distributed
{
class
PSCore
;
using
framework
::
LoDTensor
;
using
framework
::
Scope
;
using
phi
::
SelectedRows
;
using
framework
::
Variable
;
using
RpcCtxMap
=
std
::
unordered_map
<
std
::
string
,
CommContext
>
;
struct
WrapperContext
{
uint32_t
table_id
;
const
std
::
string
path
;
const
int
mode
;
const
std
::
string
meta
;
};
struct
InitContext
{
const
std
::
vector
<
int
>
dev_ids
;
// for gpu
};
class
PSWrapper
{
public:
virtual
~
PSWrapper
()
{}
PSWrapper
()
{}
// init server
virtual
int32_t
Initialize
(
InitContext
&
context
)
=
0
;
virtual
void
Stop
()
=
0
;
virtual
void
Load
(
WrapperContext
&
context
)
=
0
;
virtual
void
Save
(
WrapperContext
&
context
)
=
0
;
};
}
// end namespace distributed
}
// end namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录