learner_process.cc 13.3 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5
/*
 *Author: xiexionghang
 *Train样本
 */
#include <omp.h>
6
#include "paddle/fluid/platform/timer.h"
X
xiexionghang 已提交
7
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
X
xiexionghang 已提交
8
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset.h"
X
xiexionghang 已提交
9 10 11 12 13 14 15 16 17 18
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/process/learner_process.h"

namespace paddle {
namespace custom_trainer {
namespace feed {

int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
    int ret = Process::initialize(context_ptr);
    auto& config = _context_ptr->trainer_config;
X
xiexionghang 已提交
19 20
    _is_dump_cache_model = config["dump_cache_model"].as<bool>(false);
    _cache_load_converter = config["load_cache_converter"].as<std::string>("");
X
xiexionghang 已提交
21
    _startup_dump_inference_base = config["startup_dump_inference_base"].as<bool>(false);
X
xiexionghang 已提交
22
    if (config["executor"]) {
X
xiexionghang 已提交
23 24 25 26
        _executors.resize(config["executor"].size());
        for (size_t i = 0; i < _executors.size(); ++i) {
            _executors[i].reset(new MultiThreadExecutor());
            CHECK(_executors[i]->initialize(config["executor"][i], context_ptr) == 0);
X
xiexionghang 已提交
27 28 29 30 31
        }
    }
    return 0;
}

X
xiexionghang 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
// 更新各节点存储的CacheModel
int LearnerProcess::update_cache_model(uint64_t epoch_id, ModelSaveWay way) {
    auto fs = _context_ptr->file_system;
    auto* ps_client = _context_ptr->pslib->ps_client();
    auto* environment = _context_ptr->environment.get();
    auto* epoch_accessor = _context_ptr->epoch_accessor.get();
    if (!epoch_accessor->need_save_model(epoch_id, way)) {
        return 0;
    }
    auto* ps_param = _context_ptr->pslib->get_param();
    if (_is_dump_cache_model && way == ModelSaveWay::ModelSaveInferenceBase) {
        auto model_dir = epoch_accessor->model_save_path(epoch_id, way);
        auto& table_param = ps_param->server_param().downpour_server_param().downpour_table_param();
        for (auto& param : table_param) {
            if (param.type() != paddle::PS_SPARSE_TABLE) {
                continue;
            }
            auto cache_model_path = fs->path_join(
                model_dir, string::format_string("%03d_cache/", param.table_id()));
            if (!fs->exists(cache_model_path)) {
                continue;
            }
            auto& cache_dict = *(_context_ptr->cache_dict.get());
            cache_dict.clear();
            cache_dict.reserve(_cache_sign_max_num);
            auto cache_file_list = fs->list(fs->path_join(cache_model_path, "part*"));
            for (auto& cache_path : cache_file_list) {
                auto cache_file = fs->open_read(cache_path, _cache_load_converter);
                char *buffer = nullptr;
                size_t buffer_size = 0;
                ssize_t line_len = 0;
                while ((line_len = getline(&buffer, &buffer_size, cache_file.get())) != -1) {
                    if (line_len <= 1) {
                        continue;
                    }
X
xiexionghang 已提交
67
                    char* data_ptr = NULL;
X
xiexionghang 已提交
68 69 70 71 72 73 74 75 76 77 78
                    cache_dict.append(strtoul(buffer, &data_ptr, 10));
                }
                if (buffer != nullptr) {
                    free(buffer);
                } 
            }
            break;
        }
    }
    return 0;
}
X
xiexionghang 已提交
79
int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is_force_dump) {
R
rensilin 已提交
80
    auto fs = _context_ptr->file_system;
81
    auto* ps_client = _context_ptr->pslib->ps_client();
X
xiexionghang 已提交
82
    auto* environment = _context_ptr->environment.get();
83
    auto* epoch_accessor = _context_ptr->epoch_accessor.get();
X
xiexionghang 已提交
84
    if (!environment->is_master_node(EnvironmentRole::WORKER)) {
X
xiexionghang 已提交
85 86
        return 0;
    }
X
xiexionghang 已提交
87
    if (!is_force_dump && !epoch_accessor->need_save_model(epoch_id, way)) {
88 89 90 91
        return 0;
    }
    paddle::platform::Timer timer;
    timer.Start();
X
xiexionghang 已提交
92
    std::set<uint32_t> table_set;
R
rensilin 已提交
93
    auto model_dir = epoch_accessor->model_save_path(epoch_id, way);
X
xiexionghang 已提交
94 95 96 97 98
    for (auto& executor : _executors) {
        const auto& table_accessors = executor->table_accessors();
        for (auto& itr : table_accessors) {
            table_set.insert(itr.first);
        }
R
rensilin 已提交
99 100 101
        auto save_path = fs->path_join(model_dir, executor->train_exe_name() + "_param");
        VLOG(2) << "Start save model, save_path:" << save_path;
        executor->save_persistables(save_path);
X
xiexionghang 已提交
102
    }
X
xiexionghang 已提交
103
    int ret_size = 0;
X
xiexionghang 已提交
104
    auto table_num = table_set.size();
X
xiexionghang 已提交
105
    std::future<int> rets[table_num];
X
xiexionghang 已提交
106
    for (auto table_id : table_set) {
107 108
        VLOG(2) << "Start save model, table_id:" << table_id;
        rets[ret_size++] = ps_client->save(table_id, model_dir, std::to_string((int)way));
X
xiexionghang 已提交
109 110 111 112 113 114
    }
    int all_ret = 0;
    for (int i = 0; i < ret_size; ++i) {
        rets[i].wait();
        all_ret |= rets[i].get();
    }
115 116
    timer.Pause();
    VLOG(2) << "Save Model Cost(s):" << timer.ElapsedSec();
X
xiexionghang 已提交
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
    
    // save cache model, 只有inference需要cache_model
    auto* ps_param = _context_ptr->pslib->get_param();
    if (_is_dump_cache_model && (way == ModelSaveWay::ModelSaveInferenceBase ||
        way == ModelSaveWay::ModelSaveInferenceDelta)) {
        auto& table_param = ps_param->server_param().downpour_server_param().downpour_table_param();
        for (auto& param : table_param) {
            if (param.type() != paddle::PS_SPARSE_TABLE) {
                continue;
            } 
            double cache_threshold = 0.0;
            auto status = ps_client->get_cache_threshold(param.table_id(), cache_threshold);
            CHECK(status.get() == 0) << "CacheThreshold Get failed!";
            status = ps_client->cache_shuffle(param.table_id(), model_dir, std::to_string((int)way),
                std::to_string(cache_threshold));
            CHECK(status.get() == 0) << "Cache Shuffler Failed";
            status = ps_client->save_cache(param.table_id(), model_dir, std::to_string((int)way));
            auto feature_size = status.get();
            CHECK(feature_size >= 0) << "Cache Save Failed";
            auto cache_model_path = fs->path_join(
                model_dir, string::format_string("%03d_cache/sparse_cache.meta", param.table_id()));
            auto cache_meta_file = fs->open_write(cache_model_path, "");
            auto meta = string::format_string("file_prefix:part\npart_num:%d\nkey_num:%d\n", 
                param.sparse_table_cache_file_num(), feature_size);
            CHECK(fwrite(meta.c_str(), meta.size(), 1, cache_meta_file.get()) == 1) << "Cache Meta Failed";
            if (feature_size > _cache_sign_max_num) {
                _cache_sign_max_num = feature_size;
            }
        }
    }
147
    _context_ptr->epoch_accessor->update_model_donefile(epoch_id, way);
X
xiexionghang 已提交
148

X
xiexionghang 已提交
149 150 151
    return all_ret;
}

X
xiexionghang 已提交
152 153 154 155 156
int LearnerProcess::load_model(uint64_t epoch_id) {
    auto* environment = _context_ptr->environment.get();
    if (!environment->is_master_node(EnvironmentRole::WORKER)) {
        return 0;
    }
X
xiexionghang 已提交
157
    auto* fs = _context_ptr->file_system.get();
X
xiexionghang 已提交
158 159
    std::set<uint32_t> loaded_table_set;
    auto model_dir = _context_ptr->epoch_accessor->checkpoint_path();
L
linan17 已提交
160 161
    paddle::platform::Timer timer;
    timer.Start();
X
xiexionghang 已提交
162 163 164 165 166 167
    for (auto& executor : _executors) {
        const auto& table_accessors = executor->table_accessors();
        for (auto& itr : table_accessors) {
            if (loaded_table_set.count(itr.first)) {
                continue;
            }
X
xiexionghang 已提交
168
            auto table_model_path = fs->path_join(
X
xiexionghang 已提交
169
                model_dir, string::format_string("%03d", itr.first));
X
xiexionghang 已提交
170
            if ((!fs->exists(table_model_path)) || fs->list(table_model_path).size() == 0) {
X
xiexionghang 已提交
171 172 173 174
                VLOG(2) << "miss table_model:" << table_model_path << ", initialize by default";
                auto scope = std::move(executor->fetch_scope());
                CHECK(itr.second[0]->create(scope.get()) == 0);
            } else {
L
linan17 已提交
175
                ENVLOG_WORKER_MASTER_NOTICE("Loading model %s", model_dir.c_str());
X
xiexionghang 已提交
176 177 178 179 180 181 182
                auto status = _context_ptr->ps_client()->load(itr.first, 
                    model_dir, std::to_string((int)ModelSaveWay::ModelSaveTrainCheckpoint));
                CHECK(status.get() == 0) << "table load failed, id:" << itr.first;
            }
            loaded_table_set.insert(itr.first);
        }
    }
L
linan17 已提交
183 184
    timer.Pause();
    ENVLOG_WORKER_MASTER_NOTICE("Finished loading model, cost:%f", timer.ElapsedSec());
X
xiexionghang 已提交
185 186 187
    return 0;
}

X
xiexionghang 已提交
188
int LearnerProcess::run() {
X
xiexionghang 已提交
189
    auto* dataset = _context_ptr->dataset.get();
X
xiexionghang 已提交
190 191
    auto* environment = _context_ptr->environment.get();
    auto* epoch_accessor = _context_ptr->epoch_accessor.get(); 
X
xiexionghang 已提交
192
    uint64_t epoch_id = epoch_accessor->current_epoch_id();
X
xiexionghang 已提交
193

L
linan17 已提交
194
    ENVLOG_WORKER_MASTER_NOTICE("Resume train with epoch_id:%d %s", epoch_id, _context_ptr->epoch_accessor->text(epoch_id).c_str());
X
xiexionghang 已提交
195 196 197 198
    //尝试加载模型 or 初始化
    CHECK(load_model(epoch_id) == 0);
    environment->barrier(EnvironmentRole::WORKER); 

X
xiexionghang 已提交
199 200
    //判断是否先dump出base TODO
    wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase, _startup_dump_inference_base);
X
xiexionghang 已提交
201
    environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
202 203 204
    
