learner_process.cc 5.8 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5
/*
 *Author: xiexionghang
 *Train样本
 */
#include <omp.h>
X
xiexionghang 已提交
6
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset.h"
X
xiexionghang 已提交
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/process/learner_process.h"

namespace paddle {
namespace custom_trainer {
namespace feed {

int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
    int ret = Process::initialize(context_ptr);
    auto& config = _context_ptr->trainer_config;
    _train_thread_num = config["train_thread_num"].as<int>();
    _threads_executor.resize(_train_thread_num);
    
    if (config["executor"]) {
        _executor_num = config["executor"].size();
        omp_set_num_threads(_train_thread_num);
        #pragma omp parallel for
        for (int i = 0; i < _train_thread_num; ++i) {
            _threads_executor[i].resize(_executor_num);
            for (int e = 0; e < _executor_num; ++e) {
                auto e_class = config["executor"][e]["class"].as<std::string>();
X
xiexionghang 已提交
28
                auto* e_ptr = CREATE_INSTANCE(Executor, e_class);
X
xiexionghang 已提交
29 30 31 32 33 34 35 36 37 38
                _threads_executor[i][e].reset(e_ptr);  
                if (e_ptr->initialize(config["executor"][e], context_ptr) != 0) {
                    ret = -1;
                }
            }
        }
    }
    return 0;
}

X
xiexionghang 已提交
39
std::future<int> LearnerProcess::save_model(uint64_t epoch_id, int table_id, ModelSaveWay way) {
X
xiexionghang 已提交
40 41 42 43 44 45 46 47 48 49 50
    std::promise<int> p;
    auto ret = p.get_future();
    if (_context_ptr->epoch_accessor->need_save_model(epoch_id, way)) {
        //TODO
        //context_ptr->pslib_client()->save();
    } else {
        p.set_value(0);
    }
    return ret;
}

X
xiexionghang 已提交
51
int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way) {
X
xiexionghang 已提交
52
    auto* environment = _context_ptr->environment.get();
X
xiexionghang 已提交
53
    if (!environment->is_master_node(EnvironmentRole::WORKER)) {
X
xiexionghang 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
        return 0;
    }
    int ret_size = 0;
    auto table_num = _context_ptr->params_table_list.size();
    std::future<int> rets[table_num];
    for (int i = 0; i < table_num; ++i) {
        auto table_id = _context_ptr->params_table_list[i].table_id();
        rets[ret_size++] = save_model(epoch_id, table_id, way); 
    }

    int all_ret = 0;
    for (int i = 0; i < ret_size; ++i) {
        rets[i].wait();
        all_ret |= rets[i].get();
    }
    return all_ret;
}

int LearnerProcess::run() {
X
xiexionghang 已提交
73
    auto* dataset = _context_ptr->dataset.get();
X
xiexionghang 已提交
74 75
    auto* environment = _context_ptr->environment.get();
    auto* epoch_accessor = _context_ptr->epoch_accessor.get(); 
X
xiexionghang 已提交
76
    uint64_t epoch_id = epoch_accessor->current_epoch_id();
X
xiexionghang 已提交
77

X
xiexionghang 已提交
78
    environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
X
xiexionghang 已提交
79 80 81 82
        "Resume traine with epoch_id:%d label:%s", epoch_id, _context_ptr->epoch_accessor->text(epoch_id).c_str());
    
    //判断是否先dump出base
    wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);
X
xiexionghang 已提交
83
    environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
84 85 86
    
    while (true) {
        epoch_accessor->next_epoch();
X
xiexionghang 已提交
87
        bool already_dump_inference_model = false;
X
xiexionghang 已提交
88 89 90 91 92
        epoch_id = epoch_accessor->current_epoch_id();
        std::string epoch_log_title= paddle::string::format_string(
            "train epoch_id:%d label:%s", epoch_id, epoch_accessor->text(epoch_id).c_str());
        
        //Step1. 等待样本ready
X
xiexionghang 已提交
93
        {
X
xiexionghang 已提交
94
            environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
X
xiexionghang 已提交
95 96 97 98 99 100 101 102 103 104 105 106
                "Start %s, wait data ready", epoch_log_title.c_str());
            while (dataset->epoch_data_status(epoch_id) != DatasetStatus::Ready) {
                sleep(30);  
                dataset->pre_detect_data(epoch_id);
                environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
                    "%s, data not ready, wait 30s", epoch_log_title.c_str());
            } 
            environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
                "%s, data is ready, start traning", epoch_log_title.c_str());
            environment->barrier(EnvironmentRole::WORKER); 
        }
    
X
xiexionghang 已提交
107
        //Step2. 运行训练网络
X
xiexionghang 已提交
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
        {
            for (int i = 0; i < _executor_num; ++i) {
                std::vector<std::shared_ptr<std::thread>> train_threads(_train_thread_num);
                for (int thread_id = 0; thread_id < _train_thread_num; ++thread_id) {
                    train_threads[i].reset(new std::thread([this](int exe_idx, int thread_idx) {
                        auto* executor = _threads_executor[thread_idx][exe_idx].get();
                        run_executor(executor);                      
                    }, i, thread_id));
                }   
                for (int i = 0; i < _train_thread_num; ++i) {
                    train_threads[i]->join();
                }
                environment->barrier(EnvironmentRole::WORKER); 

                if (_threads_executor[0][i]->is_dump_all_model()) {
                    already_dump_inference_model = true;
                    wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta);
                }
                environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
127
            }
X
xiexionghang 已提交
128
        }
X
xiexionghang 已提交
129

X
xiexionghang 已提交
130 131 132
        //Step3. Dump Model For Delta&&Checkpoint
        {
            if (!already_dump_inference_model) {
X
xiexionghang 已提交
133 134
                already_dump_inference_model = true;
                wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta);
X
xiexionghang 已提交
135 136
            } 
            wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpoint);
X
xiexionghang 已提交
137
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
138
        }
X
xiexionghang 已提交
139
    
X
xiexionghang 已提交
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
        //Step4. Output Monitor && RunStatus
        //TODO
    }
    
    return 0;
}

int LearnerProcess::run_executor(Executor* executor) {
    //TODO
    return 0;
}

}  // namespace feed
}  // namespace custom_trainer
}  // namespace paddle