learner_process.cc 8.5 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5
/*
 *Author: xiexionghang
 *Train样本
 */
#include <omp.h>
6
#include "paddle/fluid/platform/timer.h"
X
xiexionghang 已提交
7
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
X
xiexionghang 已提交
8
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset.h"
X
xiexionghang 已提交
9 10 11 12 13 14 15 16 17 18 19
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/process/learner_process.h"

namespace paddle {
namespace custom_trainer {
namespace feed {

int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
    int ret = Process::initialize(context_ptr);
    auto& config = _context_ptr->trainer_config;
    if (config["executor"]) {
X
xiexionghang 已提交
20 21 22 23
        _executors.resize(config["executor"].size());
        for (size_t i = 0; i < _executors.size(); ++i) {
            _executors[i].reset(new MultiThreadExecutor());
            CHECK(_executors[i]->initialize(config["executor"][i], context_ptr) == 0);
X
xiexionghang 已提交
24 25 26 27 28
        }
    }
    return 0;
}

X
xiexionghang 已提交
29
int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way) {
30
    auto* ps_client = _context_ptr->pslib->ps_client();
X
xiexionghang 已提交
31
    auto* environment = _context_ptr->environment.get();
32
    auto* epoch_accessor = _context_ptr->epoch_accessor.get();
X
xiexionghang 已提交
33
    if (!environment->is_master_node(EnvironmentRole::WORKER)) {
X
xiexionghang 已提交
34 35
        return 0;
    }
36 37 38 39 40
    if (!epoch_accessor->need_save_model(epoch_id, way)) {
        return 0;
    }
    paddle::platform::Timer timer;
    timer.Start();
X
xiexionghang 已提交
41 42 43 44 45 46 47
    std::set<uint32_t> table_set;
    for (auto& executor : _executors) {
        const auto& table_accessors = executor->table_accessors();
        for (auto& itr : table_accessors) {
            table_set.insert(itr.first);
        }
    }
X
xiexionghang 已提交
48
    int ret_size = 0;
X
xiexionghang 已提交
49
    auto table_num = table_set.size();
X
xiexionghang 已提交
50
    std::future<int> rets[table_num];
X
xiexionghang 已提交
51
    for (auto table_id : table_set) {
52 53 54
        VLOG(2) << "Start save model, table_id:" << table_id;
        auto model_dir = epoch_accessor->model_save_path(epoch_id, way);
        rets[ret_size++] = ps_client->save(table_id, model_dir, std::to_string((int)way));
X
xiexionghang 已提交
55 56 57 58 59 60
    }
    int all_ret = 0;
    for (int i = 0; i < ret_size; ++i) {
        rets[i].wait();
        all_ret |= rets[i].get();
    }
61 62 63
    timer.Pause();
    VLOG(2) << "Save Model Cost(s):" << timer.ElapsedSec();
    _context_ptr->epoch_accessor->update_model_donefile(epoch_id, way);
X
xiexionghang 已提交
64 65 66
    return all_ret;
}

X
xiexionghang 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
int LearnerProcess::load_model(uint64_t epoch_id) {
    auto* environment = _context_ptr->environment.get();
    if (!environment->is_master_node(EnvironmentRole::WORKER)) {
        return 0;
    }
    std::set<uint32_t> loaded_table_set;
    auto model_dir = _context_ptr->epoch_accessor->checkpoint_path();
    for (auto& executor : _executors) {
        const auto& table_accessors = executor->table_accessors();
        for (auto& itr : table_accessors) {
            if (loaded_table_set.count(itr.first)) {
                continue;
            }
            auto table_model_path = _context_ptr->file_system->path_join(
                model_dir, string::format_string("%03d", itr.first));
            if (_context_ptr->file_system->list(table_model_path).size() == 0) {
                VLOG(2) << "miss table_model:" << table_model_path << ", initialize by default";
                auto scope = std::move(executor->fetch_scope());
                CHECK(itr.second[0]->create(scope.get()) == 0);
            } else {
                auto status = _context_ptr->ps_client()->load(itr.first, 
                    model_dir, std::to_string((int)ModelSaveWay::ModelSaveTrainCheckpoint));
                CHECK(status.get() == 0) << "table load failed, id:" << itr.first;
            }
            loaded_table_set.insert(itr.first);
        }
    }
    return 0;
}

X
xiexionghang 已提交
97
int LearnerProcess::run() {
X
xiexionghang 已提交
98
    auto* dataset = _context_ptr->dataset.get();
X
xiexionghang 已提交
99 100
    auto* environment = _context_ptr->environment.get();
    auto* epoch_accessor = _context_ptr->epoch_accessor.get(); 
X
xiexionghang 已提交
101
    uint64_t epoch_id = epoch_accessor->current_epoch_id();
X
xiexionghang 已提交
102

X
xiexionghang 已提交
103
    environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
X
xiexionghang 已提交
104
        "Resume train with epoch_id:%d %s", epoch_id, _context_ptr->epoch_accessor->text(epoch_id).c_str());
X
xiexionghang 已提交
105
    
X
xiexionghang 已提交
106 107 108 109
    //尝试加载模型 or 初始化
    CHECK(load_model(epoch_id) == 0);
    environment->barrier(EnvironmentRole::WORKER); 

X
xiexionghang 已提交
110 111
    //判断是否先dump出base
    wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);
