learner_process.cc 13.1 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5
/*
 *Author: xiexionghang
 *Train样本
 */
#include <omp.h>
6
#include "paddle/fluid/platform/timer.h"
X
xiexionghang 已提交
7
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
X
xiexionghang 已提交
8
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset.h"
X
xiexionghang 已提交
9 10 11 12 13 14 15 16 17 18
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/process/learner_process.h"

namespace paddle {
namespace custom_trainer {
namespace feed {

int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
    int ret = Process::initialize(context_ptr);
    auto& config = _context_ptr->trainer_config;
X
xiexionghang 已提交
19 20
    _is_dump_cache_model = config["dump_cache_model"].as<bool>(false);
    _cache_load_converter = config["load_cache_converter"].as<std::string>("");
X
xiexionghang 已提交
21
    _startup_dump_inference_base = config["startup_dump_inference_base"].as<bool>(false);
X
xiexionghang 已提交
22
    if (config["executor"]) {
X
xiexionghang 已提交
23 24 25 26
        _executors.resize(config["executor"].size());
        for (size_t i = 0; i < _executors.size(); ++i) {
            _executors[i].reset(new MultiThreadExecutor());
            CHECK(_executors[i]->initialize(config["executor"][i], context_ptr) == 0);
X
xiexionghang 已提交
27 28 29 30 31
        }
    }
    return 0;
}

X
xiexionghang 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
// 更新各节点存储的CacheModel
int LearnerProcess::update_cache_model(uint64_t epoch_id, ModelSaveWay way) {
    auto fs = _context_ptr->file_system;
    auto* ps_client = _context_ptr->pslib->ps_client();
    auto* environment = _context_ptr->environment.get();
    auto* epoch_accessor = _context_ptr->epoch_accessor.get();
    if (!epoch_accessor->need_save_model(epoch_id, way)) {
        return 0;
    }
    auto* ps_param = _context_ptr->pslib->get_param();
    if (_is_dump_cache_model && way == ModelSaveWay::ModelSaveInferenceBase) {
        auto model_dir = epoch_accessor->model_save_path(epoch_id, way);
        auto& table_param = ps_param->server_param().downpour_server_param().downpour_table_param();
        for (auto& param : table_param) {
            if (param.type() != paddle::PS_SPARSE_TABLE) {
                continue;
            }
            auto cache_model_path = fs->path_join(
                model_dir, string::format_string("%03d_cache/", param.table_id()));
            if (!fs->exists(cache_model_path)) {
                continue;
            }
            auto& cache_dict = *(_context_ptr->cache_dict.get());
            cache_dict.clear();
            cache_dict.reserve(_cache_sign_max_num);
            auto cache_file_list = fs->list(fs->path_join(cache_model_path, "part*"));
            for (auto& cache_path : cache_file_list) {
                auto cache_file = fs->open_read(cache_path, _cache_load_converter);
                char *buffer = nullptr;
                size_t buffer_size = 0;
                ssize_t line_len = 0;
                while ((line_len = getline(&buffer, &buffer_size, cache_file.get())) != -1) {
                    if (line_len <= 1) {
                        continue;
                    }
                    char* data_ptr;
                    cache_dict.append(strtoul(buffer, &data_ptr, 10));
                }
                if (buffer != nullptr) {
                    free(buffer);
                } 
            }
            break;
        }
    }
    return 0;
}
X
xiexionghang 已提交
79
int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is_force_dump) {
R
rensilin 已提交
80
    auto fs = _context_ptr->file_system;
81
    auto* ps_client = _context_ptr->pslib->ps_client();
X
xiexionghang 已提交
82
    auto* environment = _context_ptr->environment.get();
83
    auto* epoch_accessor = _context_ptr->epoch_accessor.get();
X
xiexionghang 已提交
84
    if (!environment->is_master_node(EnvironmentRole::WORKER)) {
X
xiexionghang 已提交
85 86
        return 0;
    }
X
xiexionghang 已提交
87
    if (!is_force_dump && !epoch_accessor->need_save_model(epoch_id, way)) {
88 89 90 91
        return 0;
    }
    paddle::platform::Timer timer;
    timer.Start();
X
xiexionghang 已提交
92
    std::set<uint32_t> table_set;
R
rensilin 已提交
93
    auto model_dir = epoch_accessor->model_save_path(epoch_id, way);
X
xiexionghang 已提交
94 95 96 97 98
    for (auto& executor : _executors) {
        const auto& table_accessors = executor->table_accessors();
        for (auto& itr : table_accessors) {
            table_set.insert(itr.first);
        }
R
rensilin 已提交
99 100 101
        auto save_path = fs->path_join(model_dir, executor->train_exe_name() + "_param");
        VLOG(2) << "Start save model, save_path:" << save_path;
        executor->save_persistables(save_path);
X
xiexionghang 已提交
102
    }
X
xiexionghang 已提交
103
    int ret_size = 0;
X
xiexionghang 已提交
104
    auto table_num = table_set.size();
X
xiexionghang 已提交
105
    std::future<int> rets[table_num];
X
xiexionghang 已提交
106
    for (auto table_id : table_set) {
107 108
        VLOG(2) << "Start save model, table_id:" << table_id;
        rets[ret_size++] = ps_client->save(table_id, model_dir, std::to_string((int)way));
X
xiexionghang 已提交
109 110 111 112 113 114
    }
    int all_ret = 0;
    for (int i = 0; i < ret_size; ++i) {
        rets[i].wait();
        all_ret |= rets[i].get();
    }
115 116
    timer.Pause();
    VLOG(2) << "Save Model Cost(s):" << timer.ElapsedSec();
X
xiexionghang 已提交
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
    
    // save cache model, 只有inference需要cache_model
    auto* ps_param = _context_ptr->pslib->get_param();
    if (_is_dump_cache_model && (way == ModelSaveWay::ModelSaveInferenceBase ||
        way == ModelSaveWay::ModelSaveInferenceDelta)) {
        auto& table_param = ps_param->server_param().downpour_server_param().downpour_table_param();
        for (auto& param : table_param) {
            if (param.type() != paddle::PS_SPARSE_TABLE) {
                continue;
            } 
            double cache_threshold = 0.0;
            auto status = ps_client->get_cache_threshold(param.table_id(), cache_threshold);
            CHECK(status.get() == 0) << "CacheThreshold Get failed!";
            status = ps_client->cache_shuffle(param.table_id(), model_dir, std::to_string((int)way),
                std::to_string(cache_threshold));
            CHECK(status.get() == 0) << "Cache Shuffler Failed";
            status = ps_client->save_cache(param.table_id(), model_dir, std::to_string((int)way));
            auto feature_size = status.get();
            CHECK(feature_size >= 0) << "Cache Save Failed";
            auto cache_model_path = fs->path_join(
                model_dir, string::format_string("%03d_cache/sparse_cache.meta", param.table_id()));
            auto cache_meta_file = fs->open_write(cache_model_path, "");
            auto meta = string::format_string("file_prefix:part\npart_num:%d\nkey_num:%d\n", 
                param.sparse_table_cache_file_num(), feature_size);
            CHECK(fwrite(meta.c_str(), meta.size(), 1, cache_meta_file.get()) == 1) << "Cache Meta Failed";
            if (feature_size > _cache_sign_max_num) {
                _cache_sign_max_num = feature_size;
            }
        }
    }
147
    _context_ptr->epoch_accessor->update_model_donefile(epoch_id, way);
X
xiexionghang 已提交
148

X
xiexionghang 已提交
149 150 151
    return all_ret;
}

X
xiexionghang 已提交
152 153 154 155 156
int LearnerProcess::load_model(uint64_t epoch_id) {
    auto* environment = _context_ptr->environment.get();
    if (!environment->is_master_node(EnvironmentRole::WORKER)) {
        return 0;
    }
X
xiexionghang 已提交
157
    auto* fs = _context_ptr->file_system.get();
X
xiexionghang 已提交
158 159 160 161 162 163 164 165
    std::set<uint32_t> loaded_table_set;
    auto model_dir = _context_ptr->epoch_accessor->checkpoint_path();
    for (auto& executor : _executors) {
        const auto& table_accessors = executor->table_accessors();
        for (auto& itr : table_accessors) {
            if (loaded_table_set.count(itr.first)) {
                continue;
            }
X
xiexionghang 已提交
166
            auto table_model_path = fs->path_join(
X
xiexionghang 已提交
167
                model_dir, string::format_string("%03d", itr.first));
X
xiexionghang 已提交
168
            if ((!fs->exists(table_model_path)) || fs->list(table_model_path).size() == 0) {
X
xiexionghang 已提交
169 170 171 172 173 174 175 176 177 178 179 180 181 182
                VLOG(2) << "miss table_model:" << table_model_path << ", initialize by default";
                auto scope = std::move(executor->fetch_scope());
                CHECK(itr.second[0]->create(scope.get()) == 0);
            } else {
                auto status = _context_ptr->ps_client()->load(itr.first, 
                    model_dir, std::to_string((int)ModelSaveWay::ModelSaveTrainCheckpoint));
                CHECK(status.get() == 0) << "table load failed, id:" << itr.first;
            }
            loaded_table_set.insert(itr.first);
        }
    }
    return 0;
}

X
xiexionghang 已提交
183
int LearnerProcess::run() {
X
xiexionghang 已提交
184
    auto* dataset = _context_ptr->dataset.get();
X
xiexionghang 已提交
185 186
    auto* environment = _context_ptr->environment.get();
    auto* epoch_accessor = _context_ptr->epoch_accessor.get(); 
X
xiexionghang 已提交
187
    uint64_t epoch_id = epoch_accessor->current_epoch_id();
X
xiexionghang 已提交
188

X
xiexionghang 已提交
189
    environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
X
xiexionghang 已提交
190
        "Resume train with epoch_id:%d %s", epoch_id, _context_ptr->epoch_accessor->text(epoch_id).c_str());
X
xiexionghang 已提交
191
    
X
xiexionghang 已提交
192 193 194 195
    //尝试加载模型 or 初始化
    CHECK(load_model(epoch_id) == 0);
    environment->barrier(EnvironmentRole::WORKER); 

X
xiexionghang 已提交
196 197
    //判断是否先dump出base TODO
    wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase, _startup_dump_inference_base);