    while (true) {
        epoch_accessor->next_epoch();
205
        _context_ptr->monitor_ssm.str(""); 
X
xiexionghang 已提交
206
        bool already_dump_inference_model = false;
X
xiexionghang 已提交
207
        epoch_id = epoch_accessor->current_epoch_id();
X
xiexionghang 已提交
208
        std::string epoch_log_title = paddle::string::format_string(
X
xiexionghang 已提交
209
            "train epoch_id:%d label:%s", epoch_id, epoch_accessor->text(epoch_id).c_str());
X
xiexionghang 已提交
210
        std::string data_path = paddle::string::to_string<std::string>(dataset->epoch_data_path(epoch_id));
L
linan17 已提交
211
        ENVLOG_WORKER_MASTER_NOTICE("    ==== begin %s ====", epoch_accessor->text(epoch_id).c_str());
X
xiexionghang 已提交
212
        //Step1. 等待样本ready
X
xiexionghang 已提交
213
        {
L
linan17 已提交
214
            ENVLOG_WORKER_MASTER_NOTICE("      %s, wait data ready:%s", epoch_log_title.c_str(), data_path.c_str());
X
xiexionghang 已提交
215 216 217
            while (dataset->epoch_data_status(epoch_id) != DatasetStatus::Ready) {
                sleep(30);  
                dataset->pre_detect_data(epoch_id);
L
linan17 已提交
218
                ENVLOG_WORKER_MASTER_NOTICE("      epoch_id:%d data not ready, wait 30s", epoch_id);
X
xiexionghang 已提交
219
            } 
L
linan17 已提交
220
            ENVLOG_WORKER_MASTER_NOTICE("      Start %s, data is ready", epoch_log_title.c_str());
X
xiexionghang 已提交
221 222 223
            environment->barrier(EnvironmentRole::WORKER); 
        }
    
X
xiexionghang 已提交
224
        //Step2. 运行训练网络
X
xiexionghang 已提交
225
        {
X
xiexionghang 已提交
226 227
            std::map<std::string, paddle::framework::Channel<DataItem>> backup_input_map;
            for (auto& executor : _executors) {
X
xiexionghang 已提交
228
                environment->barrier(EnvironmentRole::WORKER); 
229 230
                paddle::platform::Timer timer;
                timer.Start();
L
linan17 已提交
231
                ENVLOG_WORKER_MASTER_NOTICE("Start executor:%s", executor->train_exe_name().c_str());
X
xiexionghang 已提交
232 233 234 235 236 237 238 239
                auto data_name = executor->train_data_name();
                paddle::framework::Channel<DataItem> input_channel;
                if (backup_input_map.count(data_name)) {
                    input_channel = backup_input_map[data_name];
                } else {
                    input_channel = dataset->fetch_data(data_name, epoch_id);
                }
                input_channel = executor->run(input_channel, dataset->data_parser(data_name));
240
                timer.Pause();
L
linan17 已提交
241
                ENVLOG_WORKER_MASTER_NOTICE("End executor:%s, cost:%f", executor->train_exe_name().c_str(), timer.ElapsedSec());
X
xiexionghang 已提交
242 243 244 245

                // 等待异步梯度完成
                _context_ptr->ps_client()->flush();
                environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
246
                if (executor->is_dump_all_model()) {
X
xiexionghang 已提交
247 248 249
                    already_dump_inference_model = true;
                    wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta);
                }
X
xiexionghang 已提交
250
                backup_input_map[data_name] = input_channel;
X
xiexionghang 已提交
251
                environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
252
            }
X
xiexionghang 已提交
253
        }
X
xiexionghang 已提交
254

X
xiexionghang 已提交
255 256
        //Step3. Dump Model For Delta&&Checkpoint
        {
257
            wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);
X
xiexionghang 已提交
258
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
259 260 261
            update_cache_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);
            environment->barrier(EnvironmentRole::WORKER); 

