runtime_environment.cc 6.5 KB
Newer Older
X
xiexionghang 已提交
1
#include <mpi.h>
X
xiexionghang 已提交
2 3 4 5 6 7
#include "paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h"

namespace paddle {
namespace custom_trainer {
namespace feed {

Y
yaopenghui 已提交
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
template<class T>
struct mpi_type_trait {
};
template<>
struct mpi_type_trait<double> {
    static MPI_Datatype type() {
        return MPI_DOUBLE;
    }
};
template<>
struct mpi_type_trait<float> {
    static MPI_Datatype type() {
        return MPI_FLOAT;
    }
};
template<>
struct mpi_type_trait<int32_t> {
    static MPI_Datatype type() {
        return MPI_INT;
    }
};
template<>
struct mpi_type_trait<uint32_t> {
    static MPI_Datatype type() {
        return MPI_UNSIGNED;
    }
};
template<>
struct mpi_type_trait<int64_t> {
    static MPI_Datatype type() {
        return MPI_LONG_LONG;
    }
};
template<>
struct mpi_type_trait<uint64_t> {
    static MPI_Datatype type() {
        return MPI_UNSIGNED_LONG_LONG;
    }
};
template<>
struct mpi_type_trait<long long> {
    static MPI_Datatype type() {
        return MPI_LONG_LONG;
    }
};
template<>
struct mpi_type_trait<unsigned long long> {
    static MPI_Datatype type() {
        return MPI_UNSIGNED_LONG_LONG;
    }
};
X
xiexionghang 已提交
59 60 61 62 63 64 65 66 67 68 69 70
RuntimeEnvironment::RuntimeEnvironment() {}
RuntimeEnvironment::~RuntimeEnvironment() {}
bool RuntimeEnvironment::is_master_node(EnvironmentRole role) {
    return rank_id(role) == 0;
}
std::string format_timestamp(time_t time, const char* format) {
    std::string result;
    struct tm p = *localtime(&time);
    char time_str_buffer[64];
    int size = strftime (time_str_buffer, 64, format, &p);
    if (size > 0) {
        result.assign(time_str_buffer, size);
X
xiexionghang 已提交
71
    }
X
xiexionghang 已提交
72 73 74 75 76 77 78 79 80 81 82 83 84 85
    return result;
}

struct MpiNodeInfo {
    int rank_id = -1;
    int node_num = 0;
    MPI_Comm mpi_comm;
};

class MPIRuntimeEnvironment : public RuntimeEnvironment {
public:
    MPIRuntimeEnvironment() {}
    virtual ~MPIRuntimeEnvironment() {}
    virtual int initialize(YAML::Node config) {
X
xiexionghang 已提交
86 87
        return 0;
    }
X
xiexionghang 已提交
88
    virtual int wireup() {
89 90 91
        int argc = 0;
        char** argv = NULL;
        int hr = MPI_Init(&argc, &argv);
X
xiexionghang 已提交
92 93 94 95 96
        if (MPI_SUCCESS != hr) {
            LOG(FATAL) << "MPI_init failed with error code" << hr; 
            return -1;
        }
        _roles_node_info.resize(static_cast<int>(EnvironmentRole::ALL) + 1);
X
xiexionghang 已提交
97
        add_role(EnvironmentRole::ALL);
98 99 100 101 102 103 104 105 106
    
        char* value = getenv("JOB_ID");
        if (value) {
            _job_id = value;
        }
        value = getenv("JOB_NAME");
        if (value) {
            _job_name = value;
        }
X
xiexionghang 已提交
107 108
        return 0;
    }
X
xiexionghang 已提交
109 110 111 112 113
    
    virtual paddle::ps::PSEnvironment* ps_environment() {
        static paddle::ps::MpiPSEnvironment ps_environment;
        return &ps_environment;
    }
X
xiexionghang 已提交
114 115 116 117 118 119

    virtual uint32_t rank_id(EnvironmentRole role) {
        return mpi_node_info(role).rank_id;
    }
    virtual uint32_t node_num(EnvironmentRole role) {
        return mpi_node_info(role).node_num;
X
xiexionghang 已提交
120
    }
X
xiexionghang 已提交
121
    virtual int add_role(EnvironmentRole role) {
X
xiexionghang 已提交
122 123 124 125 126 127 128 129 130 131
        auto& node_info = mpi_node_info(role);
        if (node_info.rank_id < 0) {
            if (role == EnvironmentRole::ALL) {
                node_info.mpi_comm = MPI_COMM_WORLD;
            } else {
                MPI_Comm_split(MPI_COMM_WORLD, static_cast<int>(role), 
                    mpi_node_info(EnvironmentRole::ALL).rank_id, &(node_info.mpi_comm));
            }
            MPI_Comm_rank(node_info.mpi_comm, &(node_info.rank_id));
            MPI_Comm_size(node_info.mpi_comm, &(node_info.node_num));
X
xiexionghang 已提交
132
        }
X
xiexionghang 已提交
133
        _role_set.insert(role);
X
xiexionghang 已提交
134
        return 0;
X
xiexionghang 已提交
135
    }
X
xiexionghang 已提交
136 137 138
    virtual bool is_role(EnvironmentRole role) {
        return _role_set.count(role) > 0;
    }
X
xiexionghang 已提交
139 140 141 142

