提交 c00354af 编写于 作者: X xiexionghang

for async push_gradient

上级 b8cf64ab
...@@ -31,6 +31,10 @@ int DenseInputAccessor::initialize(YAML::Node config, ...@@ -31,6 +31,10 @@ int DenseInputAccessor::initialize(YAML::Node config,
if (config["async_pull"] && config["async_pull"].as<bool>()) { if (config["async_pull"] && config["async_pull"].as<bool>()) {
_need_async_pull = true; _need_async_pull = true;
} }
_data_buffer_list.resize(6); // 6 buffer顺序循环使用, 降低更新时的写冲突
for (auto*& buffer : _data_buffer_list) {
buffer = new float[_total_dim];
}
return 0; return 0;
} }
...@@ -52,11 +56,8 @@ int32_t DenseInputAccessor::create(::paddle::framework::Scope* scope) { ...@@ -52,11 +56,8 @@ int32_t DenseInputAccessor::create(::paddle::framework::Scope* scope) {
// rpc拉取数据,需保证单线程运行 // rpc拉取数据,需保证单线程运行
int32_t DenseInputAccessor::pull_dense(size_t table_id) { int32_t DenseInputAccessor::pull_dense(size_t table_id) {
float* data_buffer = _data_buffer;
if (data_buffer == NULL) {
data_buffer = new float[_total_dim];
}
size_t data_buffer_idx = 0; size_t data_buffer_idx = 0;
float* data_buffer = backend_data_buffer();
std::vector<paddle::ps::Region> regions; std::vector<paddle::ps::Region> regions;
for (auto& variable : _x_variables) { for (auto& variable : _x_variables) {
regions.emplace_back(data_buffer + data_buffer_idx, variable.dim); regions.emplace_back(data_buffer + data_buffer_idx, variable.dim);
...@@ -66,7 +67,8 @@ int32_t DenseInputAccessor::pull_dense(size_t table_id) { ...@@ -66,7 +67,8 @@ int32_t DenseInputAccessor::pull_dense(size_t table_id) {
auto push_status = ps_client->pull_dense(regions.data(), regions.size(), table_id); auto push_status = ps_client->pull_dense(regions.data(), regions.size(), table_id);
int32_t ret = push_status.get(); int32_t ret = push_status.get();
// TODO 使用双buffer DataBuffer,避免训练期改写,当前异步SGD下,问题不大 // TODO 使用双buffer DataBuffer,避免训练期改写,当前异步SGD下,问题不大
_data_buffer = data_buffer; switch_data_buffer();
_is_data_buffer_init = true;
return ret; return ret;
} }
...@@ -82,9 +84,9 @@ int32_t DenseInputAccessor::forward(SampleInstance* samples, size_t num, ...@@ -82,9 +84,9 @@ int32_t DenseInputAccessor::forward(SampleInstance* samples, size_t num,
int32_t DenseInputAccessor::collect_persistables(paddle::framework::Scope* scope) { int32_t DenseInputAccessor::collect_persistables(paddle::framework::Scope* scope) {
// 首次同步pull,之后异步pull // 首次同步pull,之后异步pull
if (_data_buffer == nullptr) { if (!_is_data_buffer_init) {
_pull_mutex.lock(); _pull_mutex.lock();
if (_data_buffer == nullptr) { if (!_is_data_buffer_init) {
CHECK(pull_dense(_table_id) == 0); CHECK(pull_dense(_table_id) == 0);
_async_pull_thread = std::make_shared<std::thread>( _async_pull_thread = std::make_shared<std::thread>(
[this]() { [this]() {
...@@ -101,16 +103,17 @@ int32_t DenseInputAccessor::collect_persistables(paddle::framework::Scope* scope ...@@ -101,16 +103,17 @@ int32_t DenseInputAccessor::collect_persistables(paddle::framework::Scope* scope
_pull_mutex.unlock(); _pull_mutex.unlock();
} }
size_t data_buffer_idx = 0; size_t data_buffer_idx = 0;
auto* data_buff = data_buffer();
for (auto& variable : _x_variables) { for (auto& variable : _x_variables) {
auto* shape_ptr = &(variable.shape[0]); auto* shape_ptr = &(variable.shape[0]);
paddle::framework::DDim ddim(shape_ptr, variable.shape.size()); paddle::framework::DDim ddim(shape_ptr, variable.shape.size());
auto* tensor = ScopeHelper::resize_lod_tensor(scope, variable.name, ddim); auto* tensor = ScopeHelper::resize_lod_tensor(scope, variable.name, ddim);
auto* grad_tensor = ScopeHelper::resize_lod_tensor(scope, variable.gradient_name, ddim); auto* grad_tensor = ScopeHelper::resize_lod_tensor(scope, variable.gradient_name, ddim);
VLOG(5) << "fill scope variable:" << variable.name << ", " << variable.gradient_name VLOG(5) << "fill scope variable:" << variable.name << ", " << variable.gradient_name
<< ", data_buffer: " << _data_buffer + data_buffer_idx << ", data_buffer: " << data_buff + data_buffer_idx
<< ", dim: " << variable.dim * sizeof(float); << ", dim: " << variable.dim * sizeof(float);
auto* var_data = tensor->mutable_data<float>(_trainer_context->cpu_place); auto* var_data = tensor->mutable_data<float>(_trainer_context->cpu_place);
memcpy(var_data, _data_buffer + data_buffer_idx, variable.dim * sizeof(float)); memcpy(var_data, data_buff + data_buffer_idx, variable.dim * sizeof(float));
data_buffer_idx += variable.dim; data_buffer_idx += variable.dim;
} }
if (!FLAGS_feed_trainer_debug_dense_name.empty()) { if (!FLAGS_feed_trainer_debug_dense_name.empty()) {
......
...@@ -132,8 +132,8 @@ class DenseInputAccessor : public DataInputAccessor { ...@@ -132,8 +132,8 @@ class DenseInputAccessor : public DataInputAccessor {
public: public:
DenseInputAccessor() {} DenseInputAccessor() {}
virtual ~DenseInputAccessor() { virtual ~DenseInputAccessor() {
if (_data_buffer) { for (float* buffer : _data_buffer_list) {
delete[] _data_buffer; delete[] buffer;
} }
_need_async_pull = false; _need_async_pull = false;
if (_async_pull_thread) { if (_async_pull_thread) {
...@@ -141,6 +141,24 @@ public: ...@@ -141,6 +141,24 @@ public:
} }
} }
// 返回当前可用的Dense buffer
inline float* data_buffer() {
return _data_buffer_list[_current_buffer_idx];
}
inline float* backend_data_buffer() {
return _data_buffer_list[next_buffer_idx()];
}
inline void switch_data_buffer() {
_current_buffer_idx = next_buffer_idx();
}
inline size_t next_buffer_idx() {
auto buffer_idx = _current_buffer_idx + 1;
if (buffer_idx >= _data_buffer_list.size()) {
return 0;
}
return buffer_idx;
}
virtual int initialize(YAML::Node config, virtual int initialize(YAML::Node config,
std::shared_ptr<TrainerContext> context_ptr); std::shared_ptr<TrainerContext> context_ptr);
...@@ -158,11 +176,12 @@ public: ...@@ -158,11 +176,12 @@ public:
virtual int32_t collect_persistables(paddle::framework::Scope* scope); virtual int32_t collect_persistables(paddle::framework::Scope* scope);
protected: protected:
virtual int32_t pull_dense(size_t table_id); virtual int32_t pull_dense(size_t table_id);
size_t _total_dim = 0; size_t _total_dim = 0;
std::mutex _pull_mutex; std::mutex _pull_mutex;
bool _need_async_pull = false; bool _need_async_pull = false;
float* _data_buffer = nullptr; bool _is_data_buffer_init = false;
std::vector<float*> _data_buffer_list;
size_t _current_buffer_idx = 0;
std::atomic<int> _pull_request_num; std::atomic<int> _pull_request_num;
std::vector<DenseInputVariable> _x_variables; std::vector<DenseInputVariable> _x_variables;
std::shared_ptr<std::thread> _async_pull_thread; std::shared_ptr<std::thread> _async_pull_thread;
......
...@@ -109,9 +109,10 @@ int MultiThreadExecutor::initialize(YAML::Node exe_config, ...@@ -109,9 +109,10 @@ int MultiThreadExecutor::initialize(YAML::Node exe_config,
return ret; return ret;
} }
int32_t MultiThreadExecutor::save_persistables(const std::string& filename) { int32_t MultiThreadExecutor::save_persistables(const std::string& file_path) {
// auto fs = _trainer_context->file_system; auto fs = _trainer_context->file_system;
// fs->mkdir(fs->path_split(filename).first); auto file_name = fs->path_split(file_path).second;
fs->remove(file_name);
auto scope_obj = _scope_obj_pool->get(); auto scope_obj = _scope_obj_pool->get();
for (size_t i = 0; i < _input_accessors.size(); ++i) { for (size_t i = 0; i < _input_accessors.size(); ++i) {
_input_accessors[i]->collect_persistables(scope_obj.get()); _input_accessors[i]->collect_persistables(scope_obj.get());
...@@ -121,12 +122,14 @@ int32_t MultiThreadExecutor::save_persistables(const std::string& filename) { ...@@ -121,12 +122,14 @@ int32_t MultiThreadExecutor::save_persistables(const std::string& filename) {
auto* op = block->AppendOp(); auto* op = block->AppendOp();
op->SetType("save_combine"); op->SetType("save_combine");
op->SetInput("X", _persistables); op->SetInput("X", _persistables);
op->SetAttr("file_path", filename); op->SetAttr("file_path", file_name);
op->CheckAttrs(); op->CheckAttrs();
platform::CPUPlace place; platform::CPUPlace place;
framework::Executor exe(place); framework::Executor exe(place);
exe.Run(prog, scope_obj.get(), 0, true, true); exe.Run(prog, scope_obj.get(), 0, true, true);
// exe只能将模型产出在本地,这里通过cp方式兼容其他文件系统
fs->copy(file_name, file_path);
return 0; return 0;
} }
......
...@@ -23,6 +23,30 @@ std::pair<std::string, std::string> FileSystem::path_split(const std::string& pa ...@@ -23,6 +23,30 @@ std::pair<std::string, std::string> FileSystem::path_split(const std::string& pa
return {path.substr(0, pos), path.substr(pos + 1)}; return {path.substr(0, pos), path.substr(pos + 1)};
} }
int FileSystem::copy(const std::string& ori_path, const std::string& dest_path) {
if (!exists(ori_path)) {
return -1;
}
remove(dest_path);
auto ori_file = open_read(ori_path, "");
auto dest_file = open_write(dest_path, "");
size_t read_buffer_size = 102400; // 100kb
char* buffer = new char[read_buffer_size];
while (true) {
size_t read_size = fread(buffer, 1, read_buffer_size, ori_file.get());
CHECK(ferror(ori_file.get()) == 0) << " File read Failed:" << ori_path;
if (read_size > 0) {
fwrite(buffer, 1, read_size, dest_file.get());
}
// read done
if (read_size < read_buffer_size) {
break;
}
}
delete[] buffer;
return 0;
}
int FileSystem::append_line(const std::string& path, int FileSystem::append_line(const std::string& path,
const std::string& line, size_t reserve_line_num) { const std::string& line, size_t reserve_line_num) {
std::string tail_data; std::string tail_data;
...@@ -37,10 +61,12 @@ int FileSystem::append_line(const std::string& path, ...@@ -37,10 +61,12 @@ int FileSystem::append_line(const std::string& path,
VLOG(2) << "Append to file:" << path << ", line str:" << line; VLOG(2) << "Append to file:" << path << ", line str:" << line;
while (true) { while (true) {
remove(path); remove(path);
{
auto fp = open_write(path, ""); auto fp = open_write(path, "");
if (fwrite(tail_data.c_str(), tail_data.length(), 1, &*fp) == 1) { if (fwrite(tail_data.c_str(), tail_data.length(), 1, &*fp) == 1) {
break; break;
} }
}
sleep(10); sleep(10);
VLOG(0) << "Retry Append to file:" << path << ", line str:" << line; VLOG(0) << "Retry Append to file:" << path << ", line str:" << line;
} }
......
...@@ -21,6 +21,7 @@ public: ...@@ -21,6 +21,7 @@ public:
// only support text-file // only support text-file
virtual int append_line(const std::string& path, const std::string& line, size_t reserve_line_num); virtual int append_line(const std::string& path, const std::string& line, size_t reserve_line_num);
virtual int64_t file_size(const std::string& path) = 0; virtual int64_t file_size(const std::string& path) = 0;
virtual int copy(const std::string& ori_path, const std::string& dest_path);
virtual void remove(const std::string& path) = 0; virtual void remove(const std::string& path) = 0;
virtual std::vector<std::string> list(const std::string& path) = 0; virtual std::vector<std::string> list(const std::string& path) = 0;
virtual std::string tail(const std::string& path, size_t tail_num = 1) = 0; virtual std::string tail(const std::string& path, size_t tail_num = 1) = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册