未验证 提交 c191b707 编写于 作者: TaoTao Li's avatar TaoTao Li 提交者: GitHub

Add comm context manager, add phi broadcast op (#51072)

* * add comm context for device context

* add broadcast phi operator kernel and api

* add broadcast support dtype, update ut

* fix broadcast bfloat16 type

* fix ut

* update test_collective_broadcast_api timeout to 300
上级 ff803bdc
......@@ -28,6 +28,7 @@ set(INTERPRETER_DEPS
enforce
scope
glog
comm_context_manager
${DEVICE_EVENT_LIBS}
glog)
......
......@@ -26,6 +26,7 @@
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include "paddle/fluid/operators/ops_extra_info.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/kernel_factory.h"
......@@ -712,6 +713,7 @@ void BuildOpFuncList(const platform::Place& place,
auto& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
SetDeviceCommContext(op, dev_ctx);
auto exec_ctx = ExecutionContext(
*op_with_kernel, *runtime_scope, *dev_ctx, runtime_context);
auto expected_kernel_key = framework::TransPhiKernelKeyToOpKernelType(
......@@ -1156,6 +1158,24 @@ void LogDeviceMemoryStats(const platform::Place& place) {
}
}
void SetDeviceCommContext(framework::OperatorBase* operator_base,
platform::DeviceContext* dev_ctx) {
if (operator_base->HasAttr("ring_id")) {
int ring_id = operator_base->Attr<int>("ring_id");
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (comm_context_manager.Has(ring_id)) {
auto comm_context = comm_context_manager.Get(ring_id);
if (!dev_ctx->GetCommContext()) {
dev_ctx->SetCommContext(comm_context);
}
} else {
LOG(WARNING) << "op: " << operator_base->Type()
<< ", ring_id: " << ring_id << ", get comm_context failed!";
}
}
}
} // namespace interpreter
} // namespace framework
} // namespace paddle
......@@ -111,6 +111,8 @@ void FakeInitializeOutputsForStructureKernel(
void LogDeviceMemoryStats(const platform::Place& place);
void SetDeviceCommContext(framework::OperatorBase* operator_base,
platform::DeviceContext* dev_ctx);
} // namespace interpreter
} // namespace framework
} // namespace paddle
......@@ -148,6 +148,9 @@ DeviceContext* StreamAnalyzer::ParseDeviceContext(
const int stream_priority = op_func_node.stream_priority_;
ContextManager& ctx_manager = ContextManager::Instance();
auto dev_ctx = ctx_manager.Get(op_type, place_, stream_priority).get().get();
SetDeviceCommContext(op.get(), dev_ctx);
// only gpu/npu need update. xpu not need, because xpu memcpy op kernel is
// synchronous.
if (platform::is_gpu_place(place_) || platform::is_npu_place(place_) ||
......
......@@ -170,6 +170,16 @@
func : bmm
backward : bmm_grad
- op : broadcast
args : (Tensor X, int ring_id = 0, int root = 0)
output : Tensor(Out)
infer_meta :
func : BroadcastBaseInferMeta
param: [X]
kernel :
func : broadcast
param: [X, root]
- op : broadcast_tensors
args: (Tensor[] input)
output: Tensor[]{input.size()}
......
......@@ -238,6 +238,12 @@ struct DeviceContext::Impl {
return host_generator_;
}
distributed::CommContext* GetCommContext() const { return comm_context_; }
void SetCommContext(distributed::CommContext* comm_context) {
comm_context_ = comm_context;
}
private:
void ClearHolder(TensorBase* tensor) const {
if (!tensor->initialized()) return;
......@@ -264,6 +270,8 @@ struct DeviceContext::Impl {
#endif
Generator* device_generator_{nullptr};
Generator* host_generator_{nullptr};
distributed::CommContext* comm_context_{nullptr};
};
DeviceContext::DeviceContext() { impl_ = std::make_unique<Impl>(); }
......@@ -418,4 +426,12 @@ Generator* DeviceContext::GetHostGenerator() const {
return impl_->GetHostGenerator();
}
void DeviceContext::SetCommContext(distributed::CommContext* comm_context) {
impl_->SetCommContext(comm_context);
}
distributed::CommContext* DeviceContext::GetCommContext() const {
return impl_->GetCommContext();
}
} // namespace phi
......@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/allocator.h"
#include "paddle/phi/core/distributed/comm_context.h"
#include "paddle/phi/core/generator.h"
#include "paddle/phi/core/utils/type_registry.h"
......@@ -208,6 +209,20 @@ class PADDLE_API DeviceContext {
*/
TypeInfo<DeviceContext> type_info() const { return type_info_; }
/**
* @brief Set the comm context point.
*
* @param CommContext
*/
void SetCommContext(distributed::CommContext* comm_context);
/**
* @brief Get the comm context point.
*
* @return comm context point
*/
distributed::CommContext* GetCommContext() const;
private:
struct Impl;
std::unique_ptr<Impl> impl_;
......
......@@ -24,6 +24,9 @@ class CommContext {
CommContext(int rank, int size) : rank_(rank), size_(size) {}
virtual ~CommContext() = default;
int GetRank() { return rank_; }
int GetSize() { return size_; }
protected:
int rank_;
int size_;
......
......@@ -356,6 +356,11 @@ void BatchSizeLikeInferMeta(const MetaTensor& x,
out->set_dims(output_dim);
}
void BroadcastBaseInferMeta(const MetaTensor& x, MetaTensor* out) {
out->set_dtype(x.dtype());
out->set_dims(x.dims());
}
void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out) {
out->set_dims(x.dims());
out->set_dtype(out_dtype);
......
......@@ -63,6 +63,8 @@ void BatchSizeLikeInferMeta(const MetaTensor& x,
int out_batch_size_dim,
MetaTensor* out);
void BroadcastBaseInferMeta(const MetaTensor& x, MetaTensor* out);
void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out);
void ChannelShuffleInferMeta(const MetaTensor& x,
......
......@@ -85,7 +85,11 @@ endif()
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} process_group)
if(WITH_NCCL OR WITH_RCCL)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} process_group_nccl)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} process_group_nccl
nccl_comm_context)
endif()
if(WITH_GLOO)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} gloo_comm_context)
endif()
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} processgroup_comm_utils)
if(WITH_CUDNN_FRONTEND)
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void BroadcastKernel(const Context& dev_ctx,
const DenseTensor& x,
int root,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/broadcast_kernel.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"
#if defined(PADDLE_WITH_GLOO)
#include "paddle/phi/core/distributed/gloo_comm_context.h"
#endif
namespace phi {
template <typename T, typename Context>
void BroadcastKernel(const Context& dev_ctx,
const DenseTensor& x,
int root,
DenseTensor* out) {
#if defined(PADDLE_WITH_GLOO)
dev_ctx.template Alloc<T>(out);
auto comm_context =
static_cast<distributed::GlooCommContext*>(dev_ctx.GetCommContext());
PADDLE_ENFORCE_NE(
comm_context,
nullptr,
errors::Unavailable("NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
comm_context->Broadcast(out, x, root);
#else
PADDLE_THROW(errors::Unavailable(
"PaddlePaddle should compile with GLOO by setting WITH_GLOO=ON"));
#endif
}
} // namespace phi
PD_REGISTER_KERNEL(broadcast,
CPU,
ALL_LAYOUT,
phi::BroadcastKernel,
float,
double,
int,
bool,
int8_t,
uint8_t,
int64_t,
phi::dtype::float16) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/broadcast_kernel.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#endif
namespace phi {
template <typename T, typename Context>
void BroadcastKernel(const Context& dev_ctx,
const DenseTensor& x,
int root,
DenseTensor* out) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
dev_ctx.template Alloc<T>(out);
gpuStream_t stream = dev_ctx.stream();
auto comm_context =
static_cast<distributed::NCCLCommContext*>(dev_ctx.GetCommContext());
PADDLE_ENFORCE_NE(
comm_context,
nullptr,
errors::Unavailable("NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
comm_context->Broadcast(out, x, root, stream);
out->set_lod(x.lod());
#else
PADDLE_THROW(
errors::PreconditionNotMet("PaddlePaddle should compile with GPU."));
#endif
}
} // namespace phi
#if NCCL_VERSION_CODE >= 21000
PD_REGISTER_KERNEL(broadcast,
GPU,
ALL_LAYOUT,
phi::BroadcastKernel,
float,
double,
phi::dtype::bfloat16,
int,
bool,
int8_t,
uint8_t,
int64_t,
phi::dtype::float16) {}
#else
PD_REGISTER_KERNEL(broadcast,
GPU,
ALL_LAYOUT,
phi::BroadcastKernel,
float,
double,
int,
bool,
int8_t,
uint8_t,
int64_t,
phi::dtype::float16) {}
#endif
......@@ -125,7 +125,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_broadcast_api MODULES test_collective_broadcast_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_broadcast_api
PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=DIST")
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
......
......@@ -16,10 +16,44 @@ from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
import paddle
import paddle.fluid as fluid
import paddle.fluid.data_feeder as data_feeder
import paddle.framework as framework
paddle.enable_static()
def broadcast_new(tensor, src, group=None, sync_op=True):
op_type = 'broadcast'
data_feeder.check_variable_and_dtype(
tensor,
'tensor',
[
'float16',
'float32',
'float64',
'int32',
'int64',
'int8',
'uint8',
'bool',
],
op_type,
)
helper = framework.LayerHelper(op_type, **locals())
ring_id = 0 if group is None else group.id
helper.append_op(
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [tensor]},
attrs={
'root': src,
'ring_id': ring_id,
},
)
class TestCollectiveBroadcastAPI(TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
......@@ -33,6 +67,15 @@ class TestCollectiveBroadcastAPI(TestCollectiveAPIRunnerBase):
paddle.distributed.broadcast(tindata, src=1)
return [tindata]
def get_model_new(self, main_prog, startup_program, rank, dtype=None):
with fluid.program_guard(main_prog, startup_program):
tindata = paddle.static.data(
name="tindata", shape=[-1, 10, 1000], dtype=dtype
)
tindata.desc.set_need_check_feed(False)
broadcast_new(tindata, src=1)
return [tindata]
if __name__ == "__main__":
runtime_main(TestCollectiveBroadcastAPI, "broadcast")
......@@ -26,31 +26,61 @@ class TestCollectiveBroadcastAPI(TestDistBase):
pass
def test_broadcast_nccl(self):
self.check_with_place(
"collective_broadcast_api.py", "broadcast", "nccl"
)
def test_broadcast_nccl_with_comm_context(self):
self.check_with_place(
"collective_broadcast_api.py",
"broadcast",
"nccl",
need_envs={"USE_COMM_CONTEXT": "1"},
)
def test_broadcast_gloo(self):
self.check_with_place(
"collective_broadcast_api.py", "broadcast", "gloo", "0"
)
def test_broadcast_nccl_with_comm_context(self):
dtypes_to_test = [
"float16",
"float32",
"float64",
"int32",
"int64",
"int8",
"uint8",
"bool",
]
if self._nccl_version >= 2100:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
self.check_with_place(
"collective_broadcast_api.py",
"broadcast",
"nccl",
dtype=dtype,
need_envs={"USE_COMM_CONTEXT": "1"},
)
def test_broadcast_gloo_with_comm_context(self):
def test_broadcast_gloo(self):
self.check_with_place(
"collective_broadcast_api.py",
"broadcast",
"gloo",
need_envs={"USE_COMM_CONTEXT": "1"},
)
def test_broadcast_gloo_with_comm_context(self):
dtypes_to_test = [
"float16",
"float32",
"float64",
"int32",
"int64",
"int8",
"uint8",
"bool",
]
for dtype in dtypes_to_test:
self.check_with_place(
"collective_broadcast_api.py",
"broadcast",
"gloo",
dtype=dtype,
need_envs={"USE_COMM_CONTEXT": "1"},
)
def test_broadcast_nccl_dygraph(self):
dtypes_to_test = [
"float16",
......
......@@ -127,7 +127,11 @@ class TestCollectiveAPIRunnerBase:
shape=(10, 1000), dtype=args["dtype"], seed=os.getpid()
)
if args['static_mode']:
result = self.get_model(train_prog, startup_prog, rank)
result = (
self.get_model_new(train_prog, startup_prog, rank)
if args["use_comm_context"]
else self.get_model(train_prog, startup_prog, rank)
)
exe = fluid.Executor(place)
exe.run(startup_prog)
fetch_list = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册