L
linan17 已提交
262 263 264 265 266
            if (epoch_accessor->is_last_epoch(epoch_id)) {
                wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpointBase);
            } else {
                wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpoint);
            }
X
xiexionghang 已提交
267
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
268 269 270 271
            if (epoch_accessor->is_last_epoch(epoch_id) &&
                environment->is_master_node(EnvironmentRole::WORKER)) {
                paddle::platform::Timer timer;
                timer.Start();
L
linan17 已提交
272
                ENVLOG_WORKER_MASTER_NOTICE("Start shrink table");
X
xiexionghang 已提交
273 274 275 276 277 278
                for (auto& executor : _executors) {
                    const auto& table_accessors = executor->table_accessors();
                    for (auto& itr : table_accessors) {
                        CHECK(itr.second[0]->shrink() == 0);
                    }
                } 
L
linan17 已提交
279 280
                timer.Pause();
                ENVLOG_WORKER_MASTER_NOTICE("End shrink table, cost:%f", timer.ElapsedSec());
X
xiexionghang 已提交
281 282
            }
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
283

X
xiexionghang 已提交
284 285
            epoch_accessor->epoch_done(epoch_id);
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
286
        }
L
linan17 已提交
287
        ENVLOG_WORKER_MASTER_NOTICE("    ==== end %s ====", epoch_accessor->text(epoch_id).c_str());
X
xiexionghang 已提交
288 289 290 291 292 293 294 295 296 297
        //Step4. Output Monitor && RunStatus
        //TODO
    }
    
    return 0;
}

}  // namespace feed
}  // namespace custom_trainer
}  // namespace paddle