pslib_warpper.cc 2.5 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#include <fcntl.h>
#include <fstream>
#include <sstream>
#include "json2pb/json_to_pb.h"
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include "paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h"
#include "paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h"

namespace paddle {
namespace custom_trainer {
namespace feed {

int PSlib::initialize(const std::string& conf_path, 
X
xiexionghang 已提交
15 16
    RuntimeEnvironment* environment) {
    _environment = environment;
X
xiexionghang 已提交
17 18 19 20 21 22 23 24 25 26 27 28
    init_gflag();    
    int file_descriptor = open(conf_path.c_str(), O_RDONLY);
    if (file_descriptor == -1){
        LOG(ERROR) << "FATAL: cant open " << conf_path;
        return -1;
    }
    google::protobuf::io::FileInputStream fileInput(file_descriptor);
    if (!google::protobuf::TextFormat::Parse(&fileInput, &_ps_param)) {
        LOG(ERROR) << "FATAL: fail to parse " << conf_path;
        return -1;
    }
    close(file_descriptor); 
X
xiexionghang 已提交
29 30
    init_server();
    init_client();
X
xiexionghang 已提交
31 32 33
    return 0;
}
        
X
xiexionghang 已提交
34 35
int PSlib::init_server() {
    if (_environment->is_role(EnvironmentRole::PSERVER)) {
X
xiexionghang 已提交
36 37
        _server_ptr.reset(paddle::ps::PSServerFactory::create(_ps_param));
        _server_ptr->configure(_ps_param, *(_environment->ps_environment()), 
X
xiexionghang 已提交
38
            _environment->rank_id(EnvironmentRole::PSERVER));
X
xiexionghang 已提交
39 40
        _server_ptr->start(); 
    }
41
    _environment->barrier(EnvironmentRole::ALL);
X
xiexionghang 已提交
42 43 44 45
    _environment->ps_environment()->gather_ps_servers();
    return 0;
}

X
xiexionghang 已提交
46
int PSlib::init_client() {
X
xiexionghang 已提交
47 48
    _client_ptr.reset(paddle::ps::PSClientFactory::create(_ps_param));
    _client_ptr->configure(_ps_param, *(_environment->ps_environment()), 
X
xiexionghang 已提交
49
        _environment->rank_id(EnvironmentRole::ALL));
X
xiexionghang 已提交
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
    return 0;
}

paddle::ps::PSServer* PSlib::ps_server() {
    return _server_ptr.get();
}

paddle::ps::PSClient* PSlib::ps_client() {
    return _client_ptr.get();
}

paddle::PSParameter* PSlib::get_param() {
    return &_ps_param;
}

void PSlib::init_gflag() {
    int cnt = 4;
    std::shared_ptr<char*> params(new char*[cnt]);
    char** params_ptr = params.get();
    char p0[] = "exe default";
    char p1[] = "-max_body_size=314217728";
    char p2[] = "-bthread_concurrency=40";
    char p3[] = "-socket_max_unwritten_bytes=2048000000";
    params_ptr[0] = p0;
    params_ptr[1] = p1;
    params_ptr[2] = p2;
    params_ptr[3] = p3;
    ::google::ParseCommandLineFlags(&cnt, &params_ptr, true);
}

}  // namespace feed
}  // namespace custom_trainer
}  // namespace paddle