Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
SummerGao.
Paddle
提交
29c6bcbf
P
Paddle
项目概览
SummerGao.
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
29c6bcbf
编写于
11月 01, 2021
作者:
Z
zhaocaibei123
提交者:
GitHub
11月 01, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
memory sparse table & brpc communication upgrade dependency (#36734)
上级
249081b6
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
640 addition
and
16 deletion
+640
-16
paddle/fluid/distributed/CMakeLists.txt
paddle/fluid/distributed/CMakeLists.txt
+1
-0
paddle/fluid/distributed/common/CMakeLists.txt
paddle/fluid/distributed/common/CMakeLists.txt
+4
-0
paddle/fluid/distributed/common/afs_warpper.cc
paddle/fluid/distributed/common/afs_warpper.cc
+89
-0
paddle/fluid/distributed/common/afs_warpper.h
paddle/fluid/distributed/common/afs_warpper.h
+156
-0
paddle/fluid/distributed/common/cost_timer.h
paddle/fluid/distributed/common/cost_timer.h
+93
-0
paddle/fluid/distributed/common/utils.h
paddle/fluid/distributed/common/utils.h
+15
-0
paddle/fluid/distributed/service/env.h
paddle/fluid/distributed/service/env.h
+4
-3
paddle/fluid/distributed/service/ps_client.h
paddle/fluid/distributed/service/ps_client.h
+55
-7
paddle/fluid/distributed/table/accessor.h
paddle/fluid/distributed/table/accessor.h
+3
-6
paddle/fluid/distributed/table/depends/dense.h
paddle/fluid/distributed/table/depends/dense.h
+154
-0
paddle/fluid/framework/distributed_strategy.proto
paddle/fluid/framework/distributed_strategy.proto
+66
-0
未找到文件。
paddle/fluid/distributed/CMakeLists.txt
浏览文件 @
29c6bcbf
...
...
@@ -11,6 +11,7 @@ if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
"
${
DISTRIBUTE_COMPILE_FLAGS
}
-faligned-new"
)
endif
()
add_subdirectory
(
common
)
add_subdirectory
(
service
)
add_subdirectory
(
table
)
add_subdirectory
(
test
)
...
...
paddle/fluid/distributed/common/CMakeLists.txt
0 → 100644
浏览文件 @
29c6bcbf
cc_library
(
afs_wrapper SRCS afs_warpper.cc DEPS fs ps_framework_proto
)
#set_property(GLOBAL PROPERTY COMMON_DEPS afs_warpper)
paddle/fluid/distributed/common/afs_warpper.cc
0 → 100644
浏览文件 @
29c6bcbf
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/common/afs_warpper.h"
#include "paddle/fluid/framework/io/fs.h"
namespace
paddle
{
namespace
distributed
{
// AfsClient impl
int
AfsClient
::
initialize
(
const
FsClientParameter
&
fs_client_param
)
{
// temporarily implemented with hdfs-client
return
initialize
(
fs_client_param
.
hadoop_bin
(),
fs_client_param
.
uri
(),
fs_client_param
.
user
(),
fs_client_param
.
passwd
(),
fs_client_param
.
buffer_size
());
}
int
AfsClient
::
initialize
(
const
std
::
string
&
hadoop_bin
,
const
std
::
string
&
uri
,
const
std
::
string
&
user
,
const
std
::
string
&
passwd
,
int
buffer_size_param
)
{
return
initialize
(
hadoop_bin
,
uri
,
paddle
::
string
::
format_string
(
"%s,%s"
,
user
.
c_str
(),
passwd
.
c_str
()),
buffer_size_param
);
}
int
AfsClient
::
initialize
(
const
std
::
string
&
hadoop_bin
,
const
std
::
string
&
uri
,
const
std
::
string
&
ugi
,
int
buffer_size_param
)
{
// temporarily implemented with hdfs-client
size_t
buffer_size
=
1L
<<
25
;
// 32MB
if
(
buffer_size_param
>
static_cast
<
int
>
(
buffer_size
))
{
buffer_size
=
buffer_size_param
;
}
paddle
::
framework
::
hdfs_set_buffer_size
(
buffer_size
);
paddle
::
framework
::
hdfs_set_command
(
paddle
::
string
::
format_string
(
"2>>./hdfs_err.log %s fs -Dfs.default.name=%s -Dhadoop.job.ugi=%s "
"-Ddfs.client.block.write.retries=15 -Ddfs.rpc.timeout=300000"
,
hadoop_bin
.
c_str
(),
uri
.
c_str
(),
ugi
.
c_str
()));
return
0
;
}
// open file in 'w' or 'r'
std
::
shared_ptr
<
FsReadChannel
>
AfsClient
::
open_r
(
const
FsChannelConfig
&
config
,
uint32_t
buffer_size
,
int
*
err_no
)
{
std
::
shared_ptr
<
FsReadChannel
>
channel
=
std
::
make_shared
<
FsReadChannel
>
(
buffer_size
);
std
::
shared_ptr
<
FILE
>
fp
=
paddle
::
framework
::
fs_open_read
(
config
.
path
,
err_no
,
config
.
deconverter
);
channel
->
open
(
fp
,
config
);
return
channel
;
}
std
::
shared_ptr
<
FsWriteChannel
>
AfsClient
::
open_w
(
const
FsChannelConfig
&
config
,
uint32_t
buffer_size
,
int
*
err_no
)
{
std
::
shared_ptr
<
FsWriteChannel
>
channel
=
std
::
make_shared
<
FsWriteChannel
>
(
buffer_size
);
std
::
shared_ptr
<
FILE
>
fp
=
paddle
::
framework
::
fs_open_write
(
config
.
path
,
err_no
,
config
.
converter
);
channel
->
open
(
fp
,
config
);
return
channel
;
}
// remove file in path, path maybe a reg, such as 'part-000-*'
void
AfsClient
::
remove
(
const
std
::
string
&
path
)
{
return
paddle
::
framework
::
fs_remove
(
path
);
}
void
AfsClient
::
remove_dir
(
const
std
::
string
&
dir
)
{
return
paddle
::
framework
::
fs_remove
(
dir
);
}
// list files in path, path maybe a dir with reg
std
::
vector
<
std
::
string
>
AfsClient
::
list
(
const
std
::
string
&
path
)
{
return
paddle
::
framework
::
fs_list
(
path
);
}
// exist or not
bool
AfsClient
::
exist
(
const
std
::
string
&
dir
)
{
return
paddle
::
framework
::
fs_exists
(
dir
);
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/common/afs_warpper.h
0 → 100644
浏览文件 @
29c6bcbf
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/string/string_helper.h"
namespace
paddle
{
namespace
distributed
{
struct
FsDataConverter
{
std
::
string
converter
;
std
::
string
deconverter
;
};
struct
FsChannelConfig
{
std
::
string
path
;
// path of file
std
::
string
converter
;
// data converter
std
::
string
deconverter
;
};
class
FsReadChannel
{
public:
FsReadChannel
()
:
_buffer_size
(
0
)
{}
explicit
FsReadChannel
(
uint32_t
buffer_size
)
:
_buffer_size
(
buffer_size
)
{}
virtual
~
FsReadChannel
()
{}
FsReadChannel
(
FsReadChannel
&&
)
=
delete
;
FsReadChannel
(
const
FsReadChannel
&
)
=
delete
;
int
open
(
std
::
shared_ptr
<
FILE
>
fp
,
const
FsChannelConfig
&
config
)
{
_file
=
fp
;
return
0
;
}
inline
int
close
()
{
_file
.
reset
();
return
0
;
}
inline
uint32_t
read_line
(
std
::
string
&
line_data
)
{
// NOLINT
line_data
.
clear
();
char
buffer
=
'\0'
;
size_t
read_count
=
0
;
while
(
1
==
fread
(
&
buffer
,
1
,
1
,
_file
.
get
())
&&
buffer
!=
'\n'
)
{
++
read_count
;
line_data
.
append
(
&
buffer
,
1
);
}
if
(
read_count
==
0
&&
buffer
!=
'\n'
)
{
return
-
1
;
}
return
0
;
}
private:
uint32_t
_buffer_size
;
FsChannelConfig
_config
;
std
::
shared_ptr
<
FILE
>
_file
;
};
class
FsWriteChannel
{
public:
FsWriteChannel
()
:
_buffer_size
(
0
)
{}
explicit
FsWriteChannel
(
uint32_t
buffer_size
)
:
_buffer_size
(
buffer_size
)
{}
virtual
~
FsWriteChannel
()
{}
FsWriteChannel
(
FsWriteChannel
&&
)
=
delete
;
FsWriteChannel
(
const
FsWriteChannel
&
)
=
delete
;
int
open
(
std
::
shared_ptr
<
FILE
>
fp
,
const
FsChannelConfig
&
config
)
{
_file
=
fp
;
// the buffer has set in fs.cc
// if (_buffer_size != 0) {
// _buffer = std::shared_ptr<char>(new char[_buffer_size]);
// CHECK(0 == setvbuf(&*_file, _buffer.get(), _IOFBF, _buffer_size));
//}
return
0
;
}
inline
void
flush
()
{
return
;
}
inline
int
close
()
{
flush
();
_file
.
reset
();
return
0
;
}
inline
uint32_t
write_line
(
const
char
*
data
,
uint32_t
size
)
{
size_t
write_count
=
fwrite_unlocked
(
data
,
1
,
size
,
_file
.
get
());
if
(
write_count
!=
size
)
{
return
-
1
;
}
write_count
=
fwrite_unlocked
(
"
\n
"
,
1
,
1
,
_file
.
get
());
if
(
write_count
!=
1
)
{
return
-
1
;
}
return
0
;
}
inline
uint32_t
write_line
(
const
std
::
string
&
data
)
{
return
write_line
(
data
.
c_str
(),
data
.
size
());
}
private:
uint32_t
_buffer_size
;
FsChannelConfig
_config
;
std
::
shared_ptr
<
FILE
>
_file
;
std
::
shared_ptr
<
char
>
_buffer
;
};
class
AfsClient
{
public:
AfsClient
()
{}
virtual
~
AfsClient
()
{}
AfsClient
(
AfsClient
&&
)
=
delete
;
AfsClient
(
const
AfsClient
&
)
=
delete
;
int
initialize
(
const
FsClientParameter
&
fs_client_param
);
int
initialize
(
const
std
::
string
&
hadoop_bin
,
const
std
::
string
&
uri
,
const
std
::
string
&
user
,
const
std
::
string
&
passwd
,
int
buffer_size_param
=
(
1L
<<
25
));
int
initialize
(
const
std
::
string
&
hadoop_bin
,
const
std
::
string
&
uri
,
const
std
::
string
&
ugi
,
int
buffer_size_param
=
(
1L
<<
25
));
// open file in 'w' or 'r'
std
::
shared_ptr
<
FsReadChannel
>
open_r
(
const
FsChannelConfig
&
config
,
uint32_t
buffer_size
=
0
,
int
*
err_no
=
nullptr
);
std
::
shared_ptr
<
FsWriteChannel
>
open_w
(
const
FsChannelConfig
&
config
,
uint32_t
buffer_size
=
0
,
int
*
err_no
=
nullptr
);
// remove file in path, path maybe a reg, such as 'part-000-*'
void
remove
(
const
std
::
string
&
path
);
void
remove_dir
(
const
std
::
string
&
dir
);
// list files in path, path maybe a dir with reg
std
::
vector
<
std
::
string
>
list
(
const
std
::
string
&
path
);
// exist or not
bool
exist
(
const
std
::
string
&
dir
);
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/common/cost_timer.h
0 → 100644
浏览文件 @
29c6bcbf
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <unordered_map>
#include "butil/time.h"
#include "bvar/latency_recorder.h"
#include "glog/logging.h"
namespace
paddle
{
namespace
distributed
{
struct
CostProfilerNode
{
std
::
shared_ptr
<
bvar
::
LatencyRecorder
>
recorder
;
};
class
CostProfiler
{
public:
~
CostProfiler
()
{}
static
CostProfiler
&
instance
()
{
static
CostProfiler
profiler
;
return
profiler
;
}
void
register_profiler
(
const
std
::
string
&
label
)
{
if
(
_cost_profiler_map
.
find
(
label
)
!=
_cost_profiler_map
.
end
())
{
return
;
}
auto
profiler_node
=
std
::
make_shared
<
CostProfilerNode
>
();
profiler_node
->
recorder
.
reset
(
new
bvar
::
LatencyRecorder
(
"cost_profiler"
,
label
));
_cost_profiler_map
[
label
]
=
profiler_node
;
}
CostProfilerNode
*
profiler
(
const
std
::
string
&
label
)
{
auto
itr
=
_cost_profiler_map
.
find
(
label
);
if
(
itr
!=
_cost_profiler_map
.
end
())
{
return
itr
->
second
.
get
();
}
return
NULL
;
}
private:
CostProfiler
()
{}
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
CostProfilerNode
>>
_cost_profiler_map
;
};
class
CostTimer
{
public:
explicit
CostTimer
(
const
std
::
string
&
label
)
{
_label
=
label
;
auto
&
profiler
=
CostProfiler
::
instance
();
_profiler_node
=
profiler
.
profiler
(
label
);
// 如果不在profiler中,则使用log输出耗时信息
_is_print_cost
=
_profiler_node
==
NULL
;
_start_time_ms
=
butil
::
gettimeofday_ms
();
}
explicit
CostTimer
(
CostProfilerNode
&
profiler_node
)
{
// NOLINT
_is_print_cost
=
false
;
_profiler_node
=
&
profiler_node
;
_start_time_ms
=
butil
::
gettimeofday_ms
();
}
~
CostTimer
()
{
if
(
_is_print_cost
)
{
LOG
(
INFO
)
<<
"CostTimer label:"
<<
_label
<<
", cost:"
<<
butil
::
gettimeofday_ms
()
-
_start_time_ms
<<
"ms"
;
}
else
{
*
(
_profiler_node
->
recorder
)
<<
butil
::
gettimeofday_ms
()
-
_start_time_ms
;
}
}
private:
std
::
string
_label
;
bool
_is_print_cost
;
uint64_t
_start_time_ms
;
CostProfilerNode
*
_profiler_node
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/common/utils.h
浏览文件 @
29c6bcbf
...
...
@@ -52,6 +52,20 @@ inline void ADD(int n, const T* x, const T y, T* z) {
}
}
template
<
typename
T
>
inline
void
DIV
(
int
n
,
const
T
x
,
const
T
*
y
,
T
*
z
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
z
[
i
]
=
x
/
y
[
i
];
}
}
template
<
typename
T
>
inline
void
ELE_MUL
(
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
z
[
i
]
=
x
[
i
]
*
y
[
i
];
}
}
static
bool
StartWith
(
const
std
::
string
&
str
,
const
std
::
string
&
substr
)
{
return
str
.
find
(
substr
)
==
0
;
}
...
...
@@ -91,5 +105,6 @@ inline double GetCurrentUS() {
gettimeofday
(
&
time
,
NULL
);
return
1e+6
*
time
.
tv_sec
+
time
.
tv_usec
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/service/env.h
浏览文件 @
29c6bcbf
...
...
@@ -144,8 +144,8 @@ class PSEnvironment {
virtual
std
::
vector
<
uint64_t
>
get_client_info
()
{
std
::
vector
<
uint64_t
>
client_info
;
for
(
auto
&
i
:
_ps_client_
sign_se
t
)
{
client_info
.
push_back
(
i
);
for
(
auto
&
i
:
_ps_client_
lis
t
)
{
client_info
.
push_back
(
i
.
serialize_to_uint64
()
);
}
return
client_info
;
}
...
...
@@ -250,7 +250,7 @@ class PaddlePSEnvironment : public PSEnvironment {
return
0
;
}
virtual
int32_t
set_ps_clients
(
std
::
vector
<
std
::
string
>
*
host_sign_list
,
virtual
int32_t
set_ps_clients
(
const
std
::
vector
<
std
::
string
>
*
host_sign_list
,
int
node_num
)
{
_ps_client_list
.
clear
();
_ps_client_sign_set
.
clear
();
...
...
@@ -265,6 +265,7 @@ class PaddlePSEnvironment : public PSEnvironment {
std
::
sort
(
_ps_client_list
.
begin
(),
_ps_client_list
.
end
(),
[](
const
PSHost
&
h1
,
const
PSHost
&
h2
)
{
return
h1
.
rank
<
h2
.
rank
;
});
VLOG
(
1
)
<<
"env.set_ps_clients done
\n
"
;
return
0
;
}
...
...
paddle/fluid/distributed/service/ps_client.h
浏览文件 @
29c6bcbf
...
...
@@ -20,11 +20,13 @@
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/common/cost_timer.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/service/env.h"
#include "paddle/fluid/distributed/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/table/accessor.h"
#include "paddle/fluid/distributed/table/graph/graph_node.h"
#include "paddle/fluid/platform/timer.h"
namespace
paddle
{
namespace
distributed
{
...
...
@@ -35,7 +37,7 @@ using paddle::distributed::PsResponseMessage;
typedef
std
::
function
<
void
(
void
*
)
>
PSClientCallBack
;
class
PSClientClosure
:
public
google
::
protobuf
::
Closure
{
public:
PSClientClosure
(
PSClientCallBack
callback
)
:
_callback
(
callback
)
{}
explicit
PSClientClosure
(
PSClientCallBack
callback
)
:
_callback
(
callback
)
{}
virtual
~
PSClientClosure
()
{}
virtual
void
set_promise_value
(
int
value
)
{
for
(
auto
&
promise
:
_promises
)
{
...
...
@@ -43,12 +45,17 @@ class PSClientClosure : public google::protobuf::Closure {
}
}
void
add_promise
(
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>
&
promise
)
{
void
add_promise
(
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>
&
promise
)
{
// NOLINT
_promises
.
push_back
(
promise
);
}
void
add_timer
(
std
::
shared_ptr
<
CostTimer
>
&
timer
)
{
// NOLINT
_timers
.
push_back
(
timer
);
}
protected:
PSClientCallBack
_callback
;
std
::
vector
<
std
::
shared_ptr
<
CostTimer
>>
_timers
;
std
::
vector
<
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>>
_promises
;
};
...
...
@@ -59,11 +66,11 @@ class PSClient {
PSClient
(
PSClient
&&
)
=
delete
;
PSClient
(
const
PSClient
&
)
=
delete
;
virtual
int32_t
configure
(
virtual
int32_t
configure
(
// NOLINT
const
PSParameter
&
config
,
const
std
::
map
<
uint64_t
,
std
::
vector
<
paddle
::
distributed
::
Region
>>
&
regions
,
PSEnvironment
&
_env
,
size_t
client_id
)
final
;
PSEnvironment
&
_env
,
size_t
client_id
)
final
;
// NOLINT
virtual
int32_t
create_client2client_connection
(
int
pserver_timeout_ms
,
int
pserver_connect_timeout_ms
,
...
...
@@ -86,7 +93,7 @@ class PSClient {
virtual
std
::
future
<
int32_t
>
save
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
=
0
;
//清空table数据
//
清空table数据
virtual
std
::
future
<
int32_t
>
clear
()
=
0
;
virtual
std
::
future
<
int32_t
>
clear
(
uint32_t
table_id
)
=
0
;
...
...
@@ -98,7 +105,7 @@ class PSClient {
// server将参数区块中配置的某一维提取返回
// 返回数据解包后填充到累计的多个buffer中
virtual
std
::
future
<
int32_t
>
pull_dense
(
Region
*
regions
,
size_t
region_num
,
size_t
table_id
)
=
0
;
//保留
size_t
table_id
)
=
0
;
//
保留
// firstly push dense param for parameter server
// this is neccessary because dense weight initialized in trainer on cold
...
...
@@ -107,6 +114,9 @@ class PSClient {
size_t
region_num
,
size_t
table_id
)
=
0
;
// virtual std::future<int32_t> push_dense(const Region *regions,
// size_t region_num,
// size_t table_id) = 0;
// 使用keys进行pull请求,结果填充values
// keys和values的个数均为num个,每个value占用select_size空间
// future结束前keys和values缓冲区不能再次使用
...
...
@@ -212,6 +222,10 @@ class PSClient {
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
,
void
*
done
)
=
0
;
// virtual std::future<int32_t> push_sparse(size_t table_id,
// const uint64_t *keys,
// const float **update_values,
// size_t num) = 0;
protected:
virtual
int32_t
initialize
()
=
0
;
...
...
@@ -222,8 +236,42 @@ class PSClient {
PSEnvironment
*
_env
;
std
::
unordered_map
<
uint32_t
,
std
::
shared_ptr
<
ValueAccessor
>>
_table_accessors
;
std
::
unordered_map
<
int32_t
,
MsgHandlerFunc
>
_msg_handler_map
;
//处理client2client消息
_msg_handler_map
;
// 处理client2client消息
};
template
<
class
T
>
class
AsyncRequestTask
{
public:
AsyncRequestTask
()
:
_promise
(
std
::
make_shared
<
std
::
promise
<
int32_t
>>
())
{}
AsyncRequestTask
(
T
&
data
,
size_t
table_id
,
std
::
shared_ptr
<
CostTimer
>
&
timer
)
:
_table_id
(
table_id
),
_timer
(
timer
),
_promise
(
std
::
make_shared
<
std
::
promise
<
int32_t
>>
())
{
_data
=
std
::
move
(
data
);
}
AsyncRequestTask
(
AsyncRequestTask
&
data
)
// NOLINT
:
_table_id
(
data
.
table_id
()),
_timer
(
data
.
timer
()),
_promise
(
data
.
promise
())
{
_data
=
std
::
move
(
data
.
data
());
}
~
AsyncRequestTask
()
{}
inline
T
&
data
()
{
return
_data
;
}
inline
size_t
table_id
()
{
return
_table_id
;
}
inline
std
::
shared_ptr
<
CostTimer
>
&
timer
()
{
return
_timer
;
}
inline
std
::
future
<
int32_t
>
get_future
()
{
return
_promise
->
get_future
();
}
inline
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>
&
promise
()
{
return
_promise
;
}
private:
T
_data
;
size_t
_table_id
;
std
::
shared_ptr
<
CostTimer
>
_timer
;
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>
_promise
;
};
REGISTER_PSCORE_REGISTERER
(
PSClient
);
class
PSClientFactory
{
...
...
paddle/fluid/distributed/table/accessor.h
浏览文件 @
29c6bcbf
...
...
@@ -17,15 +17,12 @@
#include <stdio.h>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/common/afs_warpper.h"
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/ps.pb.h"
namespace
paddle
{
namespace
distributed
{
struct
FsDataConverter
{
std
::
string
converter
;
std
::
string
deconverter
;
};
struct
Region
{
Region
()
:
data
(
NULL
),
size
(
0
)
{}
...
...
@@ -50,8 +47,8 @@ struct DataConverter {
class
ValueAccessor
{
public:
explicit
ValueAccessor
(){};
virtual
~
ValueAccessor
()
{};
ValueAccessor
()
{}
virtual
~
ValueAccessor
()
{}
virtual
int
configure
(
const
TableAccessorParameter
&
parameter
)
{
_config
=
parameter
;
...
...
paddle/fluid/distributed/table/depends/dense.h
浏览文件 @
29c6bcbf
...
...
@@ -183,5 +183,159 @@ class DAdam : public DenseOptimizer {
float
epsilon
;
};
// adam optimizer for dense tensor
class
DAdamD2Sum
:
public
DenseOptimizer
{
public:
explicit
DAdamD2Sum
(
const
CommonAccessorParameter
&
accessor
,
std
::
vector
<
std
::
vector
<
float
>>*
values
)
{
lr_hardcode
=
5e-6
;
auto
&
names
=
accessor
.
params
();
for
(
int
x
=
0
;
x
<
static_cast
<
int
>
(
names
.
size
());
++
x
)
{
if
(
names
[
x
]
==
"LearningRate"
)
{
learning_rate
=
(
*
values
)[
x
].
data
();
}
if
(
names
[
x
]
==
"Param"
)
{
param
=
(
*
values
)[
x
].
data
();
}
if
(
names
[
x
]
==
"Moment"
)
{
mom_velocity
=
(
*
values
)[
x
].
data
();
}
if
(
names
[
x
]
==
"G2Sum"
)
{
ada_g2sum
=
(
*
values
)[
x
].
data
();
}
if
(
names
[
x
]
==
"D2Sum"
)
{
ada_d2sum
=
(
*
values
)[
x
].
data
();
}
if
(
names
[
x
]
==
"MomentDecayRate"
)
{
mom_decay_rate
=
(
*
values
)[
x
].
data
();
}
if
(
names
[
x
]
==
"AdaDecayRate"
)
{
ada_decay_rate
=
(
*
values
)[
x
].
data
();
}
if
(
names
[
x
]
==
"AdaEpsilon"
)
{
ada_epsilon
=
(
*
values
)[
x
].
data
();
}
}
}
void
update
(
const
float
*
update_values
,
size_t
num
,
int
begin
,
int
end
)
override
{
auto
update_numel
=
end
-
begin
;
/*
// for debug
std::cout << "before update:\n";
for (int i = 0; i < 3; ++ i) {
std::cout << "param: " << i << " " << *(param+begin+i) <<
"grad: " << *(update_values+begin+i) << "\n";
}*/
std
::
vector
<
float
>
grad
,
grad2
,
scale
;
grad
.
resize
(
update_numel
);
grad2
.
resize
(
update_numel
);
scale
.
resize
(
update_numel
);
auto
blas
=
GetBlas
<
float
>
();
// copy grad
blas
.
VCOPY
(
update_numel
,
update_values
+
begin
,
grad
.
data
());
blas
.
VCOPY
(
update_numel
,
update_values
+
begin
,
grad2
.
data
());
/*
for (int i = 0; i < end-begin; ++ i) {
std::cout << "copy grad: " << i << " " << *(grad.data()+begin+i) <<
"copy grad2: " << *(grad2.data()+begin+i) << "\n";
}
for (int i = 0; i < 3; ++ i) {
std::cout << "d2sum before: " << i << " " << *(ada_d2sum+begin+i) << "\n";
}*/
// d2sum
blas
.
SCAL
(
update_numel
,
ada_decay_rate
[
0
],
ada_d2sum
+
begin
);
ADD
<
float
>
(
update_numel
,
ada_d2sum
+
begin
,
1
,
ada_d2sum
+
begin
);
/*
for (int i = 0; i < end-begin; ++ i) {
std::cout << "d2sum update: " << i << " " << *(ada_d2sum+begin+i) << "\n";
}
for (int i = 0; i < 3; ++ i) {
std::cout << "g2sum before: " << i << " " << *(ada_g2sum+begin+i) << "\n";
}*/
// g2sum
blas
.
SCAL
(
update_numel
,
ada_decay_rate
[
0
],
ada_g2sum
+
begin
);
blas
.
VSQUARE
(
update_numel
,
grad2
.
data
(),
grad2
.
data
());
blas
.
VADD
(
update_numel
,
ada_g2sum
+
begin
,
grad2
.
data
(),
ada_g2sum
+
begin
);
/*
for (int i = 0; i < end-begin; ++ i) {
std::cout << "g2sum update: " << i << " " << *(ada_g2sum+begin+i) << "\n";
}
for (int i = 0; i < 3; ++ i) {
std::cout << "mom before: " << i << " " << *(mom_velocity+begin+i) <<
"\n";
}*/
// mom
blas
.
SCAL
(
update_numel
,
mom_decay_rate
[
0
],
mom_velocity
+
begin
);
blas
.
SCAL
(
update_numel
,
1
-
mom_decay_rate
[
0
],
grad
.
data
());
blas
.
VADD
(
update_numel
,
mom_velocity
+
begin
,
grad
.
data
(),
mom_velocity
+
begin
);
/*
for (int i = 0; i < end-begin; ++ i) {
std::cout << "mom update: " << i << " " << *(mom_velocity+begin+i) <<
"\n";
}
for (int i = 0; i < 3; ++ i) {
std::cout << "scale before: " << i << " " << *(scale.data()+begin+i) <<
"\n";
}*/
// scale
float
*
scale_
=
scale
.
data
();
blas
.
VDIV
(
update_numel
,
ada_g2sum
+
begin
,
ada_d2sum
+
begin
,
scale_
);
ADD
<
float
>
(
update_numel
,
scale_
,
ada_epsilon
[
0
],
scale_
);
DIV
<
float
>
(
update_numel
,
1
+
ada_epsilon
[
0
],
scale_
,
scale_
);
SQRT
<
float
>
(
update_numel
,
scale_
,
scale_
);
/*
for (int i = 0; i < 3; ++ i) {
std::cout << "scale update: " << i << " " << *(scale.data()+begin+i) <<
"\n";
}*/
blas
.
SCAL
(
update_numel
,
learning_rate
[
0
],
scale_
);
// TODO(zhaocaibei123): check if there exists elementwise_multiply in blas
// TODO(zhaocaibei123): blas.VMUL
ELE_MUL
<
float
>
(
update_numel
,
scale_
,
mom_velocity
+
begin
,
scale_
);
/*
for (int i = 0; i < 3; ++ i) {
std::cout << "scale update2: " << i << " " << *(scale.data()+begin+i) <<
"\n";
}*/
blas
.
VSUB
(
update_numel
,
param
+
begin
,
scale_
,
param
+
begin
);
/*
for (int i = 0; i < end-begin; ++ i) {
std::cout << "param update " << i << " " << *(param+begin+i) << "\n";
}*/
}
float
*
learning_rate
;
float
lr_hardcode
;
float
*
param
;
float
*
mom_velocity
;
float
*
ada_g2sum
;
float
*
ada_d2sum
;
float
*
mom_decay_rate
;
float
*
ada_decay_rate
;
float
*
ada_epsilon
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/framework/distributed_strategy.proto
浏览文件 @
29c6bcbf
...
...
@@ -173,6 +173,68 @@ message TensorParallelConfig {
optional
int32
tensor_init_seed
=
2
[
default
=
-
1
];
}
enum
TableType
{
PS_SPARSE_TABLE
=
0
;
PS_DENSE_TABLE
=
1
;
}
message
TableParameter
{
optional
uint64
table_id
=
1
;
optional
string
table_class
=
2
;
optional
uint64
shard_num
=
3
;
optional
TableType
type
=
4
;
optional
TableAccessorParameter
accessor
=
5
;
}
message
TableAccessorParameter
{
optional
string
accessor_class
=
1
;
optional
SGDParameter
embed_sgd_param
=
2
;
optional
SGDParameter
embedx_sgd_param
=
3
;
optional
uint32
fea_dim
=
4
;
// for sparse table, this means field size of one
// value; for dense table, this means total value
// num
optional
uint32
embedx_dim
=
5
;
// embedx feature size
optional
uint32
embedx_threshold
=
6
;
// embedx feature create threshold
optional
CtrAccessorParameter
ctr_accessor_param
=
7
;
}
// TODO(guanqun): add NaiveSGD/Adam...
message
SGDParameter
{
optional
string
name
=
1
;
optional
SGDRuleParameter
adagrad
=
2
;
}
message
SGDRuleParameter
{
optional
double
learning_rate
=
1
;
optional
double
initial_g2sum
=
2
;
optional
double
initial_range
=
3
[
default
=
0
];
repeated
float
weight_bounds
=
4
;
}
message
CtrAccessorParameter
{
optional
float
nonclk_coeff
=
1
;
// to calculate show_click_score
optional
float
click_coeff
=
2
;
// to calculate show_click_score
optional
float
base_threshold
=
3
;
// show_click_score > base_threshold, this feature can be saved
optional
float
delta_threshold
=
4
;
// delta_score > delta_threshold, this feature can be saved
optional
float
delta_keep_days
=
5
;
// unseen_day < delta_keep_days, this feature can be saved
optional
float
show_click_decay_rate
=
6
;
// show/click will update to
// show/click *
// show_click_decay_rate after a day
optional
float
delete_threshold
=
7
;
// threshold to shrink a feasign
optional
float
delete_after_unseen_days
=
8
;
optional
int32
ssd_unseenday_threshold
=
9
;
}
message
FsClientParameter
{
optional
string
uri
=
1
;
optional
string
user
=
2
;
optional
string
passwd
=
3
;
optional
string
hadoop_bin
=
4
;
}
message
DistributedStrategy
{
// bool options
optional
Mode
mode
=
1
[
default
=
COLLECTIVE
];
...
...
@@ -210,6 +272,7 @@ message DistributedStrategy {
optional
bool
asp
=
33
[
default
=
false
];
optional
bool
fuse_grad_merge
=
34
[
default
=
false
];
optional
bool
semi_auto
=
35
[
default
=
false
];
optional
bool
adam_d2sum
=
36
[
default
=
true
];
optional
RecomputeConfig
recompute_configs
=
101
;
optional
AMPConfig
amp_configs
=
102
;
...
...
@@ -225,6 +288,9 @@ message DistributedStrategy {
optional
HybridConfig
hybrid_configs
=
112
;
optional
TensorParallelConfig
tensor_parallel_configs
=
113
;
optional
TrainerDescConfig
trainer_desc_configs
=
114
;
optional
TableParameter
downpour_table_param
=
115
;
optional
FsClientParameter
fs_client_param
=
116
;
optional
BuildStrategy
build_strategy
=
201
;
optional
ExecutionStrategy
execution_strategy
=
202
;
optional
GradientScaleConfig
gradient_scale_configs
=
203
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录