learner_process.cc 13.5 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5
/*
 *Author: xiexionghang
 *Train样本
 */
#include <omp.h>
6
#include "paddle/fluid/platform/timer.h"
X
xiexionghang 已提交
7
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
X
xiexionghang 已提交
8
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset.h"
X
xiexionghang 已提交
9 10 11 12 13 14 15 16 17 18
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/process/learner_process.h"

namespace paddle {
namespace custom_trainer {
namespace feed {

int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
    int ret = Process::initialize(context_ptr);
    auto& config = _context_ptr->trainer_config;
X
xiexionghang 已提交
19 20
    _is_dump_cache_model = config["dump_cache_model"].as<bool>(false);
    _cache_load_converter = config["load_cache_converter"].as<std::string>("");
X
xiexionghang 已提交
21
    _startup_dump_inference_base = config["startup_dump_inference_base"].as<bool>(false);
X
xiexionghang 已提交
22
    if (config["executor"]) {
X
xiexionghang 已提交
23 24 25 26
        _executors.resize(config["executor"].size());
        for (size_t i = 0; i < _executors.size(); ++i) {
            _executors[i].reset(new MultiThreadExecutor());
            CHECK(_executors[i]->initialize(config["executor"][i], context_ptr) == 0);
X
xiexionghang 已提交
27 28 29 30 31
        }
    }
    return 0;
}

X
xiexionghang 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
// 更新各节点存储的CacheModel
int LearnerProcess::update_cache_model(uint64_t epoch_id, ModelSaveWay way) {
    auto fs = _context_ptr->file_system;
    auto* ps_client = _context_ptr->pslib->ps_client();
    auto* environment = _context_ptr->environment.get();
    auto* epoch_accessor = _context_ptr->epoch_accessor.get();
    if (!epoch_accessor->need_save_model(epoch_id, way)) {
        return 0;
    }
    auto* ps_param = _context_ptr->pslib->get_param();
    if (_is_dump_cache_model && way == ModelSaveWay::ModelSaveInferenceBase) {
        auto model_dir = epoch_accessor->model_save_path(epoch_id, way);
        auto& table_param = ps_param->server_param().downpour_server_param().downpour_table_param();
        for (auto& param : table_param) {
            if (param.type() != paddle::PS_SPARSE_TABLE) {
                continue;
            }
            auto cache_model_path = fs->path_join(
                model_dir, string::format_string("%03d_cache/", param.table_id()));
            if (!fs->exists(cache_model_path)) {
                continue;
            }
            auto& cache_dict = *(_context_ptr->cache_dict.get());
            cache_dict.clear();
            cache_dict.reserve(_cache_sign_max_num);
            auto cache_file_list = fs->list(fs->path_join(cache_model_path, "part*"));
            for (auto& cache_path : cache_file_list) {
                auto cache_file = fs->open_read(cache_path, _cache_load_converter);
                char *buffer = nullptr;
                size_t buffer_size = 0;
                ssize_t line_len = 0;
                while ((line_len = getline(&buffer, &buffer_size, cache_file.get())) != -1) {
                    if (line_len <= 1) {
                        continue;
                    }
X
xiexionghang 已提交
67
                    char* data_ptr = NULL;
X
xiexionghang 已提交
68 69 70 71 72 73 74 75 76 77 78
                    cache_dict.append(strtoul(buffer, &data_ptr, 10));
                }
                if (buffer != nullptr) {
                    free(buffer);
                } 
            }
            break;
        }
    }
    return 0;
}
X
xiexionghang 已提交
79
int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is_force_dump) {
X
xiexionghang 已提交
80
    ContextStatusGurad status_guard(_context_ptr, TrainerStatus::Saving);
R
rensilin 已提交
81
    auto fs = _context_ptr->file_system;
82
    auto* ps_client = _context_ptr->pslib->ps_client();
X
xiexionghang 已提交
83
    auto* environment = _context_ptr->environment.get();
84
    auto* epoch_accessor = _context_ptr->epoch_accessor.get();
X
xiexionghang 已提交
85
    if (!environment->is_master_node(EnvironmentRole::WORKER)) {
X
xiexionghang 已提交
86 87
        return 0;
    }
X
xiexionghang 已提交
88
    if (!is_force_dump && !epoch_accessor->need_save_model(epoch_id, way)) {
89 90 91 92
        return 0;
    }
    paddle::platform::Timer timer;
    timer.Start();
X
xiexionghang 已提交
93
    std::set<uint32_t> table_set;
R
rensilin 已提交
94
    auto model_dir = epoch_accessor->model_save_path(epoch_id, way);
X
xiexionghang 已提交
95 96 97 98 99
    for (auto& executor : _executors) {
        const auto& table_accessors = executor->table_accessors();
        for (auto& itr : table_accessors) {
            table_set.insert(itr.first);
        }
R
rensilin 已提交
100 101 102
        auto save_path = fs->path_join(model_dir, executor->train_exe_name() + "_param");
        VLOG(2) << "Start save model, save_path:" << save_path;
        executor->save_persistables(save_path);
X
xiexionghang 已提交
103
    }
X
xiexionghang 已提交
104
    int ret_size = 0;
X
xiexionghang 已提交
105
    auto table_num = table_set.size();
X
xiexionghang 已提交
106
    std::future<int> rets[table_num];
X
xiexionghang 已提交
107
    for (auto table_id : table_set) {
108 109
        VLOG(2) << "Start save model, table_id:" << table_id;
        rets[ret_size++] = ps_client->save(table_id, model_dir, std::to_string((int)way));
X
xiexionghang 已提交
110 111 112 113 114 115
    }
    int all_ret = 0;
    for (int i = 0; i < ret_size; ++i) {
        rets[i].wait();
        all_ret |= rets[i].get();
    }
116 117
    timer.Pause();
    VLOG(2) << "Save Model Cost(s):" << timer.ElapsedSec();
X
xiexionghang 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
    
    // save cache model, 只有inference需要cache_model
    auto* ps_param = _context_ptr->pslib->get_param();
    if (_is_dump_cache_model && (way == ModelSaveWay::ModelSaveInferenceBase ||
        way == ModelSaveWay::ModelSaveInferenceDelta)) {
        auto& table_param = ps_param->server_param().downpour_server_param().downpour_table_param();
        for (auto& param : table_param) {
            if (param.type() != paddle::PS_SPARSE_TABLE) {
                continue;
            } 
            double cache_threshold = 0.0;
            auto status = ps_client->get_cache_threshold(param.table_id(), cache_threshold);
            CHECK(status.get() == 0) << "CacheThreshold Get failed!";
            status = ps_client->cache_shuffle(param.table_id(), model_dir, std::to_string((int)way),
                std::to_string(cache_threshold));
            CHECK(status.get() == 0) << "Cache Shuffler Failed";
            status = ps_client->save_cache(param.table_id(), model_dir, std::to_string((int)way));
            auto feature_size = status.get();
            CHECK(feature_size >= 0) << "Cache Save Failed";
            auto cache_model_path = fs->path_join(
                model_dir, string::format_string("%03d_cache/sparse_cache.meta", param.table_id()));
            auto cache_meta_file = fs->open_write(cache_model_path, "");
            auto meta = string::format_string("file_prefix:part\npart_num:%d\nkey_num:%d\n", 
                param.sparse_table_cache_file_num(), feature_size);
            CHECK(fwrite(meta.c_str(), meta.size(), 1, cache_meta_file.get()) == 1) << "Cache Meta Failed";
            if (feature_size > _cache_sign_max_num) {
                _cache_sign_max_num = feature_size;
            }
        }
    }
148
    _context_ptr->epoch_accessor->update_model_donefile(epoch_id, way);
X
xiexionghang 已提交
149

X
xiexionghang 已提交
150 151 152
    return all_ret;
}

X
xiexionghang 已提交
153 154 155 156 157
int LearnerProcess::load_model(uint64_t epoch_id) {
    auto* environment = _context_ptr->environment.get();
    if (!environment->is_master_node(EnvironmentRole::WORKER)) {
        return 0;
    }
X
xiexionghang 已提交
158
    auto* fs = _context_ptr->file_system.get();
X
xiexionghang 已提交
159 160
    std::set<uint32_t> loaded_table_set;
    auto model_dir = _context_ptr->epoch_accessor->checkpoint_path();
L
linan17 已提交
161 162
    paddle::platform::Timer timer;
    timer.Start();
X
xiexionghang 已提交
163 164 165 166 167 168
    for (auto& executor : _executors) {
        const auto& table_accessors = executor->table_accessors();
        for (auto& itr : table_accessors) {
            if (loaded_table_set.count(itr.first)) {
                continue;
            }
X
xiexionghang 已提交
169
            auto table_model_path = fs->path_join(
X
xiexionghang 已提交
170
                model_dir, string::format_string("%03d", itr.first));
X
xiexionghang 已提交
171
            if ((!fs->exists(table_model_path)) || fs->list(table_model_path).size() == 0) {
X
xiexionghang 已提交
172 173 174 175
                VLOG(2) << "miss table_model:" << table_model_path << ", initialize by default";
                auto scope = std::move(executor->fetch_scope());
                CHECK(itr.second[0]->create(scope.get()) == 0);
            } else {
L
linan17 已提交
176
                ENVLOG_WORKER_MASTER_NOTICE("Loading model %s", model_dir.c_str());
X
xiexionghang 已提交
177 178 179 180 181 182 183
                auto status = _context_ptr->ps_client()->load(itr.first, 
                    model_dir, std::to_string((int)ModelSaveWay::ModelSaveTrainCheckpoint));
                CHECK(status.get() == 0) << "table load failed, id:" << itr.first;
            }
            loaded_table_set.insert(itr.first);
        }
    }
L
linan17 已提交
184 185
    timer.Pause();
    ENVLOG_WORKER_MASTER_NOTICE("Finished loading model, cost:%f", timer.ElapsedSec());
X
xiexionghang 已提交
186 187 188
    return 0;
}

X
xiexionghang 已提交
189
int LearnerProcess::run() {
X
xiexionghang 已提交
190
    auto* dataset = _context_ptr->dataset.get();
X
xiexionghang 已提交
191 192
    auto* environment = _context_ptr->environment.get();
    auto* epoch_accessor = _context_ptr->epoch_accessor.get(); 
X
xiexionghang 已提交
193
    uint64_t epoch_id = epoch_accessor->current_epoch_id();
X
xiexionghang 已提交
194

L
linan17 已提交
195
    ENVLOG_WORKER_MASTER_NOTICE("Resume train with epoch_id:%d %s", epoch_id, _context_ptr->epoch_accessor->text(epoch_id).c_str());
X
xiexionghang 已提交
196 197 198 199
    //尝试加载模型 or 初始化
    CHECK(load_model(epoch_id) == 0);
    environment->barrier(EnvironmentRole::WORKER); 

X
xiexionghang 已提交
200 201
    //判断是否先dump出base TODO
    wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase, _startup_dump_inference_base);
X
xiexionghang 已提交
202
    environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
203 204 205
    
