提交 b8cf64ab 编写于 作者: X xiexionghang

for async push_gradient

上级 aaea8a39
...@@ -52,7 +52,10 @@ int32_t DenseInputAccessor::create(::paddle::framework::Scope* scope) { ...@@ -52,7 +52,10 @@ int32_t DenseInputAccessor::create(::paddle::framework::Scope* scope) {
// rpc拉取数据,需保证单线程运行 // rpc拉取数据,需保证单线程运行
int32_t DenseInputAccessor::pull_dense(size_t table_id) { int32_t DenseInputAccessor::pull_dense(size_t table_id) {
float* data_buffer = new float[_total_dim]; float* data_buffer = _data_buffer;
if (data_buffer == NULL) {
data_buffer = new float[_total_dim];
}
size_t data_buffer_idx = 0; size_t data_buffer_idx = 0;
std::vector<paddle::ps::Region> regions; std::vector<paddle::ps::Region> regions;
for (auto& variable : _x_variables) { for (auto& variable : _x_variables) {
...@@ -128,10 +131,11 @@ int32_t DenseInputAccessor::collect_persistables_name(std::vector<std::string>& ...@@ -128,10 +131,11 @@ int32_t DenseInputAccessor::collect_persistables_name(std::vector<std::string>&
return 0; return 0;
} }
int32_t DenseInputAccessor::backward(SampleInstance* samples, size_t num, std::future<int32_t> DenseInputAccessor::backward(SampleInstance* samples, size_t num,
paddle::framework::Scope* scope) { paddle::framework::Scope* scope) {
std::future<int32_t> ret;
if (!_need_gradient) { if (!_need_gradient) {
return 0; return ret;
} }
size_t data_buffer_idx = 0; size_t data_buffer_idx = 0;
std::vector<paddle::ps::Region> regions; std::vector<paddle::ps::Region> regions;
...@@ -142,8 +146,7 @@ int32_t DenseInputAccessor::backward(SampleInstance* samples, size_t num, ...@@ -142,8 +146,7 @@ int32_t DenseInputAccessor::backward(SampleInstance* samples, size_t num,
regions.emplace_back(grad_data, variable.dim); regions.emplace_back(grad_data, variable.dim);
} }
auto* ps_client = _trainer_context->pslib->ps_client(); auto* ps_client = _trainer_context->pslib->ps_client();
auto push_status = ps_client->push_dense(regions.data(), regions.size(), _table_id); ps_client->push_dense(regions.data(), regions.size(), _table_id);
//push_status.get();
if (!FLAGS_feed_trainer_debug_dense_name.empty()) { if (!FLAGS_feed_trainer_debug_dense_name.empty()) {
for (auto& variable : _x_variables) { for (auto& variable : _x_variables) {
if (variable.name != FLAGS_feed_trainer_debug_dense_name) { if (variable.name != FLAGS_feed_trainer_debug_dense_name) {
...@@ -152,7 +155,8 @@ int32_t DenseInputAccessor::backward(SampleInstance* samples, size_t num, ...@@ -152,7 +155,8 @@ int32_t DenseInputAccessor::backward(SampleInstance* samples, size_t num,
VLOG(2) << "[Debug][PushDense]" << ScopeHelper::to_string(scope, variable.gradient_name); VLOG(2) << "[Debug][PushDense]" << ScopeHelper::to_string(scope, variable.gradient_name);
} }
} }
return 0; // not wait dense push
return ret;
} }
int32_t EbdVariableInputAccessor::forward(SampleInstance* samples, size_t num, int32_t EbdVariableInputAccessor::forward(SampleInstance* samples, size_t num,
...@@ -171,10 +175,10 @@ int32_t EbdVariableInputAccessor::forward(SampleInstance* samples, size_t num, ...@@ -171,10 +175,10 @@ int32_t EbdVariableInputAccessor::forward(SampleInstance* samples, size_t num,
} }
return 0; return 0;
} }
std::future<int32_t> EbdVariableInputAccessor::backward(SampleInstance* samples, size_t num,
int32_t EbdVariableInputAccessor::backward(SampleInstance* samples, size_t num,
paddle::framework::Scope* scope) { paddle::framework::Scope* scope) {
return 0; std::future<int32_t> ret;
return ret;
} }
REGIST_CLASS(DataInputAccessor, DenseInputAccessor); REGIST_CLASS(DataInputAccessor, DenseInputAccessor);
......
...@@ -22,8 +22,10 @@ namespace feed { ...@@ -22,8 +22,10 @@ namespace feed {
} }
std::string done_text = fs->tail(_done_file_path); std::string done_text = fs->tail(_done_file_path);
_done_status = paddle::string::split_string(done_text, std::string("\t")); _done_status = paddle::string::split_string(done_text, std::string("\t"));
_current_epoch_id = get_status<uint64_t>(EpochStatusFiled::EpochIdField); _last_done_epoch_id = get_status<uint64_t>(EpochStatusFiled::EpochIdField);
_last_checkpoint_epoch_id = get_status<uint64_t>(EpochStatusFiled::CheckpointIdField); _last_checkpoint_epoch_id = get_status<uint64_t>(EpochStatusFiled::CheckpointIdField);
// 训练需要从上一个checkpoint对应的epoch开始
_current_epoch_id = _last_checkpoint_epoch_id;
_last_checkpoint_path = get_status<std::string>(EpochStatusFiled::CheckpointPathField); _last_checkpoint_path = get_status<std::string>(EpochStatusFiled::CheckpointPathField);
_inference_base_model_key = get_status<uint64_t>(EpochStatusFiled::InferenceBaseKeyField); _inference_base_model_key = get_status<uint64_t>(EpochStatusFiled::InferenceBaseKeyField);
_inference_model_path = fs->path_join(_model_root_path, config["inference_model_dir"].as<std::string>("xbox")); _inference_model_path = fs->path_join(_model_root_path, config["inference_model_dir"].as<std::string>("xbox"));
...@@ -45,8 +47,14 @@ namespace feed { ...@@ -45,8 +47,14 @@ namespace feed {
set_status(EpochStatusFiled::TimestampField, now.tv_sec); set_status(EpochStatusFiled::TimestampField, now.tv_sec);
set_status(EpochStatusFiled::CheckpointIdField, _last_checkpoint_epoch_id); set_status(EpochStatusFiled::CheckpointIdField, _last_checkpoint_epoch_id);
set_status(EpochStatusFiled::CheckpointPathField, _last_checkpoint_path); set_status(EpochStatusFiled::CheckpointPathField, _last_checkpoint_path);
set_status(EpochStatusFiled::DateField, format_timestamp(epoch_id, "%Y%m%d")); set_status(EpochStatusFiled::DateField, format_timestamp(epoch_id, "%Y%m%d-%H%M"));
set_status(EpochStatusFiled::InferenceBaseKeyField, _inference_base_model_key); set_status(EpochStatusFiled::InferenceBaseKeyField, _inference_base_model_key);
if (epoch_id > _last_done_epoch_id) {
// 保留末尾1000数据
auto fs = _trainer_context->file_system.get();
std::string done_str = paddle::string::join_strings(_done_status, '\t');
fs->append_line(_done_file_path, done_str, 1000);
}
return 0; return 0;
} }
...@@ -59,20 +67,18 @@ namespace feed { ...@@ -59,20 +67,18 @@ namespace feed {
} }
std::string done_str; std::string done_str;
std::string donefile; std::string donefile;
auto fs = _trainer_context->file_system.get();
auto model_path = model_save_path(epoch_id, save_way); auto model_path = model_save_path(epoch_id, save_way);
std::string inference_done_format("{\"id\":\"%lu\",\"key\":\"%lu\",\"input\":\"%s/000\",\"record_count\":\"1\",\"file_format\":\"pb\",\"schema_version\":\"2\",\"partition_type\":\"1\",\"job_name\":\"%s\",\"job_id\":\"%s\",\"mpi_size\":\"%d\",\"monitor_data\":\"%s\"}"); std::string inference_done_format("{\"id\":\"%lu\",\"key\":\"%lu\",\"input\":\"%s/000\",\"record_count\":\"1\",\"file_format\":\"pb\",\"schema_version\":\"2\",\"partition_type\":\"1\",\"job_name\":\"%s\",\"job_id\":\"%s\",\"mpi_size\":\"%d\",\"monitor_data\":\"%s\"}");
auto id = time(NULL); auto id = time(NULL);
switch (save_way) { switch (save_way) {
case ModelSaveWay::ModelSaveTrainCheckpoint:
donefile = _done_file_path;
done_str = paddle::string::join_strings(_done_status, '\t');
break;
case ModelSaveWay::ModelSaveInferenceDelta: case ModelSaveWay::ModelSaveInferenceDelta:
donefile = _inference_model_delta_done_path; donefile = _inference_model_delta_done_path;
done_str = string::format_string(inference_done_format.c_str(), id, _inference_base_model_key, done_str = string::format_string(inference_done_format.c_str(), id, _inference_base_model_key,
model_path.c_str(), env->job_name().c_str(), env->job_id().c_str(), model_path.c_str(), env->job_name().c_str(), env->job_id().c_str(),
env->node_num(EnvironmentRole::PSERVER), _trainer_context->monitor_ssm.str().c_str()); env->node_num(EnvironmentRole::PSERVER), _trainer_context->monitor_ssm.str().c_str());
fs->append_line(donefile, done_str, 1000);
break; break;
case ModelSaveWay::ModelSaveInferenceBase: case ModelSaveWay::ModelSaveInferenceBase:
donefile = _inference_model_base_done_path; donefile = _inference_model_base_done_path;
...@@ -80,30 +86,9 @@ namespace feed { ...@@ -80,30 +86,9 @@ namespace feed {
done_str = string::format_string(inference_done_format.c_str(), id, id, done_str = string::format_string(inference_done_format.c_str(), id, id,
model_path.c_str(), env->job_name().c_str(), env->job_id().c_str(), model_path.c_str(), env->job_name().c_str(), env->job_id().c_str(),
env->node_num(EnvironmentRole::PSERVER), _trainer_context->monitor_ssm.str().c_str()); env->node_num(EnvironmentRole::PSERVER), _trainer_context->monitor_ssm.str().c_str());
fs->append_line(donefile, done_str, 1000);
break; break;
} }
// 保留末尾1000数据
std::string tail_done_info;
auto fs = _trainer_context->file_system.get();
if (fs->exists(donefile)) {
tail_done_info = paddle::string::trim_spaces(fs->tail(donefile, 1000));
}
if (tail_done_info.size() > 0) {
tail_done_info = tail_done_info + "\n" + done_str;
} else {
tail_done_info = done_str;
}
VLOG(2) << "Write donefile " << donefile << ", str:" << done_str;
bool write_success = false;
while (true) {
fs->remove(donefile);
auto fp = fs->open_write(donefile, "");
if (fwrite(tail_done_info.c_str(), tail_done_info.length(), 1, &*fp) == 1) {
break;
}
sleep(10);
}
VLOG(2) << "Write donefile " << donefile << "success";
return 0; return 0;
} }
...@@ -155,7 +140,9 @@ namespace feed { ...@@ -155,7 +140,9 @@ namespace feed {
} }
switch (save_way) { switch (save_way) {
case ModelSaveWay::ModelSaveInferenceDelta: case ModelSaveWay::ModelSaveInferenceDelta:
return delta_id(epoch_id) % 6 == 0; // 重启训练后,中间的delta不重复dump
return epoch_id > _last_done_epoch_id &&
delta_id(epoch_id) % 6 == 0;
case ModelSaveWay::ModelSaveInferenceBase: case ModelSaveWay::ModelSaveInferenceBase:
return is_last_epoch(epoch_id); return is_last_epoch(epoch_id);
case ModelSaveWay::ModelSaveTrainCheckpoint: case ModelSaveWay::ModelSaveTrainCheckpoint:
......
...@@ -73,6 +73,7 @@ protected: ...@@ -73,6 +73,7 @@ protected:
std::string _inference_model_delta_done_path; std::string _inference_model_delta_done_path;
uint64_t _current_epoch_id = 0; uint64_t _current_epoch_id = 0;
std::string _last_checkpoint_path; std::string _last_checkpoint_path;
uint64_t _last_done_epoch_id = 0;
uint64_t _last_checkpoint_epoch_id = 0; uint64_t _last_checkpoint_epoch_id = 0;
std::vector<std::string> _done_status; // 当前完成状态,统一存成string std::vector<std::string> _done_status; // 当前完成状态,统一存成string
uint64_t _inference_base_model_key = 0; // 预估模型的base-key uint64_t _inference_base_model_key = 0; // 预估模型的base-key
......
...@@ -35,8 +35,9 @@ public: ...@@ -35,8 +35,9 @@ public:
virtual int32_t forward(SampleInstance* samples, size_t num, virtual int32_t forward(SampleInstance* samples, size_t num,
::paddle::framework::Scope* scope) = 0; ::paddle::framework::Scope* scope) = 0;
// 后向,一般用于更新梯度,在训练网络执行后调用 // 后向,一般用于更新梯度,在训练网络执行后调用, 由于backward一般是异步,这里返回future,
virtual int32_t backward(SampleInstance* samples, size_t num, // TODO 前向接口也改为future返回形式,接口一致性好些
virtual std::future<int32_t> backward(SampleInstance* samples, size_t num,
::paddle::framework::Scope* scope) = 0; ::paddle::framework::Scope* scope) = 0;
// 收集持久化变量的名称, 并将值拷贝到Scope // 收集持久化变量的名称, 并将值拷贝到Scope
...@@ -67,7 +68,7 @@ public: ...@@ -67,7 +68,7 @@ public:
virtual int32_t forward(SampleInstance* samples, size_t num, virtual int32_t forward(SampleInstance* samples, size_t num,
::paddle::framework::Scope* scope); ::paddle::framework::Scope* scope);
virtual int32_t backward(SampleInstance* samples, size_t num, virtual std::future<int32_t> backward(SampleInstance* samples, size_t num,
::paddle::framework::Scope* scope); ::paddle::framework::Scope* scope);
protected: protected:
size_t _label_total_dim = 0; size_t _label_total_dim = 0;
...@@ -108,7 +109,7 @@ public: ...@@ -108,7 +109,7 @@ public:
virtual void post_process_input(float* var_data, SparseInputVariable&, SampleInstance*, size_t num) = 0; virtual void post_process_input(float* var_data, SparseInputVariable&, SampleInstance*, size_t num) = 0;
// backward过程的梯度push // backward过程的梯度push
virtual int32_t backward(SampleInstance* samples, size_t num, virtual std::future<int32_t> backward(SampleInstance* samples, size_t num,
paddle::framework::Scope* scope); paddle::framework::Scope* scope);
// SparseGradValue会被依次调用,用于整理push的梯度 // SparseGradValue会被依次调用,用于整理push的梯度
virtual void fill_gradient(float* push_value, const float* gradient_raw, virtual void fill_gradient(float* push_value, const float* gradient_raw,
...@@ -148,7 +149,7 @@ public: ...@@ -148,7 +149,7 @@ public:
virtual int32_t forward(SampleInstance* samples, size_t num, virtual int32_t forward(SampleInstance* samples, size_t num,
paddle::framework::Scope* scope); paddle::framework::Scope* scope);
virtual int32_t backward(SampleInstance* samples, size_t num, virtual std::future<int32_t> backward(SampleInstance* samples, size_t num,
paddle::framework::Scope* scope); paddle::framework::Scope* scope);
...@@ -175,7 +176,7 @@ public: ...@@ -175,7 +176,7 @@ public:
virtual int32_t forward(SampleInstance* samples, size_t num, virtual int32_t forward(SampleInstance* samples, size_t num,
paddle::framework::Scope* scope); paddle::framework::Scope* scope);
virtual int32_t backward(SampleInstance* samples, size_t num, virtual std::future<int32_t> backward(SampleInstance* samples, size_t num,
paddle::framework::Scope* scope); paddle::framework::Scope* scope);
}; };
......
...@@ -45,10 +45,11 @@ int32_t LabelInputAccessor::forward(SampleInstance* samples, size_t num, ...@@ -45,10 +45,11 @@ int32_t LabelInputAccessor::forward(SampleInstance* samples, size_t num,
return 0; return 0;
} }
int32_t LabelInputAccessor::backward(SampleInstance* samples, size_t num, std::future<int32_t> LabelInputAccessor::backward(SampleInstance* samples, size_t num,
paddle::framework::Scope* scope) { paddle::framework::Scope* scope) {
std::future<int32_t> ret;
if (num < 1) { if (num < 1) {
return 0; return ret;
} }
for (size_t i = 0; i < num; ++i) { for (size_t i = 0; i < num; ++i) {
auto& sample = samples[i]; auto& sample = samples[i];
...@@ -69,7 +70,7 @@ int32_t LabelInputAccessor::backward(SampleInstance* samples, size_t num, ...@@ -69,7 +70,7 @@ int32_t LabelInputAccessor::backward(SampleInstance* samples, size_t num,
VLOG(2) << "[Debug][Lable]" << ScopeHelper::to_string(scope, label.label_name) << ScopeHelper::to_string(scope, label.output_name); VLOG(2) << "[Debug][Lable]" << ScopeHelper::to_string(scope, label.label_name) << ScopeHelper::to_string(scope, label.output_name);
} }
*/ */
return 0; return ret;
} }
REGIST_CLASS(DataInputAccessor, LabelInputAccessor); REGIST_CLASS(DataInputAccessor, LabelInputAccessor);
......
...@@ -136,8 +136,9 @@ int32_t BaseSparseInputAccessor::forward(SampleInstance* samples, ...@@ -136,8 +136,9 @@ int32_t BaseSparseInputAccessor::forward(SampleInstance* samples,
} }
// 更新spare数据 // 更新spare数据
int32_t BaseSparseInputAccessor::backward(SampleInstance* samples, std::future<int32_t> BaseSparseInputAccessor::backward(SampleInstance* samples,
size_t num, paddle::framework::Scope* scope) { size_t num, paddle::framework::Scope* scope) {
std::future<int32_t> ret;
int64_t runtime_data_for_scope = *ScopeHelper::get_value<int64_t>( int64_t runtime_data_for_scope = *ScopeHelper::get_value<int64_t>(
scope, _trainer_context->cpu_place, "sparse_runtime_data"); scope, _trainer_context->cpu_place, "sparse_runtime_data");
auto* runtime_data_ptr = (std::vector<SparseVarRuntimeData>*)runtime_data_for_scope; auto* runtime_data_ptr = (std::vector<SparseVarRuntimeData>*)runtime_data_for_scope;
...@@ -146,7 +147,7 @@ int32_t BaseSparseInputAccessor::backward(SampleInstance* samples, ...@@ -146,7 +147,7 @@ int32_t BaseSparseInputAccessor::backward(SampleInstance* samples,
delete runtime_data_ptr; delete runtime_data_ptr;
}); });
if (!_need_gradient) { if (!_need_gradient) {
return 0; return ret;
} }
auto* ps_client = _trainer_context->pslib->ps_client(); auto* ps_client = _trainer_context->pslib->ps_client();
auto* value_accessor = ps_client->table_accessor(_table_id); auto* value_accessor = ps_client->table_accessor(_table_id);
...@@ -204,11 +205,10 @@ int32_t BaseSparseInputAccessor::backward(SampleInstance* samples, ...@@ -204,11 +205,10 @@ int32_t BaseSparseInputAccessor::backward(SampleInstance* samples,
VLOG(2) << "[DEBUG][sparse_slot_push]" << ssm.str(); VLOG(2) << "[DEBUG][sparse_slot_push]" << ssm.str();
} }
} }
auto push_status = ps_client->push_sparse(_table_id, ret = ps_client->push_sparse(_table_id,
keys.data(), (const float**)push_values, key_idx); keys.data(), (const float**)push_values, key_idx);
//auto ret = push_status.get();
delete[] push_values; delete[] push_values;
return 0; return ret;
} }
class AbacusSparseJoinAccessor : public BaseSparseInputAccessor { class AbacusSparseJoinAccessor : public BaseSparseInputAccessor {
......
...@@ -70,7 +70,6 @@ paddle::PSParameter* PSlib::get_param() { ...@@ -70,7 +70,6 @@ paddle::PSParameter* PSlib::get_param() {
void PSlib::init_gflag() { void PSlib::init_gflag() {
int cnt = 4; int cnt = 4;
char** params_ptr = new char*[cnt]; char** params_ptr = new char*[cnt];
std::cout << "alloc_ptr" << params_ptr << std::flush;
char p0[] = "exe default"; char p0[] = "exe default";
char p1[] = "-max_body_size=314217728"; char p1[] = "-max_body_size=314217728";
char p2[] = "-bthread_concurrency=40"; char p2[] = "-bthread_concurrency=40";
......
...@@ -9,6 +9,10 @@ namespace paddle { ...@@ -9,6 +9,10 @@ namespace paddle {
namespace custom_trainer { namespace custom_trainer {
namespace feed { namespace feed {
std::once_flag MultiThreadExecutor::_async_delete_flag;
std::shared_ptr<std::thread> MultiThreadExecutor::_async_delete_thread;
paddle::framework::Channel<ScopeExecutorContext*> MultiThreadExecutor::_delete_channel;
int MultiThreadExecutor::initialize(YAML::Node exe_config, int MultiThreadExecutor::initialize(YAML::Node exe_config,
std::shared_ptr<TrainerContext> context_ptr) { std::shared_ptr<TrainerContext> context_ptr) {
int ret = 0; int ret = 0;
...@@ -85,6 +89,23 @@ int MultiThreadExecutor::initialize(YAML::Node exe_config, ...@@ -85,6 +89,23 @@ int MultiThreadExecutor::initialize(YAML::Node exe_config,
CHECK(monitor_ptr->initialize(monitor_config, context_ptr) == 0) CHECK(monitor_ptr->initialize(monitor_config, context_ptr) == 0)
<< "Monitor init Failed, class:" << monitor_class; << "Monitor init Failed, class:" << monitor_class;
} }
// 异步删除池
std::call_once(_async_delete_flag, [this](){
_delete_channel = paddle::framework::MakeChannel<ScopeExecutorContext*>();
_delete_channel->SetBlockSize(32);
_async_delete_thread.reset(new std::thread([this]{
std::vector<ScopeExecutorContext*> ctxs;
while (true) {
while (_delete_channel->Read(ctxs)) {
for (auto* ctx : ctxs) {
delete ctx;
}
}
usleep(200000); // 200ms
}
}));
});
return ret; return ret;
} }
...@@ -187,9 +208,10 @@ paddle::framework::Channel<DataItem> MultiThreadExecutor::run( ...@@ -187,9 +208,10 @@ paddle::framework::Channel<DataItem> MultiThreadExecutor::run(
auto* samples = scope_ctx->samples(); auto* samples = scope_ctx->samples();
auto sample_num = scope_ctx->sample_num(); auto sample_num = scope_ctx->sample_num();
out_items[out_idx] = 0;
scope_ctx->wait_status.resize(_input_accessors.size());
for (size_t i = 0; i < _input_accessors.size(); ++i) { for (size_t i = 0; i < _input_accessors.size(); ++i) {
out_items[out_idx] = _input_accessors[i]-> scope_ctx->wait_status[i] = _input_accessors[i]->backward(samples, sample_num, scope);
backward(samples, sample_num, scope);
} }
timer.Pause(); timer.Pause();
scope_ctx->push_gradient_cost_ms = timer.ElapsedMS(); scope_ctx->push_gradient_cost_ms = timer.ElapsedMS();
...@@ -203,7 +225,8 @@ paddle::framework::Channel<DataItem> MultiThreadExecutor::run( ...@@ -203,7 +225,8 @@ paddle::framework::Channel<DataItem> MultiThreadExecutor::run(
VLOG(2) << "[Debug][Layer]" << ScopeHelper::to_string(scope, layer_name); VLOG(2) << "[Debug][Layer]" << ScopeHelper::to_string(scope, layer_name);
} }
} }
delete scope_ctx; // 所有pipe完成后,再回收sample // 所有pipe完成后,再异步回收sample
_delete_channel->Put(scope_ctx);
} }
return 0; return 0;
}); });
......
#pragma once #pragma once
#include <thread>
#include <functional> #include <functional>
#include "paddle/fluid/framework/channel.h" #include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h" #include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
...@@ -18,6 +19,12 @@ public: ...@@ -18,6 +19,12 @@ public:
_sample_num = sample_num; _sample_num = sample_num;
} }
virtual ~ScopeExecutorContext() { virtual ~ScopeExecutorContext() {
for (auto& status : wait_status) {
if (!status.valid()) {
continue;
}
status.wait();
}
delete[] _samples; delete[] _samples;
} }
inline SampleInstance* samples() { inline SampleInstance* samples() {
...@@ -29,6 +36,7 @@ public: ...@@ -29,6 +36,7 @@ public:
size_t executor_cost_ms = 0; size_t executor_cost_ms = 0;
size_t prepare_cost_ms = 0; size_t prepare_cost_ms = 0;
size_t push_gradient_cost_ms = 0; size_t push_gradient_cost_ms = 0;
std::vector<std::future<int32_t>> wait_status;
private: private:
size_t _sample_num = 0; size_t _sample_num = 0;
SampleInstance* _samples = NULL; SampleInstance* _samples = NULL;
...@@ -83,6 +91,11 @@ protected: ...@@ -83,6 +91,11 @@ protected:
std::map<uint32_t, std::vector<DataInputAccessor*>> _table_to_accessors; std::map<uint32_t, std::vector<DataInputAccessor*>> _table_to_accessors;
std::shared_ptr<paddle::ps::ObjectPool<::paddle::framework::Scope>> _scope_obj_pool; std::shared_ptr<paddle::ps::ObjectPool<::paddle::framework::Scope>> _scope_obj_pool;
std::vector<std::string> _persistables; std::vector<std::string> _persistables;
// 异步删除
static std::once_flag _async_delete_flag;
static std::shared_ptr<std::thread> _async_delete_thread;
static paddle::framework::Channel<ScopeExecutorContext*> _delete_channel;
}; };
} // namespace feed } // namespace feed
......
...@@ -23,6 +23,30 @@ std::pair<std::string, std::string> FileSystem::path_split(const std::string& pa ...@@ -23,6 +23,30 @@ std::pair<std::string, std::string> FileSystem::path_split(const std::string& pa
return {path.substr(0, pos), path.substr(pos + 1)}; return {path.substr(0, pos), path.substr(pos + 1)};
} }
int FileSystem::append_line(const std::string& path,
const std::string& line, size_t reserve_line_num) {
std::string tail_data;
if (exists(path)) {
tail_data = paddle::string::trim_spaces(tail(path, reserve_line_num));
}
if (tail_data.size() > 0) {
tail_data = tail_data + "\n" + line;
} else {
tail_data = line;
}
VLOG(2) << "Append to file:" << path << ", line str:" << line;
while (true) {
remove(path);
auto fp = open_write(path, "");
if (fwrite(tail_data.c_str(), tail_data.length(), 1, &*fp) == 1) {
break;
}
sleep(10);
VLOG(0) << "Retry Append to file:" << path << ", line str:" << line;
}
return 0;
}
} // namespace feed } // namespace feed
} // namespace custom_trainer } // namespace custom_trainer
} // namespace paddle } // namespace paddle
...@@ -18,6 +18,8 @@ public: ...@@ -18,6 +18,8 @@ public:
virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) = 0; virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) = 0;
virtual std::shared_ptr<FILE> open_read(const std::string& path, const std::string& converter) = 0; virtual std::shared_ptr<FILE> open_read(const std::string& path, const std::string& converter) = 0;
virtual std::shared_ptr<FILE> open_write(const std::string& path, const std::string& converter) = 0; virtual std::shared_ptr<FILE> open_write(const std::string& path, const std::string& converter) = 0;
// only support text-file
virtual int append_line(const std::string& path, const std::string& line, size_t reserve_line_num);
virtual int64_t file_size(const std::string& path) = 0; virtual int64_t file_size(const std::string& path) = 0;
virtual void remove(const std::string& path) = 0; virtual void remove(const std::string& path) = 0;
virtual std::vector<std::string> list(const std::string& path) = 0; virtual std::vector<std::string> list(const std::string& path) = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册