未验证 提交 1c681383 编写于 作者: L lilong12 提交者: GitHub

[api 2.0] add collective op for cpu using gloo and paddle.distributed.* apis (#26552)

add collective op for cpu using gloo and paddle.distributed.* apis
上级 07973c57
......@@ -105,6 +105,11 @@ enum GlooStoreType { HDFS, HTTP };
class GlooWrapper {
public:
static std::shared_ptr<GlooWrapper> GetInstance() {
static auto s_instance = std::make_shared<GlooWrapper>();
return s_instance;
}
GlooWrapper() {}
virtual ~GlooWrapper() {}
......@@ -153,6 +158,11 @@ class GlooWrapper {
#endif
}
bool IsInitialized() { return is_initialized_; }
#ifdef PADDLE_WITH_GLOO
std::shared_ptr<gloo::Context> GetContext() { return context_; }
#endif
template <typename T>
std::vector<T> AllReduce(std::vector<T>& sendbuf, // NOLINT
const std::string& mode = "sum") { // NOLINT
......
......@@ -35,5 +35,9 @@ if(WITH_NCCL)
op_library(c_gen_nccl_id_op DEPS ${COLLECTIVE_DEPS} nccl_common)
endif()
if(WITH_GLOO)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} gloo_wrapper)
endif()
set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COLLECTIVE_DEPS} PARENT_SCOPE)
set(GLOB_COLLECTIVE_DEPS ${COLLECTIVE_DEPS} CACHE INTERNAL "collective dependency")
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/barrier_op.h"
#include <memory>
namespace paddle {
namespace operators {
class BarrierOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {}
};
class BarrierOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor) Input data (only used in CUDAKernel).");
AddOutput("Out", "(Tensor) Output data (only used in CUDAKernel).");
AddAttr<int>("ring_id", "(int default 0) communication ring id.")
.SetDefault(0);
AddComment(R"DOC(
Barrier Operator - Barrier among all pariticapitors.)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(barrier, ops::BarrierOp, ops::BarrierOpMaker);
REGISTER_OP_CPU_KERNEL(barrier, ops::BarrierOpCPUKernel<int>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/barrier_op.h"
#include <memory>
#if defined(PADDLE_WITH_NCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace operators {
template <typename T>
class BarrierOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_NCCL)
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
auto place = ctx.GetPlace();
ncclDataType_t dtype = platform::ToNCCLDataType(in->type());
int64_t numel = in->numel();
const void* sendbuff = in->data<void>();
void* recvbuff = out->mutable_data<T>(place);
int rid = ctx.Attr<int>("ring_id");
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
auto stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
ncclRedOp_t nccl_red_type = ncclSum;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, numel, dtype, nccl_red_type, comm->comm(), stream));
auto comm_stream =
platform::NCCLCommContext::Instance().Get(rid, place)->stream();
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(comm_stream));
#else
PADDLE_THROW(platform::errors::Unavailable(
"PaddlePaddle should compile with NCCL."));
#endif
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(barrier, ops::BarrierOpCUDAKernel<int>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_GLOO)
#include <gloo/barrier.h>
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
namespace paddle {
namespace operators {
template <typename T>
class BarrierOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_GLOO)
auto gloo = paddle::framework::GlooWrapper::GetInstance();
PADDLE_ENFORCE_EQ(
gloo->IsInitialized(), true,
platform::errors::PreconditionNotMet(
"You must initialize the gloo environment first to use it."));
gloo::BarrierOptions opts(gloo->GetContext());
gloo::barrier(opts);
#else
PADDLE_THROW(platform::errors::Unavailable(
"PaddlePaddle should compile with GLOO by setting WITH_GLOO=ON"));
#endif
}
};
} // namespace operators
} // namespace paddle
......@@ -23,6 +23,11 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_GLOO)
#include <gloo/allgather.h>
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
namespace paddle {
namespace operators {
......@@ -30,7 +35,31 @@ template <typename T>
class CAllGatherOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_THROW("unimplemented cpu kernel for CAllGatherOp.");
#if defined(PADDLE_WITH_GLOO)
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
framework::DDim out_dims = in->dims();
auto place = ctx.GetPlace();
auto gloo = paddle::framework::GlooWrapper::GetInstance();
auto nranks = gloo->Size();
out_dims[0] *= nranks;
int64_t send_numel = in->numel();
const T* send_buff = in->data<T>();
T* recv_buff = out->mutable_data<T>(out_dims, place);
PADDLE_ENFORCE_EQ(
gloo->IsInitialized(), true,
platform::errors::PreconditionNotMet(
"You must initialize the gloo environment first to use it."));
gloo::AllgatherOptions opts(gloo->GetContext());
opts.setInput(const_cast<T*>(send_buff), send_numel);
opts.setOutput(recv_buff, send_numel * nranks);
gloo::allgather(opts);
#else
PADDLE_THROW(platform::errors::Unavailable(
"PaddlePaddle should compile with GLOO by setting WITH_GLOO=ON"));
#endif
}
};
......
......@@ -25,6 +25,11 @@ limitations under the License. */
#include "paddle/fluid/platform/nccl_helper.h"
#endif
#if defined(PADDLE_WITH_GLOO)
#include <gloo/allreduce.h>
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
namespace paddle {
namespace operators {
......@@ -50,7 +55,53 @@ template <ReduceType red_type, typename T>
class CAllReduceOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_THROW("CAllReduce op do not support CPUKernel for now.");
#if defined(PADDLE_WITH_GLOO)
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
auto place = ctx.GetPlace();
int64_t send_numel = in->numel();
const T* send_buff = in->data<T>();
T* recv_buff = out->mutable_data<T>(in->dims(), place);
auto gloo = paddle::framework::GlooWrapper::GetInstance();
PADDLE_ENFORCE_EQ(
gloo->IsInitialized(), true,
platform::errors::PreconditionNotMet(
"You must initialize the gloo environment first to use it."));
gloo::AllreduceOptions opts(gloo->GetContext());
opts.setInput(const_cast<T*>(send_buff), send_numel);
opts.setOutput(recv_buff, send_numel);
switch (red_type) {
case kRedSum:
opts.setReduceFunction(
static_cast<void (*)(void*, const void*, const void*, size_t)>(
&gloo::sum<T>));
break;
case kRedMax:
opts.setReduceFunction(
static_cast<void (*)(void*, const void*, const void*, size_t)>(
&gloo::max<T>));
break;
case kRedMin:
opts.setReduceFunction(
static_cast<void (*)(void*, const void*, const void*, size_t)>(
&gloo::min<T>));
break;
case kRedProd:
opts.setReduceFunction(
static_cast<void (*)(void*, const void*, const void*, size_t)>(
&gloo::product<T>));
break;
default:
PADDLE_ENFORCE_EQ(true, false,
platform::errors::InvalidArgument(
"Invalid reduce type: %d.", red_type));
}
gloo::allreduce(opts);
#else
PADDLE_THROW(platform::errors::Unavailable(
"PaddlePaddle should compile with GLOO by setting WITH_GLOO=ON"));
#endif
}
};
......
......@@ -22,6 +22,11 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_GLOO)
#include <gloo/broadcast.h>
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
namespace paddle {
namespace operators {
......@@ -29,7 +34,27 @@ template <typename T>
class CBroadcastOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_THROW("Unimplemented cpu kernel for CBroadcastOp.");
#if defined(PADDLE_WITH_GLOO)
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
auto root = ctx.Attr<int>("root");
auto place = ctx.GetPlace();
int64_t send_numel = in->numel();
T* recv_buff = out->mutable_data<T>(in->dims(), place);
auto gloo = paddle::framework::GlooWrapper::GetInstance();
PADDLE_ENFORCE_EQ(
gloo->IsInitialized(), true,
platform::errors::PreconditionNotMet(
"You must initialize the gloo environment first to use it."));
gloo::BroadcastOptions opts(gloo->GetContext());
opts.setOutput(recv_buff, send_numel);
opts.setRoot(root);
gloo::broadcast(opts);
#else
PADDLE_THROW(platform::errors::Unavailable(
"PaddlePaddle should compile with GLOO by setting WITH_GLOO=ON"));
#endif
}
};
......
......@@ -28,6 +28,10 @@ limitations under the License. */
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
#if defined(PADDLE_WITH_GLOO)
#include <gloo/reduce.h>
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
namespace paddle {
namespace operators {
......@@ -54,9 +58,55 @@ template <ReduceType red_type, typename T>
class CReduceOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_GLOO)
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
auto root_id = ctx.Attr<int>("root_id");
auto place = ctx.GetPlace();
int64_t send_numel = in->numel();
const T* send_buff = in->data<T>();
T* recv_buff = out->mutable_data<T>(in->dims(), place);
auto gloo = paddle::framework::GlooWrapper::GetInstance();
PADDLE_ENFORCE_EQ(
true, false,
platform::errors::Unavailable("Unimplemented CReduceOpCPUKernel now."));
gloo->IsInitialized(), true,
platform::errors::PreconditionNotMet(
"You must initialize the gloo environment first to use it."));
gloo::ReduceOptions opts(gloo->GetContext());
opts.setInput(const_cast<T*>(send_buff), send_numel);
opts.setOutput(recv_buff, send_numel);
opts.setRoot(root_id);
switch (red_type) {
case kRedSum:
opts.setReduceFunction(
static_cast<void (*)(void*, const void*, const void*, size_t)>(
&gloo::sum<T>));
break;
case kRedMax:
opts.setReduceFunction(
static_cast<void (*)(void*, const void*, const void*, size_t)>(
&gloo::max<T>));
break;
case kRedMin:
opts.setReduceFunction(
static_cast<void (*)(void*, const void*, const void*, size_t)>(
&gloo::min<T>));
break;
case kRedProd:
opts.setReduceFunction(
static_cast<void (*)(void*, const void*, const void*, size_t)>(
&gloo::product<T>));
break;
default:
PADDLE_ENFORCE_EQ(true, false,
platform::errors::InvalidArgument(
"Invalid reduce type: %d.", red_type));
}
gloo::reduce(opts);
#else
PADDLE_THROW(platform::errors::Unavailable(
"PaddlePaddle should compile with GLOO by setting WITH_GLOO=ON"));
#endif
}
};
......
......@@ -22,6 +22,11 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_GLOO)
#include <gloo/scatter.h>
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
namespace paddle {
namespace operators {
......@@ -29,9 +34,39 @@ template <typename T>
class CScatterOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(true, false,
platform::errors::Unavailable(
"Unimplemented cpu kernel for CScatterOp."));
#if defined(PADDLE_WITH_GLOO)
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
auto root_id = ctx.Attr<int>("root");
auto gloo = paddle::framework::GlooWrapper::GetInstance();
PADDLE_ENFORCE_EQ(
gloo->IsInitialized(), true,
platform::errors::PreconditionNotMet(
"You must initialize the gloo environment first to use it."));
int64_t send_numel = out->numel();
auto nranks = gloo->Size();
auto rank = gloo->Rank();
T* recv_buff = out->data<T>();
gloo::ScatterOptions opts(gloo->GetContext());
if (root_id == rank) {
T* send_buff = const_cast<T*>(in->data<T>());
std::vector<T*> ptrs(nranks);
for (int i = 0; i < nranks; ++i) {
ptrs[i] = send_buff;
send_buff += send_numel;
}
opts.setInputs(ptrs, send_numel);
}
opts.setOutput(recv_buff, send_numel);
opts.setRoot(root_id);
gloo::scatter(opts);
#else
PADDLE_THROW(platform::errors::Unavailable(
"PaddlePaddle should compile with GLOO by setting WITH_GLOO=ON"));
#endif
}
};
......
......@@ -88,6 +88,10 @@ ELSE()
set(STREAM_CALLBACK_DEPS)
ENDIF()
if(WITH_GLOO)
cc_library(gloo_context SRCS gloo_context.cc DEPS framework_proto gloo_wrapper enforce)
endif()
cc_library(cudnn_workspace_helper SRCS cudnn_workspace_helper.cc DEPS boost)
# memcpy depends on device_context, here add deps individually for
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/gloo_context.h"
namespace paddle {
namespace platform {
#if defined(PADDLE_WITH_GLOO)
void GlooParallelContext::Init() {
auto gloo_ptr = paddle::framework::GlooWrapper::GetInstance();
gloo_ptr->SetRank(strategy_.rank);
gloo_ptr->SetSize(strategy_.rank_num);
gloo_ptr->SetPrefix(strategy_.prefix);
gloo_ptr->SetIface(strategy_.iface);
gloo_ptr->SetTimeoutSeconds(strategy_.init_seconds, strategy_.run_seconds);
gloo_ptr->SetHdfsStore(strategy_.path, strategy_.fs_name, strategy_.fs_ugi);
gloo_ptr->Init();
}
#endif
} // namespace platform
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
namespace paddle {
namespace platform {
#if defined(PADDLE_WITH_GLOO)
struct GlooParallelStrategy {
int rank{0};
int rank_num{1};
std::string iface;
std::string prefix;
int init_seconds{9999999};
int run_seconds{9999999};
std::string path;
std::string fs_name;
std::string fs_ugi;
};
class GlooParallelContext {
public:
explicit GlooParallelContext(const GlooParallelStrategy& strategy)
: strategy_(strategy) {}
virtual ~GlooParallelContext() {}
virtual void Init();
protected:
GlooParallelStrategy strategy_;
};
#endif
} // namespace platform
} // namespace paddle
......@@ -40,6 +40,11 @@ set(PYBIND_SRCS
inference_api.cc
generator_py.cc)
if(WITH_GLOO)
set(PYBIND_DEPS ${PYBIND_DEPS} gloo_context)
set(PYBIND_SRCS ${PYBIND_SRCS} gloo_context_py.cc)
endif(WITH_GLOO)
if (WITH_CRYPTO)
set(PYBIND_DEPS ${PYBIND_DEPS} paddle_crypto)
set(PYBIND_SRCS ${PYBIND_SRCS} crypto.cc)
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/pybind/gloo_context_py.h"
#include <Python.h>
#include <pybind11/chrono.h>
#include <pybind11/complex.h>
#include <pybind11/functional.h>
#include <pybind11/stl.h>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/memory/allocation/mmap_allocator.h"
#include "paddle/fluid/platform/gloo_context.h"
namespace paddle {
namespace pybind {
namespace py = ::pybind11;
// Bind Methods
void BindGlooContext(py::module *m) {
// define parallel context for gloo
#if defined(PADDLE_WITH_GLOO)
py::class_<platform::GlooParallelStrategy> gloo_parallel_strategy(
*m, "GlooParallelStrategy", "");
gloo_parallel_strategy.def(py::init())
.def_property("rank_num",
[](const platform::GlooParallelStrategy &self) {
return self.rank_num;
},
[](platform::GlooParallelStrategy &self, int nranks) {
self.rank_num = nranks;
})
.def_property(
"rank",
[](const platform::GlooParallelStrategy &self) { return self.rank; },
[](platform::GlooParallelStrategy &self, int rank) {
self.rank = rank;
})
.def_property(
"iface",
[](const platform::GlooParallelStrategy &self) { return self.iface; },
[](platform::GlooParallelStrategy &self, const std::string &iface) {
self.iface = iface;
})
.def_property("prefix",
[](const platform::GlooParallelStrategy &self) {
return self.prefix;
},
[](platform::GlooParallelStrategy &self,
const std::string &prefix) { self.prefix = prefix; })
.def_property("init_seconds",
[](const platform::GlooParallelStrategy &self) {
return self.init_seconds;
},
[](platform::GlooParallelStrategy &self, int init_seconds) {
self.init_seconds = init_seconds;
})
.def_property("run_seconds",
[](const platform::GlooParallelStrategy &self) {
return self.run_seconds;
},
[](platform::GlooParallelStrategy &self, int run_seconds) {
self.run_seconds = run_seconds;
})
.def_property(
"path",
[](const platform::GlooParallelStrategy &self) { return self.path; },
[](platform::GlooParallelStrategy &self, const std::string &path) {
self.path = path;
})
.def_property("fs_name",
[](const platform::GlooParallelStrategy &self) {
return self.fs_name;
},
[](platform::GlooParallelStrategy &self,
const std::string &fs_name) { self.fs_name = fs_name; })
.def_property("fs_ugi",
[](const platform::GlooParallelStrategy &self) {
return self.fs_ugi;
},
[](platform::GlooParallelStrategy &self,
const std::string &fs_ugi) { self.fs_ugi = fs_ugi; });
py::class_<platform::GlooParallelContext> gloo_ctx(*m, "GlooParallelContext");
gloo_ctx.def(py::init<const platform::GlooParallelStrategy &>())
.def("init", [](platform::GlooParallelContext &self) { self.Init(); });
#endif
}
} // namespace pybind
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <Python.h>
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace paddle {
namespace pybind {
void BindGlooContext(pybind11::module* m);
} // namespace pybind
} // namespace paddle
......@@ -86,6 +86,19 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"accuracy", {"Correct", "Total"}},
{"fill_constant", {"Out"}},
{"matmul", {"Out"}},
{"c_broadcast", {"Out"}},
{"c_allreduce_sum", {"Out"}},
{"c_allreduce_max", {"Out"}},
{"c_allreduce_min", {"Out"}},
{"c_allreduce_prod", {"Out"}},
{"c_reduce_sum", {"Out"}},
{"c_reduce_max", {"Out"}},
{"c_reduce_min", {"Out"}},
{"c_reduce_prod", {"Out"}},
{"c_reduce", {"Out"}},
{"c_allgather", {"Out"}},
{"c_scatter", {"Out"}},
{"barrier", {"Out"}},
{"fake_quantize_dequantize_moving_average_abs_max",
{"Out", "OutScale", "OutAccum", "OutState"}},
{"fake_quantize_dequantize_abs_max", {"Out", "OutScale"}},
......
......@@ -66,6 +66,7 @@ limitations under the License. */
#include "paddle/fluid/pybind/fleet_wrapper_py.h"
#include "paddle/fluid/pybind/generator_py.h"
#include "paddle/fluid/pybind/global_value_getter_setter.h"
#include "paddle/fluid/pybind/gloo_context_py.h"
#include "paddle/fluid/pybind/gloo_wrapper_py.h"
#include "paddle/fluid/pybind/heter_wrapper_py.h"
#include "paddle/fluid/pybind/imperative.h"
......@@ -2611,6 +2612,9 @@ All parameter, weight, gradient are variables in Paddle.
#endif
#ifdef PADDLE_WITH_NCCL
BindNCCLWrapper(&m);
#endif
#ifdef PADDLE_WITH_GLOO
BindGlooContext(&m);
#endif
BindGraph(&m);
BindNode(&m);
......
......@@ -11,3 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .collective import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
from ..fluid.layer_helper import LayerHelper
from ..fluid.framework import Variable, OpProtoHolder, in_dygraph_mode, convert_np_dtype_to_dtype_
from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
from ..fluid.layers.tensor import fill_constant
from ..fluid.layers import utils
from ..fluid.dygraph.parallel import prepare_context
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
__all__ = [
'broadcast',
'all_reduce',
'reduce',
'all_gather',
'scatter',
'barrier',
'ReduceOp',
]
class ReduceOp:
"""Reduce Operation"""
SUM = 0
MAX = 1
MIN = 2
PROD = 3
class _Group():
"""The abstract representation of group."""
def __init__(self, rank, rank_num):
self.rank = rank
self.nranks = rank_num
_default_group = _Group(
int(os.getenv("PADDLE_TRAINER_ID", "0")),
int(os.getenv("PADDLE_TRAINERS_NUM", "1")))
def broadcast(tensor, src, group=0):
"""
Broadcast a tensor from the source to all others.
Args:
tensor (Tensor): The Tensor to send if current rank is the source, or the tensor to receive otherwise. Its data type
should be float16, float32, float64, int32 or int64.
src (int): The source rank.
group (int): The process group to work on. It is Optional.
Returns:
None.
Examples:
.. code-block:: python
import paddle
import paddle.prepare_context as prepare_context
paddle.disable_static()
paddle.set_device('gpu:%d'%paddle.ParallelEnv().dev_id)
prepare_context()
if paddle.ParallelEnv().local_rank == 0:
np_data = np.array([[4, 5, 6], [4, 5, 6]])
else:
np_data = np.array([[1, 2, 3], [1, 2, 3]])
data = paddle.to_tensor(np_data)
paddle.distributed.broadcast(data, 1)
out = data.numpy()
# [[1, 2, 3], [1, 2, 3]]
"""
if in_dygraph_mode():
return core.ops.c_broadcast(tensor, tensor, 'root', src,
'use_calc_stream', True, 'ring_id', group)
op_type = 'c_broadcast'
check_variable_and_dtype(
tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
'broadcast')
if not isinstance(src, int) or not isinstance(group, int):
raise ValueError("Both the type of 'src' and 'group' for broadcast "
"should be int.")
helper = LayerHelper(op_type, **locals())
helper.append_op(
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [tensor]},
attrs={
'root': src,
'use_calc_stream': True,
'ring_id': group,
})
def all_reduce(tensor, op=ReduceOp.SUM, group=0):
"""
Reduce a tensor over all ranks so that all get the result.
Args:
tensor (Tensor): The input Tensor. It also works as the output Tensor. Its data type
should be float16, float32, float64, int32 or int64.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used.
group (int): Optional. The process group to work on.
Returns:
None.
Examples:
.. code-block:: python
import paddle
from paddle.distributed import ReduceOp
import paddle.prepare_context as prepare_context
paddle.disable_static()
paddle.set_device('gpu:%d'%paddle.ParallelEnv().dev_id)
prepare_context()
if paddle.ParallelEnv().local_rank == 0:
np_data = np.array([[4, 5, 6], [4, 5, 6]])
else:
np_data = np.array([[1, 2, 3], [1, 2, 3]])
data = paddle.to_tensor(np_data)
paddle.distributed.all_reduce(data)
out = data.numpy()
# [[5, 7, 9], [5, 7, 9]]
"""
if in_dygraph_mode():
if op == ReduceOp.SUM:
return core.ops.c_allreduce_sum(tensor, tensor, 'use_calc_stream',
True, 'ring_id', group)
elif op == ReduceOp.MAX:
return core.ops.c_allreduce_max(tensor, tensor, 'use_calc_stream',
True, 'ring_id', group)
elif op == ReduceOp.MIN:
return core.ops.c_allreduce_min(tensor, tensor, 'use_calc_stream',
True, 'ring_id', group)
elif op == ReduceOp.PROD:
return core.ops.c_allreduce_prod(tensor, tensor, 'use_calc_stream',
True, 'ring_id', group)
else:
raise ValueError("Unknown parameter: {}.".format(op))
check_variable_and_dtype(
tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
'all_reduce')
if not op in [ReduceOp.SUM, ReduceOp.MAX, ReduceOp.MIN, ReduceOp.PROD]:
raise ValueError("The op for all_reduce must be one of educeOp.PROD, "
"ReduceOp.SUM, ReduceOp.MAX, ReduceOp.MIN.")
if op == ReduceOp.SUM:
op_type = 'c_allreduce_sum'
elif op == ReduceOp.MAX:
op_type = 'c_allreduce_max'
elif op == ReduceOp.MIN:
op_type = 'c_allreduce_min'
elif op == ReduceOp.PROD:
op_type = 'c_allreduce_prod'
if not isinstance(group, int):
raise ValueError("The type of 'group' for all_reduce should be int.")
helper = LayerHelper(op_type, **locals())
helper.append_op(
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [tensor]},
attrs={'ring_id': group,
'use_calc_stream': True})
def reduce(tensor, dst, op=ReduceOp.SUM, group=0):
"""
Reduce a tensor to the destination from all others.
Args:
tensor (Tensor): The output Tensor for the destination and the input Tensor otherwise. Its data type
should be float16, float32, float64, int32 or int64.
dst (int): The destination rank id.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used.
group (int): The id of the process group to work on.
Returns:
None.
Examples:
.. code-block:: python
import paddle
import paddle.prepare_context as prepare_context
paddle.disable_static()
paddle.set_device('gpu:%d'%paddle.ParallelEnv().dev_id)
prepare_context()
if paddle.ParallelEnv().local_rank == 0:
np_data = np.array([[4, 5, 6], [4, 5, 6]])
else:
np_data = np.array([[1, 2, 3], [1, 2, 3]])
data = paddle.to_tensor(np_data)
paddle.distributed.reduce(data, 0)
out = data.numpy()
# [[5, 7, 9], [5, 7, 9]]
"""
if in_dygraph_mode():
if op == ReduceOp.SUM:
return core.ops.c_reduce_sum(tensor, tensor, 'use_calc_stream',
True, 'ring_id', group, 'root_id', dst)
elif op == ReduceOp.MAX:
return core.ops.c_reduce_max(tensor, tensor, 'use_calc_stream',
True, 'ring_id', group, 'root_id', dst)
elif op == ReduceOp.MIN:
return core.ops.c_reduce_min(tensor, tensor, 'use_calc_stream',
True, 'ring_id', group, 'root_id', dst)
elif op == ReduceOp.PROD:
return core.ops.c_reduce_prod(tensor, tensor, 'use_calc_stream',
True, 'ring_id', group, 'root_id',
dst)
else:
raise ValueError("Unknown parameter: {}.".format(op))
op_type = 'c_reduce'
check_variable_and_dtype(
tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
'all_reduce')
if not op in [ReduceOp.SUM, ReduceOp.MAX, ReduceOp.MIN, ReduceOp.PROD]:
raise ValueError("The op for reduce must be one of educeOp.PROD, "
"ReduceOp.SUM, ReduceOp.MAX, ReduceOp.MIN.")
if op == ReduceOp.SUM:
op_type = 'c_reduce_sum'
elif op == ReduceOp.MAX:
op_type = 'c_reduce_max'
elif op == ReduceOp.MIN:
op_type = 'c_reduce_min'
elif op == ReduceOp.PROD:
op_type = 'c_reduce_prod'
if not isinstance(dst, int) or not isinstance(group, int):
raise ValueError("Both the type of 'dst' and 'group' for reduce "
"should be int.")
helper = LayerHelper(op_type, **locals())
helper.append_op(
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [tensor]},
attrs={
'ring_id': group,
'use_calc_stream': True,
'root_id': dst,
})
def all_gather(tensor_list, tensor, group=0):
"""
Gather tensors from all participators and all get the result.
Args:
tensor_list (list): A list of output Tensors. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32 or int64.
tensor (Tensor): The Tensor to send. Its data type
should be float16, float32, float64, int32 or int64.
group (int): The id of the process group to work on.
Returns:
None.
Examples:
.. code-block:: python
import paddle
import paddle.prepare_context as prepare_context
paddle.disable_static()
paddle.set_device('gpu:%d'%paddle.ParallelEnv().dev_id)
prepare_context()
tensor_list = []
if paddle.ParallelEnv().local_rank == 0:
np_data1 = np.array([[4, 5, 6], [4, 5, 6]])
np_data2 = np.array([[4, 5, 6], [4, 5, 6]])
data1 = paddle.to_tensor(np_data1)
data2 = paddle.to_tensor(np_data2)
paddle.distributed.all_gather(tensor_list, data1)
else:
np_data1 = np.array([[1, 2, 3], [1, 2, 3]])
np_data2 = np.array([[1, 2, 3], [1, 2, 3]])
data1 = paddle.to_tensor(np_data1)
data2 = paddle.to_tensor(np_data2)
out = paddle.distributed.all_gather(tensor_list, data2)
"""
op_type = 'c_allgather'
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
if in_dygraph_mode():
core.ops.c_allgather(tensor, out, 'use_calc_stream', True, 'ring_id',
group, 'nranks', _default_group.nranks)
else:
if not isinstance(tensor_list, list):
raise ValueError("The type of 'tensor_list' for all_gather "
"should be list.")
for elem in tensor_list:
check_variable_and_dtype(
elem, 'tensor_list',
['float16', 'float32', 'float64', 'int32', 'int64'],
'all_gather')
check_variable_and_dtype(
tensor, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'], 'all_gather')
if not isinstance(group, int):
raise ValueError("The type of 'group' for all_gather "
"should be int.")
helper.append_op(
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [out]},
attrs={
'ring_id': group,
'use_calc_stream': True,
'nranks': _default_group.nranks
})
tensor_list.extend(paddle.split(out, _default_group.nranks, 0))
def scatter(tensor, tensor_list=None, src=0, group=0):
"""
Scatter a tensor to all participators.
Args:
tensor (Tensor): The output Tensor. Its data type
should be float16, float32, float64, int32 or int64.
tensor_list (list): A list of Tensors to scatter. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32 or int64.
src (int): The source rank id.
group (int): The id of the process group to work on.
Returns:
None.
Examples:
.. code-block:: python
import paddle
import paddle.prepare_context as prepare_context
paddle.disable_static()
paddle.set_device('gpu:%d'%paddle.ParallelEnv().dev_id)
prepare_context()
if paddle.ParallelEnv().local_rank == 0:
np_data1 = np.array([7, 8, 9])
np_data2 = np.array([10, 11, 12])
else:
np_data1 = np.array([1, 2, 3])
np_data2 = np.array([4, 5, 6])
data1 = paddle.to_tensor(np_data1)
data2 = paddle.to_tensor(np_data2)
if paddle.ParallelEnv().local_rank == 0:
paddle.distributed.scatter(data1, src=1)
else:
paddle.distributed.scatter(data1, tensor_list=[data1, data2], src=1)
out = data1.numpy()
"""
op_type = 'c_scatter'
global _default_group
rank = _default_group.rank
nranks = _default_group.nranks
if rank != src:
tensor_list = []
for _ in range(nranks):
tensor_list.append(tensor)
temp = paddle.concat(tensor_list, axis=0)
if in_dygraph_mode():
return core.ops.c_scatter(temp, tensor, 'use_calc_stream', True,
'ring_id', group, 'nranks',
_default_group.nranks, 'root', src)
check_variable_and_dtype(
tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
'scatter')
if not isinstance(group, int) or not isinstance(src, int):
raise ValueError("Both the type of 'src' and 'group' for scatter "
"should be int.")
helper = LayerHelper(op_type, **locals())
helper.append_op(
type=op_type,
inputs={'X': [temp]},
outputs={'Out': [tensor]},
attrs={
'ring_id': group,
'root': src,
'use_calc_stream': True,
'nranks': nranks,
})
def barrier(group=0):
"""
Barrier among all participators in the group.
Args:
group (int): The id of the process group to work on.
Returns:
None.
Examples:
.. code-block:: python
import paddle
import paddle.prepare_context as prepare_context
paddle.disable_static()
paddle.set_device('gpu:%d'%paddle.ParallelEnv().dev_id)
prepare_context()
paddle.distributed.barrier()
"""
op_type = 'barrier'
temp = paddle.fill_constant([1], dtype="int32", value="1")
if in_dygraph_mode():
return core.ops.barrier(temp, temp, 'ring_id', group)
if not isinstance(group, int):
raise ValueError("The type of 'group' for barrier must be int.")
helper = LayerHelper(op_type, **locals())
helper.append_op(
type=op_type,
inputs={'X': [temp]},
outputs={'Out': [temp]},
attrs={'ring_id': group})
......@@ -58,6 +58,12 @@ if(NOT WITH_GPU OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_broadcast)
LIST(REMOVE_ITEM TEST_OPS test_collective_reduce)
LIST(REMOVE_ITEM TEST_OPS test_collective_scatter)
LIST(REMOVE_ITEM TEST_OPS test_collective_reduce_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_scatter_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_barrier_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_allreduce_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_broadcast_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_allgather_api)
LIST(REMOVE_ITEM TEST_OPS test_reducescatter)
LIST(REMOVE_ITEM TEST_OPS test_reducescatter_api)
endif()
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import os
import sys
import signal
import time
import socket
from contextlib import closing
from six import string_types
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
import unittest
from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
class TestCollectiveAllgatherAPI(TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank):
with fluid.program_guard(main_prog, startup_program):
tensor_list = []
tindata = layers.data(
name="tindata", shape=[10, 1000], dtype='float32')
paddle.distributed.all_gather(tensor_list, tindata)
return tensor_list
if __name__ == "__main__":
runtime_main(TestCollectiveAllgatherAPI, "allgather")
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import os
import sys
import signal
import time
import socket
from contextlib import closing
from six import string_types
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
import unittest
from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
class TestCollectiveAllreduceAPI(TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank):
with fluid.program_guard(main_prog, startup_program):
tindata = layers.data(
name="tindata", shape=[10, 1000], dtype='float32')
paddle.distributed.all_reduce(tindata)
return [tindata]
if __name__ == "__main__":
runtime_main(TestCollectiveAllreduceAPI, "allreduce")
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import os
import sys
import signal
import time
import socket
from contextlib import closing
from six import string_types
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
import unittest
from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
class TestCollectiveBarrierAPI(TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank):
with fluid.program_guard(main_prog, startup_program):
paddle.distributed.barrier()
return []
if __name__ == "__main__":
runtime_main(TestCollectiveBarrierAPI, "barrier")
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import os
import sys
import signal
import time
import socket
from contextlib import closing
from six import string_types
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
import unittest
from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
class TestCollectiveBroadcastAPI(TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank):
with fluid.program_guard(main_prog, startup_program):
tindata = layers.data(
name="tindata", shape=[10, 1000], dtype='float32')
paddle.distributed.broadcast(tindata, src=1)
return [tindata]
if __name__ == "__main__":
runtime_main(TestCollectiveBroadcastAPI, "broadcast")
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import os
import sys
import signal
import time
import socket
from contextlib import closing
from six import string_types
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
import unittest
from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
class TestCollectiveReduceAPI(TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank):
with fluid.program_guard(main_prog, startup_program):
tindata = layers.data(
name="tindata", shape=[10, 1000], dtype='float32')
paddle.distributed.reduce(tindata, dst=0)
return [tindata]
if __name__ == "__main__":
runtime_main(TestCollectiveReduceAPI, "reduce")
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import os
import sys
import signal
import time
import socket
from contextlib import closing
from six import string_types
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
import unittest
from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
class TestCollectiveScatterAPI(TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank):
with fluid.program_guard(main_prog, startup_program):
tindata = layers.data(
name="tindata",
shape=[10, 1000],
dtype='float64',
append_batch_size=False)
toutdata = layers.fill_constant(
shape=[5, 1000], dtype='float64', value=1.0)
tensor_list = None
if rank == 1:
tensor_list = paddle.split(tindata, 2, axis=0)
paddle.distributed.scatter(toutdata, tensor_list, src=1)
return [toutdata]
if __name__ == "__main__":
runtime_main(TestCollectiveScatterAPI, "scatter")
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from test_collective_api_base import TestDistBase
class TestCollectiveAllgatherAPI(TestDistBase):
def _setup_config(self):
pass
def test_allgather_nccl(self):
self.check_with_place("collective_allgather_api.py", "allgather",
"nccl")
def test_allgather_gloo(self):
self.check_with_place("collective_allgather_api.py", "allgather",
"gloo", "3")
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from test_collective_api_base import TestDistBase
class TestCollectiveAllreduceAPI(TestDistBase):
def _setup_config(self):
pass
def test_allreduce_nccl(self):
self.check_with_place("collective_allreduce_api.py", "allreduce",
"nccl")
def test_allreduce_gloo(self):
self.check_with_place("collective_allreduce_api.py", "allreduce",
"gloo", "2")
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import time
import argparse
import os
import six
import sys
import subprocess
import traceback
import functools
import pickle
from contextlib import closing
from six import string_types
import paddle.fluid as fluid
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
class TestCollectiveAPIRunnerBase(object):
def get_model(self, train_prog, startup_prog, rank):
raise NotImplementedError(
"get model should be implemented by child class.")
def wait_server_ready(self, endpoints):
assert not isinstance(endpoints, string_types)
while True:
all_ok = True
not_ready_endpoints = []
for ep in endpoints:
ip_port = ep.split(":")
with closing(
socket.socket(socket.AF_INET,
socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
result = sock.connect_ex((ip_port[0], int(ip_port[1])))
if result != 0:
all_ok = False
not_ready_endpoints.append(ep)
if not all_ok:
sys.stderr.write("server not ready, wait 3 sec to retry...\n")
sys.stderr.write("not ready endpoints:" + str(
not_ready_endpoints) + "\n")
sys.stderr.flush()
time.sleep(3)
else:
break
def initCommunicator(self, program, rank, nranks, wait_port,
current_endpoint, endpoints):
other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint)
if rank == 0 and wait_port:
self.wait_server_ready(other_endpoints)
block = program.global_block()
nccl_id_var = block.create_var(
name=nameGen.generate('nccl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
block.append_op(
type='c_gen_nccl_id',
inputs={},
outputs={'Out': nccl_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints
})
block.append_op(
type='c_comm_init',
inputs={'X': nccl_id_var},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': self.global_ring_id
})
def run_trainer(self, args):
train_prog = fluid.Program()
startup_prog = fluid.Program()
endpoints = args["endpoints"].split(",")
rank = args["trainerid"]
current_endpoint = args["currentendpoint"]
nranks = 2
result = self.get_model(train_prog, startup_prog, rank)
if args['backend'] == 'nccl':
self.initCommunicator(startup_prog, rank, nranks, True,
current_endpoint, endpoints)
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
place = fluid.CUDAPlace(
device_id) #if args.use_gpu else fluid.CPUPlace()
else:
strategy = fluid.core.GlooParallelStrategy()
strategy.rank = rank
strategy.rank_num = nranks
strategy.prefix = ""
strategy.iface = "lo"
strategy.init_seconds = 999999
strategy.run_seconds = 999999
strategy.path = "/tmp/tmp%d" % args['path_id']
gloo = fluid.core.GlooParallelContext(strategy)
gloo.init()
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_prog)
np.random.seed(os.getpid())
indata = np.random.random((10, 1000))
fetch_list = []
for elem in result:
fetch_list.append(elem.name)
out = exe.run(train_prog,
feed={'tindata': indata},
fetch_list=fetch_list)
if six.PY2:
print(pickle.dumps(out))
else:
sys.stdout.buffer.write(pickle.dumps(out))
def runtime_main(test_class, col_type):
args = {}
model = test_class()
args["deviceid"] = os.getenv("FLAGS_selected_gpus")
args["trainerid"] = int(os.getenv("PADDLE_TRAINER_ID"))
args["trainernum"] = int(os.getenv("PADDLE_TRAINERS_NUM"))
args["endpoints"] = os.getenv('PADDLE_TRAINER_ENDPOINTS')
args["currentendpoint"] = os.getenv("PADDLE_CURRENT_ENDPOINT")
args["col_type"] = col_type
args["backend"] = os.getenv("BACKEND")
args["path_id"] = int(os.getenv("PATH_ID"))
model.run_trainer(args)
import paddle.compat as cpt
import socket
from contextlib import closing
class TestDistBase(unittest.TestCase):
def setUp(self):
self._port_set = set()
self._trainers = 2
self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
self._find_free_port(), self._find_free_port())
self._python_interp = sys.executable
def _find_free_port(self):
def __free_port():
with closing(socket.socket(socket.AF_INET,
socket.SOCK_STREAM)) as s:
s.bind(('', 0))
return s.getsockname()[1]
while True:
port = __free_port()
if port not in self._port_set:
self._port_set.add(port)
return port
def _run_cluster(self, model_file, envs):
worker_endpoints = self._ps_endpoints.split(",")
w0_ep, w1_ep = worker_endpoints
#print("w0_ep:",w0_ep," w1_ep:",w1_ep)
env0 = {
"FLAGS_selected_gpus": "0",
"PADDLE_TRAINER_ID": "0",
"PADDLE_TRAINERS_NUM": "2",
"PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
"PADDLE_CURRENT_ENDPOINT": w0_ep
}
env1 = {
"FLAGS_selected_gpus": "1",
"PADDLE_TRAINER_ID": "1",
"PADDLE_TRAINERS_NUM": "2",
"PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
"PADDLE_CURRENT_ENDPOINT": w1_ep
}
#update environment
env0.update(envs)
env1.update(envs)
tr_cmd = "%s %s"
tr0_cmd = tr_cmd % (self._python_interp, model_file)
tr1_cmd = tr_cmd % (self._python_interp, model_file)
tr0_pipe = open("/tmp/tr0_err.log", "wb")
tr1_pipe = open("/tmp/tr1_err.log", "wb")
#print(tr0_cmd)
tr0_proc = subprocess.Popen(
tr0_cmd.strip().split(),
stdout=subprocess.PIPE,
stderr=tr0_pipe,
env=env0)
tr1_proc = subprocess.Popen(
tr0_cmd.strip().split(),
stdout=subprocess.PIPE,
stderr=tr1_pipe,
env=env1)
tr0_out, tr0_err = tr0_proc.communicate()
tr1_out, tr1_err = tr1_proc.communicate()
sys.stderr.write('trainer 0 stderr: %s\n' % tr0_err)
sys.stderr.write('trainer 1 stderr: %s\n' % tr1_err)
# close trainer file
tr0_pipe.close()
tr1_pipe.close()
return pickle.loads(tr0_out), pickle.loads(
tr1_out), tr0_proc.pid, tr1_proc.pid
def check_with_place(self,
model_file,
col_type,
backend="nccl",
path_id="0",
check_error_log=False,
need_envs={}):
required_envs = {
"FLAGS_fraction_of_gpu_memory_to_use": "0.15",
"FLAGS_eager_delete_tensor_gb": "0.0",
"PATH": os.getenv("PATH"),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"LD_PRELOAD": os.getenv("LD_PRELOAD", ""),
"GLOG_v": "0",
"NCCL_P2P_DISABLE": "1",
"BACKEND": backend,
"PATH_ID": path_id
}
required_envs.update(need_envs)
if check_error_log:
required_envs["GLOG_v"] = "3"
required_envs["GLOG_logtostderr"] = "1"
tr0_out, tr1_out, pid0, pid1 = self._run_cluster(model_file,
required_envs)
np.random.seed(pid0)
input1 = np.random.random((10, 1000))
np.random.seed(pid1)
input2 = np.random.random((10, 1000))
if col_type == "allgather":
need_result = np.vstack((input1, input2))
tr_out0 = np.vstack((tr0_out[0], tr0_out[1]))
tr_out1 = np.vstack((tr1_out[0], tr1_out[1]))
self.assertTrue(np.allclose(tr_out0, need_result))
self.assertTrue(np.allclose(tr_out1, need_result))
elif col_type == "broadcast":
need_result = input2
self.assertTrue(np.allclose(tr0_out, need_result))
self.assertTrue(np.allclose(tr1_out, need_result))
elif col_type == "reduce":
need_result = input1 + input2
self.assertTrue(np.allclose(tr0_out, need_result))
elif col_type == "scatter":
need_result = input2
need_result1 = need_result[0:need_result.shape[0] // 2]
need_result2 = need_result[need_result.shape[0] // 2:]
self.assertTrue(np.allclose(tr0_out, need_result1))
self.assertTrue(np.allclose(tr1_out, need_result2))
elif col_type == "allreduce":
need_result = input1 + input2
self.assertTrue(
np.allclose(
tr0_out, need_result, rtol=1e-05, atol=1e-05))
self.assertTrue(
np.allclose(
tr1_out, need_result, rtol=1e-05, atol=1e-05))
else:
pass
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from test_collective_api_base import TestDistBase
class TestCollectiveBarrierAPI(TestDistBase):
def _setup_config(self):
pass
def test_barrier_nccl(self):
self.check_with_place("collective_barrier_api.py", "barrier", "nccl")
def test_barrier_gloo(self):
self.check_with_place("collective_barrier_api.py", "barrier", "gloo",
"5")
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from test_collective_api_base import TestDistBase
class TestCollectiveBroadcastAPI(TestDistBase):
def _setup_config(self):
pass
def test_broadcast_nccl(self):
self.check_with_place("collective_broadcast_api.py", "broadcast",
"nccl")
def test_broadcast_gloo(self):
self.check_with_place("collective_broadcast_api.py", "broadcast",
"gloo", "0")
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from test_collective_api_base import TestDistBase
class TestCollectiveReduceAPI(TestDistBase):
def _setup_config(self):
pass
def test_reduce_nccl(self):
self.check_with_place("collective_reduce_api.py", "reduce", "nccl")
def test_reduce_gloo(self):
self.check_with_place("collective_reduce_api.py", "reduce", "gloo", "1")
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from test_collective_api_base import TestDistBase
class TestCollectiveScatterAPI(TestDistBase):
def _setup_config(self):
pass
def test_scatter_gloo(self):
self.check_with_place("collective_scatter_api.py", "scatter", "gloo",
"4")
def test_scatter_nccl(self):
self.check_with_place("collective_scatter_api.py", "scatter", "nccl")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册