learner_process.cc 9.3 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5
/*
 *Author: xiexionghang
 *Train样本
 */
#include <omp.h>
6
#include "paddle/fluid/platform/timer.h"
X
xiexionghang 已提交
7
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
X
xiexionghang 已提交
8
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset.h"
X
xiexionghang 已提交
9 10 11 12 13 14 15 16 17 18
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/process/learner_process.h"

namespace paddle {
namespace custom_trainer {
namespace feed {

int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
    int ret = Process::initialize(context_ptr);
    auto& config = _context_ptr->trainer_config;
X
xiexionghang 已提交
19
    _startup_dump_inference_base = config["startup_dump_inference_base"].as<bool>(false);
X
xiexionghang 已提交
20
    if (config["executor"]) {
X
xiexionghang 已提交
21 22 23 24
        _executors.resize(config["executor"].size());
        for (size_t i = 0; i < _executors.size(); ++i) {
            _executors[i].reset(new MultiThreadExecutor());
            CHECK(_executors[i]->initialize(config["executor"][i], context_ptr) == 0);
X
xiexionghang 已提交
25 26 27 28 29
        }
    }
    return 0;
}

X
xiexionghang 已提交
30
int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is_force_dump) {
R
rensilin 已提交
31
    auto fs = _context_ptr->file_system;
32
    auto* ps_client = _context_ptr->pslib->ps_client();
X
xiexionghang 已提交
33
    auto* environment = _context_ptr->environment.get();
34
    auto* epoch_accessor = _context_ptr->epoch_accessor.get();
X
xiexionghang 已提交
35
    if (!environment->is_master_node(EnvironmentRole::WORKER)) {
X
xiexionghang 已提交
36 37
        return 0;
    }
X
xiexionghang 已提交
38
    if (!is_force_dump && !epoch_accessor->need_save_model(epoch_id, way)) {
39 40 41 42
        return 0;
    }
    paddle::platform::Timer timer;
    timer.Start();
X
xiexionghang 已提交
43
    std::set<uint32_t> table_set;
R
rensilin 已提交
44
    auto model_dir = epoch_accessor->model_save_path(epoch_id, way);
X
xiexionghang 已提交
45 46 47 48 49
    for (auto& executor : _executors) {
        const auto& table_accessors = executor->table_accessors();
        for (auto& itr : table_accessors) {
            table_set.insert(itr.first);
        }
R
rensilin 已提交
50 51 52
        auto save_path = fs->path_join(model_dir, executor->train_exe_name() + "_param");
        VLOG(2) << "Start save model, save_path:" << save_path;
        executor->save_persistables(save_path);
X
xiexionghang 已提交
53
    }
X
xiexionghang 已提交
54
    int ret_size = 0;
X
xiexionghang 已提交
55
    auto table_num = table_set.size();
X
xiexionghang 已提交
56
    std::future<int> rets[table_num];
X
xiexionghang 已提交
57
    for (auto table_id : table_set) {
58 59
        VLOG(2) << "Start save model, table_id:" << table_id;
        rets[ret_size++] = ps_client->save(table_id, model_dir, std::to_string((int)way));
X
xiexionghang 已提交
60 61 62 63 64 65
    }
    int all_ret = 0;
    for (int i = 0; i < ret_size; ++i) {
        rets[i].wait();
        all_ret |= rets[i].get();
    }
66 67 68
    timer.Pause();
    VLOG(2) << "Save Model Cost(s):" << timer.ElapsedSec();
    _context_ptr->epoch_accessor->update_model_donefile(epoch_id, way);
X
xiexionghang 已提交
69 70 71
    return all_ret;
}

X
xiexionghang 已提交
72 73 74 75 76
int LearnerProcess::load_model(uint64_t epoch_id) {
    auto* environment = _context_ptr->environment.get();
    if (!environment->is_master_node(EnvironmentRole::WORKER)) {
        return 0;
    }
X
xiexionghang 已提交
77
    auto* fs = _context_ptr->file_system.get();
X
xiexionghang 已提交
78 79
    std::set<uint32_t> loaded_table_set;
    auto model_dir = _context_ptr->epoch_accessor->checkpoint_path();
L
linan17 已提交
80 81
    paddle::platform::Timer timer;
    timer.Start();
X
xiexionghang 已提交
82 83 84 85 86 87
    for (auto& executor : _executors) {
        const auto& table_accessors = executor->table_accessors();
        for (auto& itr : table_accessors) {
            if (loaded_table_set.count(itr.first)) {
                continue;
            }
X
xiexionghang 已提交
88
            auto table_model_path = fs->path_join(
X
xiexionghang 已提交
89
                model_dir, string::format_string("%03d", itr.first));
X
xiexionghang 已提交
90
            if ((!fs->exists(table_model_path)) || fs->list(table_model_path).size() == 0) {
X
xiexionghang 已提交
91 92 93 94
                VLOG(2) << "miss table_model:" << table_model_path << ", initialize by default";
                auto scope = std::move(executor->fetch_scope());
                CHECK(itr.second[0]->create(scope.get()) == 0);
            } else {
L
linan17 已提交
95
                ENVLOG_WORKER_MASTER_NOTICE("Loading model %s", model_dir.c_str());
X
xiexionghang 已提交
96 97 98 99 100 101 102
                auto status = _context_ptr->ps_client()->load(itr.first, 
                    model_dir, std::to_string((int)ModelSaveWay::ModelSaveTrainCheckpoint));
                CHECK(status.get() == 0) << "table load failed, id:" << itr.first;
            }
            loaded_table_set.insert(itr.first);
        }
    }
L
linan17 已提交
103 104
    timer.Pause();
    ENVLOG_WORKER_MASTER_NOTICE("Finished loading model, cost:%f", timer.ElapsedSec());
X
xiexionghang 已提交
105 106 107
    return 0;
}

X
xiexionghang 已提交
108
int LearnerProcess::run() {
X
xiexionghang 已提交
109
    auto* dataset = _context_ptr->dataset.get();
X
xiexionghang 已提交
110 111
    auto* environment = _context_ptr->environment.get();
    auto* epoch_accessor = _context_ptr->epoch_accessor.get(); 
X
xiexionghang 已提交
112
    uint64_t epoch_id = epoch_accessor->current_epoch_id();
X
xiexionghang 已提交
113

L
linan17 已提交
114
    ENVLOG_WORKER_MASTER_NOTICE("Resume train with epoch_id:%d %s", epoch_id, _context_ptr->epoch_accessor->text(epoch_id).c_str());
X
xiexionghang 已提交
115 116 117 118
    //尝试加载模型 or 初始化
    CHECK(load_model(epoch_id) == 0);
    environment->barrier(EnvironmentRole::WORKER); 

X
xiexionghang 已提交
119 120
    //判断是否先dump出base TODO
    wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase, _startup_dump_inference_base);
X
xiexionghang 已提交
121
    environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
122 123 124
    
