Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
b8cf64ab
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b8cf64ab
编写于
9月 10, 2019
作者:
X
xiexionghang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
for async push_gradient
上级
aaea8a39
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
112 addition
and
57 deletion
+112
-57
paddle/fluid/train/custom_trainer/feed/accessor/dense_input_accessor.cc
...rain/custom_trainer/feed/accessor/dense_input_accessor.cc
+13
-9
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc
...luid/train/custom_trainer/feed/accessor/epoch_accessor.cc
+16
-29
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h
...fluid/train/custom_trainer/feed/accessor/epoch_accessor.h
+1
-0
paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h
.../train/custom_trainer/feed/accessor/input_data_accessor.h
+7
-6
paddle/fluid/train/custom_trainer/feed/accessor/label_input_accessor.cc
...rain/custom_trainer/feed/accessor/label_input_accessor.cc
+4
-3
paddle/fluid/train/custom_trainer/feed/accessor/sparse_input_accessor.cc
...ain/custom_trainer/feed/accessor/sparse_input_accessor.cc
+6
-6
paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.cc
...e/fluid/train/custom_trainer/feed/common/pslib_warpper.cc
+0
-1
paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.cc
...ain/custom_trainer/feed/executor/multi_thread_executor.cc
+26
-3
paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.h
...rain/custom_trainer/feed/executor/multi_thread_executor.h
+13
-0
paddle/fluid/train/custom_trainer/feed/io/file_system.cc
paddle/fluid/train/custom_trainer/feed/io/file_system.cc
+24
-0
paddle/fluid/train/custom_trainer/feed/io/file_system.h
paddle/fluid/train/custom_trainer/feed/io/file_system.h
+2
-0
未找到文件。
paddle/fluid/train/custom_trainer/feed/accessor/dense_input_accessor.cc
浏览文件 @
b8cf64ab
...
...
@@ -52,7 +52,10 @@ int32_t DenseInputAccessor::create(::paddle::framework::Scope* scope) {
// rpc拉取数据,需保证单线程运行
int32_t
DenseInputAccessor
::
pull_dense
(
size_t
table_id
)
{
float
*
data_buffer
=
new
float
[
_total_dim
];
float
*
data_buffer
=
_data_buffer
;
if
(
data_buffer
==
NULL
)
{
data_buffer
=
new
float
[
_total_dim
];
}
size_t
data_buffer_idx
=
0
;
std
::
vector
<
paddle
::
ps
::
Region
>
regions
;
for
(
auto
&
variable
:
_x_variables
)
{
...
...
@@ -128,10 +131,11 @@ int32_t DenseInputAccessor::collect_persistables_name(std::vector<std::string>&
return
0
;
}
int32_t
DenseInputAccessor
::
backward
(
SampleInstance
*
samples
,
size_t
num
,
std
::
future
<
int32_t
>
DenseInputAccessor
::
backward
(
SampleInstance
*
samples
,
size_t
num
,
paddle
::
framework
::
Scope
*
scope
)
{
std
::
future
<
int32_t
>
ret
;
if
(
!
_need_gradient
)
{
return
0
;
return
ret
;
}
size_t
data_buffer_idx
=
0
;
std
::
vector
<
paddle
::
ps
::
Region
>
regions
;
...
...
@@ -142,8 +146,7 @@ int32_t DenseInputAccessor::backward(SampleInstance* samples, size_t num,
regions
.
emplace_back
(
grad_data
,
variable
.
dim
);
}
auto
*
ps_client
=
_trainer_context
->
pslib
->
ps_client
();
auto
push_status
=
ps_client
->
push_dense
(
regions
.
data
(),
regions
.
size
(),
_table_id
);
//push_status.get();
ps_client
->
push_dense
(
regions
.
data
(),
regions
.
size
(),
_table_id
);
if
(
!
FLAGS_feed_trainer_debug_dense_name
.
empty
())
{
for
(
auto
&
variable
:
_x_variables
)
{
if
(
variable
.
name
!=
FLAGS_feed_trainer_debug_dense_name
)
{
...
...
@@ -152,7 +155,8 @@ int32_t DenseInputAccessor::backward(SampleInstance* samples, size_t num,
VLOG
(
2
)
<<
"[Debug][PushDense]"
<<
ScopeHelper
::
to_string
(
scope
,
variable
.
gradient_name
);
}
}
return
0
;
// not wait dense push
return
ret
;
}
int32_t
EbdVariableInputAccessor
::
forward
(
SampleInstance
*
samples
,
size_t
num
,
...
...
@@ -171,10 +175,10 @@ int32_t EbdVariableInputAccessor::forward(SampleInstance* samples, size_t num,
}
return
0
;
}
int32_t
EbdVariableInputAccessor
::
backward
(
SampleInstance
*
samples
,
size_t
num
,
std
::
future
<
int32_t
>
EbdVariableInputAccessor
::
backward
(
SampleInstance
*
samples
,
size_t
num
,
paddle
::
framework
::
Scope
*
scope
)
{
return
0
;
std
::
future
<
int32_t
>
ret
;
return
ret
;
}
REGIST_CLASS
(
DataInputAccessor
,
DenseInputAccessor
);
...
...
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc
浏览文件 @
b8cf64ab
...
...
@@ -22,8 +22,10 @@ namespace feed {
}
std
::
string
done_text
=
fs
->
tail
(
_done_file_path
);
_done_status
=
paddle
::
string
::
split_string
(
done_text
,
std
::
string
(
"
\t
"
));
_
current
_epoch_id
=
get_status
<
uint64_t
>
(
EpochStatusFiled
::
EpochIdField
);
_
last_done
_epoch_id
=
get_status
<
uint64_t
>
(
EpochStatusFiled
::
EpochIdField
);
_last_checkpoint_epoch_id
=
get_status
<
uint64_t
>
(
EpochStatusFiled
::
CheckpointIdField
);
// 训练需要从上一个checkpoint对应的epoch开始
_current_epoch_id
=
_last_checkpoint_epoch_id
;
_last_checkpoint_path
=
get_status
<
std
::
string
>
(
EpochStatusFiled
::
CheckpointPathField
);
_inference_base_model_key
=
get_status
<
uint64_t
>
(
EpochStatusFiled
::
InferenceBaseKeyField
);
_inference_model_path
=
fs
->
path_join
(
_model_root_path
,
config
[
"inference_model_dir"
].
as
<
std
::
string
>
(
"xbox"
));
...
...
@@ -45,8 +47,14 @@ namespace feed {
set_status
(
EpochStatusFiled
::
TimestampField
,
now
.
tv_sec
);
set_status
(
EpochStatusFiled
::
CheckpointIdField
,
_last_checkpoint_epoch_id
);
set_status
(
EpochStatusFiled
::
CheckpointPathField
,
_last_checkpoint_path
);
set_status
(
EpochStatusFiled
::
DateField
,
format_timestamp
(
epoch_id
,
"%Y%m%d"
));
set_status
(
EpochStatusFiled
::
DateField
,
format_timestamp
(
epoch_id
,
"%Y%m%d
-%H%M
"
));
set_status
(
EpochStatusFiled
::
InferenceBaseKeyField
,
_inference_base_model_key
);
if
(
epoch_id
>
_last_done_epoch_id
)
{
// 保留末尾1000数据
auto
fs
=
_trainer_context
->
file_system
.
get
();
std
::
string
done_str
=
paddle
::
string
::
join_strings
(
_done_status
,
'\t'
);
fs
->
append_line
(
_done_file_path
,
done_str
,
1000
);
}
return
0
;
}
...
...
@@ -59,20 +67,18 @@ namespace feed {
}
std
::
string
done_str
;
std
::
string
donefile
;
auto
fs
=
_trainer_context
->
file_system
.
get
();
auto
model_path
=
model_save_path
(
epoch_id
,
save_way
);
std
::
string
inference_done_format
(
"{
\"
id
\"
:
\"
%lu
\"
,
\"
key
\"
:
\"
%lu
\"
,
\"
input
\"
:
\"
%s/000
\"
,
\"
record_count
\"
:
\"
1
\"
,
\"
file_format
\"
:
\"
pb
\"
,
\"
schema_version
\"
:
\"
2
\"
,
\"
partition_type
\"
:
\"
1
\"
,
\"
job_name
\"
:
\"
%s
\"
,
\"
job_id
\"
:
\"
%s
\"
,
\"
mpi_size
\"
:
\"
%d
\"
,
\"
monitor_data
\"
:
\"
%s
\"
}"
);
auto
id
=
time
(
NULL
);
switch
(
save_way
)
{
case
ModelSaveWay
::
ModelSaveTrainCheckpoint
:
donefile
=
_done_file_path
;
done_str
=
paddle
::
string
::
join_strings
(
_done_status
,
'\t'
);
break
;
case
ModelSaveWay
::
ModelSaveInferenceDelta
:
donefile
=
_inference_model_delta_done_path
;
done_str
=
string
::
format_string
(
inference_done_format
.
c_str
(),
id
,
_inference_base_model_key
,
model_path
.
c_str
(),
env
->
job_name
().
c_str
(),
env
->
job_id
().
c_str
(),
env
->
node_num
(
EnvironmentRole
::
PSERVER
),
_trainer_context
->
monitor_ssm
.
str
().
c_str
());
fs
->
append_line
(
donefile
,
done_str
,
1000
);
break
;
case
ModelSaveWay
::
ModelSaveInferenceBase
:
donefile
=
_inference_model_base_done_path
;
...
...
@@ -80,30 +86,9 @@ namespace feed {
done_str
=
string
::
format_string
(
inference_done_format
.
c_str
(),
id
,
id
,
model_path
.
c_str
(),
env
->
job_name
().
c_str
(),
env
->
job_id
().
c_str
(),
env
->
node_num
(
EnvironmentRole
::
PSERVER
),
_trainer_context
->
monitor_ssm
.
str
().
c_str
());
fs
->
append_line
(
donefile
,
done_str
,
1000
);
break
;
}
// 保留末尾1000数据
std
::
string
tail_done_info
;
auto
fs
=
_trainer_context
->
file_system
.
get
();
if
(
fs
->
exists
(
donefile
))
{
tail_done_info
=
paddle
::
string
::
trim_spaces
(
fs
->
tail
(
donefile
,
1000
));
}
if
(
tail_done_info
.
size
()
>
0
)
{
tail_done_info
=
tail_done_info
+
"
\n
"
+
done_str
;
}
else
{
tail_done_info
=
done_str
;
}
VLOG
(
2
)
<<
"Write donefile "
<<
donefile
<<
", str:"
<<
done_str
;
bool
write_success
=
false
;
while
(
true
)
{
fs
->
remove
(
donefile
);
auto
fp
=
fs
->
open_write
(
donefile
,
""
);
if
(
fwrite
(
tail_done_info
.
c_str
(),
tail_done_info
.
length
(),
1
,
&*
fp
)
==
1
)
{
break
;
}
sleep
(
10
);
}
VLOG
(
2
)
<<
"Write donefile "
<<
donefile
<<
"success"
;
return
0
;
}
...
...
@@ -155,7 +140,9 @@ namespace feed {
}
switch
(
save_way
)
{
case
ModelSaveWay
::
ModelSaveInferenceDelta
:
return
delta_id
(
epoch_id
)
%
6
==
0
;
// 重启训练后,中间的delta不重复dump
return
epoch_id
>
_last_done_epoch_id
&&
delta_id
(
epoch_id
)
%
6
==
0
;
case
ModelSaveWay
::
ModelSaveInferenceBase
:
return
is_last_epoch
(
epoch_id
);
case
ModelSaveWay
::
ModelSaveTrainCheckpoint
:
...
...
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h
浏览文件 @
b8cf64ab
...
...
@@ -73,6 +73,7 @@ protected:
std
::
string
_inference_model_delta_done_path
;
uint64_t
_current_epoch_id
=
0
;
std
::
string
_last_checkpoint_path
;
uint64_t
_last_done_epoch_id
=
0
;
uint64_t
_last_checkpoint_epoch_id
=
0
;
std
::
vector
<
std
::
string
>
_done_status
;
// 当前完成状态,统一存成string
uint64_t
_inference_base_model_key
=
0
;
// 预估模型的base-key
...
...
paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h
浏览文件 @
b8cf64ab
...
...
@@ -35,8 +35,9 @@ public:
virtual
int32_t
forward
(
SampleInstance
*
samples
,
size_t
num
,
::
paddle
::
framework
::
Scope
*
scope
)
=
0
;
// 后向,一般用于更新梯度,在训练网络执行后调用
virtual
int32_t
backward
(
SampleInstance
*
samples
,
size_t
num
,
// 后向,一般用于更新梯度,在训练网络执行后调用, 由于backward一般是异步,这里返回future,
// TODO 前向接口也改为future返回形式,接口一致性好些
virtual
std
::
future
<
int32_t
>
backward
(
SampleInstance
*
samples
,
size_t
num
,
::
paddle
::
framework
::
Scope
*
scope
)
=
0
;
// 收集持久化变量的名称, 并将值拷贝到Scope
...
...
@@ -67,7 +68,7 @@ public:
virtual
int32_t
forward
(
SampleInstance
*
samples
,
size_t
num
,
::
paddle
::
framework
::
Scope
*
scope
);
virtual
int32_t
backward
(
SampleInstance
*
samples
,
size_t
num
,
virtual
std
::
future
<
int32_t
>
backward
(
SampleInstance
*
samples
,
size_t
num
,
::
paddle
::
framework
::
Scope
*
scope
);
protected:
size_t
_label_total_dim
=
0
;
...
...
@@ -108,7 +109,7 @@ public:
virtual
void
post_process_input
(
float
*
var_data
,
SparseInputVariable
&
,
SampleInstance
*
,
size_t
num
)
=
0
;
// backward过程的梯度push
virtual
int32_t
backward
(
SampleInstance
*
samples
,
size_t
num
,
virtual
std
::
future
<
int32_t
>
backward
(
SampleInstance
*
samples
,
size_t
num
,
paddle
::
framework
::
Scope
*
scope
);
// SparseGradValue会被依次调用,用于整理push的梯度
virtual
void
fill_gradient
(
float
*
push_value
,
const
float
*
gradient_raw
,
...
...
@@ -148,7 +149,7 @@ public:
virtual
int32_t
forward
(
SampleInstance
*
samples
,
size_t
num
,
paddle
::
framework
::
Scope
*
scope
);
virtual
int32_t
backward
(
SampleInstance
*
samples
,
size_t
num
,
virtual
std
::
future
<
int32_t
>
backward
(
SampleInstance
*
samples
,
size_t
num
,
paddle
::
framework
::
Scope
*
scope
);
...
...
@@ -175,7 +176,7 @@ public:
virtual
int32_t
forward
(
SampleInstance
*
samples
,
size_t
num
,
paddle
::
framework
::
Scope
*
scope
);
virtual
int32_t
backward
(
SampleInstance
*
samples
,
size_t
num
,
virtual
std
::
future
<
int32_t
>
backward
(
SampleInstance
*
samples
,
size_t
num
,
paddle
::
framework
::
Scope
*
scope
);
};
...
...
paddle/fluid/train/custom_trainer/feed/accessor/label_input_accessor.cc
浏览文件 @
b8cf64ab
...
...
@@ -45,10 +45,11 @@ int32_t LabelInputAccessor::forward(SampleInstance* samples, size_t num,
return
0
;
}
int32_t
LabelInputAccessor
::
backward
(
SampleInstance
*
samples
,
size_t
num
,
std
::
future
<
int32_t
>
LabelInputAccessor
::
backward
(
SampleInstance
*
samples
,
size_t
num
,
paddle
::
framework
::
Scope
*
scope
)
{
std
::
future
<
int32_t
>
ret
;
if
(
num
<
1
)
{
return
0
;
return
ret
;
}
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
auto
&
sample
=
samples
[
i
];
...
...
@@ -69,7 +70,7 @@ int32_t LabelInputAccessor::backward(SampleInstance* samples, size_t num,
VLOG(2) << "[Debug][Lable]" << ScopeHelper::to_string(scope, label.label_name) << ScopeHelper::to_string(scope, label.output_name);
}
*/
return
0
;
return
ret
;
}
REGIST_CLASS
(
DataInputAccessor
,
LabelInputAccessor
);
...
...
paddle/fluid/train/custom_trainer/feed/accessor/sparse_input_accessor.cc
浏览文件 @
b8cf64ab
...
...
@@ -136,8 +136,9 @@ int32_t BaseSparseInputAccessor::forward(SampleInstance* samples,
}
// 更新spare数据
int32_t
BaseSparseInputAccessor
::
backward
(
SampleInstance
*
samples
,
std
::
future
<
int32_t
>
BaseSparseInputAccessor
::
backward
(
SampleInstance
*
samples
,
size_t
num
,
paddle
::
framework
::
Scope
*
scope
)
{
std
::
future
<
int32_t
>
ret
;
int64_t
runtime_data_for_scope
=
*
ScopeHelper
::
get_value
<
int64_t
>
(
scope
,
_trainer_context
->
cpu_place
,
"sparse_runtime_data"
);
auto
*
runtime_data_ptr
=
(
std
::
vector
<
SparseVarRuntimeData
>*
)
runtime_data_for_scope
;
...
...
@@ -146,7 +147,7 @@ int32_t BaseSparseInputAccessor::backward(SampleInstance* samples,
delete
runtime_data_ptr
;
});
if
(
!
_need_gradient
)
{
return
0
;
return
ret
;
}
auto
*
ps_client
=
_trainer_context
->
pslib
->
ps_client
();
auto
*
value_accessor
=
ps_client
->
table_accessor
(
_table_id
);
...
...
@@ -204,11 +205,10 @@ int32_t BaseSparseInputAccessor::backward(SampleInstance* samples,
VLOG
(
2
)
<<
"[DEBUG][sparse_slot_push]"
<<
ssm
.
str
();
}
}
auto
push_status
=
ps_client
->
push_sparse
(
_table_id
,
keys
.
data
(),
(
const
float
**
)
push_values
,
key_idx
);
//auto ret = push_status.get();
ret
=
ps_client
->
push_sparse
(
_table_id
,
keys
.
data
(),
(
const
float
**
)
push_values
,
key_idx
);
delete
[]
push_values
;
return
0
;
return
ret
;
}
class
AbacusSparseJoinAccessor
:
public
BaseSparseInputAccessor
{
...
...
paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.cc
浏览文件 @
b8cf64ab
...
...
@@ -70,7 +70,6 @@ paddle::PSParameter* PSlib::get_param() {
void
PSlib
::
init_gflag
()
{
int
cnt
=
4
;
char
**
params_ptr
=
new
char
*
[
cnt
];
std
::
cout
<<
"alloc_ptr"
<<
params_ptr
<<
std
::
flush
;
char
p0
[]
=
"exe default"
;
char
p1
[]
=
"-max_body_size=314217728"
;
char
p2
[]
=
"-bthread_concurrency=40"
;
...
...
paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.cc
浏览文件 @
b8cf64ab
...
...
@@ -9,6 +9,10 @@ namespace paddle {
namespace
custom_trainer
{
namespace
feed
{
std
::
once_flag
MultiThreadExecutor
::
_async_delete_flag
;
std
::
shared_ptr
<
std
::
thread
>
MultiThreadExecutor
::
_async_delete_thread
;
paddle
::
framework
::
Channel
<
ScopeExecutorContext
*>
MultiThreadExecutor
::
_delete_channel
;
int
MultiThreadExecutor
::
initialize
(
YAML
::
Node
exe_config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
{
int
ret
=
0
;
...
...
@@ -85,6 +89,23 @@ int MultiThreadExecutor::initialize(YAML::Node exe_config,
CHECK
(
monitor_ptr
->
initialize
(
monitor_config
,
context_ptr
)
==
0
)
<<
"Monitor init Failed, class:"
<<
monitor_class
;
}
// 异步删除池
std
::
call_once
(
_async_delete_flag
,
[
this
](){
_delete_channel
=
paddle
::
framework
::
MakeChannel
<
ScopeExecutorContext
*>
();
_delete_channel
->
SetBlockSize
(
32
);
_async_delete_thread
.
reset
(
new
std
::
thread
([
this
]{
std
::
vector
<
ScopeExecutorContext
*>
ctxs
;
while
(
true
)
{
while
(
_delete_channel
->
Read
(
ctxs
))
{
for
(
auto
*
ctx
:
ctxs
)
{
delete
ctx
;
}
}
usleep
(
200000
);
// 200ms
}
}));
});
return
ret
;
}
...
...
@@ -187,9 +208,10 @@ paddle::framework::Channel<DataItem> MultiThreadExecutor::run(
auto
*
samples
=
scope_ctx
->
samples
();
auto
sample_num
=
scope_ctx
->
sample_num
();
out_items
[
out_idx
]
=
0
;
scope_ctx
->
wait_status
.
resize
(
_input_accessors
.
size
());
for
(
size_t
i
=
0
;
i
<
_input_accessors
.
size
();
++
i
)
{
out_items
[
out_idx
]
=
_input_accessors
[
i
]
->
backward
(
samples
,
sample_num
,
scope
);
scope_ctx
->
wait_status
[
i
]
=
_input_accessors
[
i
]
->
backward
(
samples
,
sample_num
,
scope
);
}
timer
.
Pause
();
scope_ctx
->
push_gradient_cost_ms
=
timer
.
ElapsedMS
();
...
...
@@ -203,7 +225,8 @@ paddle::framework::Channel<DataItem> MultiThreadExecutor::run(
VLOG
(
2
)
<<
"[Debug][Layer]"
<<
ScopeHelper
::
to_string
(
scope
,
layer_name
);
}
}
delete
scope_ctx
;
// 所有pipe完成后,再回收sample
// 所有pipe完成后,再异步回收sample
_delete_channel
->
Put
(
scope_ctx
);
}
return
0
;
});
...
...
paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.h
浏览文件 @
b8cf64ab
#pragma once
#include <thread>
#include <functional>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
...
...
@@ -18,6 +19,12 @@ public:
_sample_num
=
sample_num
;
}
virtual
~
ScopeExecutorContext
()
{
for
(
auto
&
status
:
wait_status
)
{
if
(
!
status
.
valid
())
{
continue
;
}
status
.
wait
();
}
delete
[]
_samples
;
}
inline
SampleInstance
*
samples
()
{
...
...
@@ -29,6 +36,7 @@ public:
size_t
executor_cost_ms
=
0
;
size_t
prepare_cost_ms
=
0
;
size_t
push_gradient_cost_ms
=
0
;
std
::
vector
<
std
::
future
<
int32_t
>>
wait_status
;
private:
size_t
_sample_num
=
0
;
SampleInstance
*
_samples
=
NULL
;
...
...
@@ -83,6 +91,11 @@ protected:
std
::
map
<
uint32_t
,
std
::
vector
<
DataInputAccessor
*>>
_table_to_accessors
;
std
::
shared_ptr
<
paddle
::
ps
::
ObjectPool
<::
paddle
::
framework
::
Scope
>>
_scope_obj_pool
;
std
::
vector
<
std
::
string
>
_persistables
;
// 异步删除
static
std
::
once_flag
_async_delete_flag
;
static
std
::
shared_ptr
<
std
::
thread
>
_async_delete_thread
;
static
paddle
::
framework
::
Channel
<
ScopeExecutorContext
*>
_delete_channel
;
};
}
// namespace feed
...
...
paddle/fluid/train/custom_trainer/feed/io/file_system.cc
浏览文件 @
b8cf64ab
...
...
@@ -23,6 +23,30 @@ std::pair<std::string, std::string> FileSystem::path_split(const std::string& pa
return
{
path
.
substr
(
0
,
pos
),
path
.
substr
(
pos
+
1
)};
}
int
FileSystem
::
append_line
(
const
std
::
string
&
path
,
const
std
::
string
&
line
,
size_t
reserve_line_num
)
{
std
::
string
tail_data
;
if
(
exists
(
path
))
{
tail_data
=
paddle
::
string
::
trim_spaces
(
tail
(
path
,
reserve_line_num
));
}
if
(
tail_data
.
size
()
>
0
)
{
tail_data
=
tail_data
+
"
\n
"
+
line
;
}
else
{
tail_data
=
line
;
}
VLOG
(
2
)
<<
"Append to file:"
<<
path
<<
", line str:"
<<
line
;
while
(
true
)
{
remove
(
path
);
auto
fp
=
open_write
(
path
,
""
);
if
(
fwrite
(
tail_data
.
c_str
(),
tail_data
.
length
(),
1
,
&*
fp
)
==
1
)
{
break
;
}
sleep
(
10
);
VLOG
(
0
)
<<
"Retry Append to file:"
<<
path
<<
", line str:"
<<
line
;
}
return
0
;
}
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/io/file_system.h
浏览文件 @
b8cf64ab
...
...
@@ -18,6 +18,8 @@ public:
virtual
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
)
=
0
;
virtual
std
::
shared_ptr
<
FILE
>
open_read
(
const
std
::
string
&
path
,
const
std
::
string
&
converter
)
=
0
;
virtual
std
::
shared_ptr
<
FILE
>
open_write
(
const
std
::
string
&
path
,
const
std
::
string
&
converter
)
=
0
;
// only support text-file
virtual
int
append_line
(
const
std
::
string
&
path
,
const
std
::
string
&
line
,
size_t
reserve_line_num
);
virtual
int64_t
file_size
(
const
std
::
string
&
path
)
=
0
;
virtual
void
remove
(
const
std
::
string
&
path
)
=
0
;
virtual
std
::
vector
<
std
::
string
>
list
(
const
std
::
string
&
path
)
=
0
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录