learner_process.cc 9.1 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5
/*
 *Author: xiexionghang
 *Train样本
 */
#include <omp.h>
6
#include "paddle/fluid/platform/timer.h"
X
xiexionghang 已提交
7
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
X
xiexionghang 已提交
8
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset.h"
X
xiexionghang 已提交
9 10 11 12 13 14 15 16 17 18
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/process/learner_process.h"

namespace paddle {
namespace custom_trainer {
namespace feed {

int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
    int ret = Process::initialize(context_ptr);
    auto& config = _context_ptr->trainer_config;
X
xiexionghang 已提交
19
    _startup_dump_inference_base = config["startup_dump_inference_base"].as<bool>(false);
X
xiexionghang 已提交
20
    if (config["executor"]) {
X
xiexionghang 已提交
21 22 23 24
        _executors.resize(config["executor"].size());
        for (size_t i = 0; i < _executors.size(); ++i) {
            _executors[i].reset(new MultiThreadExecutor());
            CHECK(_executors[i]->initialize(config["executor"][i], context_ptr) == 0);
X
xiexionghang 已提交
25 26 27 28 29
        }
    }
    return 0;
}

X
xiexionghang 已提交
30
int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is_force_dump) {
R
rensilin 已提交
31
    auto fs = _context_ptr->file_system;
32
    auto* ps_client = _context_ptr->pslib->ps_client();
X
xiexionghang 已提交
33
    auto* environment = _context_ptr->environment.get();
34
    auto* epoch_accessor = _context_ptr->epoch_accessor.get();
X
xiexionghang 已提交
35
    if (!environment->is_master_node(EnvironmentRole::WORKER)) {
X
xiexionghang 已提交
36 37
        return 0;
    }
X
xiexionghang 已提交
38
    if (!is_force_dump && !epoch_accessor->need_save_model(epoch_id, way)) {
39 40 41 42
        return 0;
    }
    paddle::platform::Timer timer;
    timer.Start();
X
xiexionghang 已提交
43
    std::set<uint32_t> table_set;
R
rensilin 已提交
44
    auto model_dir = epoch_accessor->model_save_path(epoch_id, way);
X
xiexionghang 已提交
45 46 47 48 49
    for (auto& executor : _executors) {
        const auto& table_accessors = executor->table_accessors();
        for (auto& itr : table_accessors) {
            table_set.insert(itr.first);
        }
R
rensilin 已提交
50 51 52
        auto save_path = fs->path_join(model_dir, executor->train_exe_name() + "_param");
        VLOG(2) << "Start save model, save_path:" << save_path;
        executor->save_persistables(save_path);
X
xiexionghang 已提交
53
    }
X
xiexionghang 已提交
54
    int ret_size = 0;
X
xiexionghang 已提交
55
    auto table_num = table_set.size();
X
xiexionghang 已提交
56
    std::future<int> rets[table_num];
X
xiexionghang 已提交
57
    for (auto table_id : table_set) {
58 59
        VLOG(2) << "Start save model, table_id:" << table_id;
        rets[ret_size++] = ps_client->save(table_id, model_dir, std::to_string((int)way));
X
xiexionghang 已提交
60 61 62 63 64 65
    }
    int all_ret = 0;
    for (int i = 0; i < ret_size; ++i) {
        rets[i].wait();
        all_ret |= rets[i].get();
    }
66 67 68
    timer.Pause();
    VLOG(2) << "Save Model Cost(s):" << timer.ElapsedSec();
    _context_ptr->epoch_accessor->update_model_donefile(epoch_id, way);
X
xiexionghang 已提交
69 70 71
    return all_ret;
}

X
xiexionghang 已提交
72 73 74 75 76
int LearnerProcess::load_model(uint64_t epoch_id) {
    auto* environment = _context_ptr->environment.get();
    if (!environment->is_master_node(EnvironmentRole::WORKER)) {
        return 0;
    }
X
xiexionghang 已提交
77
    auto* fs = _context_ptr->file_system.get();
X
xiexionghang 已提交
78 79 80 81 82 83 84 85
    std::set<uint32_t> loaded_table_set;
    auto model_dir = _context_ptr->epoch_accessor->checkpoint_path();
    for (auto& executor : _executors) {
        const auto& table_accessors = executor->table_accessors();
        for (auto& itr : table_accessors) {
            if (loaded_table_set.count(itr.first)) {
                continue;
            }
X
xiexionghang 已提交
86
            auto table_model_path = fs->path_join(
X
xiexionghang 已提交
87
                model_dir, string::format_string("%03d", itr.first));
X
xiexionghang 已提交
88
            if ((!fs->exists(table_model_path)) || fs->list(table_model_path).size() == 0) {
X
xiexionghang 已提交
89 90 91 92 93 94 95 96 97 98 99 100 101 102
                VLOG(2) << "miss table_model:" << table_model_path << ", initialize by default";
                auto scope = std::move(executor->fetch_scope());
                CHECK(itr.second[0]->create(scope.get()) == 0);
            } else {
                auto status = _context_ptr->ps_client()->load(itr.first, 
                    model_dir, std::to_string((int)ModelSaveWay::ModelSaveTrainCheckpoint));
                CHECK(status.get() == 0) << "table load failed, id:" << itr.first;
            }
            loaded_table_set.insert(itr.first);
        }
    }
    return 0;
}

X
xiexionghang 已提交
103
int LearnerProcess::run() {
X
xiexionghang 已提交
104
    auto* dataset = _context_ptr->dataset.get();
X
xiexionghang 已提交
105 106
    auto* environment = _context_ptr->environment.get();
    auto* epoch_accessor = _context_ptr->epoch_accessor.get(); 
X
xiexionghang 已提交
107
    uint64_t epoch_id = epoch_accessor->current_epoch_id();
X
xiexionghang 已提交
108

X
xiexionghang 已提交
109
    environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
X
xiexionghang 已提交
110
        "Resume train with epoch_id:%d %s", epoch_id, _context_ptr->epoch_accessor->text(epoch_id).c_str());
X
xiexionghang 已提交
111
    
X
xiexionghang 已提交
112 113 114 115
    //尝试加载模型 or 初始化
    CHECK(load_model(epoch_id) == 0);
    environment->barrier(EnvironmentRole::WORKER); 

X
xiexionghang 已提交
116 117
    //判断是否先dump出base TODO
    wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase, _startup_dump_inference_base);