    while (true) {
        epoch_accessor->next_epoch();
125
        _context_ptr->monitor_ssm.str(""); 
X
xiexionghang 已提交
126
        bool already_dump_inference_model = false;
X
xiexionghang 已提交
127
        epoch_id = epoch_accessor->current_epoch_id();
X
xiexionghang 已提交
128
        std::string epoch_log_title = paddle::string::format_string(
X
xiexionghang 已提交
129
            "train epoch_id:%d label:%s", epoch_id, epoch_accessor->text(epoch_id).c_str());
X
xiexionghang 已提交
130
        std::string data_path = paddle::string::to_string<std::string>(dataset->epoch_data_path(epoch_id));
L
linan17 已提交
131
        ENVLOG_WORKER_MASTER_NOTICE("    ==== begin %s ====", epoch_accessor->text(epoch_id).c_str());
X
xiexionghang 已提交
132
        //Step1. 等待样本ready
X
xiexionghang 已提交
133
        {
L
linan17 已提交
134
            ENVLOG_WORKER_MASTER_NOTICE("      %s, wait data ready:%s", epoch_log_title.c_str(), data_path.c_str());
X
xiexionghang 已提交
135 136 137
            while (dataset->epoch_data_status(epoch_id) != DatasetStatus::Ready) {
                sleep(30);  
                dataset->pre_detect_data(epoch_id);
L
linan17 已提交
138
                ENVLOG_WORKER_MASTER_NOTICE("      epoch_id:%d data not ready, wait 30s", epoch_id);
X
xiexionghang 已提交
139
            } 
L
linan17 已提交
140
            ENVLOG_WORKER_MASTER_NOTICE("      Start %s, data is ready", epoch_log_title.c_str());
X
xiexionghang 已提交
141 142 143
            environment->barrier(EnvironmentRole::WORKER); 
        }
    
X
xiexionghang 已提交
144
        //Step2. 运行训练网络
X
xiexionghang 已提交
145
        {
X
xiexionghang 已提交
146 147
            std::map<std::string, paddle::framework::Channel<DataItem>> backup_input_map;
            for (auto& executor : _executors) {
X
xiexionghang 已提交
148
                environment->barrier(EnvironmentRole::WORKER); 
149 150
                paddle::platform::Timer timer;
                timer.Start();
L
linan17 已提交
151
                ENVLOG_WORKER_MASTER_NOTICE("Start executor:%s", executor->train_exe_name().c_str());
X
xiexionghang 已提交
152 153 154 155 156 157 158 159
                auto data_name = executor->train_data_name();
                paddle::framework::Channel<DataItem> input_channel;
                if (backup_input_map.count(data_name)) {
                    input_channel = backup_input_map[data_name];
                } else {
                    input_channel = dataset->fetch_data(data_name, epoch_id);
                }
                input_channel = executor->run(input_channel, dataset->data_parser(data_name));
160
                timer.Pause();
L
linan17 已提交
161
                ENVLOG_WORKER_MASTER_NOTICE("End executor:%s, cost:%f", executor->train_exe_name().c_str(), timer.ElapsedSec());
X
xiexionghang 已提交
162 163 164 165

                // 等待异步梯度完成
                _context_ptr->ps_client()->flush();
                environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
166
                if (executor->is_dump_all_model()) {
X
xiexionghang 已提交
167 168 169
                    already_dump_inference_model = true;
                    wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta);
                }
X
xiexionghang 已提交
170
                backup_input_map[data_name] = input_channel;
X
xiexionghang 已提交
171
                environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
172
            }
X
xiexionghang 已提交
173
        }
X
xiexionghang 已提交
174

X
xiexionghang 已提交
175 176
        //Step3. Dump Model For Delta&&Checkpoint
        {
177
            wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);
X
xiexionghang 已提交
178
            environment->barrier(EnvironmentRole::WORKER); 
L
linan17 已提交
179 180 181 182 183
            if (epoch_accessor->is_last_epoch(epoch_id)) {
                wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpointBase);
            } else {
                wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpoint);
            }