X
xiexionghang 已提交
198
    environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
199 200 201
    
    while (true) {
        epoch_accessor->next_epoch();
202
        _context_ptr->monitor_ssm.str(""); 
X
xiexionghang 已提交
203
        bool already_dump_inference_model = false;
X
xiexionghang 已提交
204
        epoch_id = epoch_accessor->current_epoch_id();
X
xiexionghang 已提交
205
        std::string epoch_log_title = paddle::string::format_string(
X
xiexionghang 已提交
206
            "train epoch_id:%d label:%s", epoch_id, epoch_accessor->text(epoch_id).c_str());
X
xiexionghang 已提交
207
        std::string data_path = paddle::string::to_string<std::string>(dataset->epoch_data_path(epoch_id));
X
xiexionghang 已提交
208 209
        
        //Step1. 等待样本ready
X
xiexionghang 已提交
210
        {
X
xiexionghang 已提交
211
            environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
X
xiexionghang 已提交
212
                "%s, wait data ready:%s", epoch_log_title.c_str(), data_path.c_str());
X
xiexionghang 已提交
213 214 215 216
            while (dataset->epoch_data_status(epoch_id) != DatasetStatus::Ready) {
                sleep(30);  
                dataset->pre_detect_data(epoch_id);
                environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
X
xiexionghang 已提交
217
                "data not ready, wait 30s");
X
xiexionghang 已提交
218 219
            } 
            environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
