learner_process.cc 5.6 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5
/*
 *Author: xiexionghang
 *Train样本
 */
#include <omp.h>
X
xiexionghang 已提交
6
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset.h"
X
xiexionghang 已提交
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/process/learner_process.h"

namespace paddle {
namespace custom_trainer {
namespace feed {

int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
    int ret = Process::initialize(context_ptr);
    auto& config = _context_ptr->trainer_config;
    _train_thread_num = config["train_thread_num"].as<int>();
    _threads_executor.resize(_train_thread_num);
    
    if (config["executor"]) {
        _executor_num = config["executor"].size();
        omp_set_num_threads(_train_thread_num);
        #pragma omp parallel for
        for (int i = 0; i < _train_thread_num; ++i) {
            _threads_executor[i].resize(_executor_num);
            for (int e = 0; e < _executor_num; ++e) {
                auto e_class = config["executor"][e]["class"].as<std::string>();
                auto* e_ptr = CREATE_CLASS(Executor, e_class);
                _threads_executor[i][e].reset(e_ptr);  
                if (e_ptr->initialize(config["executor"][e], context_ptr) != 0) {
                    ret = -1;
                }
            }
        }
    }
    return 0;
}

X
xiexionghang 已提交
39
std::future<int> LearnerProcess::save_model(uint64_t epoch_id, int table_id, ModelSaveWay way) {
X
xiexionghang 已提交
40 41 42 43 44 45 46 47 48 49 50
    std::promise<int> p;
    auto ret = p.get_future();
    if (_context_ptr->epoch_accessor->need_save_model(epoch_id, way)) {
        //TODO
        //context_ptr->pslib_client()->save();
    } else {
        p.set_value(0);
    }
    return ret;
}

X
xiexionghang 已提交
51
int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way) {
X
xiexionghang 已提交
52
    auto* environment = _context_ptr->environment.get();
X
xiexionghang 已提交
53
    if (!environment->is_master_node(EnvironmentRole::WORKER)) {
X
xiexionghang 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
        return 0;
    }
    int ret_size = 0;
    auto table_num = _context_ptr->params_table_list.size();
    std::future<int> rets[table_num];
    for (int i = 0; i < table_num; ++i) {
        auto table_id = _context_ptr->params_table_list[i].table_id();
        rets[ret_size++] = save_model(epoch_id, table_id, way); 
    }

    int all_ret = 0;
    for (int i = 0; i < ret_size; ++i) {
        rets[i].wait();
        all_ret |= rets[i].get();
    }
    return all_ret;
}

int LearnerProcess::run() {
X
xiexionghang 已提交
73
    auto* dataset = _context_ptr->dataset.get();
X
xiexionghang 已提交
74 75
    auto* environment = _context_ptr->environment.get();
    auto* epoch_accessor = _context_ptr->epoch_accessor.get(); 
X
xiexionghang 已提交
76
    uint64_t epoch_id = epoch_accessor->current_epoch_id();
X
xiexionghang 已提交
77

X
xiexionghang 已提交
78
    environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
X
xiexionghang 已提交
79 80 81 82
        "Resume traine with epoch_id:%d label:%s", epoch_id, _context_ptr->epoch_accessor->text(epoch_id).c_str());
    
    //判断是否先dump出base
    wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);
X
xiexionghang 已提交
83
    environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
84 85 86 87 88 89 90 91
    
    while (true) {
        epoch_accessor->next_epoch();
        epoch_id = epoch_accessor->current_epoch_id();
        std::string epoch_log_title= paddle::string::format_string(
            "train epoch_id:%d label:%s", epoch_id, epoch_accessor->text(epoch_id).c_str());
        
        //Step1. 等待样本ready
X
xiexionghang 已提交
92
        environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
X
xiexionghang 已提交
93
            "Start %s, wait data ready", epoch_log_title.c_str());
X
xiexionghang 已提交
94
        while (dataset->epoch_data_status(epoch_id) != DatasetStatus::Ready) {
X
xiexionghang 已提交
95
            sleep(30);  
X
xiexionghang 已提交
96 97
            dataset->pre_detect_data(epoch_id);
            environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
X
xiexionghang 已提交
98 99
                "%s, data not ready, wait 30s", epoch_log_title.c_str());
        } 
X
xiexionghang 已提交
100
        environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
X
xiexionghang 已提交
101
            "%s, data is ready, start traning", epoch_log_title.c_str());
X
xiexionghang 已提交
102
        environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
103 104 105 106 107 108 109 110 111 112 113 114 115 116

        //Step2. 运行训练网络
        bool already_dump_inference_model = false;
        for (int i = 0; i < _executor_num; ++i) {
            std::vector<std::shared_ptr<std::thread>> train_threads(_train_thread_num);
            for (int thread_id = 0; thread_id < _train_thread_num; ++thread_id) {
                train_threads[i].reset(new std::thread([this](int exe_idx, int thread_idx) {
                    auto* executor = _threads_executor[thread_idx][exe_idx].get();
                    run_executor(executor);                      
                }, i, thread_id));
            }   
            for (int i = 0; i < _train_thread_num; ++i) {
                train_threads[i]->join();
            }
X
xiexionghang 已提交
117
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
118 119 120 121 122

            if (_threads_executor[0][i]->is_dump_all_model()) {
                already_dump_inference_model = true;
                wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta);
            }
X
xiexionghang 已提交
123
            environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
124 125 126 127 128 129 130 131
        }

        //Step3. Dump Model For Delta&&Checkpoint
        if (!already_dump_inference_model) {
            already_dump_inference_model = true;
            wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta);
        } 
        wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpoint);
X
xiexionghang 已提交
132
        environment->barrier(EnvironmentRole::WORKER); 
X
xiexionghang 已提交
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
        
        //Step4. Output Monitor && RunStatus
        //TODO
    }
    
    return 0;
}

int LearnerProcess::run_executor(Executor* executor) {
    //TODO
    return 0;
}

}  // namespace feed
}  // namespace custom_trainer
}  // namespace paddle