X
xiexionghang 已提交
184
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
185 186 187 188
            if (epoch_accessor->is_last_epoch(epoch_id) &&
                environment->is_master_node(EnvironmentRole::WORKER)) {
                paddle::platform::Timer timer;
                timer.Start();
L
linan17 已提交
189
                ENVLOG_WORKER_MASTER_NOTICE("Start shrink table");
X
xiexionghang 已提交
190 191 192 193 194 195
                for (auto& executor : _executors) {
                    const auto& table_accessors = executor->table_accessors();
                    for (auto& itr : table_accessors) {
                        CHECK(itr.second[0]->shrink() == 0);
                    }
                } 
L
linan17 已提交
196 197
                timer.Pause();
                ENVLOG_WORKER_MASTER_NOTICE("End shrink table, cost:%f", timer.ElapsedSec());
X
xiexionghang 已提交
198 199
            }
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
200

X
xiexionghang 已提交
201 202
            epoch_accessor->epoch_done(epoch_id);
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
203
        }
L
linan17 已提交
204
        ENVLOG_WORKER_MASTER_NOTICE("    ==== end %s ====", epoch_accessor->text(epoch_id).c_str());
X
xiexionghang 已提交
205 206 207 208 209 210 211 212 213 214
        //Step4. Output Monitor && RunStatus
        //TODO
    }
    
    return 0;
}

}  // namespace feed
}  // namespace custom_trainer
}  // namespace paddle