Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleRec
提交
b8cf64ab
P
PaddleRec
项目概览
PaddlePaddle
/
PaddleRec
通知
68
Star
12
Fork
5
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
27
列表
看板
标记
里程碑
合并请求
10
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
27
Issue
27
列表
看板
标记
里程碑
合并请求
10
合并请求
10
Pages
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b8cf64ab
编写于
9月 10, 2019
作者:
X
xiexionghang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
for async push_gradient
上级
aaea8a39
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
112 addition
and
57 deletion
+112
-57
paddle/fluid/train/custom_trainer/feed/accessor/dense_input_accessor.cc
...rain/custom_trainer/feed/accessor/dense_input_accessor.cc
+13
-9
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc
...luid/train/custom_trainer/feed/accessor/epoch_accessor.cc
+16
-29
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h
...fluid/train/custom_trainer/feed/accessor/epoch_accessor.h
+1
-0
paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h
.../train/custom_trainer/feed/accessor/input_data_accessor.h
+7
-6
paddle/fluid/train/custom_trainer/feed/accessor/label_input_accessor.cc
...rain/custom_trainer/feed/accessor/label_input_accessor.cc
+4
-3
paddle/fluid/train/custom_trainer/feed/accessor/sparse_input_accessor.cc
...ain/custom_trainer/feed/accessor/sparse_input_accessor.cc
+6
-6
paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.cc
...e/fluid/train/custom_trainer/feed/common/pslib_warpper.cc
+0
-1
paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.cc
...ain/custom_trainer/feed/executor/multi_thread_executor.cc
+26
-3
paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.h
...rain/custom_trainer/feed/executor/multi_thread_executor.h
+13
-0
paddle/fluid/train/custom_trainer/feed/io/file_system.cc
paddle/fluid/train/custom_trainer/feed/io/file_system.cc
+24
-0
paddle/fluid/train/custom_trainer/feed/io/file_system.h
paddle/fluid/train/custom_trainer/feed/io/file_system.h
+2
-0
未找到文件。
paddle/fluid/train/custom_trainer/feed/accessor/dense_input_accessor.cc
浏览文件 @
b8cf64ab
...
...
@@ -52,7 +52,10 @@ int32_t DenseInputAccessor::create(::paddle::framework::Scope* scope) {
// rpc拉取数据,需保证单线程运行
int32_t
DenseInputAccessor
::
pull_dense
(
size_t
table_id
)
{
float
*
data_buffer
=
new
float
[
_total_dim
];
float
*
data_buffer
=
_data_buffer
;
if
(
data_buffer
==
NULL
)
{
data_buffer
=
new
float
[
_total_dim
];
}
size_t
data_buffer_idx
=
0
;
std
::
vector
<
paddle
::
ps
::
Region
>
regions
;
for
(
auto
&
variable
:
_x_variables
)
{
...
...
@@ -128,10 +131,11 @@ int32_t DenseInputAccessor::collect_persistables_name(std::vector<std::string>&
return
0
;
}
int32_t
DenseInputAccessor
::
backward
(
SampleInstance
*
samples
,
size_t
num
,
std
::
future
<
int32_t
>
DenseInputAccessor
::
backward
(
SampleInstance
*
samples
,
size_t
num
,
paddle
::
framework
::
Scope
*
scope
)
{
std
::
future
<
int32_t
>
ret
;
if
(
!
_need_gradient
)
{
return
0
;
return
ret
;
}
size_t
data_buffer_idx
=
0
;
std
::
vector
<
paddle
::
ps
::
Region
>
regions
;
...
...
@@ -142,8 +146,7 @@ int32_t DenseInputAccessor::backward(SampleInstance* samples, size_t num,
regions
.
emplace_back
(
grad_data
,
variable
.
dim
);
}
auto
*
ps_client
=
_trainer_context
->
pslib
->
ps_client
();
auto
push_status
=
ps_client
->
push_dense
(
regions
.
data
(),
regions
.
size
(),
_table_id
);
//push_status.get();
ps_client
->
push_dense
(
regions
.
data
(),
regions
.
size
(),
_table_id
);
if
(
!
FLAGS_feed_trainer_debug_dense_name
.
empty
())
{
for
(
auto
&
variable
:
_x_variables
)
{
if
(
variable
.
name
!=
FLAGS_feed_trainer_debug_dense_name
)
{
...
...
@@ -152,7 +155,8 @@ int32_t DenseInputAccessor::backward(SampleInstance* samples, size_t num,
VLOG
(
2
)
<<
"[Debug][PushDense]"
<<
ScopeHelper
::
to_string
(
scope
,
variable
.
gradient_name
);
}
}
return
0
;
// not wait dense push
return
ret
;
}
int32_t
EbdVariableInputAccessor
::
forward
(
SampleInstance
*
samples
,
size_t
num
,
...
...
@@ -171,10 +175,10 @@ int32_t EbdVariableInputAccessor::forward(SampleInstance* samples, size_t num,
}
return
0
;
}
int32_t
EbdVariableInputAccessor
::
backward
(
SampleInstance
*
samples
,
size_t
num
,
std
::
future
<
int32_t
>
EbdVariableInputAccessor
::
backward
(
SampleInstance
*
samples
,
size_t
num
,
paddle
::
framework
::
Scope
*
scope
)
{
return
0
;
std
::
future
<
int32_t
>
ret
;
return
ret
;
}
REGIST_CLASS
(
DataInputAccessor
,
DenseInputAccessor
);
...
...
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc
浏览文件 @
b8cf64ab
...
...
@@ -22,8 +22,10 @@ namespace feed {
}
std
::
string
done_text
=
fs
->
tail
(
_done_file_path
);
_done_status
=
paddle
::
string
::
split_string
(
done_text
,
std
::
string
(
"
\t
"
));
_
current
_epoch_id
=
get_status
<
uint64_t
>
(
EpochStatusFiled
::
EpochIdField
);
_
last_done
_epoch_id
=
get_status
<
uint64_t
>
(
EpochStatusFiled
::
EpochIdField
);
_last_checkpoint_epoch_id
=
get_status
<
uint64_t
>
(
EpochStatusFiled
::
CheckpointIdField
);
// 训练需要从上一个checkpoint对应的epoch开始
_current_epoch_id
=
_last_checkpoint_epoch_id
;
_last_checkpoint_path
=
get_status
<
std
::
string
>
(
EpochStatusFiled
::
CheckpointPathField
);
_inference_base_model_key
=
get_status
<
uint64_t
>
(
EpochStatusFiled
::
InferenceBaseKeyField
);
_inference_model_path
=
fs
->
path_join
(
_model_root_path
,
config
[
"inference_model_dir"
].
as
<
std
::
string
>
(
"xbox"
));
...
...
@@ -45,8 +47,14 @@ namespace feed {
set_status
(
EpochStatusFiled
::
TimestampField
,
now
.
tv_sec
);
set_status
(
EpochStatusFiled
::
CheckpointIdField
,
_last_checkpoint_epoch_id
);
set_status
(
EpochStatusFiled
::
CheckpointPathField
,
_last_checkpoint_path
);
set_status
(
EpochStatusFiled
::
DateField
,
format_timestamp
(
epoch_id
,
"%Y%m%d"
));
set_status
(
EpochStatusFiled
::
DateField
,
format_timestamp
(
epoch_id
,
"%Y%m%d
-%H%M
"
));
set_status
(
EpochStatusFiled
::
InferenceBaseKeyField
,
_inference_base_model_key
);
if
(
epoch_id
>
_last_done_epoch_id
)
{
// 保留末尾1000数据
auto
fs
=
_trainer_context
->
file_system
.
get
();
std
::
string
done_str
=
paddle
::
string
::
join_strings
(
_done_status
,
'\t'
);
fs
->
append_line
(
_done_file_path
,
done_str
,
1000
);
}
return
0
;
}
...
...
@@ -59,20 +67,18 @@ namespace feed {
}
std
::
string
done_str
;
std
::
string
donefile
;
auto
fs
=
_trainer_context
->
file_system
.
get
();
auto
model_path
=
model_save_path
(
epoch_id
,
save_way
);
std
::
string
inference_done_format
(
"{
\"
id
\"
:
\"
%lu
\"
,
\"
key
\"
:
\"
%lu
\"
,
\"
input
\"
:
\"
%s/000
\"
,
\"
record_count
\"
:
\"
1
\"
,
\"
file_format
\"
:
\"
pb
\"
,
\"
schema_version
\"
:
\"
2
\"
,
\"
partition_type
\"
:
\"
1
\"
,
\"
job_name
\"
:
\"
%s
\"
,
\"
job_id
\"
:
\"
%s
\"
,
\"
mpi_size
\"
:
\"
%d
\"
,
\"
monitor_data
\"
:
\"
%s
\"
}"
);
auto
id
=
time
(
NULL
);
switch
(
save_way
)
{
case
ModelSaveWay
::
ModelSaveTrainCheckpoint
:
donefile
=
_done_file_path
;
done_str
=
paddle
::
string
::
join_strings
(
_done_status
,
'\t'
);
break
;
case
ModelSaveWay
::
ModelSaveInferenceDelta
:
donefile
=
_inference_model_delta_done_path
;
done_str
=
string
::
format_string
(
inference_done_format
.
c_str
(),
id
,
_inference_base_model_key
,
model_path
.
c_str
(),
env
->
job_name
().
c_str
(),
env
->
job_id
().
c_str
(),
env
->
node_num
(
EnvironmentRole
::
PSERVER
),
_trainer_context
->
monitor_ssm
.
str
().
c_str
());
fs
->
append_line
(
donefile
,
done_str
,
1000
);
break
;
case
ModelSaveWay
::
ModelSaveInferenceBase
:
donefile
=
_inference_model_base_done_path
;
...
...
@@ -80,30 +86,9 @@ namespace feed {
done_str
=
string
::
format_string
(
inference_done_format
.
c_str
(),
id
,
id
,
model_path
.
c_str
(),
env
->
job_name
().
c_str
(),
env
->
job_id
().
c_str
(),
env
->
node_num
(
EnvironmentRole
::
PSERVER
),
_trainer_context
->
monitor_ssm
.
str
().
c_str
());
fs
->
append_line
(
donefile
,
done_str
,
1000
);
break
;
}
// 保留末尾1000数据
std
::
string
tail_done_info
;
auto
fs
=
_trainer_context
->
file_system
.
get
();
if
(
fs
->
exists
(
donefile
))
{
tail_done_info
=
paddle
::
string
::
trim_spaces
(
fs
->
tail
(
donefile
,
1000
));
}
if
(
tail_done_info
.
size
()
>
0
)
{
tail_done_info
=
tail_done_info
+
"
\n
"
+
done_str
;
}
else
{
tail_done_info
=
done_str
;
}
VLOG
(
2
)
<<
"Write donefile "
<<
donefile
<<
", str:"
<<
done_str
;
bool
write_success
=
false
;
while
(
true
)
{
fs
->
remove
(
donefile
);
auto
fp
=
fs
->
open_write
(
donefile
,
""
);
if
(
fwrite
(
tail_done_info
.
c_str
(),
tail_done_info
.
length
(),
1
,
&*
fp
)
==
1
)
{
break
;
}
sleep
(
10
);
}
VLOG
(
2
)
<<
"Write donefile "
<<
donefile
<<
"success"
;
return
0
;
}
...
...
@@ -155,7 +140,9 @@ namespace feed {
}
switch
(
save_way
)
{
case
ModelSaveWay
::
ModelSaveInferenceDelta
:
return
delta_id
(
epoch_id
)
%
6
==
0
;
// 重启训练后,中间的delta不重复dump
return
epoch_id
>
_last_done_epoch_id
&&
delta_id
(
epoch_id
)
%
6
==
0
;
case
ModelSaveWay
::
ModelSaveInferenceBase
:
return
is_last_epoch
(
epoch_id
);
case
ModelSaveWay
::
ModelSaveTrainCheckpoint
:
...
...
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h
浏览文件 @
b8cf64ab
...
...
@@ -73,6 +73,7 @@ protected:
std
::
string
_inference_model_delta_done_path
;
uint64_t
_current_epoch_id
=
0
;
std
::
string
_last_checkpoint_path
;
uint64_t
_last_done_epoch_id
=
0
;
uint64_t
_last_checkpoint_epoch_id
=
0
;
std
::
vector
<
std
::
string
>
_done_status
;
// 当前完成状态,统一存成string
uint64_t
_inference_base_model_key
=
0
;
// 预估模型的base-key
...
...
paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h
浏览文件 @
b8cf64ab
...
...
@@ -35,8 +35,9 @@ public:
virtual
int32_t
forward
(
SampleInstance
*
samples
,
size_t
num
,
::
paddle
::
framework
::
Scope
*
scope
)
=
0
;
// 后向,一般用于更新梯度,在训练网络执行后调用
virtual
int32_t
backward
(
SampleInstance
*
samples
,
size_t
num
,
// 后向,一般用于更新梯度,在训练网络执行后调用, 由于backward一般是异步,这里返回future,
// TODO 前向接口也改为future返回形式,接口一致性好些
virtual
std
::
future
<
int32_t
>
backward
(
SampleInstance
*
samples
,
size_t
num
,
::
paddle
::
framework
::
Scope
*
scope
)
=
0
;
// 收集持久化变量的名称, 并将值拷贝到Scope
...
...
@@ -67,7 +68,7 @@ public:
virtual
int32_t
forward
(
SampleInstance
*
samples
,
size_t
num
,
::
paddle
::
framework
::
Scope
*
scope
);
virtual
int32_t
backward
(
SampleInstance
*
samples
,
size_t
num
,
virtual
std
::
future
<
int32_t
>
backward
(
SampleInstance
*
samples
,
size_t
num
,
::
paddle
::
framework
::
Scope
*
scope
);
protected:
size_t
_label_total_dim
=
0
;
...
...
@@ -108,7 +109,7 @@ public:
virtual
void
post_process_input
(
float
*
var_data
,
SparseInputVariable
&
,
SampleInstance
*
,
size_t
num
)
=
0
;
// backward过程的梯度push
virtual
int32_t
backward
(
SampleInstance
*
samples
,
size_t
num
,
virtual
std
::
future
<
int32_t
>
backward
(
SampleInstance
*
samples
,
size_t
num
,
paddle
::
framework
::
Scope
*
scope
);
// SparseGradValue会被依次调用,用于整理push的梯度
virtual
void
fill_gradient
(
float
*
push_value
,
const
float
*
gradient_raw
,
...
...
@@ -148,7 +149,7 @@ public:
virtual
int32_t
forward
(
SampleInstance
*
samples
,
size_t
num
,
paddle
::
framework
::
Scope
*
scope
);
virtual
int32_t
backward
(
SampleInstance
*
samples
,
size_t
num
,
virtual
std
::
future
<
int32_t
>
backward
(
SampleInstance
*
samples
,
size_t
num
,
paddle
::
framework
::
Scope
*
scope
);
...
...
@@ -175,7 +176,7 @@ public:
virtual
int32_t
forward
(
SampleInstance
*
samples
,
size_t
num
,
paddle
::
framework
::
Scope
*
scope
);
virtual
int32_t
backward
(
SampleInstance
*
samples
,
size_t
num
,
virtual
std
::
future
<
int32_t
>
backward
(
SampleInstance
*
samples
,
size_t
num
,
paddle
::
framework
::
Scope
*
scope
);
};
...
...
paddle/fluid/train/custom_trainer/feed/accessor/label_input_accessor.cc
浏览文件 @
b8cf64ab
...
...
@@ -45,10 +45,11 @@ int32_t LabelInputAccessor::forward(SampleInstance* samples, size_t num,
return
0
;
}
int32_t
LabelInputAccessor
::
backward
(
SampleInstance
*
samples
,
size_t
num
,
std
::
future
<
int32_t
>
LabelInputAccessor
::
backward
(
SampleInstance
*
samples
,
size_t
num
,
paddle
::
framework
::
Scope
*
scope
)
{
std
::
future
<
int32_t
>
ret
;
if
(
num
<
1
)
{
return
0
;
return
ret
;
}
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
auto
&
sample
=
samples
[
i
];
...
...
@@ -69,7 +70,7 @@ int32_t LabelInputAccessor::backward(SampleInstance* samples, size_t num,
VLOG(2) << "[Debug][Lable]" << ScopeHelper::to_string(scope, label.label_name) << ScopeHelper::to_string(scope, label.output_name);
}
*/
return
0
;
return
ret
;
}
REGIST_CLASS
(
DataInputAccessor
,
LabelInputAccessor
);
...
...
paddle/fluid/train/custom_trainer/feed/accessor/sparse_input_accessor.cc
浏览文件 @
b8cf64ab
...
...
@@ -136,8 +136,9 @@ int32_t BaseSparseInputAccessor::forward(SampleInstance* samples,
}
// 更新spare数据
int32_t
BaseSparseInputAccessor
::
backward
(
SampleInstance
*
samples
,
std
::
future
<
int32_t
>
BaseSparseInputAccessor
::
backward
(
SampleInstance
*
samples
,
size_t
num
,
paddle
::
framework
::
Scope
*
scope
)
{
std
::
future
<
int32_t
>
ret
;
int64_t
runtime_data_for_scope
=
*
ScopeHelper
::
get_value
<
int64_t
>
(
scope
,
_trainer_context
->
cpu_place
,
"sparse_runtime_data"
);
auto
*
runtime_data_ptr
=
(
std
::
vector
<
SparseVarRuntimeData
>*
)
runtime_data_for_scope
;
...
...
@@ -146,7 +147,7 @@ int32_t BaseSparseInputAccessor::backward(SampleInstance* samples,
delete
runtime_data_ptr
;
});
if
(
!
_need_gradient
)
{
return
0
;
return
ret
;
}
auto
*
ps_client
=
_trainer_context
->
pslib
->
ps_client
();
auto
*
value_accessor
=
ps_client
->
table_accessor
(
_table_id
);
...
...
@@ -204,11 +205,10 @@ int32_t BaseSparseInputAccessor::backward(SampleInstance* samples,
VLOG
(
2
)
<<
"[DEBUG][sparse_slot_push]"
<<
ssm
.
str
();
}
}
auto
push_status
=
ps_client
->
push_sparse
(
_table_id
,
keys
.
data
(),
(
const
float
**
)
push_values
,
key_idx
);
//auto ret = push_status.get();
ret
=
ps_client
->
push_sparse
(
_table_id
,
keys
.
data
(),
(
const
float
**
)
push_values
,
key_idx
);
delete
[]
push_values
;
return
0
;
return
ret
;
}
class
AbacusSparseJoinAccessor
:
public
BaseSparseInputAccessor
{
...
...
paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.cc
浏览文件 @
b8cf64ab
...
...
@@ -70,7 +70,6 @@ paddle::PSParameter* PSlib::get_param() {
void
PSlib
::
init_gflag
()
{
int
cnt
=
4
;
char
**
params_ptr
=
new
char
*
[
cnt
];
std
::
cout
<<
"alloc_ptr"
<<
params_ptr
<<
std
::
flush
;
char
p0
[]
=
"exe default"
;
char
p1
[]
=
"-max_body_size=314217728"
;
char
p2
[]
=
"-bthread_concurrency=40"
;
...
...
paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.cc
浏览文件 @
b8cf64ab
...
...
@@ -9,6 +9,10 @@ namespace paddle {
namespace
custom_trainer
{
namespace
feed
{
std
::
once_flag
MultiThreadExecutor
::
_async_delete_flag
;
std
::
shared_ptr
<
std
::
thread
>
MultiThreadExecutor
::
_async_delete_thread
;
paddle
::
framework
::
Channel
<
ScopeExecutorContext
*>
MultiThreadExecutor
::
_delete_channel
;
int
MultiThreadExecutor
::
initialize
(
YAML
::
Node
exe_config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
{
int
ret
=
0
;
...
...
@@ -85,6 +89,23 @@ int MultiThreadExecutor::initialize(YAML::Node exe_config,
CHECK
(
monitor_ptr
->
initialize
(
monitor_config
,
context_ptr
)
==
0
)
<<
"Monitor init Failed, class:"
<<
monitor_class
;
}
// 异步删除池
std
::
call_once
(
_async_delete_flag
,
[
this
](){
_delete_channel
=
paddle
::
framework
::
MakeChannel
<
ScopeExecutorContext
*>
();
_delete_channel
->
SetBlockSize
(
32
);
_async_delete_thread
.
reset
(
new
std
::
thread
([
this
]{
std
::
vector
<
ScopeExecutorContext
*>
ctxs
;
while
(
true
)
{
while
(
_delete_channel
->
Read
(
ctxs
))
{
for
(
auto
*
ctx
:
ctxs
)
{
delete
ctx
;
}
}
usleep
(
200000
);
// 200ms
}
}));
});
return
ret
;
}
...
...
@@ -187,9 +208,10 @@ paddle::framework::Channel<DataItem> MultiThreadExecutor::run(
auto
*
samples
=
scope_ctx
->
samples
();
auto
sample_num
=
scope_ctx
->
sample_num
();
out_items
[
out_idx
]
=
0
;
scope_ctx
->
wait_status
.
resize
(
_input_accessors
.
size
());
for
(
size_t
i
=
0
;
i
<
_input_accessors
.
size
();
++
i
)
{
out_items
[
out_idx
]
=
_input_accessors
[
i
]
->
backward
(
samples
,
sample_num
,
scope
);
scope_ctx
->
wait_status
[
i
]
=
_input_accessors
[
i
]
->
backward
(
samples
,
sample_num
,
scope
);
}
timer
.
Pause
();
scope_ctx
->
push_gradient_cost_ms
=
timer
.
ElapsedMS
();
...
...
@@ -203,7 +225,8 @@ paddle::framework::Channel<DataItem> MultiThreadExecutor::run(
VLOG
(
2
)
<<
"[Debug][Layer]"
<<
ScopeHelper
::
to_string
(
scope
,
layer_name
);
}
}
delete
scope_ctx
;
// 所有pipe完成后,再回收sample
// 所有pipe完成后,再异步回收sample
_delete_channel
->
Put
(
scope_ctx
);
}
return
0
;
});
...
...
paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.h
浏览文件 @
b8cf64ab
#pragma once
#include <thread>
#include <functional>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
...
...
@@ -18,6 +19,12 @@ public:
_sample_num
=
sample_num
;
}
virtual
~
ScopeExecutorContext
()
{
for
(
auto
&
status
:
wait_status
)
{
if
(
!
status
.
valid
())
{
continue
;
}
status
.
wait
();
}
delete
[]
_samples
;
}
inline
SampleInstance
*
samples
()
{
...
...
@@ -29,6 +36,7 @@ public:
size_t
executor_cost_ms
=
0
;
size_t
prepare_cost_ms
=
0
;
size_t
push_gradient_cost_ms
=
0
;
std
::
vector
<
std
::
future
<
int32_t
>>
wait_status
;
private:
size_t
_sample_num
=
0
;
SampleInstance
*
_samples
=
NULL
;
...
...
@@ -83,6 +91,11 @@ protected:
std
::
map
<
uint32_t
,
std
::
vector
<
DataInputAccessor
*>>
_table_to_accessors
;
std
::
shared_ptr
<
paddle
::
ps
::
ObjectPool
<::
paddle
::
framework
::
Scope
>>
_scope_obj_pool
;
std
::
vector
<
std
::
string
>
_persistables
;
// 异步删除
static
std
::
once_flag
_async_delete_flag
;
static
std
::
shared_ptr
<
std
::
thread
>
_async_delete_thread
;
static
paddle
::
framework
::
Channel
<
ScopeExecutorContext
*>
_delete_channel
;
};
}
// namespace feed
...
...
paddle/fluid/train/custom_trainer/feed/io/file_system.cc
浏览文件 @
b8cf64ab
...
...
@@ -23,6 +23,30 @@ std::pair<std::string, std::string> FileSystem::path_split(const std::string& pa
return
{
path
.
substr
(
0
,
pos
),
path
.
substr
(
pos
+
1
)};
}
int
FileSystem
::
append_line
(
const
std
::
string
&
path
,
const
std
::
string
&
line
,
size_t
reserve_line_num
)
{
std
::
string
tail_data
;
if
(
exists
(
path
))
{
tail_data
=
paddle
::
string
::
trim_spaces
(
tail
(
path
,
reserve_line_num
));
}
if
(
tail_data
.
size
()
>
0
)
{
tail_data
=
tail_data
+
"
\n
"
+
line
;
}
else
{
tail_data
=
line
;
}
VLOG
(
2
)
<<
"Append to file:"
<<
path
<<
", line str:"
<<
line
;
while
(
true
)
{
remove
(
path
);
auto
fp
=
open_write
(
path
,
""
);
if
(
fwrite
(
tail_data
.
c_str
(),
tail_data
.
length
(),
1
,
&*
fp
)
==
1
)
{
break
;
}
sleep
(
10
);
VLOG
(
0
)
<<
"Retry Append to file:"
<<
path
<<
", line str:"
<<
line
;
}
return
0
;
}
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/io/file_system.h
浏览文件 @
b8cf64ab
...
...
@@ -18,6 +18,8 @@ public:
virtual
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
)
=
0
;
virtual
std
::
shared_ptr
<
FILE
>
open_read
(
const
std
::
string
&
path
,
const
std
::
string
&
converter
)
=
0
;
virtual
std
::
shared_ptr
<
FILE
>
open_write
(
const
std
::
string
&
path
,
const
std
::
string
&
converter
)
=
0
;
// only support text-file
virtual
int
append_line
(
const
std
::
string
&
path
,
const
std
::
string
&
line
,
size_t
reserve_line_num
);
virtual
int64_t
file_size
(
const
std
::
string
&
path
)
=
0
;
virtual
void
remove
(
const
std
::
string
&
path
)
=
0
;
virtual
std
::
vector
<
std
::
string
>
list
(
const
std
::
string
&
path
)
=
0
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录