“dcfe198631058dbcd4fe6e887a4e514008ed1e68”上不存在“paddle/cinn/optim/if_simplify.h”
learner_process.cc 5.2 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
/*
 *Author: xiexionghang
 *Train样本
 */
#include <omp.h>
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/process/learner_process.h"

namespace paddle {
namespace custom_trainer {
namespace feed {

int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
    int ret = Process::initialize(context_ptr);
    auto& config = _context_ptr->trainer_config;
    _train_thread_num = config["train_thread_num"].as<int>();
    _threads_executor.resize(_train_thread_num);
    
    if (config["executor"]) {
        _executor_num = config["executor"].size();
        omp_set_num_threads(_train_thread_num);
        #pragma omp parallel for
        for (int i = 0; i < _train_thread_num; ++i) {
            _threads_executor[i].resize(_executor_num);
            for (int e = 0; e < _executor_num; ++e) {
                auto e_class = config["executor"][e]["class"].as<std::string>();
                auto* e_ptr = CREATE_CLASS(Executor, e_class);
                _threads_executor[i][e].reset(e_ptr);  
                if (e_ptr->initialize(config["executor"][e], context_ptr) != 0) {
                    ret = -1;
                }
            }
        }
    }
    return 0;
}

std::future<int> LearnerProcess::save_model(int epoch_id, int table_id, ModelSaveWay way) {
    std::promise<int> p;
    auto ret = p.get_future();
    if (_context_ptr->epoch_accessor->need_save_model(epoch_id, way)) {
        //TODO
        //context_ptr->pslib_client()->save();
    } else {
        p.set_value(0);
    }
    return ret;
}

int LearnerProcess::wait_save_model(int epoch_id, ModelSaveWay way) {
    auto* environment = _context_ptr->environment.get();
    if (!environment->is_master_node()) {
        return 0;
    }
    int ret_size = 0;
    auto table_num = _context_ptr->params_table_list.size();
    std::future<int> rets[table_num];
    for (int i = 0; i < table_num; ++i) {
        auto table_id = _context_ptr->params_table_list[i].table_id();
        rets[ret_size++] = save_model(epoch_id, table_id, way); 
    }

    int all_ret = 0;
    for (int i = 0; i < ret_size; ++i) {
        rets[i].wait();
        all_ret |= rets[i].get();
    }
    return all_ret;
}

int LearnerProcess::run() {
    auto* environment = _context_ptr->environment.get();
    auto* epoch_accessor = _context_ptr->epoch_accessor.get(); 
    int epoch_id = epoch_accessor->current_epoch_id();

    environment->log(EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
        "Resume traine with epoch_id:%d label:%s", epoch_id, _context_ptr->epoch_accessor->text(epoch_id).c_str());
    
    //判断是否先dump出base
    wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);
    environment->barrier_all(); 
    
    while (true) {
        epoch_accessor->next_epoch();
        epoch_id = epoch_accessor->current_epoch_id();
        std::string epoch_log_title= paddle::string::format_string(
            "train epoch_id:%d label:%s", epoch_id, epoch_accessor->text(epoch_id).c_str());
        
        //Step1. 等待样本ready
        environment->log(EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
            "Start %s, wait data ready", epoch_log_title.c_str());
        while (!epoch_accessor->data_ready(epoch_id)) {
            sleep(30);  
            environment->log(EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
                "%s, data not ready, wait 30s", epoch_log_title.c_str());
        } 
        environment->log(EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, 
            "%s, data is ready, start traning", epoch_log_title.c_str());
        environment->barrier_all();

        //Step2. 运行训练网络
        bool already_dump_inference_model = false;
        for (int i = 0; i < _executor_num; ++i) {
            std::vector<std::shared_ptr<std::thread>> train_threads(_train_thread_num);
            for (int thread_id = 0; thread_id < _train_thread_num; ++thread_id) {
                train_threads[i].reset(new std::thread([this](int exe_idx, int thread_idx) {
                    auto* executor = _threads_executor[thread_idx][exe_idx].get();
                    run_executor(executor);                      
                }, i, thread_id));
            }   
            for (int i = 0; i < _train_thread_num; ++i) {
                train_threads[i]->join();
            }
            environment->barrier_all();

            if (_threads_executor[0][i]->is_dump_all_model()) {
                already_dump_inference_model = true;
                wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta);
            }
            environment->barrier_all();
        }

        //Step3. Dump Model For Delta&&Checkpoint
        if (!already_dump_inference_model) {
            already_dump_inference_model = true;
            wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta);
        } 
        wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpoint);
        environment->barrier_all(); 
        
        //Step4. Output Monitor && RunStatus
        //TODO
    }
    
    return 0;
}

int LearnerProcess::run_executor(Executor* executor) {
    //TODO
    return 0;
}

}  // namespace feed
}  // namespace custom_trainer
}  // namespace paddle