X
xiexionghang 已提交
118
    environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
119 120 121
    
    while (true) {
        epoch_accessor->next_epoch();
122
        _context_ptr->monitor_ssm.str(""); 
X
xiexionghang 已提交
123
        bool already_dump_inference_model = false;
X
xiexionghang 已提交
124
        epoch_id = epoch_accessor->current_epoch_id();
X
xiexionghang 已提交
125
        std::string epoch_log_title = paddle::string::format_string(
X
xiexionghang 已提交
126
            "train epoch_id:%d label:%s", epoch_id, epoch_accessor->text(epoch_id).c_str());
X
xiexionghang 已提交
127
        std::string data_path = paddle::string::to_string<std::string>(dataset->epoch_data_path(epoch_id));
X
xiexionghang 已提交
128 129
        
        //Step1. 等待样本ready
X
xiexionghang 已提交
130
        {
X
xiexionghang 已提交
131
            environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
X
xiexionghang 已提交
132
                "%s, wait data ready:%s", epoch_log_title.c_str(), data_path.c_str());
X
xiexionghang 已提交
133 134 135 136
            while (dataset->epoch_data_status(epoch_id) != DatasetStatus::Ready) {
                sleep(30);  
                dataset->pre_detect_data(epoch_id);
                environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
X
xiexionghang 已提交
137
                "data not ready, wait 30s");
X
xiexionghang 已提交
138 139
            } 
            environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
X
xiexionghang 已提交
140
                "Start %s, data is ready", epoch_log_title.c_str());
X
xiexionghang 已提交
141 142 143
            environment->barrier(EnvironmentRole::WORKER); 
        }
    
X
xiexionghang 已提交
144
        //Step2. 运行训练网络
X
xiexionghang 已提交
145
        {
X
xiexionghang 已提交
146 147
            std::map<std::string, paddle::framework::Channel<DataItem>> backup_input_map;
            for (auto& executor : _executors) {
X
xiexionghang 已提交
148
                environment->barrier(EnvironmentRole::WORKER); 
149 150
                paddle::platform::Timer timer;
                timer.Start();
X
xiexionghang 已提交
151
                VLOG(2) << "Start executor:" << executor->train_exe_name();
X
xiexionghang 已提交
152 153 154 155 156 157 158 159
                auto data_name = executor->train_data_name();
                paddle::framework::Channel<DataItem> input_channel;
                if (backup_input_map.count(data_name)) {
                    input_channel = backup_input_map[data_name];
                } else {
                    input_channel = dataset->fetch_data(data_name, epoch_id);
                }
                input_channel = executor->run(input_channel, dataset->data_parser(data_name));
160
                timer.Pause();
X
xiexionghang 已提交
161
                VLOG(2) << "End executor:" << executor->train_exe_name() << ", cost:" << timer.ElapsedSec();
X
xiexionghang 已提交
162 163 164 165

                // 等待异步梯度完成
                _context_ptr->ps_client()->flush();
                environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
166
                if (executor->is_dump_all_model()) {
X
xiexionghang 已提交
167 168 169
                    already_dump_inference_model = true;
                    wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta);
                }
X
xiexionghang 已提交
170
                backup_input_map[data_name] = input_channel;
X
xiexionghang 已提交
171
                environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
172
            }
X
xiexionghang 已提交
173
        }
X
xiexionghang 已提交
174

X
xiexionghang 已提交
175 176
        //Step3. Dump Model For Delta&&Checkpoint
        {
177
            wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);
X
xiexionghang 已提交
178
            environment->barrier(EnvironmentRole::WORKER); 
L
linan17 已提交
179 180 181 182 183
            if (epoch_accessor->is_last_epoch(epoch_id)) {
                wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpointBase);
            } else {
                wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpoint);
            }
X
xiexionghang 已提交
184
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
185 186 187 188 189 190 191 192 193 194 195
            if (epoch_accessor->is_last_epoch(epoch_id) &&
                environment->is_master_node(EnvironmentRole::WORKER)) {
                paddle::platform::Timer timer;
                timer.Start();
                VLOG(2) << "Start shrink table"; 
                for (auto& executor : _executors) {
                    const auto& table_accessors = executor->table_accessors();
                    for (auto& itr : table_accessors) {
                        CHECK(itr.second[0]->shrink() == 0);
                    }
                } 
X
xiexionghang 已提交
196
                VLOG(2) << "End shrink table, cost:" << timer.ElapsedSec();
X
xiexionghang 已提交
197 198
            }
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
199

X
xiexionghang 已提交
200 201
            epoch_accessor->epoch_done(epoch_id);
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
202
        }
X
xiexionghang 已提交
203
    
X
xiexionghang 已提交
204 205 206 207 208 209 210 211 212 213
        //Step4. Output Monitor && RunStatus
        //TODO
    }
    
    return 0;
}

}  // namespace feed
}  // namespace custom_trainer
}  // namespace paddle