X
xiexionghang 已提交
220
                "Start %s, data is ready", epoch_log_title.c_str());
X
xiexionghang 已提交
221 222 223
            environment->barrier(EnvironmentRole::WORKER); 
        }
    
X
xiexionghang 已提交
224
        //Step2. 运行训练网络
X
xiexionghang 已提交
225
        {
X
xiexionghang 已提交
226 227
            std::map<std::string, paddle::framework::Channel<DataItem>> backup_input_map;
            for (auto& executor : _executors) {
X
xiexionghang 已提交
228
                environment->barrier(EnvironmentRole::WORKER); 
229 230
                paddle::platform::Timer timer;
                timer.Start();
X
xiexionghang 已提交
231
                VLOG(2) << "Start executor:" << executor->train_exe_name();
X
xiexionghang 已提交
232 233 234 235 236 237 238 239
                auto data_name = executor->train_data_name();
                paddle::framework::Channel<DataItem> input_channel;
                if (backup_input_map.count(data_name)) {
                    input_channel = backup_input_map[data_name];
                } else {
                    input_channel = dataset->fetch_data(data_name, epoch_id);
                }
                input_channel = executor->run(input_channel, dataset->data_parser(data_name));
240
                timer.Pause();
X
xiexionghang 已提交
241
                VLOG(2) << "End executor:" << executor->train_exe_name() << ", cost:" << timer.ElapsedSec();
X
xiexionghang 已提交
242 243 244 245

                // 等待异步梯度完成
                _context_ptr->ps_client()->flush();
                environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
246
                if (executor->is_dump_all_model()) {
X
xiexionghang 已提交
247 248 249
                    already_dump_inference_model = true;
                    wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta);
                }
X
xiexionghang 已提交
250
                backup_input_map[data_name] = input_channel;
X
xiexionghang 已提交
251
                environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
252
            }
X
xiexionghang 已提交
253
        }
X
xiexionghang 已提交
254

X
xiexionghang 已提交
255 256
        //Step3. Dump Model For Delta&&Checkpoint
        {
257
            wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);
X
xiexionghang 已提交
258
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
259 260 261
            update_cache_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);
            environment->barrier(EnvironmentRole::WORKER); 

L
linan17 已提交
262 263 264 265 266
            if (epoch_accessor->is_last_epoch(epoch_id)) {
                wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpointBase);
            } else {
                wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpoint);
            }
X
xiexionghang 已提交
267
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
268 269 270 271 272 273 274 275 276 277 278
            if (epoch_accessor->is_last_epoch(epoch_id) &&
                environment->is_master_node(EnvironmentRole::WORKER)) {
                paddle::platform::Timer timer;
                timer.Start();
                VLOG(2) << "Start shrink table"; 
                for (auto& executor : _executors) {
                    const auto& table_accessors = executor->table_accessors();
                    for (auto& itr : table_accessors) {
                        CHECK(itr.second[0]->shrink() == 0);
                    }
                } 
X
xiexionghang 已提交
279
                VLOG(2) << "End shrink table, cost:" << timer.ElapsedSec();
X
xiexionghang 已提交
280 281
            }
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
282

X
xiexionghang 已提交
283 284
            epoch_accessor->epoch_done(epoch_id);
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
285
        }
X
xiexionghang 已提交
286
    
X
xiexionghang 已提交
287 288 289 290 291 292 293 294 295 296
        //Step4. Output Monitor && RunStatus
        //TODO
    }
    
    return 0;
}

}  // namespace feed
}  // namespace custom_trainer
}  // namespace paddle