learner_process.cc 7.7 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5
/*
 *Author: xiexionghang
 *Train样本
 */
#include <omp.h>
X
xiexionghang 已提交
6
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
X
xiexionghang 已提交
7
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset.h"
X
xiexionghang 已提交
8 9 10 11 12 13 14 15 16 17 18
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/process/learner_process.h"

namespace paddle {
namespace custom_trainer {
namespace feed {

int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
    int ret = Process::initialize(context_ptr);
    auto& config = _context_ptr->trainer_config;
    if (config["executor"]) {
X
xiexionghang 已提交
19 20 21 22
        _executors.resize(config["executor"].size());
        for (size_t i = 0; i < _executors.size(); ++i) {
            _executors[i].reset(new MultiThreadExecutor());
            CHECK(_executors[i]->initialize(config["executor"][i], context_ptr) == 0);
X
xiexionghang 已提交
23 24 25 26 27
        }
    }
    return 0;
}

X
xiexionghang 已提交
28
std::future<int> LearnerProcess::save_model(uint64_t epoch_id, int table_id, ModelSaveWay way) {
X
xiexionghang 已提交
29 30
    std::promise<int> p;
    auto ret = p.get_future();
X
xiexionghang 已提交
31 32 33 34 35 36
    auto* ps_client = _context_ptr->pslib->ps_client();
    auto* epoch_accessor = _context_ptr->epoch_accessor.get();
    if (epoch_accessor->need_save_model(epoch_id, way)) {
        VLOG(2) << "Start save model, table_id:" << table_id;
        auto model_dir = epoch_accessor->model_save_path(epoch_id, way);
        return ps_client->save(table_id, model_dir, std::to_string((int)way));
X
xiexionghang 已提交
37 38 39 40 41 42
    } else {
        p.set_value(0);
    }
    return ret;
}

X
xiexionghang 已提交
43
int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way) {
X
xiexionghang 已提交
44
    auto* environment = _context_ptr->environment.get();
X
xiexionghang 已提交
45
    if (!environment->is_master_node(EnvironmentRole::WORKER)) {
X
xiexionghang 已提交
46 47
        return 0;
    }
X
xiexionghang 已提交
48 49 50 51 52 53 54
    std::set<uint32_t> table_set;
    for (auto& executor : _executors) {
        const auto& table_accessors = executor->table_accessors();
        for (auto& itr : table_accessors) {
            table_set.insert(itr.first);
        }
    }
X
xiexionghang 已提交
55
    int ret_size = 0;
X
xiexionghang 已提交
56
    auto table_num = table_set.size();
X
xiexionghang 已提交
57
    std::future<int> rets[table_num];
X
xiexionghang 已提交
58
    for (auto table_id : table_set) {
X
xiexionghang 已提交
59 60 61 62 63 64 65 66 67 68
        rets[ret_size++] = save_model(epoch_id, table_id, way); 
    }
    int all_ret = 0;
    for (int i = 0; i < ret_size; ++i) {
        rets[i].wait();
        all_ret |= rets[i].get();
    }
    return all_ret;
}

X
xiexionghang 已提交
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
int LearnerProcess::load_model(uint64_t epoch_id) {
    auto* environment = _context_ptr->environment.get();
    if (!environment->is_master_node(EnvironmentRole::WORKER)) {
        return 0;
    }
    std::set<uint32_t> loaded_table_set;
    auto model_dir = _context_ptr->epoch_accessor->checkpoint_path();
    for (auto& executor : _executors) {
        const auto& table_accessors = executor->table_accessors();
        for (auto& itr : table_accessors) {
            if (loaded_table_set.count(itr.first)) {
                continue;
            }
            auto table_model_path = _context_ptr->file_system->path_join(
                model_dir, string::format_string("%03d", itr.first));
            if (_context_ptr->file_system->list(table_model_path).size() == 0) {
                VLOG(2) << "miss table_model:" << table_model_path << ", initialize by default";
                auto scope = std::move(executor->fetch_scope());
                CHECK(itr.second[0]->create(scope.get()) == 0);
            } else {
                auto status = _context_ptr->ps_client()->load(itr.first, 
                    model_dir, std::to_string((int)ModelSaveWay::ModelSaveTrainCheckpoint));
                CHECK(status.get() == 0) << "table load failed, id:" << itr.first;
            }
            loaded_table_set.insert(itr.first);
        }
    }
    return 0;
}

X
xiexionghang 已提交
99
int LearnerProcess::run() {
X
xiexionghang 已提交
100
    auto* dataset = _context_ptr->dataset.get();
X
xiexionghang 已提交
101 102
    auto* environment = _context_ptr->environment.get();
    auto* epoch_accessor = _context_ptr->epoch_accessor.get(); 
X
xiexionghang 已提交
103
    uint64_t epoch_id = epoch_accessor->current_epoch_id();
X
xiexionghang 已提交
104

X
xiexionghang 已提交
105
    environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
X
xiexionghang 已提交
106
        "Resume train with epoch_id:%d %s", epoch_id, _context_ptr->epoch_accessor->text(epoch_id).c_str());
X
xiexionghang 已提交
107
    
X
xiexionghang 已提交
108 109 110 111
    //尝试加载模型 or 初始化
    CHECK(load_model(epoch_id) == 0);
    environment->barrier(EnvironmentRole::WORKER); 

X
xiexionghang 已提交
112 113
    //判断是否先dump出base
    wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);