    while (true) {
        epoch_accessor->next_epoch();
206
        _context_ptr->monitor_ssm.str(""); 
X
xiexionghang 已提交
207
        bool already_dump_inference_model = false;
X
xiexionghang 已提交
208
        epoch_id = epoch_accessor->current_epoch_id();
X
xiexionghang 已提交
209
        std::string epoch_log_title = paddle::string::format_string(
X
xiexionghang 已提交
210
            "train epoch_id:%d label:%s", epoch_id, epoch_accessor->text(epoch_id).c_str());
X
xiexionghang 已提交
211
        std::string data_path = paddle::string::to_string<std::string>(dataset->epoch_data_path(epoch_id));
L
linan17 已提交
212
        ENVLOG_WORKER_MASTER_NOTICE("    ==== begin %s ====", epoch_accessor->text(epoch_id).c_str());
X
xiexionghang 已提交
213
        //Step1. 等待样本ready
X
xiexionghang 已提交
214
        {
L
linan17 已提交
215
            ENVLOG_WORKER_MASTER_NOTICE("      %s, wait data ready:%s", epoch_log_title.c_str(), data_path.c_str());
X
xiexionghang 已提交
216 217 218
            while (dataset->epoch_data_status(epoch_id) != DatasetStatus::Ready) {
                sleep(30);  
                dataset->pre_detect_data(epoch_id);
L
linan17 已提交
219
                ENVLOG_WORKER_MASTER_NOTICE("      epoch_id:%d data not ready, wait 30s", epoch_id);
X
xiexionghang 已提交
220
            } 
L
linan17 已提交
221
            ENVLOG_WORKER_MASTER_NOTICE("      Start %s, data is ready", epoch_log_title.c_str());
X
xiexionghang 已提交
222 223 224
            environment->barrier(EnvironmentRole::WORKER); 
        }
    
X
xiexionghang 已提交
225
        //Step2. 运行训练网络
X
xiexionghang 已提交
226
        {
X
xiexionghang 已提交
227
            ContextStatusGurad status_guard(_context_ptr, TrainerStatus::Training);
X
xiexionghang 已提交
228 229
            std::map<std::string, paddle::framework::Channel<DataItem>> backup_input_map;
            for (auto& executor : _executors) {
X
xiexionghang 已提交
230
                environment->barrier(EnvironmentRole::WORKER); 
231 232
                paddle::platform::Timer timer;
                timer.Start();
L
linan17 已提交
233
                ENVLOG_WORKER_MASTER_NOTICE("Start executor:%s", executor->train_exe_name().c_str());
X
xiexionghang 已提交
234 235 236 237 238 239 240 241
                auto data_name = executor->train_data_name();
                paddle::framework::Channel<DataItem> input_channel;
                if (backup_input_map.count(data_name)) {
                    input_channel = backup_input_map[data_name];
                } else {
                    input_channel = dataset->fetch_data(data_name, epoch_id);
                }
                input_channel = executor->run(input_channel, dataset->data_parser(data_name));
242
                timer.Pause();
L
linan17 已提交
243
                ENVLOG_WORKER_MASTER_NOTICE("End executor:%s, cost:%f", executor->train_exe_name().c_str(), timer.ElapsedSec());
X
xiexionghang 已提交
244 245 246 247

                // 等待异步梯度完成
                _context_ptr->ps_client()->flush();
                environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
248
                if (executor->is_dump_all_model()) {
X
xiexionghang 已提交
249 250 251
                    already_dump_inference_model = true;
                    wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta);
                }
X
xiexionghang 已提交
252
                backup_input_map[data_name] = input_channel;
X
xiexionghang 已提交
253
                environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
254
            }
X
xiexionghang 已提交
255
        }
X
xiexionghang 已提交
256

X
xiexionghang 已提交
257 258
        //Step3. Dump Model For Delta&&Checkpoint
        {
259
            wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);
X
xiexionghang 已提交
260
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
261 262 263
            update_cache_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);
            environment->barrier(EnvironmentRole::WORKER); 

