learner_process.cc 13.6 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5
/*
 *Author: xiexionghang
 *Train样本
 */
#include <omp.h>
6
#include "paddle/fluid/platform/timer.h"
X
xiexionghang 已提交
7
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
X
xiexionghang 已提交
8
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset.h"
X
xiexionghang 已提交
9 10 11 12 13 14 15 16 17 18
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/process/learner_process.h"

namespace paddle {
namespace custom_trainer {
namespace feed {

int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
    int ret = Process::initialize(context_ptr);
    auto& config = _context_ptr->trainer_config;
X
xiexionghang 已提交
19 20
    _is_dump_cache_model = config["dump_cache_model"].as<bool>(false);
    _cache_load_converter = config["load_cache_converter"].as<std::string>("");
X
xiexionghang 已提交
21
    _startup_dump_inference_base = config["startup_dump_inference_base"].as<bool>(false);
X
xiexionghang 已提交
22
    if (config["executor"]) {
X
xiexionghang 已提交
23 24 25 26
        _executors.resize(config["executor"].size());
        for (size_t i = 0; i < _executors.size(); ++i) {
            _executors[i].reset(new MultiThreadExecutor());
            CHECK(_executors[i]->initialize(config["executor"][i], context_ptr) == 0);
X
xiexionghang 已提交
27 28 29 30 31
        }
    }
    return 0;
}

X
xiexionghang 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
// 更新各节点存储的CacheModel
int LearnerProcess::update_cache_model(uint64_t epoch_id, ModelSaveWay way) {
    auto fs = _context_ptr->file_system;
    auto* ps_client = _context_ptr->pslib->ps_client();
    auto* environment = _context_ptr->environment.get();
    auto* epoch_accessor = _context_ptr->epoch_accessor.get();
    if (!epoch_accessor->need_save_model(epoch_id, way)) {
        return 0;
    }
    auto* ps_param = _context_ptr->pslib->get_param();
    if (_is_dump_cache_model && way == ModelSaveWay::ModelSaveInferenceBase) {
        auto model_dir = epoch_accessor->model_save_path(epoch_id, way);
        auto& table_param = ps_param->server_param().downpour_server_param().downpour_table_param();
        for (auto& param : table_param) {
            if (param.type() != paddle::PS_SPARSE_TABLE) {
                continue;
            }
            auto cache_model_path = fs->path_join(
                model_dir, string::format_string("%03d_cache/", param.table_id()));
            if (!fs->exists(cache_model_path)) {
                continue;
            }
            auto& cache_dict = *(_context_ptr->cache_dict.get());
            cache_dict.clear();
            cache_dict.reserve(_cache_sign_max_num);
            auto cache_file_list = fs->list(fs->path_join(cache_model_path, "part*"));
            for (auto& cache_path : cache_file_list) {
                auto cache_file = fs->open_read(cache_path, _cache_load_converter);
                char *buffer = nullptr;
                size_t buffer_size = 0;
                ssize_t line_len = 0;
                while ((line_len = getline(&buffer, &buffer_size, cache_file.get())) != -1) {
                    if (line_len <= 1) {
                        continue;
                    }
X
xiexionghang 已提交
67
                    char* data_ptr = NULL;
X
xiexionghang 已提交
68 69 70 71 72 73 74 75 76 77 78
                    cache_dict.append(strtoul(buffer, &data_ptr, 10));
                }
                if (buffer != nullptr) {
                    free(buffer);
                } 
            }
            break;
        }
    }
    return 0;
}
X
xiexionghang 已提交
79
int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is_force_dump) {
X
xiexionghang 已提交
80
    ContextStatusGurad status_guard(_context_ptr, TrainerStatus::Saving);
R
rensilin 已提交
81
    auto fs = _context_ptr->file_system;
82
    auto* ps_client = _context_ptr->pslib->ps_client();
X
xiexionghang 已提交
83
    auto* environment = _context_ptr->environment.get();
84
    auto* epoch_accessor = _context_ptr->epoch_accessor.get();
X
xiexionghang 已提交
85
    if (!environment->is_master_node(EnvironmentRole::WORKER)) {
X
xiexionghang 已提交
86 87
        return 0;
    }
X
xiexionghang 已提交
88
    if (!is_force_dump && !epoch_accessor->need_save_model(epoch_id, way)) {
89 90 91 92
        return 0;
    }
    paddle::platform::Timer timer;
    timer.Start();
X
xiexionghang 已提交
93
    std::set<uint32_t> table_set;
R
rensilin 已提交
94
    auto model_dir = epoch_accessor->model_save_path(epoch_id, way);
X
xiexionghang 已提交
95 96 97 98 99
    for (auto& executor : _executors) {
        const auto& table_accessors = executor->table_accessors();
        for (auto& itr : table_accessors) {
            table_set.insert(itr.first);
        }
R
rensilin 已提交
100 101 102
        auto save_path = fs->path_join(model_dir, executor->train_exe_name() + "_param");
        VLOG(2) << "Start save model, save_path:" << save_path;
        executor->save_persistables(save_path);
X
xiexionghang 已提交
103
    }
X
xiexionghang 已提交
104
    int ret_size = 0;
X
xiexionghang 已提交
105
    auto table_num = table_set.size();
X
xiexionghang 已提交
106
    std::future<int> rets[table_num];
X
xiexionghang 已提交
107
    for (auto table_id : table_set) {
108 109
        VLOG(2) << "Start save model, table_id:" << table_id;
        rets[ret_size++] = ps_client->save(table_id, model_dir, std::to_string((int)way));
X
xiexionghang 已提交
110 111 112 113 114 115
    }
    int all_ret = 0;
    for (int i = 0; i < ret_size; ++i) {
        rets[i].wait();
        all_ret |= rets[i].get();
    }
116 117
    timer.Pause();
    VLOG(2) << "Save Model Cost(s):" << timer.ElapsedSec();
X
xiexionghang 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
    
    // save cache model, 只有inference需要cache_model
    auto* ps_param = _context_ptr->pslib->get_param();
    if (_is_dump_cache_model && (way == ModelSaveWay::ModelSaveInferenceBase ||
        way == ModelSaveWay::ModelSaveInferenceDelta)) {
        auto& table_param = ps_param->server_param().downpour_server_param().downpour_table_param();
        for (auto& param : table_param) {
            if (param.type() != paddle::PS_SPARSE_TABLE) {
                continue;
            } 
            double cache_threshold = 0.0;
            auto status = ps_client->get_cache_threshold(param.table_id(), cache_threshold);
            CHECK(status.get() == 0) << "CacheThreshold Get failed!";
            status = ps_client->cache_shuffle(param.table_id(), model_dir, std::to_string((int)way),
                std::to_string(cache_threshold));
            CHECK(status.get() == 0) << "Cache Shuffler Failed";
            status = ps_client->save_cache(param.table_id(), model_dir, std::to_string((int)way));
            auto feature_size = status.get();
            CHECK(feature_size >= 0) << "Cache Save Failed";
            auto cache_model_path = fs->path_join(
                model_dir, string::format_string("%03d_cache/sparse_cache.meta", param.table_id()));
            auto cache_meta_file = fs->open_write(cache_model_path, "");
            auto meta = string::format_string("file_prefix:part\npart_num:%d\nkey_num:%d\n", 
                param.sparse_table_cache_file_num(), feature_size);
            CHECK(fwrite(meta.c_str(), meta.size(), 1, cache_meta_file.get()) == 1) << "Cache Meta Failed";
            if (feature_size > _cache_sign_max_num) {
                _cache_sign_max_num = feature_size;
            }
        }
    }
148
    _context_ptr->epoch_accessor->update_model_donefile(epoch_id, way);
X
xiexionghang 已提交
149

X
xiexionghang 已提交
150 151 152
    return all_ret;
}

X
xiexionghang 已提交
153 154 155 156 157
int LearnerProcess::load_model(uint64_t epoch_id) {
    auto* environment = _context_ptr->environment.get();
    if (!environment->is_master_node(EnvironmentRole::WORKER)) {
        return 0;
    }
X
xiexionghang 已提交
158
    auto* fs = _context_ptr->file_system.get();
X
xiexionghang 已提交
159 160
    std::set<uint32_t> loaded_table_set;
    auto model_dir = _context_ptr->epoch_accessor->checkpoint_path();
L
linan17 已提交
161 162
    paddle::platform::Timer timer;
    timer.Start();
X
xiexionghang 已提交
163 164 165 166 167 168
    for (auto& executor : _executors) {
        const auto& table_accessors = executor->table_accessors();
        for (auto& itr : table_accessors) {
            if (loaded_table_set.count(itr.first)) {
                continue;
            }
X
xiexionghang 已提交
169
            auto table_model_path = fs->path_join(
X
xiexionghang 已提交
170
                model_dir, string::format_string("%03d", itr.first));
X
xiexionghang 已提交
171
            if ((!fs->exists(table_model_path)) || fs->list(table_model_path).size() == 0) {
X
xiexionghang 已提交
172 173 174 175
                VLOG(2) << "miss table_model:" << table_model_path << ", initialize by default";
                auto scope = std::move(executor->fetch_scope());
                CHECK(itr.second[0]->create(scope.get()) == 0);
            } else {
L
linan17 已提交
176
                ENVLOG_WORKER_MASTER_NOTICE("Loading model %s", model_dir.c_str());
X
xiexionghang 已提交
177 178 179 180 181 182 183
                auto status = _context_ptr->ps_client()->load(itr.first, 
                    model_dir, std::to_string((int)ModelSaveWay::ModelSaveTrainCheckpoint));
                CHECK(status.get() == 0) << "table load failed, id:" << itr.first;
            }
            loaded_table_set.insert(itr.first);
        }
    }
L
linan17 已提交
184 185
    timer.Pause();
    ENVLOG_WORKER_MASTER_NOTICE("Finished loading model, cost:%f", timer.ElapsedSec());
X
xiexionghang 已提交
186 187 188
    return 0;
}

X
xiexionghang 已提交
189
int LearnerProcess::run() {
X
xiexionghang 已提交
190
    auto* dataset = _context_ptr->dataset.get();
X
xiexionghang 已提交
191 192
    auto* environment = _context_ptr->environment.get();
    auto* epoch_accessor = _context_ptr->epoch_accessor.get(); 
X
xiexionghang 已提交
193
    uint64_t epoch_id = epoch_accessor->current_epoch_id();
X
xiexionghang 已提交
194

L
linan17 已提交
195
    ENVLOG_WORKER_MASTER_NOTICE("Resume train with epoch_id:%d %s", epoch_id, _context_ptr->epoch_accessor->text(epoch_id).c_str());
X
xiexionghang 已提交
196 197 198 199
    //尝试加载模型 or 初始化
    CHECK(load_model(epoch_id) == 0);
    environment->barrier(EnvironmentRole::WORKER); 

X
xiexionghang 已提交
200
    //判断是否先dump出base TODO
X
xiexionghang 已提交
201 202 203 204 205
    if (_startup_dump_inference_base) {
        wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase, _startup_dump_inference_base);
        environment->barrier(EnvironmentRole::WORKER); 
    }

X
xiexionghang 已提交
206 207
    while (true) {
        epoch_accessor->next_epoch();
208
        _context_ptr->monitor_ssm.str(""); 
X
xiexionghang 已提交
209
        bool already_dump_inference_model = false;
X
xiexionghang 已提交
210
        epoch_id = epoch_accessor->current_epoch_id();
X
xiexionghang 已提交
211
        std::string epoch_log_title = paddle::string::format_string(
X
xiexionghang 已提交
212
            "train epoch_id:%d label:%s", epoch_id, epoch_accessor->text(epoch_id).c_str());
X
xiexionghang 已提交
213
        std::string data_path = paddle::string::to_string<std::string>(dataset->epoch_data_path(epoch_id));
L
linan17 已提交
214
        ENVLOG_WORKER_MASTER_NOTICE("    ==== begin %s ====", epoch_accessor->text(epoch_id).c_str());
X
xiexionghang 已提交
215
        //Step1. 等待样本ready
X
xiexionghang 已提交
216
        {
L
linan17 已提交
217
            ENVLOG_WORKER_MASTER_NOTICE("      %s, wait data ready:%s", epoch_log_title.c_str(), data_path.c_str());
X
xiexionghang 已提交
218
            dataset->pre_detect_data(epoch_id);
X
xiexionghang 已提交
219 220 221
            while (dataset->epoch_data_status(epoch_id) != DatasetStatus::Ready) {
                sleep(30);  
                dataset->pre_detect_data(epoch_id);
L
linan17 已提交
222
                ENVLOG_WORKER_MASTER_NOTICE("      epoch_id:%d data not ready, wait 30s", epoch_id);
X
xiexionghang 已提交
223
            } 
L
linan17 已提交
224
            ENVLOG_WORKER_MASTER_NOTICE("      Start %s, data is ready", epoch_log_title.c_str());
X
xiexionghang 已提交
225 226 227
            environment->barrier(EnvironmentRole::WORKER); 
        }
    
X
xiexionghang 已提交
228
        //Step2. 运行训练网络
X
xiexionghang 已提交
229
        {
X
xiexionghang 已提交
230
            ContextStatusGurad status_guard(_context_ptr, TrainerStatus::Training);
X
xiexionghang 已提交
231 232
            std::map<std::string, paddle::framework::Channel<DataItem>> backup_input_map;
            for (auto& executor : _executors) {
X
xiexionghang 已提交
233
                environment->barrier(EnvironmentRole::WORKER); 
234 235
                paddle::platform::Timer timer;
                timer.Start();
L
linan17 已提交
236
                ENVLOG_WORKER_MASTER_NOTICE("Start executor:%s", executor->train_exe_name().c_str());
X
xiexionghang 已提交
237 238 239 240 241 242 243 244
                auto data_name = executor->train_data_name();
                paddle::framework::Channel<DataItem> input_channel;
                if (backup_input_map.count(data_name)) {
                    input_channel = backup_input_map[data_name];
                } else {
                    input_channel = dataset->fetch_data(data_name, epoch_id);
                }
                input_channel = executor->run(input_channel, dataset->data_parser(data_name));
245
                timer.Pause();
L
linan17 已提交
246
                ENVLOG_WORKER_MASTER_NOTICE("End executor:%s, cost:%f", executor->train_exe_name().c_str(), timer.ElapsedSec());
X
xiexionghang 已提交
247 248 249 250

                // 等待异步梯度完成
                _context_ptr->ps_client()->flush();
                environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
251
                if (executor->is_dump_all_model()) {
X
xiexionghang 已提交
252 253 254
                    already_dump_inference_model = true;
                    wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta);
                }
X
xiexionghang 已提交
255
                backup_input_map[data_name] = input_channel;
X
xiexionghang 已提交
256
                environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
257
            }
X
xiexionghang 已提交
258
        }
X
xiexionghang 已提交
259

X
xiexionghang 已提交
260 261
        //Step3. Dump Model For Delta&&Checkpoint
        {
262
            wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);
X
xiexionghang 已提交
263
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
264 265 266
            update_cache_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);
            environment->barrier(EnvironmentRole::WORKER); 

