提交 234a3fd4 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4170 dlopen cpu mpi adapter

Merge pull request !4170 from kisnwang/disable-host-mpi
......@@ -140,15 +140,6 @@ if (ENABLE_MPI)
COMPONENT mindspore
)
endif ()
file(GLOB_RECURSE MPI_LIB_LIST
${ompi_LIBPATH}/libmpi${CMAKE_SHARED_LIBRARY_SUFFIX}*
${ompi_LIBPATH}/libopen*${CMAKE_SHARED_LIBRARY_SUFFIX}*
)
install(
FILES ${MPI_LIB_LIST}
DESTINATION ${INSTALL_LIB_DIR}
COMPONENT mindspore
)
endif ()
if (ENABLE_GPU)
......
......@@ -155,11 +155,7 @@ if (ENABLE_DEBUGGER)
endif()
target_link_libraries(mindspore proto_input)
if (ENABLE_MPI AND ENABLE_CPU)
target_link_libraries(mindspore securec mindspore::flatbuffers mpi_adapter)
else ()
target_link_libraries(mindspore securec mindspore::flatbuffers)
endif ()
target_link_libraries(mindspore securec mindspore::flatbuffers)
if (NOT WIN32)
target_link_libraries(mindspore dl)
......
......@@ -15,7 +15,7 @@
*/
#include "backend/kernel_compiler/cpu/allgather_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "runtime/device/cpu/mpi/mpi_adapter.h"
#include "runtime/device/cpu/mpi/mpi_interface.h"
#include "utils/log_adapter.h"
namespace mindspore {
......@@ -45,9 +45,7 @@ bool AllGatherCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
auto input_data_num = inputs[0]->size / sizeof(float);
auto mpi_instance = device::cpu::MPIAdapter::Instance();
MS_EXCEPTION_IF_NULL(mpi_instance);
return mpi_instance->AllGather(input_addr, output_addr, ranks_group_, input_data_num);
return MPIAllGather(input_addr, output_addr, ranks_group_, input_data_num);
}
} // namespace kernel
} // namespace mindspore
......@@ -16,7 +16,7 @@
#include <thread>
#include "backend/kernel_compiler/cpu/embedding_look_up_comm_grad_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "runtime/device/cpu/mpi/mpi_adapter.h"
#include "runtime/device/cpu/mpi/mpi_interface.h"
namespace mindspore {
namespace kernel {
......@@ -49,11 +49,8 @@ bool EmbeddingLookUpCommGradCPUKernel::Launch(const std::vector<kernel::AddressP
const std::vector<int> &rank_group = {0, 1, 2, 3, 4, 5, 6, 7};
size_t input_split_lens = input_size / split_num_ / sizeof(float_t);
size_t output_split_lens = output_size / split_num_ / sizeof(float_t);
auto mpi_instance = device::cpu::MPIAdapter::Instance();
MS_EXCEPTION_IF_NULL(mpi_instance);
for (int i = 0; i < split_num_; i++) {
mpi_instance->AllGather(input_addr + i * input_split_lens, output_addr + i * output_split_lens, rank_group,
input_split_lens);
MPIAllGather(input_addr + i * input_split_lens, output_addr + i * output_split_lens, rank_group, input_split_lens);
}
#if defined(_WIN32) || defined(_WIN64)
auto end_time = std::chrono::steady_clock::now();
......
......@@ -15,7 +15,7 @@
*/
#include "backend/kernel_compiler/cpu/reduce_scatter_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "runtime/device/cpu/mpi/mpi_adapter.h"
#include "runtime/device/cpu/mpi/mpi_interface.h"
#include "ir/primitive.h"
namespace mindspore {
......@@ -24,7 +24,7 @@ namespace {
constexpr auto kRanksGroup = "group";
} // namespace
ReduceScatterCPUKernel::ReduceScatterCPUKernel() : op_type_(device::cpu::kOpTypeSum) {}
ReduceScatterCPUKernel::ReduceScatterCPUKernel() : op_type_(kMPIOpTypeSum) {}
void ReduceScatterCPUKernel::InitKernel(const CNodePtr &kernel_node) {
auto op = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("op");
......@@ -46,9 +46,7 @@ bool ReduceScatterCPUKernel::Launch(const std::vector<kernel::AddressPtr> &input
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
auto output_data_num = outputs[0]->size / sizeof(float);
auto mpi_instance = device::cpu::MPIAdapter::Instance();
MS_EXCEPTION_IF_NULL(mpi_instance);
return mpi_instance->ReduceScatter(input_addr, output_addr, ranks_group_, output_data_num, op_type_);
return MPIReduceScatter(input_addr, output_addr, ranks_group_, output_data_num, op_type_);
}
} // namespace kernel
} // namespace mindspore
......@@ -14,12 +14,12 @@ endif ()
if (ENABLE_CPU)
file(GLOB_RECURSE CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cpu/*.cc")
list(REMOVE_ITEM CPU_SRC_LIST "cpu/mpi/mpi_adapter.cc")
list(REMOVE_ITEM CPU_SRC_LIST "cpu/mpi/mpi_adapter.cc", "cpu/mpi/mpi_export.cc")
endif ()
if (ENABLE_MPI)
if (ENABLE_CPU)
file(GLOB_RECURSE MPI_SRC_LIST "cpu/mpi/mpi_adapter.cc")
file(GLOB_RECURSE MPI_SRC_LIST "cpu/mpi/mpi_adapter.cc", "cpu/mpi/mpi_export.cc")
set_property(SOURCE ${MPI_SRC_LIST}
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
add_library(mpi_adapter SHARED ${MPI_SRC_LIST})
......
......@@ -22,7 +22,7 @@
#include <exception>
#include <algorithm>
#include "runtime/device/ascend/ascend_device_address.h"
#include "runtime/device/cpu/mpi/mpi_adapter.h"
#include "runtime/device/cpu/mpi/mpi_interface.h"
#include "utils/ms_context.h"
#include "utils/context/context_extends.h"
#include "utils/mpi/mpi_config.h"
......@@ -64,9 +64,7 @@ std::string GetRankId() {
auto mpi_config_ptr = MpiConfig::GetInstance();
MS_EXCEPTION_IF_NULL(mpi_config_ptr);
if (mpi_config_ptr->enable_mpi()) {
auto mpi_instance = device::cpu::MPIAdapter::Instance();
MS_EXCEPTION_IF_NULL(mpi_instance);
int rank_id = mpi_instance->GetRankId();
int rank_id = GetMPIRankId();
const char *offset = std::getenv("RANK_OFFSET");
if (offset != nullptr) {
try {
......
......@@ -14,11 +14,11 @@
* limitations under the License.
*/
#include "runtime/device/cpu/mpi/mpi_adapter.h"
#ifdef ENABLE_MPI
#include <algorithm>
#include <sstream>
#include <vector>
#include <string>
#include "pybind11/pybind11.h"
#endif // ENABLE_MPI
#include "utils/log_adapter.h"
namespace mindspore {
......@@ -33,8 +33,6 @@ std::shared_ptr<MPIAdapter> MPIAdapter::Instance() {
return instance_;
}
#ifdef ENABLE_MPI
#define RAISE_EXCEPTION(message) \
{ \
std::ostringstream oss; \
......@@ -271,7 +269,6 @@ bool MPIAdapter::AllGather(const float *input, float *output, const std::vector<
}
return true;
}
#endif // ENABLE_MPI
} // namespace cpu
} // namespace device
} // namespace mindspore
......@@ -16,13 +16,11 @@
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_ADAPTER_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_ADAPTER_H_
#ifdef ENABLE_MPI
#include <mpi.h>
#include <vector>
#include <map>
#include <string>
#include <mutex>
#endif // ENABLE_MPI
#include <memory>
namespace mindspore {
......@@ -31,27 +29,19 @@ namespace cpu {
#ifndef FUNC_EXPORT
#define FUNC_EXPORT __attribute__((visibility("default")))
#endif
constexpr auto kOpTypeSum = "sum";
class MPIAdapter {
public:
FUNC_EXPORT static std::shared_ptr<MPIAdapter> Instance();
FUNC_EXPORT int GetRankId() const { return rank_id_; }
FUNC_EXPORT int GetRankSize() const { return rank_size_; }
#ifdef ENABLE_MPI
FUNC_EXPORT ~MPIAdapter();
FUNC_EXPORT bool ReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group,
size_t data_num, const std::string &op_type = kOpTypeSum);
size_t data_num, const std::string &op_type);
FUNC_EXPORT bool ReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t in_data_num,
size_t output_size, const std::string &op_type = kOpTypeSum,
float *output = nullptr);
size_t output_size, const std::string &op_type, float *output);
FUNC_EXPORT bool AllGather(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num);
#else
FUNC_EXPORT ~MPIAdapter() = default;
#endif // ENABLE_MPI
private:
#ifdef ENABLE_MPI
MPIAdapter();
void Init();
MPI_Group AddGroup(const std::vector<int> &ranks);
......@@ -60,9 +50,6 @@ class MPIAdapter {
// key:ranks group, value: mpi group
std::map<std::vector<int>, MPI_Group> ranks_group_;
std::mutex group_mutex_;
#else
MPIAdapter() = default;
#endif // ENABLE_MPI
int rank_id_{-1};
int rank_size_{0};
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/cpu/mpi/mpi_export.h"
#include <vector>
#include <string>
#include "runtime/device/cpu/mpi/mpi_adapter.h"
int GetMPIRankId() {
auto inst = mindspore::device::cpu::MPIAdapter::Instance();
if (inst == nullptr) {
return 0;
}
return inst->GetRankId();
}
int GetMPIRankSize() {
auto inst = mindspore::device::cpu::MPIAdapter::Instance();
if (inst == nullptr) {
return 0;
}
return inst->GetRankSize();
}
bool MPIReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num,
const std::string &op_type) {
auto inst = mindspore::device::cpu::MPIAdapter::Instance();
if (inst == nullptr) {
return false;
}
return inst->ReduceScatter(input, output, ranks_group, data_num, op_type);
}
bool MPIReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t in_data_num,
size_t output_size, const std::string &op_type, float *output) {
auto inst = mindspore::device::cpu::MPIAdapter::Instance();
if (inst == nullptr) {
return false;
}
return inst->ReduceScatterOverwriteInput(input, ranks_group, in_data_num, output_size, op_type, output);
}
bool MPIAllGather(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num) {
auto inst = mindspore::device::cpu::MPIAdapter::Instance();
if (inst == nullptr) {
return false;
}
return inst->AllGather(input, output, ranks_group, data_num);
}
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_EXPORT_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_EXPORT_H_
#include <vector>
#include <string>
#ifndef FUNC_EXPORT
#define FUNC_EXPORT __attribute__((visibility("default")))
#endif
extern "C" FUNC_EXPORT FUNC_EXPORT int GetMPIRankId();
extern "C" FUNC_EXPORT FUNC_EXPORT int GetMPIRankSize();
extern "C" FUNC_EXPORT bool MPIReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group,
size_t data_num, const std::string &op_type);
extern "C" FUNC_EXPORT bool MPIReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group,
size_t in_data_num, size_t output_size,
const std::string &op_type, float *output);
extern "C" FUNC_EXPORT bool MPIAllGather(const float *input, float *output, const std::vector<int> &ranks_group,
size_t data_num);
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_EXPORT_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/cpu/mpi/mpi_interface.h"
#ifdef ENABLE_MPI
#include <dlfcn.h>
#include <vector>
#include <string>
#include "utils/log_adapter.h"
inline void *LoadLibrary(const char *name) {
auto handle = dlopen(name, RTLD_LAZY | RTLD_LOCAL);
if (handle == nullptr) {
MS_LOG(EXCEPTION) << "Load lib " << name << " failed, make sure you have installed it!";
}
return handle;
}
inline void *GetMPIAdapterHandle() {
static void *handle = LoadLibrary("mpi_adapter.so");
return handle;
}
inline void *GetMPIAdapterFunc(const char *name) {
static void *handle = GetMPIAdapterHandle();
if (handle == nullptr) {
MS_LOG(EXCEPTION) << "Load lib " << name << " failed, make sure you have installed it!";
}
void *func = dlsym(handle, name);
if (func == nullptr) {
MS_LOG(EXCEPTION) << "Load func " << name << " failed, make sure you have implied it!";
}
return func;
}
typedef int (*GetMPIRankIdFunc)();
typedef int (*GetMPIRankSizeFunc)();
typedef bool (*MPIReduceScatterFunc)(const float *input, float *output, const std::vector<int> &ranks_group,
size_t data_num, const std::string &op_type);
typedef bool (*MPIReduceScatterOverwriteInputFunc)(float *input, const std::vector<int> &ranks_group,
size_t in_data_num, size_t output_size, const std::string &op_type,
float *output);
typedef bool (*MPIAllGatherFunc)(const float *input, float *output, const std::vector<int> &ranks_group,
size_t data_num);
int GetMPIRankId() {
static GetMPIRankIdFunc func = reinterpret_cast<GetMPIRankIdFunc>(GetMPIAdapterFunc("GetMPIRankId"));
return func();
}
int GetMPIRankSize() {
static GetMPIRankIdFunc func = reinterpret_cast<GetMPIRankSizeFunc>(GetMPIAdapterFunc("GetMPIRankSize"));
return func();
}
bool MPIReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num,
const std::string &op_type) {
static MPIReduceScatterFunc func = reinterpret_cast<MPIReduceScatterFunc>(GetMPIAdapterFunc("MPIReduceScatter"));
return func(input, output, ranks_group, data_num, op_type);
}
bool MPIReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t in_data_num,
size_t output_size, const std::string &op_type, float *output) {
static MPIReduceScatterOverwriteInputFunc func =
reinterpret_cast<MPIReduceScatterOverwriteInputFunc>(GetMPIAdapterFunc("MPIReduceScatterOverwriteInput"));
return func(input, ranks_group, in_data_num, output_size, op_type, output);
}
bool MPIAllGather(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num) {
static MPIAllGatherFunc func = reinterpret_cast<MPIAllGatherFunc>(GetMPIAdapterFunc("MPIAllGather"));
return func(input, output, ranks_group, data_num);
}
#endif // ENABLE_MPI
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_INTERFACE_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_INTERFACE_H_
#include <vector>
#include <string>
#ifndef FUNC_EXPORT
#define FUNC_EXPORT __attribute__((visibility("default")))
#endif
constexpr auto kMPIOpTypeSum = "sum";
#ifdef ENABLE_MPI
int GetMPIRankId();
int GetMPIRankSize();
bool MPIReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num,
const std::string &op_type = kMPIOpTypeSum);
bool MPIReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t in_data_num,
size_t output_size, const std::string &op_type = kMPIOpTypeSum,
float *output = nullptr);
bool MPIAllGather(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num);
#endif // ENABLE_MPI
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_INTERFACE_H_
......@@ -19,6 +19,7 @@
#include <mpi.h>
#include <pybind11/operators.h>
#include <iostream>
#include <vector>
#include <string>
namespace mindspore {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册