    virtual void barrier(EnvironmentRole role) {
        MPI_Barrier(mpi_node_info(role).mpi_comm);
    }
X
xiexionghang 已提交
143

X
xiexionghang 已提交
144 145
    virtual void bcast(paddle::framework::BinaryArchive& ar, int root_id, EnvironmentRole role) {
        auto& node_info = mpi_node_info(role);
X
xiexionghang 已提交
146
        int len = (int)ar.Length();
X
xiexionghang 已提交
147
        MPI_Bcast(&len, 1, MPI_INT, root_id, node_info.mpi_comm);
X
xiexionghang 已提交
148 149 150
        ar.Resize(len);
        ar.SetCursor(ar.Buffer());
        MPI_Bcast(ar.Buffer(), len, MPI_BYTE, root_id, node_info.mpi_comm);
X
xiexionghang 已提交
151
    }
X
xiexionghang 已提交
152 153 154 155 156 157 158
    virtual void all_reduce_in_place(double* x, int n, ReduceOperator op, EnvironmentRole role) {
        auto& node_info = mpi_node_info(role);
        if (op == ReduceOperator::SUM) {
            MPI_Allreduce(MPI_IN_PLACE, x, n, MPI_DOUBLE, MPI_SUM, node_info.mpi_comm);
        } else {
            CHECK(false) << "unsupport operator";
        }
Y
yaopenghui 已提交
159 160
    }

X
xiexionghang 已提交
161
protected:
X
xiexionghang 已提交
162 163 164
    virtual void print_log(EnvironmentRole role, EnvironmentLogType type, 
        EnvironmentLogLevel level,  const std::string& log_str) {
        if (type == EnvironmentLogType::MASTER_LOG && !is_master_node(role)) {
X
xiexionghang 已提交
165 166
            return;
        }
X
xiexionghang 已提交
167
        VLOG(static_cast<int>(level)) << log_str;
168 169 170 171 172
        /*
        static std::mutex mtx;
        std::lock_guard<std::mutex> guard(mtx);
        std::err << log_str;
        */
X
xiexionghang 已提交
173 174
    }

X
xiexionghang 已提交
175 176 177
    inline MpiNodeInfo& mpi_node_info(EnvironmentRole role) {
        return _roles_node_info[static_cast<int>(role)];
    }
X
xiexionghang 已提交
178

X
xiexionghang 已提交
179
private:
X
xiexionghang 已提交
180
    std::set<EnvironmentRole> _role_set;
X
xiexionghang 已提交
181 182
    std::vector<MpiNodeInfo> _roles_node_info;
};
X
xiexionghang 已提交
183
REGIST_CLASS(RuntimeEnvironment, MPIRuntimeEnvironment);
X
xiexionghang 已提交
184 185 186 187 188 189 190 191 192 193 194 195

//用于本地模式单机训练
class LocalRuntimeEnvironment : public RuntimeEnvironment {
public:
    LocalRuntimeEnvironment() {}
    virtual ~LocalRuntimeEnvironment() {}
    virtual int initialize(YAML::Node config) {
        return 0;
    }
    virtual int wireup() {
        return 0;
    }
X
xiexionghang 已提交
196 197 198 199
    virtual paddle::ps::PSEnvironment* ps_environment() {
        static paddle::ps::LocalPSEnvironment ps_environment;
        return &ps_environment;
    }
X
xiexionghang 已提交
200 201 202 203 204 205
    virtual uint32_t rank_id(EnvironmentRole role) {
        return 0;
    }
    virtual uint32_t node_num(EnvironmentRole role) {
        return 1;
    }
X
xiexionghang 已提交
206
    virtual int add_role(EnvironmentRole role) {
X
xiexionghang 已提交
207 208
        return 0;
    }
X
xiexionghang 已提交
209 210 211
    virtual bool is_role(EnvironmentRole role) {
        return true;
    }
X
xiexionghang 已提交
212
    virtual void barrier(EnvironmentRole role) {
X
xiexionghang 已提交
213 214
        return;
    }
X
xiexionghang 已提交
215 216 217
    virtual void bcast(paddle::framework::BinaryArchive& ar, int root_id, EnvironmentRole role) {
        return;
    }
X
xiexionghang 已提交
218
    virtual void all_reduce_in_place(double* x, int n, ReduceOperator op, EnvironmentRole role) {
Y
yaopenghui 已提交
219 220
        return;
    }
X
xiexionghang 已提交
221 222 223 224 225 226
protected:
    virtual void print_log(EnvironmentRole role, EnvironmentLogType type, 
        EnvironmentLogLevel level,  const std::string& log_str) {
        VLOG(static_cast<int>(level)) << log_str;
    }
};
X
xiexionghang 已提交
227
REGIST_CLASS(RuntimeEnvironment, LocalRuntimeEnvironment);
X
xiexionghang 已提交
228 229 230 231

}  // namespace feed
}  // namespace custom_trainer
}  // namespace paddle