L
linan17 已提交
267 268 269 270 271
            if (epoch_accessor->is_last_epoch(epoch_id)) {
                wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpointBase);
            } else {
                wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpoint);
            }
X
xiexionghang 已提交
272
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
273 274 275 276
            if (epoch_accessor->is_last_epoch(epoch_id) &&
                environment->is_master_node(EnvironmentRole::WORKER)) {
                paddle::platform::Timer timer;
                timer.Start();
L
linan17 已提交
277
                ENVLOG_WORKER_MASTER_NOTICE("Start shrink table");
X
xiexionghang 已提交
278 279 280 281 282 283
                for (auto& executor : _executors) {
                    const auto& table_accessors = executor->table_accessors();
                    for (auto& itr : table_accessors) {
                        CHECK(itr.second[0]->shrink() == 0);
                    }
                } 
L
linan17 已提交
284 285
                timer.Pause();
                ENVLOG_WORKER_MASTER_NOTICE("End shrink table, cost:%f", timer.ElapsedSec());
X
xiexionghang 已提交
286 287
            }
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
288

X
xiexionghang 已提交
289 290
            epoch_accessor->epoch_done(epoch_id);
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
291
        }
L
linan17 已提交
292
        ENVLOG_WORKER_MASTER_NOTICE("    ==== end %s ====", epoch_accessor->text(epoch_id).c_str());
X
xiexionghang 已提交
293 294 295 296 297 298 299 300 301 302
        //Step4. Output Monitor && RunStatus
        //TODO
    }
    
    return 0;
}

}  // namespace feed
}  // namespace custom_trainer
}  // namespace paddle