learner_process.cc 8.7 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5
/*
 *Author: xiexionghang
 *Train样本
 */
#include <omp.h>
6
#include "paddle/fluid/platform/timer.h"
X
xiexionghang 已提交
7
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
X
xiexionghang 已提交
8
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset.h"
X
xiexionghang 已提交
9 10 11 12 13 14 15 16 17 18 19
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/process/learner_process.h"

namespace paddle {
namespace custom_trainer {
namespace feed {

int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
    int ret = Process::initialize(context_ptr);
    auto& config = _context_ptr->trainer_config;
    if (config["executor"]) {
X
xiexionghang 已提交
20 21 22 23
        _executors.resize(config["executor"].size());
        for (size_t i = 0; i < _executors.size(); ++i) {
            _executors[i].reset(new MultiThreadExecutor());
            CHECK(_executors[i]->initialize(config["executor"][i], context_ptr) == 0);
X
xiexionghang 已提交
24 25 26 27 28
        }
    }
    return 0;
}

X
xiexionghang 已提交
29
int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way) {
R
rensilin 已提交
30
    auto fs = _context_ptr->file_system;
31
    auto* ps_client = _context_ptr->pslib->ps_client();
X
xiexionghang 已提交
32
    auto* environment = _context_ptr->environment.get();
33
    auto* epoch_accessor = _context_ptr->epoch_accessor.get();
X
xiexionghang 已提交
34
    if (!environment->is_master_node(EnvironmentRole::WORKER)) {
X
xiexionghang 已提交
35 36
        return 0;
    }
37 38 39 40 41
    if (!epoch_accessor->need_save_model(epoch_id, way)) {
        return 0;
    }
    paddle::platform::Timer timer;
    timer.Start();
X
xiexionghang 已提交
42
    std::set<uint32_t> table_set;
R
rensilin 已提交
43
    auto model_dir = epoch_accessor->model_save_path(epoch_id, way);
X
xiexionghang 已提交
44 45 46 47 48
    for (auto& executor : _executors) {
        const auto& table_accessors = executor->table_accessors();
        for (auto& itr : table_accessors) {
            table_set.insert(itr.first);
        }
R
rensilin 已提交
49 50 51
        auto save_path = fs->path_join(model_dir, executor->train_exe_name() + "_param");
        VLOG(2) << "Start save model, save_path:" << save_path;
        executor->save_persistables(save_path);
X
xiexionghang 已提交
52
    }
X
xiexionghang 已提交
53
    int ret_size = 0;
X
xiexionghang 已提交
54
    auto table_num = table_set.size();
X
xiexionghang 已提交
55
    std::future<int> rets[table_num];
X
xiexionghang 已提交
56
    for (auto table_id : table_set) {
57 58
        VLOG(2) << "Start save model, table_id:" << table_id;
        rets[ret_size++] = ps_client->save(table_id, model_dir, std::to_string((int)way));
X
xiexionghang 已提交
59 60 61 62 63 64
    }
    int all_ret = 0;
    for (int i = 0; i < ret_size; ++i) {
        rets[i].wait();
        all_ret |= rets[i].get();
    }
65 66 67
    timer.Pause();
    VLOG(2) << "Save Model Cost(s):" << timer.ElapsedSec();
    _context_ptr->epoch_accessor->update_model_donefile(epoch_id, way);
X
xiexionghang 已提交
68 69 70
    return all_ret;
}

X
xiexionghang 已提交
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
int LearnerProcess::load_model(uint64_t epoch_id) {
    auto* environment = _context_ptr->environment.get();
    if (!environment->is_master_node(EnvironmentRole::WORKER)) {
        return 0;
    }
    std::set<uint32_t> loaded_table_set;
    auto model_dir = _context_ptr->epoch_accessor->checkpoint_path();
    for (auto& executor : _executors) {
        const auto& table_accessors = executor->table_accessors();
        for (auto& itr : table_accessors) {
            if (loaded_table_set.count(itr.first)) {
                continue;
            }
            auto table_model_path = _context_ptr->file_system->path_join(
                model_dir, string::format_string("%03d", itr.first));
            if (_context_ptr->file_system->list(table_model_path).size() == 0) {
                VLOG(2) << "miss table_model:" << table_model_path << ", initialize by default";
                auto scope = std::move(executor->fetch_scope());
                CHECK(itr.second[0]->create(scope.get()) == 0);
            } else {
                auto status = _context_ptr->ps_client()->load(itr.first, 
                    model_dir, std::to_string((int)ModelSaveWay::ModelSaveTrainCheckpoint));
                CHECK(status.get() == 0) << "table load failed, id:" << itr.first;
            }
            loaded_table_set.insert(itr.first);
        }
    }
    return 0;
}

X
xiexionghang 已提交
101
int LearnerProcess::run() {
X
xiexionghang 已提交
102
    auto* dataset = _context_ptr->dataset.get();
X
xiexionghang 已提交
103 104
    auto* environment = _context_ptr->environment.get();
    auto* epoch_accessor = _context_ptr->epoch_accessor.get(); 
X
xiexionghang 已提交
105
    uint64_t epoch_id = epoch_accessor->current_epoch_id();
X
xiexionghang 已提交
106

X
xiexionghang 已提交
107
    environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
X
xiexionghang 已提交
108
        "Resume train with epoch_id:%d %s", epoch_id, _context_ptr->epoch_accessor->text(epoch_id).c_str());
X
xiexionghang 已提交
109
    
X
xiexionghang 已提交
110 111 112 113
    //尝试加载模型 or 初始化
    CHECK(load_model(epoch_id) == 0);
    environment->barrier(EnvironmentRole::WORKER); 

X
xiexionghang 已提交
114 115
    //判断是否先dump出base
    wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);
