Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
29c6bcbf
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
29c6bcbf
编写于
11月 01, 2021
作者:
Z
zhaocaibei123
提交者:
GitHub
11月 01, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
memory sparse table & brpc communication upgrade dependency (#36734)
上级
249081b6
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
640 addition
and
16 deletion
+640
-16
paddle/fluid/distributed/CMakeLists.txt
paddle/fluid/distributed/CMakeLists.txt
+1
-0
paddle/fluid/distributed/common/CMakeLists.txt
paddle/fluid/distributed/common/CMakeLists.txt
+4
-0
paddle/fluid/distributed/common/afs_warpper.cc
paddle/fluid/distributed/common/afs_warpper.cc
+89
-0
paddle/fluid/distributed/common/afs_warpper.h
paddle/fluid/distributed/common/afs_warpper.h
+156
-0
paddle/fluid/distributed/common/cost_timer.h
paddle/fluid/distributed/common/cost_timer.h
+93
-0
paddle/fluid/distributed/common/utils.h
paddle/fluid/distributed/common/utils.h
+15
-0
paddle/fluid/distributed/service/env.h
paddle/fluid/distributed/service/env.h
+4
-3
paddle/fluid/distributed/service/ps_client.h
paddle/fluid/distributed/service/ps_client.h
+55
-7
paddle/fluid/distributed/table/accessor.h
paddle/fluid/distributed/table/accessor.h
+3
-6
paddle/fluid/distributed/table/depends/dense.h
paddle/fluid/distributed/table/depends/dense.h
+154
-0
paddle/fluid/framework/distributed_strategy.proto
paddle/fluid/framework/distributed_strategy.proto
+66
-0
未找到文件。
paddle/fluid/distributed/CMakeLists.txt
浏览文件 @
29c6bcbf
...
...
@@ -11,6 +11,7 @@ if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
"
${
DISTRIBUTE_COMPILE_FLAGS
}
-faligned-new"
)
endif
()
add_subdirectory
(
common
)
add_subdirectory
(
service
)
add_subdirectory
(
table
)
add_subdirectory
(
test
)
...
...
paddle/fluid/distributed/common/CMakeLists.txt
0 → 100644
浏览文件 @
29c6bcbf
cc_library
(
afs_wrapper SRCS afs_warpper.cc DEPS fs ps_framework_proto
)
#set_property(GLOBAL PROPERTY COMMON_DEPS afs_warpper)
paddle/fluid/distributed/common/afs_warpper.cc
0 → 100644
浏览文件 @
29c6bcbf
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/common/afs_warpper.h"
#include "paddle/fluid/framework/io/fs.h"
namespace
paddle
{
namespace
distributed
{
// AfsClient impl
int
AfsClient
::
initialize
(
const
FsClientParameter
&
fs_client_param
)
{
// temporarily implemented with hdfs-client
return
initialize
(
fs_client_param
.
hadoop_bin
(),
fs_client_param
.
uri
(),
fs_client_param
.
user
(),
fs_client_param
.
passwd
(),
fs_client_param
.
buffer_size
());
}
int
AfsClient
::
initialize
(
const
std
::
string
&
hadoop_bin
,
const
std
::
string
&
uri
,
const
std
::
string
&
user
,
const
std
::
string
&
passwd
,
int
buffer_size_param
)
{
return
initialize
(
hadoop_bin
,
uri
,
paddle
::
string
::
format_string
(
"%s,%s"
,
user
.
c_str
(),
passwd
.
c_str
()),
buffer_size_param
);
}
int
AfsClient
::
initialize
(
const
std
::
string
&
hadoop_bin
,
const
std
::
string
&
uri
,
const
std
::
string
&
ugi
,
int
buffer_size_param
)
{
// temporarily implemented with hdfs-client
size_t
buffer_size
=
1L
<<
25
;
// 32MB
if
(
buffer_size_param
>
static_cast
<
int
>
(
buffer_size
))
{
buffer_size
=
buffer_size_param
;
}
paddle
::
framework
::
hdfs_set_buffer_size
(
buffer_size
);
paddle
::
framework
::
hdfs_set_command
(
paddle
::
string
::
format_string
(
"2>>./hdfs_err.log %s fs -Dfs.default.name=%s -Dhadoop.job.ugi=%s "
"-Ddfs.client.block.write.retries=15 -Ddfs.rpc.timeout=300000"
,
hadoop_bin
.
c_str
(),
uri
.
c_str
(),
ugi
.
c_str
()));
return
0
;
}
// open file in 'w' or 'r'
std
::
shared_ptr
<
FsReadChannel
>
AfsClient
::
open_r
(
const
FsChannelConfig
&
config
,
uint32_t
buffer_size
,
int
*
err_no
)
{
std
::
shared_ptr
<
FsReadChannel
>
channel
=
std
::
make_shared
<
FsReadChannel
>
(
buffer_size
);
std
::
shared_ptr
<
FILE
>
fp
=
paddle
::
framework
::
fs_open_read
(
config
.
path
,
err_no
,
config
.
deconverter
);
channel
->
open
(
fp
,
config
);
return
channel
;
}
std
::
shared_ptr
<
FsWriteChannel
>
AfsClient
::
open_w
(
const
FsChannelConfig
&
config
,
uint32_t
buffer_size
,
int
*
err_no
)
{
std
::
shared_ptr
<
FsWriteChannel
>
channel
=
std
::
make_shared
<
FsWriteChannel
>
(
buffer_size
);
std
::
shared_ptr
<
FILE
>
fp
=
paddle
::
framework
::
fs_open_write
(
config
.
path
,
err_no
,
config
.
converter
);
channel
->
open
(
fp
,
config
);
return
channel
;
}
// remove file in path, path maybe a reg, such as 'part-000-*'
void
AfsClient
::
remove
(
const
std
::
string
&
path
)
{
return
paddle
::
framework
::
fs_remove
(
path
);
}
void
AfsClient
::
remove_dir
(
const
std
::
string
&
dir
)
{
return
paddle
::
framework
::
fs_remove
(
dir
);
}
// list files in path, path maybe a dir with reg
std
::
vector
<
std
::
string
>
AfsClient
::
list
(
const
std
::
string
&
path
)
{
return
paddle
::
framework
::
fs_list
(
path
);
}
// exist or not
bool
AfsClient
::
exist
(
const
std
::
string
&
dir
)
{
return
paddle
::
framework
::
fs_exists
(
dir
);
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/common/afs_warpper.h
0 → 100644
浏览文件 @
29c6bcbf
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/string/string_helper.h"
namespace
paddle
{
namespace
distributed
{
struct
FsDataConverter
{
std
::
string
converter
;
std
::
string
deconverter
;
};
struct
FsChannelConfig
{
std
::
string
path
;
// path of file
std
::
string
converter
;
// data converter
std
::
string
deconverter
;
};
class
FsReadChannel
{
public:
FsReadChannel
()
:
_buffer_size
(
0
)
{}
explicit
FsReadChannel
(
uint32_t
buffer_size
)
:
_buffer_size
(
buffer_size
)
{}
virtual
~
FsReadChannel
()
{}
FsReadChannel
(
FsReadChannel
&&
)
=
delete
;
FsReadChannel
(
const
FsReadChannel
&
)
=
delete
;
int
open
(
std
::
shared_ptr
<
FILE
>
fp
,
const
FsChannelConfig
&
config
)
{
_file
=
fp
;
return
0
;
}
inline
int
close
()
{
_file
.
reset
();
return
0
;
}
inline
uint32_t
read_line
(
std
::
string
&
line_data
)
{
// NOLINT
line_data
.
clear
();
char
buffer
=
'\0'
;
size_t
read_count
=
0
;
while
(
1
==
fread
(
&
buffer
,
1
,
1
,
_file
.
get
())
&&
buffer
!=
'\n'
)
{
++
read_count
;
line_data
.
append
(
&
buffer
,
1
);
}
if
(
read_count
==
0
&&
buffer
!=
'\n'
)
{
return
-
1
;
}
return
0
;
}
private:
uint32_t
_buffer_size
;
FsChannelConfig
_config
;
std
::
shared_ptr
<
FILE
>
_file
;
};
class
FsWriteChannel
{
public:
FsWriteChannel
()
:
_buffer_size
(
0
)
{}
explicit
FsWriteChannel
(
uint32_t
buffer_size
)
:
_buffer_size
(
buffer_size
)
{}
virtual
~
FsWriteChannel
()
{}
FsWriteChannel
(
FsWriteChannel
&&
)
=
delete
;
FsWriteChannel
(
const
FsWriteChannel
&
)
=
delete
;
int
open
(
std
::
shared_ptr
<
FILE
>
fp
,
const
FsChannelConfig
&
config
)
{
_file
=
fp
;
// the buffer has set in fs.cc
// if (_buffer_size != 0) {
// _buffer = std::shared_ptr<char>(new char[_buffer_size]);
// CHECK(0 == setvbuf(&*_file, _buffer.get(), _IOFBF, _buffer_size));
//}
return
0
;
}
inline
void
flush
()
{
return
;
}
inline
int
close
()
{
flush
();
_file
.
reset
();
return
0
;
}
inline
uint32_t
write_line
(
const
char
*
data
,
uint32_t
size
)
{
size_t
write_count
=
fwrite_unlocked
(
data
,
1
,
size
,
_file
.
get
());
if
(
write_count
!=
size
)
{
return
-
1
;
}
write_count
=
fwrite_unlocked
(
"
\n
"
,
1
,
1
,
_file
.
get
());
if
(
write_count
!=
1
)
{
return
-
1
;
}
return
0
;
}
inline
uint32_t
write_line
(
const
std
::
string
&
data
)
{
return
write_line
(
data
.
c_str
(),
data
.
size
());
}
private:
uint32_t
_buffer_size
;
FsChannelConfig
_config
;
std
::
shared_ptr
<
FILE
>
_file
;
std
::
shared_ptr
<
char
>
_buffer
;
};
class
AfsClient
{
public:
AfsClient
()
{}
virtual
~
AfsClient
()
{}
AfsClient
(
AfsClient
&&
)
=
delete
;
AfsClient
(
const
AfsClient
&
)
=
delete
;
int
initialize
(
const
FsClientParameter
&
fs_client_param
);
int
initialize
(
const
std
::
string
&
hadoop_bin
,
const
std
::
string
&
uri
,
const
std
::
string
&
user
,
const
std
::
string
&
passwd
,
int
buffer_size_param
=
(
1L
<<
25
));
int
initialize
(
const
std
::
string
&
hadoop_bin
,
const
std
::
string
&
uri
,
const
std
::
string
&
ugi
,
int
buffer_size_param
=
(
1L
<<
25
));
// open file in 'w' or 'r'
std
::
shared_ptr
<
FsReadChannel
>
open_r
(
const
FsChannelConfig
&
config
,
uint32_t
buffer_size
=
0
,
int
*
err_no
=
nullptr
);
std
::
shared_ptr
<
FsWriteChannel
>
open_w
(
const
FsChannelConfig
&
config
,
uint32_t
buffer_size
=
0
,
int
*
err_no
=
nullptr
);
// remove file in path, path maybe a reg, such as 'part-000-*'
void
remove
(
const
std
::
string
&
path
);
void
remove_dir
(
const
std
::
string
&
dir
);
// list files in path, path maybe a dir with reg
std
::
vector
<
std
::
string
>
list
(
const
std
::
string
&
path
);
// exist or not
bool
exist
(
const
std
::
string
&
dir
);
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/common/cost_timer.h
0 → 100644
浏览文件 @
29c6bcbf
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <unordered_map>
#include "butil/time.h"
#include "bvar/latency_recorder.h"
#include "glog/logging.h"
namespace
paddle
{
namespace
distributed
{
struct
CostProfilerNode
{
std
::
shared_ptr
<
bvar
::
LatencyRecorder
>
recorder
;
};
class
CostProfiler
{
public:
~
CostProfiler
()
{}
static
CostProfiler
&
instance
()
{
static
CostProfiler
profiler
;
return
profiler
;
}
void
register_profiler
(
const
std
::
string
&
label
)
{
if
(
_cost_profiler_map
.
find
(
label
)
!=
_cost_profiler_map
.
end
())
{
return
;
}
auto
profiler_node
=
std
::
make_shared
<
CostProfilerNode
>
();
profiler_node
->
recorder
.
reset
(
new
bvar
::
LatencyRecorder
(
"cost_profiler"
,
label
));
_cost_profiler_map
[
label
]
=
profiler_node
;
}
CostProfilerNode
*
profiler
(
const
std
::
string
&
label
)
{
auto
itr
=
_cost_profiler_map
.
find
(
label
);
if
(
itr
!=
_cost_profiler_map
.
end
())
{
return
itr
->
second
.
get
();
}
return
NULL
;
}
private:
CostProfiler
()
{}
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
CostProfilerNode
>>
_cost_profiler_map
;
};
class
CostTimer
{
public:
explicit
CostTimer
(
const
std
::
string
&
label
)
{
_label
=
label
;
auto
&
profiler
=
CostProfiler
::
instance
();
_profiler_node
=
profiler
.
profiler
(
label
);
// 如果不在profiler中,则使用log输出耗时信息
_is_print_cost
=
_profiler_node
==
NULL
;
_start_time_ms
=
butil
::
gettimeofday_ms
();
}
explicit
CostTimer
(
CostProfilerNode
&
profiler_node
)
{
// NOLINT
_is_print_cost
=
false
;
_profiler_node
=
&
profiler_node
;
_start_time_ms
=
butil
::
gettimeofday_ms
();
}
~
CostTimer
()
{
if
(
_is_print_cost
)
{
LOG
(
INFO
)
<<
"CostTimer label:"
<<
_label
<<
", cost:"
<<
butil
::
gettimeofday_ms
()
-
_start_time_ms
<<
"ms"
;
}
else
{
*
(
_profiler_node
->
recorder
)
<<
butil
::
gettimeofday_ms
()
-
_start_time_ms
;
}
}
private:
std
::
string
_label
;
bool
_is_print_cost
;
uint64_t
_start_time_ms
;
CostProfilerNode
*
_profiler_node
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/common/utils.h
浏览文件 @
29c6bcbf
...
...
@@ -52,6 +52,20 @@ inline void ADD(int n, const T* x, const T y, T* z) {
}
}
template
<
typename
T
>
inline
void
DIV
(
int
n
,
const
T
x
,
const
T
*
y
,
T
*
z
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
z
[
i
]
=
x
/
y
[
i
];
}
}
template
<
typename
T
>
inline
void
ELE_MUL
(
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
z
[
i
]
=
x
[
i
]
*
y
[
i
];
}
}
static
bool
StartWith
(
const
std
::
string
&
str
,
const
std
::
string
&
substr
)
{
return
str
.
find
(
substr
)
==
0
;
}
...
...
@@ -91,5 +105,6 @@ inline double GetCurrentUS() {
gettimeofday
(
&
time
,
NULL
);
return
1e+6
*
time
.
tv_sec
+
time
.
tv_usec
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/service/env.h
浏览文件 @
29c6bcbf
...
...
@@ -144,8 +144,8 @@ class PSEnvironment {
virtual
std
::
vector
<
uint64_t
>
get_client_info
()
{
std
::
vector
<
uint64_t
>
client_info
;
for
(
auto
&
i
:
_ps_client_
sign_se
t
)
{
client_info
.
push_back
(
i
);
for
(
auto
&
i
:
_ps_client_
lis
t
)
{
client_info
.
push_back
(
i
.
serialize_to_uint64
()
);
}
return
client_info
;
}
...
...
@@ -250,7 +250,7 @@ class PaddlePSEnvironment : public PSEnvironment {
return
0
;
}
virtual
int32_t
set_ps_clients
(
std
::
vector
<
std
::
string
>
*
host_sign_list
,
virtual
int32_t
set_ps_clients
(
const
std
::
vector
<
std
::
string
>
*
host_sign_list
,
int
node_num
)
{
_ps_client_list
.
clear
();
_ps_client_sign_set
.
clear
();
...
...
@@ -265,6 +265,7 @@ class PaddlePSEnvironment : public PSEnvironment {
std
::
sort
(
_ps_client_list
.
begin
(),
_ps_client_list
.
end
(),
[](
const
PSHost
&
h1
,
const
PSHost
&
h2
)
{
return
h1
.
rank
<
h2
.
rank
;
});
VLOG
(
1
)
<<
"env.set_ps_clients done
\n
"
;
return
0
;
}
...
...
paddle/fluid/distributed/service/ps_client.h
浏览文件 @
29c6bcbf
...
...
@@ -20,11 +20,13 @@
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/common/cost_timer.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/service/env.h"
#include "paddle/fluid/distributed/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/table/accessor.h"
#include "paddle/fluid/distributed/table/graph/graph_node.h"
#include "paddle/fluid/platform/timer.h"
namespace
paddle
{
namespace
distributed
{
...
...
@@ -35,7 +37,7 @@ using paddle::distributed::PsResponseMessage;
typedef
std
::
function
<
void
(
void
*
)
>
PSClientCallBack
;
class
PSClientClosure
:
public
google
::
protobuf
::
Closure
{
public:
PSClientClosure
(
PSClientCallBack
callback
)
:
_callback
(
callback
)
{}
explicit
PSClientClosure
(
PSClientCallBack
callback
)
:
_callback
(
callback
)
{}
virtual
~
PSClientClosure
()
{}
virtual
void
set_promise_value
(
int
value
)
{
for
(
auto
&
promise
:
_promises
)
{
...
...
@@ -43,12 +45,17 @@ class PSClientClosure : public google::protobuf::Closure {
}
}
void
add_promise
(
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>
&
promise
)
{
void
add_promise
(
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>
&
promise
)
{
// NOLINT
_promises
.
push_back
(
promise
);
}
void
add_timer
(
std
::
shared_ptr
<
CostTimer
>
&
timer
)
{
// NOLINT
_timers
.
push_back
(
timer
);
}
protected:
PSClientCallBack
_callback
;
std
::
vector
<
std
::
shared_ptr
<
CostTimer
>>
_timers
;
std
::
vector
<
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>>
_promises
;
};
...
...
@@ -59,11 +66,11 @@ class PSClient {
PSClient
(
PSClient
&&
)
=
delete
;
PSClient
(
const
PSClient
&
)
=
delete
;
virtual
int32_t
configure
(
virtual
int32_t
configure
(
// NOLINT
const
PSParameter
&
config
,
const
std
::
map
<
uint64_t
,
std
::
vector
<
paddle
::
distributed
::
Region
>>
&
regions
,
PSEnvironment
&
_env
,
size_t
client_id
)
final
;
PSEnvironment
&
_env
,
size_t
client_id
)
final
;
// NOLINT
virtual
int32_t
create_client2client_connection
(
int
pserver_timeout_ms
,
int
pserver_connect_timeout_ms
,
...
...
@@ -86,7 +93,7 @@ class PSClient {
virtual
std
::
future
<
int32_t
>
save
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
=
0
;
//清空table数据
//
清空table数据
virtual
std
::
future
<
int32_t
>
clear
()
=
0
;
virtual
std
::
future
<
int32_t
>
clear
(
uint32_t
table_id
)
=
0
;
...
...
@@ -98,7 +105,7 @@ class PSClient {
// server将参数区块中配置的某一维提取返回
// 返回数据解包后填充到累计的多个buffer中
virtual
std
::
future
<
int32_t
>
pull_dense
(
Region
*
regions
,
size_t
region_num
,
size_t
table_id
)
=
0
;
//保留
size_t
table_id
)
=
0
;
//
保留
// firstly push dense param for parameter server
// this is neccessary because dense weight initialized in trainer on cold
...
...
@@ -107,6 +114,9 @@ class PSClient {
size_t
region_num
,
size_t
table_id
)
=
0
;
// virtual std::future<int32_t> push_dense(const Region *regions,
// size_t region_num,
// size_t table_id) = 0;
// 使用keys进行pull请求,结果填充values
// keys和values的个数均为num个,每个value占用select_size空间
// future结束前keys和values缓冲区不能再次使用
...
...
@@ -212,6 +222,10 @@ class PSClient {
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
,
void
*
done
)
=
0
;
// virtual std::future<int32_t> push_sparse(size_t table_id,
// const uint64_t *keys,
// const float **update_values,
// size_t num) = 0;
protected:
virtual
int32_t
initialize
()
=
0
;
...
...
@@ -222,8 +236,42 @@ class PSClient {
PSEnvironment
*
_env
;
std
::
unordered_map
<
uint32_t
,
std
::
shared_ptr
<
ValueAccessor
>>
_table_accessors
;
std
::
unordered_map
<
int32_t
,
MsgHandlerFunc
>
_msg_handler_map
;
//处理client2client消息
_msg_handler_map
;
// 处理client2client消息
};
template
<
class
T
>
class
AsyncRequestTask
{
public:
AsyncRequestTask
()
:
_promise
(
std
::
make_shared
<
std
::
promise
<
int32_t
>>
())
{}
AsyncRequestTask
(
T
&
data
,
size_t
table_id
,
std
::
shared_ptr
<
CostTimer
>
&
timer
)
:
_table_id
(
table_id
),
_timer
(
timer
),
_promise
(
std
::
make_shared
<
std
::
promise
<
int32_t
>>
())
{
_data
=
std
::
move
(
data
);
}
AsyncRequestTask
(
AsyncRequestTask
&
data
)
// NOLINT
:
_table_id
(
data
.
table_id
()),
_timer
(
data
.
timer
()),
_promise
(
data
.
promise
())
{
_data
=
std
::
move
(
data
.
data
());
}
~
AsyncRequestTask
()
{}
inline
T
&
data
()
{
return
_data
;
}
inline
size_t
table_id
()
{
return
_table_id
;
}
inline
std
::
shared_ptr
<
CostTimer
>
&
timer
()
{
return
_timer
;
}
inline
std
::
future
<
int32_t
>
get_future
()
{
return
_promise
->
get_future
();
}
inline
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>
&
promise
()
{
return
_promise
;
}
private:
T
_data
;
size_t
_table_id
;
std
::
shared_ptr
<
CostTimer
>
_timer
;
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>
_promise
;
};
REGISTER_PSCORE_REGISTERER
(
PSClient
);
class
PSClientFactory
{
...
...
paddle/fluid/distributed/table/accessor.h
浏览文件 @
29c6bcbf
...
...
@@ -17,15 +17,12 @@
#include <stdio.h>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/common/afs_warpper.h"
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/ps.pb.h"
namespace
paddle
{
namespace
distributed
{
struct
FsDataConverter
{
std
::
string
converter
;
std
::
string
deconverter
;
};
struct
Region
{
Region
()
:
data
(
NULL
),
size
(
0
)
{}
...
...
@@ -50,8 +47,8 @@ struct DataConverter {
class
ValueAccessor
{
public:
explicit
ValueAccessor
(){};
virtual
~
ValueAccessor
()
{};
ValueAccessor
()
{}
virtual
~
ValueAccessor
()
{}
virtual
int
configure
(
const
TableAccessorParameter
&
parameter
)
{
_config
=
parameter
;
...
...
paddle/fluid/distributed/table/depends/dense.h
浏览文件 @
29c6bcbf
...
...
@@ -183,5 +183,159 @@ class DAdam : public DenseOptimizer {
float
epsilon
;
};
// adam optimizer for dense tensor
class
DAdamD2Sum
:
public
DenseOptimizer
{
public:
explicit
DAdamD2Sum
(
const
CommonAccessorParameter
&
accessor
,
std
::
vector
<
std
::
vector
<
float
>>*
values
)
{
lr_hardcode
=
5e-6
;
auto
&
names
=
accessor
.
params
();
for
(
int
x
=
0
;
x
<
static_cast
<
int
>
(
names
.
size
());
++
x
)
{
if
(
names
[
x
]
==
"LearningRate"
)
{
learning_rate
=
(
*
values
)[
x
].
data
();
}
if
(
names
[
x
]
==
"Param"
)
{
param
=
(
*
values
)[
x
].
data
();
}
if
(
names
[
x
]
==
"Moment"
)
{
mom_velocity
=
(
*
values
)[
x
].
data
();
}
if
(
names
[
x
]
==
"G2Sum"
)
{
ada_g2sum
=
(
*
values
)[
x
].
data
();
}
if
(
names
[
x
]
==
"D2Sum"
)
{
ada_d2sum
=
(
*
values
)[
x
].
data
();
}
if
(
names
[
x
]
==
"MomentDecayRate"
)
{
mom_decay_rate
=
(
*
values
)[
x
].
data
();
}
if
(
names
[
x
]
==
"AdaDecayRate"
)
{
ada_decay_rate
=
(
*
values
)[
x
].
data
();
}
if
(
names
[
x
]
==
"AdaEpsilon"
)
{
ada_epsilon
=
(
*
values
)[
x
].
data
();
}
}
}
void
update
(
const
float
*
update_values
,
size_t
num
,
int
begin
,
int
end
)
override
{
auto
update_numel
=
end
-
begin
;
/*
// for debug
std::cout << "before update:\n";
for (int i = 0; i < 3; ++ i) {
std::cout << "param: " << i << " " << *(param+begin+i) <<
"grad: " << *(update_values+begin+i) << "\n";
}*/
std
::
vector
<
float
>
grad
,
grad2
,
scale
;
grad
.
resize
(
update_numel
);
grad2
.
resize
(
update_numel
);
scale
.
resize
(
update_numel
);
auto
blas
=
GetBlas
<
float
>
();
// copy grad
blas
.
VCOPY
(
update_numel
,
update_values
+
begin
,
grad
.
data
());
blas
.
VCOPY
(
update_numel
,
update_values
+
begin
,
grad2
.
data
());
/*
for (int i = 0; i < end-begin; ++ i) {
std::cout << "copy grad: " << i << " " << *(grad.data()+begin+i) <<
"copy grad2: " << *(grad2.data()+begin+i) << "\n";
}
for (int i = 0; i < 3; ++ i) {
std::cout << "d2sum before: " << i << " " << *(ada_d2sum+begin+i) << "\n";
}*/
// d2sum
blas
.
SCAL
(
update_numel
,
ada_decay_rate
[
0
],
ada_d2sum
+
begin
);
ADD
<
float
>
(
update_numel
,
ada_d2sum
+
begin
,
1
,
ada_d2sum
+
begin
);
/*
for (int i = 0; i < end-begin; ++ i) {
std::cout << "d2sum update: " << i << " " << *(ada_d2sum+begin+i) << "\n";
}
for (int i = 0; i < 3; ++ i) {
std::cout << "g2sum before: " << i << " " << *(ada_g2sum+begin+i) << "\n";
}*/
// g2sum
blas
.
SCAL
(
update_numel
,
ada_decay_rate
[
0
],
ada_g2sum
+
begin
);
blas
.
VSQUARE
(
update_numel
,
grad2
.
data
(),
grad2
.
data
());
blas
.
VADD
(
update_numel
,
ada_g2sum
+
begin
,
grad2
.
data
(),
ada_g2sum
+
begin
);
/*
for (int i = 0; i < end-begin; ++ i) {
std::cout << "g2sum update: " << i << " " << *(ada_g2sum+begin+i) << "\n";
}
for (int i = 0; i < 3; ++ i) {
std::cout << "mom before: " << i << " " << *(mom_velocity+begin+i) <<
"\n";
}*/
// mom
blas
.
SCAL
(
update_numel
,
mom_decay_rate
[
0
],
mom_velocity
+
begin
);
blas
.
SCAL
(
update_numel
,
1
-
mom_decay_rate
[
0
],
grad
.
data
());
blas
.
VADD
(
update_numel
,
mom_velocity
+
begin
,
grad
.
data
(),
mom_velocity
+
begin
);
/*
for (int i = 0; i < end-begin; ++ i) {
std::cout << "mom update: " << i << " " << *(mom_velocity+begin+i) <<
"\n";
}
for (int i = 0; i < 3; ++ i) {
std::cout << "scale before: " << i << " " << *(scale.data()+begin+i) <<
"\n";
}*/
// scale
float
*
scale_
=
scale
.
data
();
blas
.
VDIV
(
update_numel
,
ada_g2sum
+
begin
,
ada_d2sum
+
begin
,
scale_
);
ADD
<
float
>
(
update_numel
,
scale_
,
ada_epsilon
[
0
],
scale_
);
DIV
<
float
>
(
update_numel
,
1
+
ada_epsilon
[
0
],
scale_
,
scale_
);
SQRT
<
float
>
(
update_numel
,
scale_
,
scale_
);
/*
for (int i = 0; i < 3; ++ i) {
std::cout << "scale update: " << i << " " << *(scale.data()+begin+i) <<
"\n";
}*/
blas
.
SCAL
(
update_numel
,
learning_rate
[
0
],
scale_
);
// TODO(zhaocaibei123): check if there exists elementwise_multiply in blas
// TODO(zhaocaibei123): blas.VMUL
ELE_MUL
<
float
>
(
update_numel
,
scale_
,
mom_velocity
+
begin
,
scale_
);
/*
for (int i = 0; i < 3; ++ i) {
std::cout << "scale update2: " << i << " " << *(scale.data()+begin+i) <<
"\n";
}*/
blas
.
VSUB
(
update_numel
,
param
+
begin
,
scale_
,
param
+
begin
);
/*
for (int i = 0; i < end-begin; ++ i) {
std::cout << "param update " << i << " " << *(param+begin+i) << "\n";
}*/
}
float
*
learning_rate
;
float
lr_hardcode
;
float
*
param
;
float
*
mom_velocity
;
float
*
ada_g2sum
;
float
*
ada_d2sum
;
float
*
mom_decay_rate
;
float
*
ada_decay_rate
;
float
*
ada_epsilon
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/framework/distributed_strategy.proto
浏览文件 @
29c6bcbf
...
...
@@ -173,6 +173,68 @@ message TensorParallelConfig {
optional
int32
tensor_init_seed
=
2
[
default
=
-
1
];
}
enum
TableType
{
PS_SPARSE_TABLE
=
0
;
PS_DENSE_TABLE
=
1
;
}
message
TableParameter
{
optional
uint64
table_id
=
1
;
optional
string
table_class
=
2
;
optional
uint64
shard_num
=
3
;
optional
TableType
type
=
4
;
optional
TableAccessorParameter
accessor
=
5
;
}
message
TableAccessorParameter
{
optional
string
accessor_class
=
1
;
optional
SGDParameter
embed_sgd_param
=
2
;
optional
SGDParameter
embedx_sgd_param
=
3
;
optional
uint32
fea_dim
=
4
;
// for sparse table, this means field size of one
// value; for dense table, this means total value
// num
optional
uint32
embedx_dim
=
5
;
// embedx feature size
optional
uint32
embedx_threshold
=
6
;
// embedx feature create threshold
optional
CtrAccessorParameter
ctr_accessor_param
=
7
;
}
// TODO(guanqun): add NaiveSGD/Adam...
message
SGDParameter
{
optional
string
name
=
1
;
optional
SGDRuleParameter
adagrad
=
2
;
}
message
SGDRuleParameter
{
optional
double
learning_rate
=
1
;
optional
double
initial_g2sum
=
2
;
optional
double
initial_range
=
3
[
default
=
0
];
repeated
float
weight_bounds
=
4
;
}
message
CtrAccessorParameter
{
optional
float
nonclk_coeff
=
1
;
// to calculate show_click_score
optional
float
click_coeff
=
2
;
// to calculate show_click_score
optional
float
base_threshold
=
3
;
// show_click_score > base_threshold, this feature can be saved
optional
float
delta_threshold
=
4
;
// delta_score > delta_threshold, this feature can be saved
optional
float
delta_keep_days
=
5
;
// unseen_day < delta_keep_days, this feature can be saved
optional
float
show_click_decay_rate
=
6
;
// show/click will update to
// show/click *
// show_click_decay_rate after a day
optional
float
delete_threshold
=
7
;
// threshold to shrink a feasign
optional
float
delete_after_unseen_days
=
8
;
optional
int32
ssd_unseenday_threshold
=
9
;
}
message
FsClientParameter
{
optional
string
uri
=
1
;
optional
string
user
=
2
;
optional
string
passwd
=
3
;
optional
string
hadoop_bin
=
4
;
}
message
DistributedStrategy
{
// bool options
optional
Mode
mode
=
1
[
default
=
COLLECTIVE
];
...
...
@@ -210,6 +272,7 @@ message DistributedStrategy {
optional
bool
asp
=
33
[
default
=
false
];
optional
bool
fuse_grad_merge
=
34
[
default
=
false
];
optional
bool
semi_auto
=
35
[
default
=
false
];
optional
bool
adam_d2sum
=
36
[
default
=
true
];
optional
RecomputeConfig
recompute_configs
=
101
;
optional
AMPConfig
amp_configs
=
102
;
...
...
@@ -225,6 +288,9 @@ message DistributedStrategy {
optional
HybridConfig
hybrid_configs
=
112
;
optional
TensorParallelConfig
tensor_parallel_configs
=
113
;
optional
TrainerDescConfig
trainer_desc_configs
=
114
;
optional
TableParameter
downpour_table_param
=
115
;
optional
FsClientParameter
fs_client_param
=
116
;
optional
BuildStrategy
build_strategy
=
201
;
optional
ExecutionStrategy
execution_strategy
=
202
;
optional
GradientScaleConfig
gradient_scale_configs
=
203
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录