L
linan17 已提交
264 265 266 267 268
            if (epoch_accessor->is_last_epoch(epoch_id)) {
                wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpointBase);
            } else {
                wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpoint);
            }
X
xiexionghang 已提交
269
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
270 271 272 273
            if (epoch_accessor->is_last_epoch(epoch_id) &&
                environment->is_master_node(EnvironmentRole::WORKER)) {
                paddle::platform::Timer timer;
                timer.Start();
L
linan17 已提交
274
                ENVLOG_WORKER_MASTER_NOTICE("Start shrink table");
X
xiexionghang 已提交
275 276 277 278 279 280
                for (auto& executor : _executors) {
                    const auto& table_accessors = executor->table_accessors();
                    for (auto& itr : table_accessors) {
                        CHECK(itr.second[0]->shrink() == 0);
                    }
                } 
L
linan17 已提交
281 282
                timer.Pause();
                ENVLOG_WORKER_MASTER_NOTICE("End shrink table, cost:%f", timer.ElapsedSec());
X
xiexionghang 已提交
283 284
            }
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
285

X
xiexionghang 已提交
286 287
            epoch_accessor->epoch_done(epoch_id);
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
288
        }
L
linan17 已提交
289
        ENVLOG_WORKER_MASTER_NOTICE("    ==== end %s ====", epoch_accessor->text(epoch_id).c_str());
X
xiexionghang 已提交
290 291 292 293 294 295 296 297 298 299
        //Step4. Output Monitor && RunStatus
        //TODO
    }
    
    return 0;
}

}  // namespace feed
}  // namespace custom_trainer
}  // namespace paddle