X
xiexionghang 已提交
116
    environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
117 118 119
    
    while (true) {
        epoch_accessor->next_epoch();
120
        _context_ptr->monitor_ssm.str(""); 
X
xiexionghang 已提交
121
        bool already_dump_inference_model = false;
X
xiexionghang 已提交
122
        epoch_id = epoch_accessor->current_epoch_id();
X
xiexionghang 已提交
123
        std::string epoch_log_title = paddle::string::format_string(
X
xiexionghang 已提交
124
            "train epoch_id:%d label:%s", epoch_id, epoch_accessor->text(epoch_id).c_str());
X
xiexionghang 已提交
125
        std::string data_path = paddle::string::to_string<std::string>(dataset->epoch_data_path(epoch_id));
X
xiexionghang 已提交
126 127
        
        //Step1. 等待样本ready
X
xiexionghang 已提交
128
        {
X
xiexionghang 已提交
129
            environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
X
xiexionghang 已提交
130
                "%s, wait data ready:%s", epoch_log_title.c_str(), data_path.c_str());
X
xiexionghang 已提交
131 132 133 134
            while (dataset->epoch_data_status(epoch_id) != DatasetStatus::Ready) {
                sleep(30);  
                dataset->pre_detect_data(epoch_id);
                environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
X
xiexionghang 已提交
135
                "data not ready, wait 30s");
X
xiexionghang 已提交
136 137
            } 
            environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
X
xiexionghang 已提交
138
                "Start %s, data is ready", epoch_log_title.c_str());
X
xiexionghang 已提交
139 140 141
            environment->barrier(EnvironmentRole::WORKER); 
        }
    
X
xiexionghang 已提交
142
        //Step2. 运行训练网络
X
xiexionghang 已提交
143
        {
X
xiexionghang 已提交
144 145
            std::map<std::string, paddle::framework::Channel<DataItem>> backup_input_map;
            for (auto& executor : _executors) {
X
xiexionghang 已提交
146
                environment->barrier(EnvironmentRole::WORKER); 
147 148
                paddle::platform::Timer timer;
                timer.Start();
X
xiexionghang 已提交
149
                VLOG(2) << "Start executor:" << executor->train_exe_name();
X
xiexionghang 已提交
150 151 152 153 154 155 156 157
                auto data_name = executor->train_data_name();
                paddle::framework::Channel<DataItem> input_channel;
                if (backup_input_map.count(data_name)) {
                    input_channel = backup_input_map[data_name];
                } else {
                    input_channel = dataset->fetch_data(data_name, epoch_id);
                }
                input_channel = executor->run(input_channel, dataset->data_parser(data_name));
158 159
                timer.Pause();
                VLOG(2) << "End executor:" << executor->train_exe_name() << ", cost" << timer.ElapsedSec();
X
xiexionghang 已提交
160 161 162 163

                // 等待异步梯度完成
                _context_ptr->ps_client()->flush();
                environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
164
                if (executor->is_dump_all_model()) {
X
xiexionghang 已提交
165 166 167
                    already_dump_inference_model = true;
                    wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta);
                }
X
xiexionghang 已提交
168
                backup_input_map[data_name] = input_channel;
X
xiexionghang 已提交
169
                environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
170
            }
X
xiexionghang 已提交
171
        }
X
xiexionghang 已提交
172

X
xiexionghang 已提交
173 174
        //Step3. Dump Model For Delta&&Checkpoint
        {
175
            wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);
X
xiexionghang 已提交
176
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
177
            wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpoint);
X
xiexionghang 已提交
178
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
179 180 181 182 183 184 185 186 187 188 189 190 191 192
            if (epoch_accessor->is_last_epoch(epoch_id) &&
                environment->is_master_node(EnvironmentRole::WORKER)) {
                paddle::platform::Timer timer;
                timer.Start();
                VLOG(2) << "Start shrink table"; 
                for (auto& executor : _executors) {
                    const auto& table_accessors = executor->table_accessors();
                    for (auto& itr : table_accessors) {
                        CHECK(itr.second[0]->shrink() == 0);
                    }
                } 
                VLOG(2) << "End shrink table, cost" << timer.ElapsedSec();
            }
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
193

X
xiexionghang 已提交
194 195
            epoch_accessor->epoch_done(epoch_id);
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
196
        }
X
xiexionghang 已提交
197
    
X
xiexionghang 已提交
198 199 200 201 202 203 204 205 206 207
        //Step4. Output Monitor && RunStatus
        //TODO
    }
    
    return 0;
}

}  // namespace feed
}  // namespace custom_trainer
}  // namespace paddle