提交 b8cf64ab 编写于 作者: X xiexionghang

for async push_gradient

上级 aaea8a39
......@@ -52,7 +52,10 @@ int32_t DenseInputAccessor::create(::paddle::framework::Scope* scope) {
// rpc拉取数据,需保证单线程运行
int32_t DenseInputAccessor::pull_dense(size_t table_id) {
float* data_buffer = new float[_total_dim];
float* data_buffer = _data_buffer;
if (data_buffer == NULL) {
data_buffer = new float[_total_dim];
}
size_t data_buffer_idx = 0;
std::vector<paddle::ps::Region> regions;
for (auto& variable : _x_variables) {
......@@ -128,10 +131,11 @@ int32_t DenseInputAccessor::collect_persistables_name(std::vector<std::string>&
return 0;
}
int32_t DenseInputAccessor::backward(SampleInstance* samples, size_t num,
std::future<int32_t> DenseInputAccessor::backward(SampleInstance* samples, size_t num,
paddle::framework::Scope* scope) {
std::future<int32_t> ret;
if (!_need_gradient) {
return 0;
return ret;
}
size_t data_buffer_idx = 0;
std::vector<paddle::ps::Region> regions;
......@@ -142,8 +146,7 @@ int32_t DenseInputAccessor::backward(SampleInstance* samples, size_t num,
regions.emplace_back(grad_data, variable.dim);
}
auto* ps_client = _trainer_context->pslib->ps_client();
auto push_status = ps_client->push_dense(regions.data(), regions.size(), _table_id);
//push_status.get();
ps_client->push_dense(regions.data(), regions.size(), _table_id);
if (!FLAGS_feed_trainer_debug_dense_name.empty()) {
for (auto& variable : _x_variables) {
if (variable.name != FLAGS_feed_trainer_debug_dense_name) {
......@@ -152,7 +155,8 @@ int32_t DenseInputAccessor::backward(SampleInstance* samples, size_t num,
VLOG(2) << "[Debug][PushDense]" << ScopeHelper::to_string(scope, variable.gradient_name);
}
}
return 0;
// not wait dense push
return ret;
}
int32_t EbdVariableInputAccessor::forward(SampleInstance* samples, size_t num,
......@@ -171,10 +175,10 @@ int32_t EbdVariableInputAccessor::forward(SampleInstance* samples, size_t num,
}
return 0;
}
int32_t EbdVariableInputAccessor::backward(SampleInstance* samples, size_t num,
std::future<int32_t> EbdVariableInputAccessor::backward(SampleInstance* samples, size_t num,
paddle::framework::Scope* scope) {
return 0;
std::future<int32_t> ret;
return ret;
}
REGIST_CLASS(DataInputAccessor, DenseInputAccessor);
......
......@@ -22,8 +22,10 @@ namespace feed {
}
std::string done_text = fs->tail(_done_file_path);
_done_status = paddle::string::split_string(done_text, std::string("\t"));
_current_epoch_id = get_status<uint64_t>(EpochStatusFiled::EpochIdField);
_last_done_epoch_id = get_status<uint64_t>(EpochStatusFiled::EpochIdField);
_last_checkpoint_epoch_id = get_status<uint64_t>(EpochStatusFiled::CheckpointIdField);
// 训练需要从上一个checkpoint对应的epoch开始
_current_epoch_id = _last_checkpoint_epoch_id;
_last_checkpoint_path = get_status<std::string>(EpochStatusFiled::CheckpointPathField);
_inference_base_model_key = get_status<uint64_t>(EpochStatusFiled::InferenceBaseKeyField);
_inference_model_path = fs->path_join(_model_root_path, config["inference_model_dir"].as<std::string>("xbox"));
......@@ -45,8 +47,14 @@ namespace feed {
set_status(EpochStatusFiled::TimestampField, now.tv_sec);
set_status(EpochStatusFiled::CheckpointIdField, _last_checkpoint_epoch_id);
set_status(EpochStatusFiled::CheckpointPathField, _last_checkpoint_path);
set_status(EpochStatusFiled::DateField, format_timestamp(epoch_id, "%Y%m%d"));
set_status(EpochStatusFiled::DateField, format_timestamp(epoch_id, "%Y%m%d-%H%M"));
set_status(EpochStatusFiled::InferenceBaseKeyField, _inference_base_model_key);
if (epoch_id > _last_done_epoch_id) {
// 保留末尾1000数据
auto fs = _trainer_context->file_system.get();
std::string done_str = paddle::string::join_strings(_done_status, '\t');
fs->append_line(_done_file_path, done_str, 1000);
}
return 0;
}
......@@ -59,20 +67,18 @@ namespace feed {
}
std::string done_str;
std::string donefile;
auto fs = _trainer_context->file_system.get();
auto model_path = model_save_path(epoch_id, save_way);
std::string inference_done_format("{\"id\":\"%lu\",\"key\":\"%lu\",\"input\":\"%s/000\",\"record_count\":\"1\",\"file_format\":\"pb\",\"schema_version\":\"2\",\"partition_type\":\"1\",\"job_name\":\"%s\",\"job_id\":\"%s\",\"mpi_size\":\"%d\",\"monitor_data\":\"%s\"}");
auto id = time(NULL);
switch (save_way) {
case ModelSaveWay::ModelSaveTrainCheckpoint:
donefile = _done_file_path;
done_str = paddle::string::join_strings(_done_status, '\t');
break;
case ModelSaveWay::ModelSaveInferenceDelta:
donefile = _inference_model_delta_done_path;
done_str = string::format_string(inference_done_format.c_str(), id, _inference_base_model_key,
model_path.c_str(), env->job_name().c_str(), env->job_id().c_str(),
env->node_num(EnvironmentRole::PSERVER), _trainer_context->monitor_ssm.str().c_str());
fs->append_line(donefile, done_str, 1000);
break;
case ModelSaveWay::ModelSaveInferenceBase:
donefile = _inference_model_base_done_path;
......@@ -80,30 +86,9 @@ namespace feed {
done_str = string::format_string(inference_done_format.c_str(), id, id,
model_path.c_str(), env->job_name().c_str(), env->job_id().c_str(),
env->node_num(EnvironmentRole::PSERVER), _trainer_context->monitor_ssm.str().c_str());
fs->append_line(donefile, done_str, 1000);
break;
}
// 保留末尾1000数据
std::string tail_done_info;
auto fs = _trainer_context->file_system.get();
if (fs->exists(donefile)) {
tail_done_info = paddle::string::trim_spaces(fs->tail(donefile, 1000));
}
if (tail_done_info.size() > 0) {
tail_done_info = tail_done_info + "\n" + done_str;
} else {
tail_done_info = done_str;
}
VLOG(2) << "Write donefile " << donefile << ", str:" << done_str;
bool write_success = false;
while (true) {
fs->remove(donefile);
auto fp = fs->open_write(donefile, "");
if (fwrite(tail_done_info.c_str(), tail_done_info.length(), 1, &*fp) == 1) {
break;
}
sleep(10);
}
VLOG(2) << "Write donefile " << donefile << "success";
return 0;
}
......@@ -155,7 +140,9 @@ namespace feed {
}
switch (save_way) {
case ModelSaveWay::ModelSaveInferenceDelta:
return delta_id(epoch_id) % 6 == 0;
// 重启训练后,中间的delta不重复dump
return epoch_id > _last_done_epoch_id &&
delta_id(epoch_id) % 6 == 0;
case ModelSaveWay::ModelSaveInferenceBase:
return is_last_epoch(epoch_id);
case ModelSaveWay::ModelSaveTrainCheckpoint:
......
......@@ -73,6 +73,7 @@ protected:
std::string _inference_model_delta_done_path;
uint64_t _current_epoch_id = 0;
std::string _last_checkpoint_path;
uint64_t _last_done_epoch_id = 0;
uint64_t _last_checkpoint_epoch_id = 0;
std::vector<std::string> _done_status; // 当前完成状态,统一存成string
uint64_t _inference_base_model_key = 0; // 预估模型的base-key
......
......@@ -35,8 +35,9 @@ public:
virtual int32_t forward(SampleInstance* samples, size_t num,
::paddle::framework::Scope* scope) = 0;
// 后向,一般用于更新梯度,在训练网络执行后调用
virtual int32_t backward(SampleInstance* samples, size_t num,
// 后向,一般用于更新梯度,在训练网络执行后调用, 由于backward一般是异步,这里返回future,
// TODO 前向接口也改为future返回形式,接口一致性好些
virtual std::future<int32_t> backward(SampleInstance* samples, size_t num,
::paddle::framework::Scope* scope) = 0;
// 收集持久化变量的名称, 并将值拷贝到Scope
......@@ -67,7 +68,7 @@ public:
virtual int32_t forward(SampleInstance* samples, size_t num,
::paddle::framework::Scope* scope);
virtual int32_t backward(SampleInstance* samples, size_t num,
virtual std::future<int32_t> backward(SampleInstance* samples, size_t num,
::paddle::framework::Scope* scope);
protected:
size_t _label_total_dim = 0;
......@@ -108,7 +109,7 @@ public:
virtual void post_process_input(float* var_data, SparseInputVariable&, SampleInstance*, size_t num) = 0;
// backward过程的梯度push
virtual int32_t backward(SampleInstance* samples, size_t num,
virtual std::future<int32_t> backward(SampleInstance* samples, size_t num,
paddle::framework::Scope* scope);
// SparseGradValue会被依次调用,用于整理push的梯度
virtual void fill_gradient(float* push_value, const float* gradient_raw,
......@@ -148,7 +149,7 @@ public:
virtual int32_t forward(SampleInstance* samples, size_t num,
paddle::framework::Scope* scope);
virtual int32_t backward(SampleInstance* samples, size_t num,
virtual std::future<int32_t> backward(SampleInstance* samples, size_t num,
paddle::framework::Scope* scope);
......@@ -175,7 +176,7 @@ public:
virtual int32_t forward(SampleInstance* samples, size_t num,
paddle::framework::Scope* scope);
virtual int32_t backward(SampleInstance* samples, size_t num,
virtual std::future<int32_t> backward(SampleInstance* samples, size_t num,
paddle::framework::Scope* scope);
};
......
......@@ -45,10 +45,11 @@ int32_t LabelInputAccessor::forward(SampleInstance* samples, size_t num,
return 0;
}
int32_t LabelInputAccessor::backward(SampleInstance* samples, size_t num,
std::future<int32_t> LabelInputAccessor::backward(SampleInstance* samples, size_t num,
paddle::framework::Scope* scope) {
std::future<int32_t> ret;
if (num < 1) {
return 0;
return ret;
}
for (size_t i = 0; i < num; ++i) {
auto& sample = samples[i];
......@@ -69,7 +70,7 @@ int32_t LabelInputAccessor::backward(SampleInstance* samples, size_t num,
VLOG(2) << "[Debug][Lable]" << ScopeHelper::to_string(scope, label.label_name) << ScopeHelper::to_string(scope, label.output_name);
}
*/
return 0;
return ret;
}
REGIST_CLASS(DataInputAccessor, LabelInputAccessor);
......
......@@ -136,8 +136,9 @@ int32_t BaseSparseInputAccessor::forward(SampleInstance* samples,
}
// 更新spare数据
int32_t BaseSparseInputAccessor::backward(SampleInstance* samples,
std::future<int32_t> BaseSparseInputAccessor::backward(SampleInstance* samples,
size_t num, paddle::framework::Scope* scope) {
std::future<int32_t> ret;
int64_t runtime_data_for_scope = *ScopeHelper::get_value<int64_t>(
scope, _trainer_context->cpu_place, "sparse_runtime_data");
auto* runtime_data_ptr = (std::vector<SparseVarRuntimeData>*)runtime_data_for_scope;
......@@ -146,7 +147,7 @@ int32_t BaseSparseInputAccessor::backward(SampleInstance* samples,
delete runtime_data_ptr;
});
if (!_need_gradient) {
return 0;
return ret;
}
auto* ps_client = _trainer_context->pslib->ps_client();
auto* value_accessor = ps_client->table_accessor(_table_id);
......@@ -204,11 +205,10 @@ int32_t BaseSparseInputAccessor::backward(SampleInstance* samples,
VLOG(2) << "[DEBUG][sparse_slot_push]" << ssm.str();
}
}
auto push_status = ps_client->push_sparse(_table_id,
ret = ps_client->push_sparse(_table_id,
keys.data(), (const float**)push_values, key_idx);
//auto ret = push_status.get();
delete[] push_values;
return 0;
return ret;
}
class AbacusSparseJoinAccessor : public BaseSparseInputAccessor {
......
......@@ -70,7 +70,6 @@ paddle::PSParameter* PSlib::get_param() {
void PSlib::init_gflag() {
int cnt = 4;
char** params_ptr = new char*[cnt];
std::cout << "alloc_ptr" << params_ptr << std::flush;
char p0[] = "exe default";
char p1[] = "-max_body_size=314217728";
char p2[] = "-bthread_concurrency=40";
......
......@@ -9,6 +9,10 @@ namespace paddle {
namespace custom_trainer {
namespace feed {
std::once_flag MultiThreadExecutor::_async_delete_flag;
std::shared_ptr<std::thread> MultiThreadExecutor::_async_delete_thread;
paddle::framework::Channel<ScopeExecutorContext*> MultiThreadExecutor::_delete_channel;
int MultiThreadExecutor::initialize(YAML::Node exe_config,
std::shared_ptr<TrainerContext> context_ptr) {
int ret = 0;
......@@ -85,6 +89,23 @@ int MultiThreadExecutor::initialize(YAML::Node exe_config,
CHECK(monitor_ptr->initialize(monitor_config, context_ptr) == 0)
<< "Monitor init Failed, class:" << monitor_class;
}
// 异步删除池
std::call_once(_async_delete_flag, [this](){
_delete_channel = paddle::framework::MakeChannel<ScopeExecutorContext*>();
_delete_channel->SetBlockSize(32);
_async_delete_thread.reset(new std::thread([this]{
std::vector<ScopeExecutorContext*> ctxs;
while (true) {
while (_delete_channel->Read(ctxs)) {
for (auto* ctx : ctxs) {
delete ctx;
}
}
usleep(200000); // 200ms
}
}));
});
return ret;
}
......@@ -187,9 +208,10 @@ paddle::framework::Channel<DataItem> MultiThreadExecutor::run(
auto* samples = scope_ctx->samples();
auto sample_num = scope_ctx->sample_num();
out_items[out_idx] = 0;
scope_ctx->wait_status.resize(_input_accessors.size());
for (size_t i = 0; i < _input_accessors.size(); ++i) {
out_items[out_idx] = _input_accessors[i]->
backward(samples, sample_num, scope);
scope_ctx->wait_status[i] = _input_accessors[i]->backward(samples, sample_num, scope);
}
timer.Pause();
scope_ctx->push_gradient_cost_ms = timer.ElapsedMS();
......@@ -203,7 +225,8 @@ paddle::framework::Channel<DataItem> MultiThreadExecutor::run(
VLOG(2) << "[Debug][Layer]" << ScopeHelper::to_string(scope, layer_name);
}
}
delete scope_ctx; // 所有pipe完成后,再回收sample
// 所有pipe完成后,再异步回收sample
_delete_channel->Put(scope_ctx);
}
return 0;
});
......
#pragma once
#include <thread>
#include <functional>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
......@@ -18,6 +19,12 @@ public:
_sample_num = sample_num;
}
virtual ~ScopeExecutorContext() {
for (auto& status : wait_status) {
if (!status.valid()) {
continue;
}
status.wait();
}
delete[] _samples;
}
inline SampleInstance* samples() {
......@@ -29,6 +36,7 @@ public:
size_t executor_cost_ms = 0;
size_t prepare_cost_ms = 0;
size_t push_gradient_cost_ms = 0;
std::vector<std::future<int32_t>> wait_status;
private:
size_t _sample_num = 0;
SampleInstance* _samples = NULL;
......@@ -83,6 +91,11 @@ protected:
std::map<uint32_t, std::vector<DataInputAccessor*>> _table_to_accessors;
std::shared_ptr<paddle::ps::ObjectPool<::paddle::framework::Scope>> _scope_obj_pool;
std::vector<std::string> _persistables;
// 异步删除
static std::once_flag _async_delete_flag;
static std::shared_ptr<std::thread> _async_delete_thread;
static paddle::framework::Channel<ScopeExecutorContext*> _delete_channel;
};
} // namespace feed
......
......@@ -23,6 +23,30 @@ std::pair<std::string, std::string> FileSystem::path_split(const std::string& pa
return {path.substr(0, pos), path.substr(pos + 1)};
}
int FileSystem::append_line(const std::string& path,
const std::string& line, size_t reserve_line_num) {
std::string tail_data;
if (exists(path)) {
tail_data = paddle::string::trim_spaces(tail(path, reserve_line_num));
}
if (tail_data.size() > 0) {
tail_data = tail_data + "\n" + line;
} else {
tail_data = line;
}
VLOG(2) << "Append to file:" << path << ", line str:" << line;
while (true) {
remove(path);
auto fp = open_write(path, "");
if (fwrite(tail_data.c_str(), tail_data.length(), 1, &*fp) == 1) {
break;
}
sleep(10);
VLOG(0) << "Retry Append to file:" << path << ", line str:" << line;
}
return 0;
}
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
......@@ -18,6 +18,8 @@ public:
virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) = 0;
virtual std::shared_ptr<FILE> open_read(const std::string& path, const std::string& converter) = 0;
virtual std::shared_ptr<FILE> open_write(const std::string& path, const std::string& converter) = 0;
// only support text-file
virtual int append_line(const std::string& path, const std::string& line, size_t reserve_line_num);
virtual int64_t file_size(const std::string& path) = 0;
virtual void remove(const std::string& path) = 0;
virtual std::vector<std::string> list(const std::string& path) = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册