未验证 提交 517d8074 编写于 作者: S shentanyue 提交者: GitHub

[XPU][Fleet] Support multi-card infer for xpu (#50490)

* support xpu multi-card infer

* add ut

* clean code

* clean code

* fix

* fix

* fix

* fix
上级 3b6ebc9d
......@@ -42,6 +42,7 @@ static std::unordered_set<std::string> kMultiDeviceOps{
"c_comm_init_all",
"c_comm_init_multitrainer",
"c_gen_nccl_id",
"c_gen_bkcl_id",
"c_sync_comm_stream",
"send",
"recv",
......
......@@ -261,7 +261,7 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
if (mul_type == "mul") {
fc_xpu_op_desc.SetAttr(
"in_num_col_dims",
PADDLE_GET_CONST(int, mul->Op()->GetAttr("in_num_col_dims")));
PADDLE_GET_CONST(int, mul->Op()->GetAttr("x_num_col_dims")));
}
fc_xpu_op_desc.SetAttr("transpose_x", false);
fc_xpu_op_desc.SetAttr("alpha", 1.f);
......
......@@ -562,9 +562,7 @@ bool AnalysisPredictor::PrepareProgram(
OptimizeInferenceProgram();
}
}
executor_->CreateVariables(*inference_program_, 0, false, sub_scope_);
return true;
}
......@@ -785,6 +783,30 @@ void AnalysisPredictor::InsertCommOp(
comm_init_op->SetAttr("op_role",
static_cast<int>(framework::OpRole::kForward));
comm_init_op->CheckAttrs();
} else if (config_.use_xpu()) {
framework::VarDesc *new_var = block->Var(tmp_var_name);
new_var->SetType(framework::proto::VarType::RAW);
new_var->SetPersistable(true);
framework::OpDesc *gen_bkcl_id_op = block->AppendOp();
gen_bkcl_id_op->SetType("c_gen_bkcl_id");
gen_bkcl_id_op->SetOutput("Out", {tmp_var_name});
gen_bkcl_id_op->SetAttr("rank", rank);
gen_bkcl_id_op->SetAttr("endpoint",
config_.dist_config().current_endpoint());
gen_bkcl_id_op->SetAttr("other_endpoints", peer_endpoints);
gen_bkcl_id_op->SetAttr("ring_id", ring_id);
gen_bkcl_id_op->SetAttr("op_role",
static_cast<int>(framework::OpRole::kForward));
gen_bkcl_id_op->CheckAttrs();
framework::OpDesc *comm_init_op = block->AppendOp();
comm_init_op->SetType("c_comm_init");
comm_init_op->SetInput("X", {tmp_var_name});
comm_init_op->SetAttr("rank", rank);
comm_init_op->SetAttr("nranks", nranks);
comm_init_op->SetAttr("ring_id", ring_id);
comm_init_op->SetAttr("op_role",
static_cast<int>(framework::OpRole::kForward));
comm_init_op->CheckAttrs();
} else {
LOG(WARNING) << "DistModelInf doesn't init comm.";
// TODO(fleet exe dev): comm init for more devices
......@@ -1319,7 +1341,6 @@ void AnalysisPredictor::PrepareArgument() {
// NOTE All the members in AnalysisConfig should be copied to Argument.
void AnalysisPredictor::OptimizeInferenceProgram() {
PrepareArgument();
#ifdef PADDLE_WITH_TENSORRT
if (config_.tensorrt_engine_enabled()) {
inference::tensorrt::TensorRTEngine::predictor_id_per_thread =
......@@ -1328,9 +1349,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
<< inference::tensorrt::TensorRTEngine::predictor_id_per_thread;
}
#endif
Analyzer().Run(argument_.get());
PADDLE_ENFORCE_EQ(
argument_->scope_valid(),
true,
......
......@@ -1194,6 +1194,20 @@ if(WITH_DISTRIBUTE
--infer_model=${OCR_INSTALL_DIR}/model)
endif()
if(WITH_DISTRIBUTE
AND WITH_PSCORE
AND WITH_XPU
AND WITH_XPU_BKCL)
inference_analysis_test(
test_analyzer_dist_model_xpu
SRCS
analyzer_dist_model_xpu_tester.cc
EXTRA_DEPS
paddle_inference_shared
ARGS
--infer_model=${OCR_INSTALL_DIR}/model)
endif()
inference_analysis_test(
test_analyzer_paddletensor_tensor
SRCS
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gtest/gtest.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/tests/api/tester_helper.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace paddle {
namespace inference {
TEST(test_dist_model_xpu, dist_model_xpu) {
std::cout << "Analysis Predictor DistModel XPU test." << std::endl;
AnalysisConfig config;
config.SetModel(FLAGS_infer_model + "/__model__",
FLAGS_infer_model + "/__params__");
config.SwitchUseFeedFetchOps(false);
config.EnableXpu();
config.SetXpuDeviceId(0);
DistConfig dist_config;
dist_config.SetRanks(1, 0);
dist_config.EnableDistModel(true);
dist_config.SetEndpoints({""}, "");
config.SetDistConfig(dist_config);
auto predictor = paddle_infer::CreatePredictor(config);
int batch_size = 1;
int channels = 1;
int height = 48;
int width = 512;
int nums = batch_size * channels * height * width;
std::cout << "Created predictor." << std::endl;
float* input = new float[nums];
for (int i = 0; i < nums; ++i) input[i] = 0;
auto input_names = predictor->GetInputNames();
auto input_t = predictor->GetInputHandle(input_names[0]);
input_t->Reshape({batch_size, channels, height, width});
input_t->CopyFromCpu(input);
std::cout << "Input data." << std::endl;
predictor->Run();
std::cout << "Zero Copy Run." << std::endl;
std::vector<float> out_data;
auto output_names = predictor->GetOutputNames();
auto output_t = predictor->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_t->shape();
int out_num = std::accumulate(
output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
out_data.resize(out_num);
output_t->CopyToCpu(out_data.data());
std::cout << "Output data." << std::endl;
delete[] input;
}
} // namespace inference
} // namespace paddle
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_broadcast_op.h"
#ifdef PADDLE_WITH_XPU_BKCL
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/xpu/bkcl_helper.h"
#endif
namespace paddle {
namespace operators {
template <typename T>
class CBroadcastOpXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_XPU_BKCL)
auto x = ctx.Input<phi::DenseTensor>("X");
auto out = ctx.Output<phi::DenseTensor>("Out");
size_t numel = x->numel();
BKCLDataType dtype =
platform::ToBKCLDataType(framework::TransToProtoVarType(x->dtype()));
int ring_id = ctx.Attr<int>("ring_id");
auto place = ctx.GetPlace();
auto comm =
paddle::platform::BKCLCommContext::Instance().Get(ring_id, place);
XPUStream stream = nullptr;
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
if (ctx.Attr<bool>("use_calc_stream")) {
stream = static_cast<platform::XPUDeviceContext*>(dev_ctx)
->x_context()
->xpu_stream;
} else {
stream = comm->stream();
}
int root = ctx.Attr<int>("root");
VLOG(3) << "begin bkcl broadcast, parameter is: "
<< "root " << root << ", comm: " << comm->comm()
<< ", stream: " << stream;
void* send_recv_buffer = nullptr;
if (root == comm->rank()) {
// API: BKCLResult_t bkcl_broadcast(const BKCLContext_t ctx,
// const void* sendbuf,
// void* recvbuf,
// size_t count, BKCLDataType datatype,
// int root,
// XPUStream stream);
send_recv_buffer = reinterpret_cast<void*>(const_cast<T*>(x->data<T>()));
auto ret = bkcl_broadcast(comm->comm(),
send_recv_buffer,
send_recv_buffer,
numel,
dtype,
root,
stream);
PADDLE_ENFORCE_EQ(ret,
BKCL_SUCCESS,
platform::errors::PreconditionNotMet(
"XPU BKCL c_broadcast execute failed"));
if (out != x) {
framework::TensorCopy(
*static_cast<const phi::DenseTensor*>(x),
place,
*platform::DeviceContextPool::Instance().Get(place),
static_cast<phi::DenseTensor*>(out));
}
} else {
auto& dev_ctx = ctx.template device_context<platform::XPUDeviceContext>();
dev_ctx.template Alloc<T>(out);
send_recv_buffer = out->data<T>();
auto ret = bkcl_broadcast(comm->comm(),
send_recv_buffer,
send_recv_buffer,
numel,
dtype,
root,
stream);
PADDLE_ENFORCE_EQ(ret,
BKCL_SUCCESS,
platform::errors::PreconditionNotMet(
"XPU BKCL c_broadcast execute failed"));
}
VLOG(3) << "rank " << comm->rank() << " invoke Bcast. received "
<< phi::product(out->dims());
out->Resize(x->dims());
out->set_lod(x->lod());
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should be compiled with XPU and BKCL."));
#endif
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_broadcast,
ops::CBroadcastOpXPUKernel<float>,
ops::CBroadcastOpXPUKernel<plat::float16>);
......@@ -84,15 +84,6 @@ class CCommInitOp : public framework::OperatorBase {
int nranks = Attr<int>("nranks");
int rid = Attr<int>("ring_id");
#if defined(PADDLE_WITH_XPU_BKCL)
PADDLE_ENFORCE_EQ(
rid,
0,
platform::errors::OutOfRange(
"Ring id must equal 0 in multi Kunlun cards training, but got %d",
rid));
#endif
int device_id = place.device;
if (Attr<int>("device_id") >= 0) {
device_id = Attr<int>("device_id");
......
......@@ -340,9 +340,7 @@ BKCLComm* BKCLCommContext::CreateComm(
BKCLContext_t comm = nullptr;
platform::SetXPUDeviceId(dev_id);
PADDLE_ENFORCE_XPU_SUCCESS(bkcl_init_rank(&comm, rank, nranks, bkcl_id));
auto* comm_wrapper = AssignBKCLComm(comm, nranks, rank, dev_id, ring_id);
VLOG(1) << "bkcl communicator of rank " << rank << " in ring " << ring_id
<< " has been created on device " << dev_id;
......@@ -372,30 +370,27 @@ BKCLComm* BKCLCommContext::AssignBKCLComm(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetZeroAllocator(paddle::platform::CPUPlace())
.get());
BKCLCommImpl* c = new BKCLCommImpl;
c->set_ring_id(ring_id);
c->set_nranks(nranks);
c->set_rank(rank);
c->set_comm(comm);
c->set_dev_ctx(std::move(dev_ctx));
comm_map_mutex_.lock();
if (comm_map_.count(ring_id) == 0) {
comm_map_.emplace(ring_id, std::map<int, std::unique_ptr<BKCLComm>>());
}
auto& dev2comm = comm_map_[ring_id];
dev2comm.emplace(dev_id, std::unique_ptr<BKCLComm>(c));
comm_map_mutex_.unlock();
if (ring_id == 0) {
auto* dev_ctx = static_cast<platform::XPUDeviceContext*>(
platform::DeviceContextPool::Instance().Get(
platform::XPUPlace(dev_id)));
dev_ctx->SetBkclContext(comm);
}
VLOG(3) << "add bkcl comm: " << comm_map_[ring_id][dev_id].get()
<< ", ring_id:" << ring_id << ", dev_id:" << dev_id;
return comm_map_[ring_id][dev_id].get();
}
......
......@@ -82,6 +82,8 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT16,
phi::DataType::FLOAT32,
phi::DataType::INT32})},
{"c_broadcast",
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})},
{"c_concat",
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})},
{"c_embedding", XPUKernelSet({phi::DataType::FLOAT32})},
......
......@@ -44,11 +44,12 @@ void Pool2dKernel(const Context& ctx,
phi::errors::InvalidArgument(
"The Pool2d XPU OP only support 2 dimension pooling!"));
PADDLE_ENFORCE_EQ(
// old model's data_format maybe AnyLayout
PADDLE_ENFORCE_NE(
data_format,
"NCHW",
phi::errors::InvalidArgument("The Pool2d XPU OP only support "
"data_format is 'NCHW', but received %s",
"NHWC",
phi::errors::InvalidArgument("The Pool2d XPU OP does not support "
"data_format is 'NHWC', but received %s",
data_format));
if (global_pooling) {
......
......@@ -143,15 +143,25 @@ class ProcessGroup:
core.NCCLParallelContext(strategy, place).init_with_ring_id(
ring_id
)
elif core.is_compiled_with_xpu():
place = core.XPUPlace(genv.device_id)
core.BKCLParallelContext(strategy, place).init_with_ring_id(
ring_id
)
else:
assert False, "No CUDA device found"
# TODO(shenliang03): This is a temporary solution to solve the problem of
# hang caused by cross-creation of new_group
paddle.disable_static()
if core.is_compiled_with_cuda():
paddle.set_device(
'gpu:%d' % paddle.distributed.ParallelEnv().dev_id
)
elif core.is_compiled_with_xpu():
paddle.set_device(
'xpu:%d' % paddle.distributed.ParallelEnv().dev_id
)
tmp = (
paddle.to_tensor([1], dtype="int32")
if in_dygraph_mode()
......
......@@ -1143,6 +1143,7 @@ class ShardingOptimizer(MetaOptimizerBase):
"c_sync_comm_stream",
"c_calc_comm_stream",
"c_gen_nccl_id",
"c_gen_bkcl_id",
"c_comm_init",
'send_v2',
'recv_v2',
......
......@@ -163,7 +163,36 @@ class Collective:
self.op_role_key: OpRole.Forward,
},
)
else:
elif core.is_compiled_with_xpu():
bkcl_id_var = block.create_var(
name=unique_name.generate('bkcl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW,
)
endpoint_to_index_map = {e: idx for idx, e in enumerate(endpoints)}
block.append_op(
type='c_gen_bkcl_id',
inputs={},
outputs={'Out': bkcl_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints,
self.op_role_key: OpRole.Forward,
},
)
block.append_op(
type='c_comm_init',
inputs={'X': bkcl_id_var},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': ring_id,
self.op_role_key: OpRole.Forward,
},
)
elif core.is_compiled_with_cuda():
nccl_id_var = block.create_var(
name=unique_name.generate('nccl_id'),
persistable=True,
......
......@@ -161,7 +161,7 @@ class Collective:
self.op_role_key: OpRole.Forward,
},
)
else:
elif core.is_compiled_with_cuda():
nccl_id_var = block.create_var(
name=unique_name.generate('nccl_id'),
persistable=True,
......@@ -202,6 +202,34 @@ class Collective:
self.op_role_key: OpRole.Forward,
},
)
elif core.is_compiled_with_xpu():
bkcl_id_var = block.create_var(
name=unique_name.generate('bkcl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW,
)
block.append_op(
type='c_gen_bkcl_id',
inputs={},
outputs={'Out': bkcl_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints,
self.op_role_key: OpRole.Forward,
},
)
block.append_op(
type='c_comm_init',
inputs={'X': bkcl_id_var},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': ring_id,
self.op_role_key: OpRole.Forward,
},
)
def _broadcast_params(self):
block = self.startup_program.global_block()
......
......@@ -30,7 +30,7 @@ class TestCollectiveAllGather(TestCollectiveRunnerBase):
nranks = 2
with fluid.program_guard(main_prog, startup_program):
tindata = paddle.static.data(
name="tindata", shape=[-1, 10, 1000], dtype='float32'
name="tindata", shape=[10, 1000], dtype='float32'
)
toutdata = main_prog.current_block().create_var(
name="outofgather",
......
......@@ -31,7 +31,7 @@ class TestCollectiveAllReduce(TestCollectiveRunnerBase):
ring_id = 0
with fluid.program_guard(main_prog, startup_program):
tindata = paddle.static.data(
name="tindata", shape=[-1, 10, 1000], dtype='float32'
name="tindata", shape=[10, 1000], dtype='float32'
)
toutdata = main_prog.current_block().create_var(
name="outofreduce",
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from test_collective_base_xpu import TestCollectiveRunnerBase, runtime_main
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
paddle.enable_static()
class TestCollectiveBroadcast(TestCollectiveRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program):
ring_id = 0
rootid = 1
with fluid.program_guard(main_prog, startup_program):
tindata = paddle.static.data(
name="tindata", shape=[10, 1000], dtype='float32'
)
toutdata = main_prog.current_block().create_var(
name="outofbroadcast",
dtype='float32',
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False,
)
main_prog.global_block().append_op(
type="c_broadcast",
inputs={'X': tindata},
attrs={'ring_id': ring_id, 'root': rootid},
outputs={'Out': toutdata},
)
main_prog.global_block().append_op(
type="c_sync_comm_stream",
inputs={'X': toutdata},
outputs={'Out': toutdata},
attrs={'ring_id': ring_id},
)
return toutdata
if __name__ == "__main__":
os.environ["BKCL_PCIE_RING"] = "1"
runtime_main(TestCollectiveBroadcast, "broadcast", 0)
......@@ -30,7 +30,7 @@ class TestCollectiveIdentity(TestCollectiveRunnerBase):
nranks = 2
with fluid.program_guard(main_prog, startup_program):
tindata = paddle.static.data(
name="tindata", shape=[-1, 10, 1000], dtype='float32'
name="tindata", shape=[10, 1000], dtype='float32'
)
toutdata = main_prog.current_block().create_var(
name="outofgather",
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
from test_collective_base_xpu import TestDistBase
import paddle
from paddle.fluid import core
sys.path.append("..")
from xpu.get_test_cover_info import XPUOpTestWrapper, create_test_class
paddle.enable_static()
class XPUTestCBroadcastOP(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'c_broadcast'
self.use_dynamic_create_class = False
class TestCBroadcastOp(TestDistBase):
def _setup_config(self):
pass
def test_broadcast(self):
self.check_with_place(
"collective_broadcast_op_xpu.py", "broadcast", self.in_type_str
)
support_types = ["float32"]
for stype in support_types:
create_test_class(
globals(),
XPUTestCBroadcastOP,
stype,
ignore_device_version=[core.XPUVersion.XPU1],
)
if __name__ == '__main__':
unittest.main()
......@@ -184,6 +184,34 @@ def init_communicator(
'rank_ids': nranks,
},
)
elif core.is_compiled_with_xpu():
bkcl_id_var = block.create_var(
name=fluid.unique_name.generate('bkcl_id'),
persistable=True,
type=fluid.core.VarDesc.VarType.RAW,
)
block.append_op(
type='c_gen_bkcl_id',
inputs={},
outputs={'Out': bkcl_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints,
},
)
block.append_op(
type='c_comm_init',
inputs={'X': bkcl_id_var},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': 0,
},
)
def prepare_distributed_context(place=None):
......
......@@ -34,6 +34,7 @@ def init_communicator(block, rank, ranks, ring_id):
comm_id_var = block.create_var(
name=comm_var_name, persistable=True, type=core.VarDesc.VarType.RAW
)
if core.is_compiled_with_cuda():
block.append_op(
type='c_gen_nccl_id',
inputs={},
......@@ -45,6 +46,18 @@ def init_communicator(block, rank, ranks, ring_id):
'ring_id': ring_id,
},
)
elif core.is_compiled_with_xpu():
block.append_op(
type='c_gen_bkcl_id',
inputs={},
outputs={'Out': comm_id_var},
attrs={
'rank': local_rank,
'endpoint': cur_ep,
'other_endpoints': other_eps,
'ring_id': ring_id,
},
)
block.append_op(
type='c_comm_init',
inputs={'X': comm_id_var},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册