pslib_warpper.cc 2.8 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#include <fcntl.h>
#include <fstream>
#include <sstream>
#include "json2pb/json_to_pb.h"
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include "paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h"
#include "paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h"

namespace paddle {
namespace custom_trainer {
namespace feed {

int PSlib::initialize(const std::string& conf_path, 
X
xiexionghang 已提交
15 16
    RuntimeEnvironment* environment) {
    _environment = environment;
X
xiexionghang 已提交
17 18 19 20 21 22 23 24 25 26 27 28
    init_gflag();    
    int file_descriptor = open(conf_path.c_str(), O_RDONLY);
    if (file_descriptor == -1){
        LOG(ERROR) << "FATAL: cant open " << conf_path;
        return -1;
    }
    google::protobuf::io::FileInputStream fileInput(file_descriptor);
    if (!google::protobuf::TextFormat::Parse(&fileInput, &_ps_param)) {
        LOG(ERROR) << "FATAL: fail to parse " << conf_path;
        return -1;
    }
    close(file_descriptor); 
X
xiexionghang 已提交
29 30
    init_server();
    init_client();
X
xiexionghang 已提交
31 32 33
    return 0;
}
        
X
xiexionghang 已提交
34 35
int PSlib::init_server() {
    if (_environment->is_role(EnvironmentRole::PSERVER)) {
X
xiexionghang 已提交
36 37
        _server_ptr.reset(paddle::ps::PSServerFactory::create(_ps_param));
        _server_ptr->configure(_ps_param, *(_environment->ps_environment()), 
X
xiexionghang 已提交
38
            _environment->rank_id(EnvironmentRole::PSERVER));
X
xiexionghang 已提交
39 40
        _server_ptr->start(); 
    }
41
    _environment->barrier(EnvironmentRole::ALL);
X
xiexionghang 已提交
42 43 44 45
    _environment->ps_environment()->gather_ps_servers();
    return 0;
}

X
xiexionghang 已提交
46
int PSlib::init_client() {
X
xiexionghang 已提交
47
    // 所有节点都启动psclient
X
xiexionghang 已提交
48 49
    _client_ptr.reset(paddle::ps::PSClientFactory::create(_ps_param));
    _client_ptr->configure(_ps_param, *(_environment->ps_environment()), 
X
xiexionghang 已提交
50
        _environment->rank_id(EnvironmentRole::ALL));
X
xiexionghang 已提交
51 52 53 54

    _environment->barrier(EnvironmentRole::ALL);
    _environment->ps_environment()->gather_ps_clients();
    _client_ptr->create_client2client_connection();
X
xiexionghang 已提交
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
    return 0;
}

paddle::ps::PSServer* PSlib::ps_server() {
    return _server_ptr.get();
}

paddle::ps::PSClient* PSlib::ps_client() {
    return _client_ptr.get();
}

paddle::PSParameter* PSlib::get_param() {
    return &_ps_param;
}

void PSlib::init_gflag() {
    int cnt = 4;
X
xiexionghang 已提交
72 73
    char** params_ptr = new char*[cnt];
    std::cout << "alloc_ptr" << params_ptr << std::flush;
X
xiexionghang 已提交
74 75 76 77 78 79 80 81
    char p0[] = "exe default";
    char p1[] = "-max_body_size=314217728";
    char p2[] = "-bthread_concurrency=40";
    char p3[] = "-socket_max_unwritten_bytes=2048000000";
    params_ptr[0] = p0;
    params_ptr[1] = p1;
    params_ptr[2] = p2;
    params_ptr[3] = p3;
X
xiexionghang 已提交
82 83 84 85
    // ParseCommandLineFlags would change param_ptr, so copy it
    char** params_ptrp = params_ptr;
    ::google::ParseCommandLineFlags(&cnt, &params_ptrp, true);
    delete[] params_ptr;
X
xiexionghang 已提交
86 87 88 89 90
}

}  // namespace feed
}  // namespace custom_trainer
}  // namespace paddle