提交 aaea8a39 编写于 作者: X xiexionghang

merge master

......@@ -69,6 +69,15 @@ int32_t DenseInputAccessor::pull_dense(size_t table_id) {
int32_t DenseInputAccessor::forward(SampleInstance* samples, size_t num,
paddle::framework::Scope* scope) {
collect_persistables(scope);
if (_need_async_pull) {
++_pull_request_num;
}
return 0;
}
int32_t DenseInputAccessor::collect_persistables(paddle::framework::Scope* scope) {
// 首次同步pull,之后异步pull
if (_data_buffer == nullptr) {
_pull_mutex.lock();
......@@ -94,7 +103,9 @@ int32_t DenseInputAccessor::forward(SampleInstance* samples, size_t num,
paddle::framework::DDim ddim(shape_ptr, variable.shape.size());
auto* tensor = ScopeHelper::resize_lod_tensor(scope, variable.name, ddim);
auto* grad_tensor = ScopeHelper::resize_lod_tensor(scope, variable.gradient_name, ddim);
VLOG(5) << "fill scope variable:" << variable.name << ", " << variable.gradient_name;
VLOG(5) << "fill scope variable:" << variable.name << ", " << variable.gradient_name
<< ", data_buffer: " << _data_buffer + data_buffer_idx
<< ", dim: " << variable.dim * sizeof(float);
auto* var_data = tensor->mutable_data<float>(_trainer_context->cpu_place);
memcpy(var_data, _data_buffer + data_buffer_idx, variable.dim * sizeof(float));
data_buffer_idx += variable.dim;
......@@ -107,8 +118,12 @@ int32_t DenseInputAccessor::forward(SampleInstance* samples, size_t num,
VLOG(2) << "[Debug][PullDense]" << ScopeHelper::to_string(scope, variable.name);
}
}
if (_need_async_pull) {
++_pull_request_num;
return 0;
}
int32_t DenseInputAccessor::collect_persistables_name(std::vector<std::string>& persistables) {
for (auto& variable : _x_variables) {
persistables.push_back(variable.name);
}
return 0;
}
......
......@@ -38,6 +38,12 @@ public:
// 后向,一般用于更新梯度,在训练网络执行后调用
virtual int32_t backward(SampleInstance* samples, size_t num,
::paddle::framework::Scope* scope) = 0;
// 收集持久化变量的名称, 并将值拷贝到Scope
virtual int32_t collect_persistables_name(std::vector<std::string>& persistables) {return 0;}
// 填充持久化变量的值,用于保存
virtual int32_t collect_persistables(paddle::framework::Scope* scope) {return 0;}
protected:
size_t _table_id = 0;
bool _need_gradient = false;
......@@ -144,6 +150,11 @@ public:
virtual int32_t backward(SampleInstance* samples, size_t num,
paddle::framework::Scope* scope);
virtual int32_t collect_persistables_name(std::vector<std::string>& persistables);
virtual int32_t collect_persistables(paddle::framework::Scope* scope);
protected:
virtual int32_t pull_dense(size_t table_id);
......
......@@ -30,6 +30,8 @@ dataset:
pipeline_cmd: './tool/ins_weight.py | awk -f ./tool/format_newcate_hotnews.awk'
parser:
class: AbacusTextDataParser
shuffler:
class: LocalShuffler
epoch:
epoch_class: TimelyEpochAccessor
......
......@@ -31,7 +31,7 @@ int DatasetContainer::initialize(
_data_root_paths = config["root_path"].as<std::vector<std::string>>();
_data_split_interval = config["data_spit_interval"].as<int>();
_data_path_formater = config["data_path_formater"].as<std::string>();
std::string shuffler = config["shuffler"]["name"].as<std::string>();
std::string shuffler = config["shuffler"]["class"].as<std::string>();
_shuffler.reset(CREATE_INSTANCE(Shuffler, shuffler));
_shuffler->initialize(config, context);
std::string data_reader_class = config["data_reader"].as<std::string>();
......
......@@ -2,6 +2,8 @@
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
#include "paddle/fluid/train/custom_trainer/feed/monitor/monitor.h"
#include "paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace custom_trainer {
......@@ -55,6 +57,7 @@ int MultiThreadExecutor::initialize(YAML::Node exe_config,
CHECK(_trainer_context->file_system->exists(model_config_path))
<< "miss model config file:" << model_config_path;
_model_config = YAML::LoadFile(model_config_path);
_persistables.clear();
for (const auto& accessor_config : _model_config["input_accessor"]) {
auto accessor_class = accessor_config["class"].as<std::string>();
auto* accessor_ptr = CREATE_INSTANCE(DataInputAccessor, accessor_class);
......@@ -69,7 +72,10 @@ int MultiThreadExecutor::initialize(YAML::Node exe_config,
_table_to_accessors[table_id] = {accessor_ptr};
}
}
CHECK(accessor_ptr->collect_persistables_name(_persistables) == 0)
<< "collect_persistables Failed, class:" << accessor_class;
}
// std::sort(_persistables.begin(), _persistables.end()); // 持久化变量名一定要排序
// Monitor组件
for (const auto& monitor_config : _model_config["monitor"]) {
......@@ -82,6 +88,27 @@ int MultiThreadExecutor::initialize(YAML::Node exe_config,
return ret;
}
int32_t MultiThreadExecutor::save_persistables(const std::string& filename) {
// auto fs = _trainer_context->file_system;
// fs->mkdir(fs->path_split(filename).first);
auto scope_obj = _scope_obj_pool->get();
for (size_t i = 0; i < _input_accessors.size(); ++i) {
_input_accessors[i]->collect_persistables(scope_obj.get());
}
framework::ProgramDesc prog;
auto* block = prog.MutableBlock(0);
auto* op = block->AppendOp();
op->SetType("save_combine");
op->SetInput("X", _persistables);
op->SetAttr("file_path", filename);
op->CheckAttrs();
platform::CPUPlace place;
framework::Executor exe(place);
exe.Run(prog, scope_obj.get(), 0, true, true);
return 0;
}
paddle::framework::Channel<DataItem> MultiThreadExecutor::run(
paddle::framework::Channel<DataItem> input, const DataParser* parser) {
......
......@@ -47,6 +47,8 @@ public:
virtual paddle::framework::Channel<DataItem> run(
paddle::framework::Channel<DataItem> input, const DataParser* parser);
virtual int32_t save_persistables(const std::string& filename);
virtual bool is_dump_all_model() {
return _need_dump_all_model;
}
......@@ -80,6 +82,7 @@ protected:
std::vector<std::shared_ptr<DataInputAccessor>> _input_accessors;
std::map<uint32_t, std::vector<DataInputAccessor*>> _table_to_accessors;
std::shared_ptr<paddle::ps::ObjectPool<::paddle::framework::Scope>> _scope_obj_pool;
std::vector<std::string> _persistables;
};
} // namespace feed
......
......@@ -25,6 +25,10 @@ public:
virtual bool exists(const std::string& path) = 0;
virtual void mkdir(const std::string& path) = 0;
virtual std::string path_join(const std::string& dir, const std::string& path);
template<class... STRS>
std::string path_join(const std::string& dir, const std::string& path, const STRS&... paths) {
return path_join(path_join(dir, path), paths...);
}
virtual std::pair<std::string, std::string> path_split(const std::string& path);
protected:
};
......
......@@ -27,6 +27,7 @@ int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
}
int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way) {
auto fs = _context_ptr->file_system;
auto* ps_client = _context_ptr->pslib->ps_client();
auto* environment = _context_ptr->environment.get();
auto* epoch_accessor = _context_ptr->epoch_accessor.get();
......@@ -39,18 +40,21 @@ int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way) {
paddle::platform::Timer timer;
timer.Start();
std::set<uint32_t> table_set;
auto model_dir = epoch_accessor->model_save_path(epoch_id, way);
for (auto& executor : _executors) {
const auto& table_accessors = executor->table_accessors();
for (auto& itr : table_accessors) {
table_set.insert(itr.first);
}
auto save_path = fs->path_join(model_dir, executor->train_exe_name() + "_param");
VLOG(2) << "Start save model, save_path:" << save_path;
executor->save_persistables(save_path);
}
int ret_size = 0;
auto table_num = table_set.size();
std::future<int> rets[table_num];
for (auto table_id : table_set) {
VLOG(2) << "Start save model, table_id:" << table_id;
auto model_dir = epoch_accessor->model_save_path(epoch_id, way);
rets[ret_size++] = ps_client->save(table_id, model_dir, std::to_string((int)way));
}
int all_ret = 0;
......
......@@ -124,6 +124,14 @@ class ModelBuilder:
with open(os.path.join(self._save_path, name + '.pbtxt'), 'w') as fout:
fout.write(str(program))
fluid.io.save_inference_model(self._save_path,
[var.name for var in inputs],
outputs,
executor=None,
main_program=test_program,
model_filename='inference_program',
program_only=True)
params = filter(fluid.io.is_parameter, main_program.list_vars())
vars = []
sums=[]
......
#pragma once
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.h"
......@@ -30,7 +29,7 @@ public:
return 0;
}
};
REGIST_CLASS(DataParser, LocalShuffler);
REGIST_CLASS(Shuffler, LocalShuffler);
class GlobalShuffler : public Shuffler {
public:
......@@ -109,7 +108,7 @@ private:
uint32_t _max_concurrent_num = 0;
};
REGIST_CLASS(DataParser, GlobalShuffler);
REGIST_CLASS(Shuffler, GlobalShuffler);
} // namespace feed
} // namespace custom_trainer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册