提交 c00354af 编写于 作者: X xiexionghang

for async push_gradient

上级 b8cf64ab
......@@ -31,6 +31,10 @@ int DenseInputAccessor::initialize(YAML::Node config,
if (config["async_pull"] && config["async_pull"].as<bool>()) {
_need_async_pull = true;
}
_data_buffer_list.resize(6); // 6 buffer顺序循环使用, 降低更新时的写冲突
for (auto*& buffer : _data_buffer_list) {
buffer = new float[_total_dim];
}
return 0;
}
......@@ -52,11 +56,8 @@ int32_t DenseInputAccessor::create(::paddle::framework::Scope* scope) {
// rpc拉取数据,需保证单线程运行
int32_t DenseInputAccessor::pull_dense(size_t table_id) {
float* data_buffer = _data_buffer;
if (data_buffer == NULL) {
data_buffer = new float[_total_dim];
}
size_t data_buffer_idx = 0;
float* data_buffer = backend_data_buffer();
std::vector<paddle::ps::Region> regions;
for (auto& variable : _x_variables) {
regions.emplace_back(data_buffer + data_buffer_idx, variable.dim);
......@@ -66,7 +67,8 @@ int32_t DenseInputAccessor::pull_dense(size_t table_id) {
auto push_status = ps_client->pull_dense(regions.data(), regions.size(), table_id);
int32_t ret = push_status.get();
// TODO 使用双buffer DataBuffer,避免训练期改写,当前异步SGD下,问题不大
_data_buffer = data_buffer;
switch_data_buffer();
_is_data_buffer_init = true;
return ret;
}
......@@ -82,9 +84,9 @@ int32_t DenseInputAccessor::forward(SampleInstance* samples, size_t num,
int32_t DenseInputAccessor::collect_persistables(paddle::framework::Scope* scope) {
// 首次同步pull,之后异步pull
if (_data_buffer == nullptr) {
if (!_is_data_buffer_init) {
_pull_mutex.lock();
if (_data_buffer == nullptr) {
if (!_is_data_buffer_init) {
CHECK(pull_dense(_table_id) == 0);
_async_pull_thread = std::make_shared<std::thread>(
[this]() {
......@@ -101,16 +103,17 @@ int32_t DenseInputAccessor::collect_persistables(paddle::framework::Scope* scope
_pull_mutex.unlock();
}
size_t data_buffer_idx = 0;
auto* data_buff = data_buffer();
for (auto& variable : _x_variables) {
auto* shape_ptr = &(variable.shape[0]);
paddle::framework::DDim ddim(shape_ptr, variable.shape.size());
auto* tensor = ScopeHelper::resize_lod_tensor(scope, variable.name, ddim);
auto* grad_tensor = ScopeHelper::resize_lod_tensor(scope, variable.gradient_name, ddim);
VLOG(5) << "fill scope variable:" << variable.name << ", " << variable.gradient_name
<< ", data_buffer: " << _data_buffer + data_buffer_idx
<< ", data_buffer: " << data_buff + data_buffer_idx
<< ", dim: " << variable.dim * sizeof(float);
auto* var_data = tensor->mutable_data<float>(_trainer_context->cpu_place);
memcpy(var_data, _data_buffer + data_buffer_idx, variable.dim * sizeof(float));
memcpy(var_data, data_buff + data_buffer_idx, variable.dim * sizeof(float));
data_buffer_idx += variable.dim;
}
if (!FLAGS_feed_trainer_debug_dense_name.empty()) {
......
......@@ -132,14 +132,32 @@ class DenseInputAccessor : public DataInputAccessor {
public:
DenseInputAccessor() {}
virtual ~DenseInputAccessor() {
if (_data_buffer) {
delete[] _data_buffer;
for (float* buffer : _data_buffer_list) {
delete[] buffer;
}
_need_async_pull = false;
if (_async_pull_thread) {
_async_pull_thread->join();
}
}
// 返回当前可用的Dense buffer
inline float* data_buffer() {
return _data_buffer_list[_current_buffer_idx];
}
inline float* backend_data_buffer() {
return _data_buffer_list[next_buffer_idx()];
}
inline void switch_data_buffer() {
_current_buffer_idx = next_buffer_idx();
}
inline size_t next_buffer_idx() {
auto buffer_idx = _current_buffer_idx + 1;
if (buffer_idx >= _data_buffer_list.size()) {
return 0;
}
return buffer_idx;
}
virtual int initialize(YAML::Node config,
std::shared_ptr<TrainerContext> context_ptr);
......@@ -158,11 +176,12 @@ public:
virtual int32_t collect_persistables(paddle::framework::Scope* scope);
protected:
virtual int32_t pull_dense(size_t table_id);
size_t _total_dim = 0;
std::mutex _pull_mutex;
bool _need_async_pull = false;
float* _data_buffer = nullptr;
bool _is_data_buffer_init = false;
std::vector<float*> _data_buffer_list;
size_t _current_buffer_idx = 0;
std::atomic<int> _pull_request_num;
std::vector<DenseInputVariable> _x_variables;
std::shared_ptr<std::thread> _async_pull_thread;
......
......@@ -109,9 +109,10 @@ int MultiThreadExecutor::initialize(YAML::Node exe_config,
return ret;
}
int32_t MultiThreadExecutor::save_persistables(const std::string& filename) {
// auto fs = _trainer_context->file_system;
// fs->mkdir(fs->path_split(filename).first);
int32_t MultiThreadExecutor::save_persistables(const std::string& file_path) {
auto fs = _trainer_context->file_system;
auto file_name = fs->path_split(file_path).second;
fs->remove(file_name);
auto scope_obj = _scope_obj_pool->get();
for (size_t i = 0; i < _input_accessors.size(); ++i) {
_input_accessors[i]->collect_persistables(scope_obj.get());
......@@ -121,12 +122,14 @@ int32_t MultiThreadExecutor::save_persistables(const std::string& filename) {
auto* op = block->AppendOp();
op->SetType("save_combine");
op->SetInput("X", _persistables);
op->SetAttr("file_path", filename);
op->SetAttr("file_path", file_name);
op->CheckAttrs();
platform::CPUPlace place;
framework::Executor exe(place);
exe.Run(prog, scope_obj.get(), 0, true, true);
// exe只能将模型产出在本地,这里通过cp方式兼容其他文件系统
fs->copy(file_name, file_path);
return 0;
}
......
......@@ -23,6 +23,30 @@ std::pair<std::string, std::string> FileSystem::path_split(const std::string& pa
return {path.substr(0, pos), path.substr(pos + 1)};
}
int FileSystem::copy(const std::string& ori_path, const std::string& dest_path) {
if (!exists(ori_path)) {
return -1;
}
remove(dest_path);
auto ori_file = open_read(ori_path, "");
auto dest_file = open_write(dest_path, "");
size_t read_buffer_size = 102400; // 100kb
char* buffer = new char[read_buffer_size];
while (true) {
size_t read_size = fread(buffer, 1, read_buffer_size, ori_file.get());
CHECK(ferror(ori_file.get()) == 0) << " File read Failed:" << ori_path;
if (read_size > 0) {
fwrite(buffer, 1, read_size, dest_file.get());
}
// read done
if (read_size < read_buffer_size) {
break;
}
}
delete[] buffer;
return 0;
}
int FileSystem::append_line(const std::string& path,
const std::string& line, size_t reserve_line_num) {
std::string tail_data;
......@@ -37,9 +61,11 @@ int FileSystem::append_line(const std::string& path,
VLOG(2) << "Append to file:" << path << ", line str:" << line;
while (true) {
remove(path);
auto fp = open_write(path, "");
if (fwrite(tail_data.c_str(), tail_data.length(), 1, &*fp) == 1) {
break;
{
auto fp = open_write(path, "");
if (fwrite(tail_data.c_str(), tail_data.length(), 1, &*fp) == 1) {
break;
}
}
sleep(10);
VLOG(0) << "Retry Append to file:" << path << ", line str:" << line;
......
......@@ -21,6 +21,7 @@ public:
// only support text-file
virtual int append_line(const std::string& path, const std::string& line, size_t reserve_line_num);
virtual int64_t file_size(const std::string& path) = 0;
virtual int copy(const std::string& ori_path, const std::string& dest_path);
virtual void remove(const std::string& path) = 0;
virtual std::vector<std::string> list(const std::string& path) = 0;
virtual std::string tail(const std::string& path, size_t tail_num = 1) = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册