mpi_utils.cpp 2.5 KB
Newer Older
T
tangwei12 已提交
1 2 3 4
//
// Created by tangwei12 on 2018/3/27.
//

T
tangwei12 已提交
5
#include <stdio.h>
T
tangwei12 已提交
6
#include <string.h>
T
tangwei12 已提交
7

T
tangwei12 已提交
8 9
#include <mpi.h>
#include "mpi_utils.h"
T
tangwei12 已提交
10 11

#define max_worker_name_length 128
T
tangwei12 已提交
12
#define mpi_tag = 2008
T
tangwei12 已提交
13 14

namespace paddle {
T
tangwei12 已提交
15 16 17 18
    namespace operators {
        namespace detail {
            MPIUtils::MPIUtils(const std::string &worker_name) {
                InitMPI();
T
tangwei12 已提交
19

T
tangwei12 已提交
20 21 22 23 24
                int rank = 0, size = 1;
                char my_name[max_work_group_size];
                MPI_Comm_rank(MPI_COMM_WORLD, &rank);
                MPI_Comm_size(MPI_COMM_WORLD, &size);
                snprintf(my_name, max_worker_name_length, worker_name.c_str());
T
tangwei12 已提交
25

T
tangwei12 已提交
26 27 28 29 30 31 32
                std::vector<char> worker_names(size * max_worker_name_length);
                MPI_Allgather(my_name, max_worker_name_length, MPI_CHAR, &worker_names[0],
                              max_worker_name_length, MPI_CHAR, MPI_COMM_WORLD);
                for (int i = 0; i < number_of_procs; i++) {
                    name_to_id_[std::string(&worker_names[i * 128])] = i;
                }
            }
T
tangwei12 已提交
33

T
tangwei12 已提交
34 35 36
            void MPIUtils::InitMPI() {
                int flag = 0;
                MPI_CHECK(MPI_Initialized(&flag));
T
tangwei12 已提交
37

T
tangwei12 已提交
38 39 40
                if (!flag) {
                    int rank = 0, size = 1, len = -1;
                    char host_name[max_worker_name_length];
T
tangwei12 已提交
41

T
tangwei12 已提交
42 43 44 45 46 47
                    MPI_Init(0, 0);
                    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
                    MPI_Comm_size(MPI_COMM_WORLD, &size);
                    MPI_Get_processor_name(host_name, &len);
                }
            };
T
tangwei12 已提交
48 49


T
tangwei12 已提交
50 51 52 53 54
            MPISend::MPISend(const Meta &meta) {
                done1_ = 1;
                done2_ = 0;
                this->meta = meta;
            }
T
tangwei12 已提交
55

T
tangwei12 已提交
56 57 58 59 60
            MPISend::Send() {
                MPI_Send(&meta.request, meta.count, meta.datatype, meta.dst, meta.tag,
                         MPI_COMM_WORLD);
                done2_ = 1;
            }
T
tangwei12 已提交
61

T
tangwei12 已提交
62 63 64
            bool MPISend::IsReady() {
                return true;
            }
T
tangwei12 已提交
65

T
tangwei12 已提交
66
            bool MPISend::IsFinished() { return done1_ && done2_; }
T
tangwei12 已提交
67

T
tangwei12 已提交
68
            MPISend::~MPISend() { MPI_Free_mem(meta); }
T
tangwei12 已提交
69 70


T
tangwei12 已提交
71 72 73
            MPIRecv::MPIRecv(const Meta &meta) {
                this->meta = meta;
            }
T
tangwei12 已提交
74

T
tangwei12 已提交
75
            MPIRecv::Recv() {}
T
tangwei12 已提交
76

T
tangwei12 已提交
77 78 79
            bool MPIRecv::IsReady() {
                return true;
            }
T
tangwei12 已提交
80

T
tangwei12 已提交
81
            MPIRecv::IsFinished() {}
T
tangwei12 已提交
82

T
tangwei12 已提交
83 84 85
            MPIRecv::~MPIRecv() {
                MPI_Free_mem(meta);
            }
T
tangwei12 已提交
86

T
tangwei12 已提交
87
        }  // namespace detail
T
tangwei12 已提交
88

T
tangwei12 已提交
89
    }  // namespace operators
T
tangwei12 已提交
90
}  // namespace paddle