X
xiexionghang 已提交
114
    environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
115 116 117
    
    while (true) {
        epoch_accessor->next_epoch();
X
xiexionghang 已提交
118
        bool already_dump_inference_model = false;
X
xiexionghang 已提交
119
        epoch_id = epoch_accessor->current_epoch_id();
X
xiexionghang 已提交
120
        std::string epoch_log_title = paddle::string::format_string(
X
xiexionghang 已提交
121
            "train epoch_id:%d label:%s", epoch_id, epoch_accessor->text(epoch_id).c_str());
X
xiexionghang 已提交
122
        std::string data_path = paddle::string::to_string<std::string>(dataset->epoch_data_path(epoch_id));
X
xiexionghang 已提交
123 124
        
        //Step1. 等待样本ready
X
xiexionghang 已提交
125
        {
X
xiexionghang 已提交
126
            environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
X
xiexionghang 已提交
127
                "%s, wait data ready:%s", epoch_log_title.c_str(), data_path.c_str());
X
xiexionghang 已提交
128 129 130 131
            while (dataset->epoch_data_status(epoch_id) != DatasetStatus::Ready) {
                sleep(30);  
                dataset->pre_detect_data(epoch_id);
                environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
X
xiexionghang 已提交
132
                "data not ready, wait 30s");
X
xiexionghang 已提交
133 134
            } 
            environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
X
xiexionghang 已提交
135
                "Start %s, data is ready", epoch_log_title.c_str());
X
xiexionghang 已提交
136 137 138
            environment->barrier(EnvironmentRole::WORKER); 
        }
    
X
xiexionghang 已提交
139
        //Step2. 运行训练网络
X
xiexionghang 已提交
140
        {
X
xiexionghang 已提交
141 142
            std::map<std::string, paddle::framework::Channel<DataItem>> backup_input_map;
            for (auto& executor : _executors) {
X
xiexionghang 已提交
143
                environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
144
                VLOG(2) << "Start executor:" << executor->train_exe_name();
X
xiexionghang 已提交
145 146 147 148 149 150 151 152
                auto data_name = executor->train_data_name();
                paddle::framework::Channel<DataItem> input_channel;
                if (backup_input_map.count(data_name)) {
                    input_channel = backup_input_map[data_name];
                } else {
                    input_channel = dataset->fetch_data(data_name, epoch_id);
                }
                input_channel = executor->run(input_channel, dataset->data_parser(data_name));
X
xiexionghang 已提交
153 154 155 156 157 158
                VLOG(2) << "End executor:" << executor->train_exe_name();

                // 等待异步梯度完成
                _context_ptr->ps_client()->flush();
                environment->barrier(EnvironmentRole::WORKER); 

X
xiexionghang 已提交
159
                if (executor->is_dump_all_model()) {
X
xiexionghang 已提交
160 161 162
                    already_dump_inference_model = true;
                    wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta);
                }
X
xiexionghang 已提交
163
                backup_input_map[data_name] = input_channel;
X
xiexionghang 已提交
164
                environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
165
            }
X
xiexionghang 已提交
166
        }
X
xiexionghang 已提交
167

X
xiexionghang 已提交
168 169 170
        //Step3. Dump Model For Delta&&Checkpoint
        {
            if (!already_dump_inference_model) {
X
xiexionghang 已提交
171 172
                already_dump_inference_model = true;
                wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta);
X
xiexionghang 已提交
173 174
            } 
            wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpoint);
X
xiexionghang 已提交
175
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
176

X
xiexionghang 已提交
177 178 179
            epoch_accessor->epoch_done(epoch_id);
            environment->barrier(EnvironmentRole::WORKER); 

X
xiexionghang 已提交
180
        }
X
xiexionghang 已提交
181
    
X
xiexionghang 已提交
182 183 184 185 186 187 188 189 190 191
        //Step4. Output Monitor && RunStatus
        //TODO
    }
    
    return 0;
}

}  // namespace feed
}  // namespace custom_trainer
}  // namespace paddle