未验证 提交 e3a457d3 编写于 作者: X xujiaqi01 提交者: GitHub

add collective communication library in fleet (#22211)

* add collective communication library in fleet to replace mpi
* test=develop
上级 05ee05e2
......@@ -62,6 +62,10 @@ if(WITH_PSLIB)
add_definitions(-DPADDLE_WITH_PSLIB)
endif()
if(WITH_GLOO)
add_definitions(-DPADDLE_WITH_GLOO)
endif()
if(WITH_BOX_PS)
add_definitions(-DPADDLE_WITH_BOX_PS)
endif()
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
INCLUDE(ExternalProject)
SET(GLOO_PROJECT "extern_gloo")
IF((NOT DEFINED GLOO_VER) OR (NOT DEFINED GLOO_URL))
MESSAGE(STATUS "use pre defined download url")
SET(GLOO_VER "master" CACHE STRING "" FORCE)
SET(GLOO_NAME "gloo" CACHE STRING "" FORCE)
SET(GLOO_URL "https://pslib.bj.bcebos.com/gloo.tar.gz" CACHE STRING "" FORCE)
ENDIF()
MESSAGE(STATUS "GLOO_NAME: ${GLOO_NAME}, GLOO_URL: ${GLOO_URL}")
SET(GLOO_SOURCE_DIR "${THIRD_PARTY_PATH}/gloo")
SET(GLOO_DOWNLOAD_DIR "${GLOO_SOURCE_DIR}/src/${GLOO_PROJECT}")
SET(GLOO_DST_DIR "gloo")
SET(GLOO_INSTALL_ROOT "${THIRD_PARTY_PATH}/install")
SET(GLOO_INSTALL_DIR ${GLOO_INSTALL_ROOT}/${GLOO_DST_DIR})
SET(GLOO_ROOT ${GLOO_INSTALL_DIR})
SET(GLOO_INC_DIR ${GLOO_ROOT}/include)
SET(GLOO_LIB_DIR ${GLOO_ROOT}/lib)
SET(GLOO_LIB ${GLOO_LIB_DIR}/libgloo.a)
#SET(GLOO_IOMP_LIB ${GLOO_LIB_DIR}/libiomp5.so) #todo what is this
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${GLOO_ROOT}/lib")
INCLUDE_DIRECTORIES(${GLOO_INC_DIR})
FILE(WRITE ${GLOO_DOWNLOAD_DIR}/CMakeLists.txt
"PROJECT(GLOO)\n"
"cmake_minimum_required(VERSION 3.0)\n"
"install(DIRECTORY ${GLOO_NAME}/include ${GLOO_NAME}/lib \n"
" DESTINATION ${GLOO_DST_DIR})\n")
ExternalProject_Add(
${GLOO_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${GLOO_SOURCE_DIR}
DOWNLOAD_DIR ${GLOO_DOWNLOAD_DIR}
DOWNLOAD_COMMAND wget --no-check-certificate ${GLOO_URL} -c -q -O ${GLOO_NAME}.tar.gz
&& tar zxvf ${GLOO_NAME}.tar.gz
DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${GLOO_INSTALL_ROOT}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${GLOO_INSTALL_ROOT}
)
ADD_LIBRARY(gloo SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET gloo PROPERTY IMPORTED_LOCATION ${GLOO_LIB})
ADD_DEPENDENCIES(gloo ${GLOO_PROJECT})
......@@ -241,6 +241,11 @@ if(WITH_PSLIB)
endif()
endif(WITH_PSLIB)
if(NOT WIN32 AND NOT APPLE)
include(external/gloo)
list(APPEND third_party_deps extern_gloo)
endif()
if(WITH_BOX_PS)
include(external/box_ps)
list(APPEND third_party_deps extern_box_ps)
......
......@@ -10,3 +10,11 @@ if(WITH_BOX_PS)
else()
cc_library(box_wrapper SRCS box_wrapper.cc DEPS framework_proto lod_tensor)
endif(WITH_BOX_PS)
if(WITH_GLOO)
cc_library(gloo_wrapper SRCS gloo_wrapper.cc DEPS framework_proto variable_helper scope gloo)
else()
cc_library(gloo_wrapper SRCS gloo_wrapper.cc DEPS framework_proto variable_helper scope)
endif(WITH_GLOO)
cc_test(test_fleet SRCS test_fleet.cc DEPS fleet_wrapper gloo_wrapper fs shell)
......@@ -107,6 +107,16 @@ uint64_t FleetWrapper::RunServer() {
#endif
}
uint64_t FleetWrapper::RunServer(const std::string& ip, uint32_t port) {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "Going to run server with ip " << ip << " port " << port;
auto ret = pslib_ptr_->run_server(ip, port);
return ret;
#else
return 0;
#endif
}
void FleetWrapper::GatherServers(const std::vector<uint64_t>& host_sign_list,
int node_num) {
#ifdef PADDLE_WITH_PSLIB
......
......@@ -151,6 +151,8 @@ class FleetWrapper {
void FinalizeWorker();
// run server
uint64_t RunServer();
// run server with ip port
uint64_t RunServer(const std::string& ip, uint32_t port);
// gather server ip
void GatherServers(const std::vector<uint64_t>& host_sign_list, int node_num);
// gather client ip
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#include <vector>
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/platform/errors.h"
namespace gloo {
namespace rendezvous {
HdfsStore::HdfsStore(const std::string& path) {
path_ = path;
wait_sleep_ms_ = 3000;
wait_timeout_ = std::chrono::seconds(999999999);
}
void HdfsStore::set(const std::string& key, const std::vector<char>& data) {
#ifdef PADDLE_WITH_GLOO
auto tmp = TmpPath(key);
auto path = ObjectPath(key);
bool is_exists = paddle::framework::fs_exists(path);
if (is_exists) {
LOG(WARNING) << "path exists, will be removed: " << path;
paddle::framework::fs_remove(path);
}
int err_no = 0;
std::shared_ptr<FILE> fp = paddle::framework::fs_open_write(tmp, &err_no, "");
size_t write_count = fwrite_unlocked(data.data(), 1, data.size(), fp.get());
VLOG(3) << "HdfsStore::set write_count=" << write_count << " key " << key;
fp.reset();
paddle::framework::fs_mv(tmp, path);
#endif
}
std::vector<char> HdfsStore::get(const std::string& key) {
auto path = ObjectPath(key);
std::vector<char> result;
#ifdef PADDLE_WITH_GLOO
// block until key is set
wait({key});
bool is_exists = paddle::framework::fs_exists(path);
PADDLE_ENFORCE_EQ(is_exists, true,
paddle::platform::errors::NotFound(
"HdfsStore::get, path not exists: " + path));
int err_no = 0;
std::shared_ptr<FILE> fp = paddle::framework::fs_open_read(path, &err_no, "");
char buffer = '\0';
size_t read_count = 0;
while (fread(&buffer, 1, 1, fp.get()) == 1) {
++read_count;
result.push_back(buffer);
}
VLOG(3) << "HdfsStore::get read_count " << read_count;
#endif
return result;
}
void HdfsStore::wait(const std::vector<std::string>& keys) {
#ifdef PADDLE_WITH_GLOO
wait(keys, wait_timeout_); // NOLINT
#endif
}
void HdfsStore::wait(const std::vector<std::string>& keys,
const std::chrono::milliseconds&) { // NOLINT
#ifdef PADDLE_WITH_GLOO
auto start = std::chrono::steady_clock::now();
while (!Check(keys)) {
auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::steady_clock::now() - start);
if (wait_timeout_ != gloo::kNoTimeout && elapsed > wait_timeout_) {
PADDLE_ENFORCE_EQ(0, 1, paddle::platform::errors::ExecutionTimeout(
"HdfsStore::wait, Wait timeout for key(s): " +
::gloo::MakeString(keys)));
}
std::this_thread::sleep_for(std::chrono::milliseconds(wait_sleep_ms_));
}
#endif
}
std::string HdfsStore::EncodeName(const std::string& name) {
thread_local std::hash<std::string> hash_func;
return std::to_string(hash_func(name));
}
std::string HdfsStore::TmpPath(const std::string& name) {
return path_ + "/" + EncodeName(name) + "_tmp";
}
std::string HdfsStore::ObjectPath(const std::string& name) {
return path_ + "/" + EncodeName(name);
}
bool HdfsStore::Check(const std::vector<std::string>& keys) {
#ifdef PADDLE_WITH_GLOO
std::vector<std::string> paths;
for (const auto& key : keys) {
paths.push_back(ObjectPath(key));
}
for (const auto& path : paths) {
bool is_exists = paddle::framework::fs_exists(path);
VLOG(3) << "HdfsStore::Check " << is_exists << " path " << path;
if (!is_exists) {
return false;
}
}
#endif
return true;
}
} // namespace rendezvous
} // namespace gloo
namespace paddle {
namespace framework {
void GlooWrapper::Init(int rank, int size, const std::string& path,
const std::string& fs_name, const std::string& fs_ugi,
const std::string& iface, const std::string& prefix) {
if (is_initialized_) {
return;
}
rank_ = rank;
size_ = size;
std::string cmd = std::string("hadoop fs");
cmd += " -D fs.default.name=" + fs_name;
cmd += " -D hadoop.job.ugi=" + fs_ugi;
paddle::framework::hdfs_set_command(cmd);
#ifdef PADDLE_WITH_GLOO
gloo::transport::tcp::attr attr;
attr.iface = iface;
auto file_store = gloo::rendezvous::HdfsStore(path);
auto prefix_store = gloo::rendezvous::PrefixStore(prefix, file_store);
auto dev = gloo::transport::tcp::CreateDevice(attr);
auto context = std::make_shared<gloo::rendezvous::Context>(rank, size);
context->setTimeout(file_store.wait_timeout_);
context->connectFullMesh(prefix_store, dev);
context_ = std::move(context);
#endif
is_initialized_ = true;
}
template void GlooWrapper::AllReduce<int64_t>(
std::vector<int64_t>& sendbuf, // NOLINT
std::vector<int64_t>& recvbuf, // NOLINT
const std::string& mode);
template void GlooWrapper::AllReduce<double>(
std::vector<double>& sendbuf, // NOLINT
std::vector<double>& recvbuf, // NOLINT
const std::string& mode);
template std::vector<int64_t> GlooWrapper::AllGather<int64_t>(
int64_t& input); // NOLINT
template std::vector<double> GlooWrapper::AllGather<double>(
double& input); // NOLINT
} // namespace framework
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif
#ifdef _LINUX
#include <sys/types.h>
#include <unistd.h>
#endif
#include <iostream>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#ifdef PADDLE_WITH_GLOO
#include <gloo/allgather.h>
#include <gloo/allreduce.h>
#include <gloo/barrier.h>
#include <gloo/rendezvous/context.h>
#include <gloo/rendezvous/file_store.h>
#include <gloo/rendezvous/prefix_store.h>
#include <gloo/rendezvous/store.h>
#include <gloo/transport/tcp/device.h>
#endif
#include "paddle/fluid/framework/variable_helper.h"
namespace gloo {
namespace rendezvous {
#ifdef PADDLE_WITH_GLOO
class HdfsStore : public gloo::rendezvous::Store {
#else
class HdfsStore {
#endif
public: // NOLINT
explicit HdfsStore(const std::string& path);
virtual ~HdfsStore() {}
virtual void set(const std::string& key, const std::vector<char>& data);
virtual std::vector<char> get(const std::string& key);
virtual void wait(const std::vector<std::string>& keys);
virtual void wait(const std::vector<std::string>& keys,
const std::chrono::milliseconds& timeout);
std::string EncodeName(const std::string& name);
std::string TmpPath(const std::string& name);
std::string ObjectPath(const std::string& name);
bool Check(const std::vector<std::string>& keys);
std::string path_;
int wait_sleep_ms_;
std::chrono::seconds wait_timeout_;
};
} // namespace rendezvous
} // namespace gloo
namespace paddle {
namespace framework {
class GlooWrapper {
public:
GlooWrapper() {}
virtual ~GlooWrapper() {}
void Init(int rank, int size, const std::string& path,
const std::string& fs_name, const std::string& fs_ugi,
const std::string& iface, const std::string& prefix);
int Rank() {
CHECK_EQ(is_initialized_, true);
return rank_;
}
int Size() {
CHECK_EQ(is_initialized_, true);
return size_;
}
void Barrier() {
CHECK_EQ(is_initialized_, true);
#ifdef PADDLE_WITH_GLOO
gloo::BarrierOptions opts(context_);
gloo::barrier(opts);
#endif
}
template <typename T>
void AllReduce(std::vector<T>& sendbuf, std::vector<T>& recvbuf, // NOLINT
const std::string& mode = "sum") {
CHECK_EQ(is_initialized_, true);
CHECK_EQ(sendbuf.size() == recvbuf.size(), true);
#ifdef PADDLE_WITH_GLOO
gloo::AllreduceOptions opts(context_);
opts.setInput(sendbuf.data(), sendbuf.size());
opts.setOutput(recvbuf.data(), recvbuf.size());
if (mode == "sum") {
opts.setReduceFunction(
static_cast<void (*)(void*, const void*, const void*, size_t)>(
&gloo::sum<T>));
} else if (mode == "max") {
opts.setReduceFunction(
static_cast<void (*)(void*, const void*, const void*, size_t)>(
&gloo::max<T>));
} else if (mode == "min") {
opts.setReduceFunction(
static_cast<void (*)(void*, const void*, const void*, size_t)>(
&gloo::min<T>));
} else {
PADDLE_ENFORCE_EQ(0, 1, paddle::platform::errors::InvalidArgument(
"AllReduce mode not known: " + mode));
}
gloo::allreduce(opts);
#endif
}
template <typename T>
std::vector<T> AllGather(T& input) { // NOLINT
CHECK_EQ(is_initialized_, true);
std::vector<T> ret(size_, T());
#ifdef PADDLE_WITH_GLOO
gloo::AllgatherOptions opts(context_);
opts.setInput(&input, 1);
opts.setOutput(ret.data(), size_);
gloo::allgather(opts);
#endif
return std::move(ret);
}
protected:
bool is_initialized_ = false;
#ifdef PADDLE_WITH_GLOO
std::shared_ptr<gloo::Context> context_ = nullptr;
#endif
int rank_ = 0;
int size_ = 0;
};
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <fstream>
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#include "paddle/fluid/framework/io/fs.h"
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif
TEST(TEST_GLOO, store_1) {
#ifdef _LINUX
#ifdef PADDLE_WITH_GLOO
#else
auto store = gloo::rendezvous::HdfsStore("./test_gllo_store");
store.set("1", std::vector<char>{'t', 'e', 's', 't'});
store.get("1");
try {
store.get("2");
} catch (...) {
VLOG(3) << "catch expected error of not found";
}
store.wait(std::vector<std::string>{"test"});
store.wait(std::vector<std::string>{"test"}, std::chrono::milliseconds(0));
store.EncodeName("1");
store.TmpPath("1");
store.ObjectPath("1");
store.Check(std::vector<std::string>{"test"});
auto gw = paddle::framework::GlooWrapper();
gw.Init(0, 1, "", "", "", "", "");
gw.Init(0, 1, "", "", "", "", "");
gw.Rank();
gw.Size();
gw.Barrier();
std::vector<double> input;
std::vector<double> output;
gw.AllReduce(input, output);
int64_t t;
gw.AllGather(t);
#endif
#endif
}
TEST(TEST_FLEET, fleet_1) {
auto fleet = paddle::framework::FleetWrapper::GetInstance();
#ifdef PADDLE_WITH_PSLIB
#else
fleet->RunServer("", 0);
#endif
}
cc_library(fs SRCS fs.cc DEPS string_helper glog boost)
cc_library(shell SRCS shell.cc DEPS string_helper glog)
cc_test(test_fs SRCS test_fs.cc DEPS fs shell)
......@@ -196,6 +196,13 @@ void localfs_mkdir(const std::string& path) {
shell_execute(string::format_string("mkdir -p %s", path.c_str()));
}
void localfs_mv(const std::string& src, const std::string& dest) {
if (src == "" || dest == "") {
return;
}
shell_execute(string::format_string("mv %s %s", src.c_str(), dest.c_str()));
}
static size_t& hdfs_buffer_size_internal() {
static size_t x = 0;
return x;
......@@ -314,6 +321,14 @@ void hdfs_mkdir(const std::string& path) {
hdfs_command().c_str(), path.c_str()));
}
void hdfs_mv(const std::string& src, const std::string& dest) {
if (src == "" || dest == "") {
return;
}
shell_execute(string::format_string(
"%s -mv %s %s; true", hdfs_command().c_str(), src.c_str(), dest.c_str()));
}
int fs_select_internal(const std::string& path) {
if (fs_begin_with_internal(path, "hdfs:")) {
return 1;
......@@ -452,5 +467,19 @@ void fs_mkdir(const std::string& path) {
LOG(FATAL) << "Not supported";
}
}
void fs_mv(const std::string& src, const std::string& dest) {
int s = fs_select_internal(src);
int d = fs_select_internal(dest);
CHECK_EQ(s, d);
switch (s) {
case 0:
return localfs_mv(src, dest);
case 1:
return hdfs_mv(src, dest);
}
}
} // end namespace framework
} // end namespace paddle
......@@ -50,6 +50,8 @@ extern bool localfs_exists(const std::string& path);
extern void localfs_mkdir(const std::string& path);
extern void localfs_mv(const std::string& src, const std::string& dest);
// hdfs
extern size_t hdfs_buffer_size();
......@@ -75,6 +77,8 @@ extern bool hdfs_exists(const std::string& path);
extern void hdfs_mkdir(const std::string& path);
extern void hdfs_mv(const std::string& src, const std::string& dest);
// aut-detect fs
extern std::shared_ptr<FILE> fs_open_read(const std::string& path, int* err_no,
const std::string& converter);
......@@ -97,5 +101,8 @@ extern std::string fs_tail(const std::string& path);
extern bool fs_exists(const std::string& path);
extern void fs_mkdir(const std::string& path);
extern void fs_mv(const std::string& src, const std::string& dest);
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <fstream>
#include "paddle/fluid/framework/io/fs.h"
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif
TEST(FS, mv) {
#ifdef _LINUX
std::ofstream out("src.txt");
out.close();
paddle::framework::fs_mv("src.txt", "dest.txt");
paddle::framework::hdfs_mv("", "");
paddle::framework::localfs_mv("", "");
try {
paddle::framework::hdfs_mv("afs:/none", "afs:/none");
} catch (...) {
VLOG(3) << "test hdfs_mv, catch expected errors of unknown path";
}
try {
paddle::framework::fs_mv("afs:/none", "afs:/none");
} catch (...) {
VLOG(3) << "test hdfs_mv, catch expected errors of unknown path";
}
try {
paddle::framework::hdfs_mv("unknown:/none", "unknown:/none");
} catch (...) {
VLOG(3) << "test hdfs_mv, catch expected errors of unknown prefix";
}
#endif
}
set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapper nccl_wrapper prune
feed_fetch_method pass_builder parallel_executor profiler layer tracer engine scope_pool
analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context)
analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context
gloo_wrapper)
if(NOT WIN32)
set(PYBIND_DEPS ${PYBIND_DEPS} nccl_context)
......@@ -22,6 +23,7 @@ set(PYBIND_SRCS
global_value_getter_setter.cc
reader_py.cc
fleet_wrapper_py.cc
gloo_wrapper_py.cc
box_helper_py.cc
nccl_wrapper_py.cc
data_set_py.cc
......
......@@ -46,7 +46,11 @@ void BindFleetWrapper(py::module* m) {
.def("push_dense", &framework::FleetWrapper::PushDenseVarsSync)
.def("pull_dense", &framework::FleetWrapper::PullDenseVarsSync)
.def("init_server", &framework::FleetWrapper::InitServer)
.def("run_server", &framework::FleetWrapper::RunServer)
.def("run_server", (uint64_t (framework::FleetWrapper::*)(void)) &
framework::FleetWrapper::RunServer)
.def("run_server", (uint64_t (framework::FleetWrapper::*)( // NOLINT
const std::string&, uint32_t)) & // NOLINT
framework::FleetWrapper::RunServer)
.def("init_worker", &framework::FleetWrapper::InitWorker)
.def("init_model", &framework::FleetWrapper::PushDenseParamSync)
.def("save_model", &framework::FleetWrapper::SaveModel)
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fcntl.h>
#ifdef _POSIX_C_SOURCE
#undef _POSIX_C_SOURCE
#endif
#ifdef _XOPEN_SOURCE
#undef _XOPEN_SOURCE
#endif
#include <string>
#include <vector>
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/pybind/gloo_wrapper_py.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
void BindGlooWrapper(py::module* m) {
py::class_<framework::GlooWrapper>(*m, "Gloo")
.def(py::init())
.def("init", &framework::GlooWrapper::Init)
.def("rank", &framework::GlooWrapper::Rank)
.def("size", &framework::GlooWrapper::Size)
.def("barrier", &framework::GlooWrapper::Barrier)
.def("all_reduce", &framework::GlooWrapper::AllReduce<int64_t>)
.def("all_reduce", &framework::GlooWrapper::AllReduce<double>)
.def("all_gather", &framework::GlooWrapper::AllGather<int64_t>)
.def("all_gather", &framework::GlooWrapper::AllGather<double>)
.def("Allreduce", &framework::GlooWrapper::AllReduce<int64_t>)
.def("Allreduce", &framework::GlooWrapper::AllReduce<double>);
} // end BindGlooWrapper
} // end namespace pybind
} // end namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
void BindGlooWrapper(py::module* m);
} // namespace pybind
} // namespace paddle
......@@ -62,6 +62,7 @@ limitations under the License. */
#include "paddle/fluid/pybind/exception.h"
#include "paddle/fluid/pybind/fleet_wrapper_py.h"
#include "paddle/fluid/pybind/global_value_getter_setter.h"
#include "paddle/fluid/pybind/gloo_wrapper_py.h"
#include "paddle/fluid/pybind/imperative.h"
#include "paddle/fluid/pybind/inference_api.h"
#include "paddle/fluid/pybind/ir.h"
......@@ -2204,6 +2205,7 @@ All parameter, weight, gradient are variables in Paddle.
.def("device_count", &ParallelExecutor::DeviceCount);
BindFleetWrapper(&m);
BindGlooWrapper(&m);
BindBoxHelper(&m);
#ifndef _WIN32
BindNCCLWrapper(&m);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册