X
xiexionghang 已提交
112
    environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
113 114 115
    
    while (true) {
        epoch_accessor->next_epoch();
116
        _context_ptr->monitor_ssm.str(""); 
X
xiexionghang 已提交
117
        bool already_dump_inference_model = false;
X
xiexionghang 已提交
118
        epoch_id = epoch_accessor->current_epoch_id();
X
xiexionghang 已提交
119
        std::string epoch_log_title = paddle::string::format_string(
X
xiexionghang 已提交
120
            "train epoch_id:%d label:%s", epoch_id, epoch_accessor->text(epoch_id).c_str());
X
xiexionghang 已提交
121
        std::string data_path = paddle::string::to_string<std::string>(dataset->epoch_data_path(epoch_id));
X
xiexionghang 已提交
122 123
        
        //Step1. 等待样本ready
X
xiexionghang 已提交
124
        {
X
xiexionghang 已提交
125
            environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
X
xiexionghang 已提交
126
                "%s, wait data ready:%s", epoch_log_title.c_str(), data_path.c_str());
X
xiexionghang 已提交
127 128 129 130
            while (dataset->epoch_data_status(epoch_id) != DatasetStatus::Ready) {
                sleep(30);  
                dataset->pre_detect_data(epoch_id);
                environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
X
xiexionghang 已提交
131
                "data not ready, wait 30s");
X
xiexionghang 已提交
132 133
            } 
            environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
X
xiexionghang 已提交
134
                "Start %s, data is ready", epoch_log_title.c_str());
X
xiexionghang 已提交
135 136 137
            environment->barrier(EnvironmentRole::WORKER); 
        }
    
X
xiexionghang 已提交
138
        //Step2. 运行训练网络
X
xiexionghang 已提交
139
        {
X
xiexionghang 已提交
140 141
            std::map<std::string, paddle::framework::Channel<DataItem>> backup_input_map;
            for (auto& executor : _executors) {
X
xiexionghang 已提交
142
                environment->barrier(EnvironmentRole::WORKER); 
143 144
                paddle::platform::Timer timer;
                timer.Start();
X
xiexionghang 已提交
145
                VLOG(2) << "Start executor:" << executor->train_exe_name();
X
xiexionghang 已提交
146 147 148 149 150 151 152 153
                auto data_name = executor->train_data_name();
                paddle::framework::Channel<DataItem> input_channel;
                if (backup_input_map.count(data_name)) {
                    input_channel = backup_input_map[data_name];
                } else {
                    input_channel = dataset->fetch_data(data_name, epoch_id);
                }
                input_channel = executor->run(input_channel, dataset->data_parser(data_name));
154 155
                timer.Pause();
                VLOG(2) << "End executor:" << executor->train_exe_name() << ", cost" << timer.ElapsedSec();
X
xiexionghang 已提交
156 157 158 159

                // 等待异步梯度完成
                _context_ptr->ps_client()->flush();
                environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
160
                if (executor->is_dump_all_model()) {
X
xiexionghang 已提交
161 162 163
                    already_dump_inference_model = true;
                    wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta);
                }
X
xiexionghang 已提交
164
                backup_input_map[data_name] = input_channel;
X
xiexionghang 已提交
165
                environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
166
            }
X
xiexionghang 已提交
167
        }
X
xiexionghang 已提交
168

X
xiexionghang 已提交
169 170
        //Step3. Dump Model For Delta&&Checkpoint
        {
171
            wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);
X
xiexionghang 已提交
172
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
173
            wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpoint);
X
xiexionghang 已提交
174
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
175 176 177 178 179 180 181 182 183 184 185 186 187 188
            if (epoch_accessor->is_last_epoch(epoch_id) &&
                environment->is_master_node(EnvironmentRole::WORKER)) {
                paddle::platform::Timer timer;
                timer.Start();
                VLOG(2) << "Start shrink table"; 
                for (auto& executor : _executors) {
                    const auto& table_accessors = executor->table_accessors();
                    for (auto& itr : table_accessors) {
                        CHECK(itr.second[0]->shrink() == 0);
                    }
                } 
                VLOG(2) << "End shrink table, cost" << timer.ElapsedSec();
            }
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
189

X
xiexionghang 已提交
190 191
            epoch_accessor->epoch_done(epoch_id);
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
192
        }
X
xiexionghang 已提交
193
    
X
xiexionghang 已提交
194 195 196 197 198 199 200 201 202 203
        //Step4. Output Monitor && RunStatus
        //TODO
    }
    
    return 0;
}

}  // namespace feed
}  // namespace custom_trainer
}  // namespace paddle