learner_process.cc 5.3 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5
/*
 *Author: xiexionghang
 *Train样本
 */
#include <omp.h>
X
xiexionghang 已提交
6
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset.h"
X
xiexionghang 已提交
7 8 9 10 11 12 13 14 15 16 17
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/process/learner_process.h"

namespace paddle {
namespace custom_trainer {
namespace feed {

int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
    int ret = Process::initialize(context_ptr);
    auto& config = _context_ptr->trainer_config;
    if (config["executor"]) {
X
xiexionghang 已提交
18 19 20 21
        _executors.resize(config["executor"].size());
        for (size_t i = 0; i < _executors.size(); ++i) {
            _executors[i].reset(new MultiThreadExecutor());
            CHECK(_executors[i]->initialize(config["executor"][i], context_ptr) == 0);
X
xiexionghang 已提交
22 23 24 25 26
        }
    }
    return 0;
}

X
xiexionghang 已提交
27
std::future<int> LearnerProcess::save_model(uint64_t epoch_id, int table_id, ModelSaveWay way) {
X
xiexionghang 已提交
28 29 30
    std::promise<int> p;
    auto ret = p.get_future();
    if (_context_ptr->epoch_accessor->need_save_model(epoch_id, way)) {
X
xiexionghang 已提交
31
        LOG(NOTICE) << "save table, table_id:" << table_id;
X
xiexionghang 已提交
32 33 34 35 36 37
    } else {
        p.set_value(0);
    }
    return ret;
}

X
xiexionghang 已提交
38
int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way) {
X
xiexionghang 已提交
39
    auto* environment = _context_ptr->environment.get();
X
xiexionghang 已提交
40
    if (!environment->is_master_node(EnvironmentRole::WORKER)) {
X
xiexionghang 已提交
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
        return 0;
    }
    int ret_size = 0;
    auto table_num = _context_ptr->params_table_list.size();
    std::future<int> rets[table_num];
    for (int i = 0; i < table_num; ++i) {
        auto table_id = _context_ptr->params_table_list[i].table_id();
        rets[ret_size++] = save_model(epoch_id, table_id, way); 
    }

    int all_ret = 0;
    for (int i = 0; i < ret_size; ++i) {
        rets[i].wait();
        all_ret |= rets[i].get();
    }
    return all_ret;
}

int LearnerProcess::run() {
X
xiexionghang 已提交
60
    auto* dataset = _context_ptr->dataset.get();
X
xiexionghang 已提交
61 62
    auto* environment = _context_ptr->environment.get();
    auto* epoch_accessor = _context_ptr->epoch_accessor.get(); 
X
xiexionghang 已提交
63
    uint64_t epoch_id = epoch_accessor->current_epoch_id();
X
xiexionghang 已提交
64

X
xiexionghang 已提交
65
    environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
X
xiexionghang 已提交
66 67 68 69
        "Resume traine with epoch_id:%d label:%s", epoch_id, _context_ptr->epoch_accessor->text(epoch_id).c_str());
    
    //判断是否先dump出base
    wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);
X
xiexionghang 已提交
70
    environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
71 72 73
    
    while (true) {
        epoch_accessor->next_epoch();
X
xiexionghang 已提交
74
        bool already_dump_inference_model = false;
X
xiexionghang 已提交
75 76 77 78 79
        epoch_id = epoch_accessor->current_epoch_id();
        std::string epoch_log_title= paddle::string::format_string(
            "train epoch_id:%d label:%s", epoch_id, epoch_accessor->text(epoch_id).c_str());
        
        //Step1. 等待样本ready
X
xiexionghang 已提交
80
        {
X
xiexionghang 已提交
81
            environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
X
xiexionghang 已提交
82 83 84 85 86 87 88 89 90 91 92 93
                "Start %s, wait data ready", epoch_log_title.c_str());
            while (dataset->epoch_data_status(epoch_id) != DatasetStatus::Ready) {
                sleep(30);  
                dataset->pre_detect_data(epoch_id);
                environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
                    "%s, data not ready, wait 30s", epoch_log_title.c_str());
            } 
            environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
                "%s, data is ready, start traning", epoch_log_title.c_str());
            environment->barrier(EnvironmentRole::WORKER); 
        }
    
X
xiexionghang 已提交
94
        //Step2. 运行训练网络
X
xiexionghang 已提交
95
        {
X
xiexionghang 已提交
96 97
            std::map<std::string, paddle::framework::Channel<DataItem>> backup_input_map;
            for (auto& executor : _executors) {
X
xiexionghang 已提交
98
                environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
99 100 101 102 103 104 105 106 107
                auto data_name = executor->train_data_name();
                paddle::framework::Channel<DataItem> input_channel;
                if (backup_input_map.count(data_name)) {
                    input_channel = backup_input_map[data_name];
                } else {
                    input_channel = dataset->fetch_data(data_name, epoch_id);
                }
                input_channel = executor->run(input_channel, dataset->data_parser(data_name));
                if (executor->is_dump_all_model()) {
X
xiexionghang 已提交
108 109 110
                    already_dump_inference_model = true;
                    wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta);
                }
X
xiexionghang 已提交
111
                backup_input_map[data_name] = input_channel;
X
xiexionghang 已提交
112
                environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
113
            }
X
xiexionghang 已提交
114
        }
X
xiexionghang 已提交
115

X
xiexionghang 已提交
116 117 118
        //Step3. Dump Model For Delta&&Checkpoint
        {
            if (!already_dump_inference_model) {
X
xiexionghang 已提交
119 120
                already_dump_inference_model = true;
                wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta);
X
xiexionghang 已提交
121 122
            } 
            wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpoint);
X
xiexionghang 已提交
123
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
124
        }
X
xiexionghang 已提交
125
    
X
xiexionghang 已提交
126 127 128 129 130 131 132 133 134 135
        //Step4. Output Monitor && RunStatus
        //TODO
    }
    
    return 0;
}

}  // namespace feed
}  // namespace custom_trainer
}  // namespace paddle