diff --git a/paddle/fluid/framework/details/multi_devices_helper.cc b/paddle/fluid/framework/details/multi_devices_helper.cc index 2cfd27b0ef91ffcc8a2968c6cfe3077eaddec4d3..945895183d9c0e119b715cc4431fd28d6423a966 100644 --- a/paddle/fluid/framework/details/multi_devices_helper.cc +++ b/paddle/fluid/framework/details/multi_devices_helper.cc @@ -42,6 +42,7 @@ static std::unordered_set kMultiDeviceOps{ "c_comm_init_all", "c_comm_init_multitrainer", "c_gen_nccl_id", + "c_gen_bkcl_id", "c_sync_comm_stream", "send", "recv", diff --git a/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc index 7dd9e0779008e6740ab0a17650e91f931a8e88cb..268e9e30137879108c2636c44988556ed27f26ad 100644 --- a/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc @@ -261,7 +261,7 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph, if (mul_type == "mul") { fc_xpu_op_desc.SetAttr( "in_num_col_dims", - PADDLE_GET_CONST(int, mul->Op()->GetAttr("in_num_col_dims"))); + PADDLE_GET_CONST(int, mul->Op()->GetAttr("x_num_col_dims"))); } fc_xpu_op_desc.SetAttr("transpose_x", false); fc_xpu_op_desc.SetAttr("alpha", 1.f); diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 5604e913fd01234f4ad578739dec6a21f8966935..dfdbd63767c9e32fc94c6720478932d5ee392b1b 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -562,9 +562,7 @@ bool AnalysisPredictor::PrepareProgram( OptimizeInferenceProgram(); } } - executor_->CreateVariables(*inference_program_, 0, false, sub_scope_); - return true; } @@ -785,6 +783,30 @@ void AnalysisPredictor::InsertCommOp( comm_init_op->SetAttr("op_role", static_cast(framework::OpRole::kForward)); comm_init_op->CheckAttrs(); + } else if (config_.use_xpu()) { + framework::VarDesc *new_var = block->Var(tmp_var_name); + new_var->SetType(framework::proto::VarType::RAW); + new_var->SetPersistable(true); + framework::OpDesc *gen_bkcl_id_op = block->AppendOp(); + gen_bkcl_id_op->SetType("c_gen_bkcl_id"); + gen_bkcl_id_op->SetOutput("Out", {tmp_var_name}); + gen_bkcl_id_op->SetAttr("rank", rank); + gen_bkcl_id_op->SetAttr("endpoint", + config_.dist_config().current_endpoint()); + gen_bkcl_id_op->SetAttr("other_endpoints", peer_endpoints); + gen_bkcl_id_op->SetAttr("ring_id", ring_id); + gen_bkcl_id_op->SetAttr("op_role", + static_cast(framework::OpRole::kForward)); + gen_bkcl_id_op->CheckAttrs(); + framework::OpDesc *comm_init_op = block->AppendOp(); + comm_init_op->SetType("c_comm_init"); + comm_init_op->SetInput("X", {tmp_var_name}); + comm_init_op->SetAttr("rank", rank); + comm_init_op->SetAttr("nranks", nranks); + comm_init_op->SetAttr("ring_id", ring_id); + comm_init_op->SetAttr("op_role", + static_cast(framework::OpRole::kForward)); + comm_init_op->CheckAttrs(); } else { LOG(WARNING) << "DistModelInf doesn't init comm."; // TODO(fleet exe dev): comm init for more devices @@ -1319,7 +1341,6 @@ void AnalysisPredictor::PrepareArgument() { // NOTE All the members in AnalysisConfig should be copied to Argument. void AnalysisPredictor::OptimizeInferenceProgram() { PrepareArgument(); - #ifdef PADDLE_WITH_TENSORRT if (config_.tensorrt_engine_enabled()) { inference::tensorrt::TensorRTEngine::predictor_id_per_thread = @@ -1328,9 +1349,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() { << inference::tensorrt::TensorRTEngine::predictor_id_per_thread; } #endif - Analyzer().Run(argument_.get()); - PADDLE_ENFORCE_EQ( argument_->scope_valid(), true, diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index a5cdfda3243eb469418a87709fd851f089f37478..ca552ffc0525f4e4c30ce3dab349240be798616d 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -1194,6 +1194,20 @@ if(WITH_DISTRIBUTE --infer_model=${OCR_INSTALL_DIR}/model) endif() +if(WITH_DISTRIBUTE + AND WITH_PSCORE + AND WITH_XPU + AND WITH_XPU_BKCL) + inference_analysis_test( + test_analyzer_dist_model_xpu + SRCS + analyzer_dist_model_xpu_tester.cc + EXTRA_DEPS + paddle_inference_shared + ARGS + --infer_model=${OCR_INSTALL_DIR}/model) +endif() + inference_analysis_test( test_analyzer_paddletensor_tensor SRCS diff --git a/paddle/fluid/inference/tests/api/analyzer_dist_model_xpu_tester.cc b/paddle/fluid/inference/tests/api/analyzer_dist_model_xpu_tester.cc new file mode 100644 index 0000000000000000000000000000000000000000..cb7688221e2673cbd842f88af25b575317173a69 --- /dev/null +++ b/paddle/fluid/inference/tests/api/analyzer_dist_model_xpu_tester.cc @@ -0,0 +1,73 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/block_desc.h" +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/inference/tests/api/tester_helper.h" +#include "paddle/fluid/inference/utils/singleton.h" + +namespace paddle { +namespace inference { + +TEST(test_dist_model_xpu, dist_model_xpu) { + std::cout << "Analysis Predictor DistModel XPU test." << std::endl; + AnalysisConfig config; + config.SetModel(FLAGS_infer_model + "/__model__", + FLAGS_infer_model + "/__params__"); + config.SwitchUseFeedFetchOps(false); + config.EnableXpu(); + config.SetXpuDeviceId(0); + DistConfig dist_config; + dist_config.SetRanks(1, 0); + dist_config.EnableDistModel(true); + dist_config.SetEndpoints({""}, ""); + config.SetDistConfig(dist_config); + + auto predictor = paddle_infer::CreatePredictor(config); + int batch_size = 1; + int channels = 1; + int height = 48; + int width = 512; + int nums = batch_size * channels * height * width; + std::cout << "Created predictor." << std::endl; + + float* input = new float[nums]; + for (int i = 0; i < nums; ++i) input[i] = 0; + auto input_names = predictor->GetInputNames(); + + auto input_t = predictor->GetInputHandle(input_names[0]); + input_t->Reshape({batch_size, channels, height, width}); + input_t->CopyFromCpu(input); + std::cout << "Input data." << std::endl; + + predictor->Run(); + std::cout << "Zero Copy Run." << std::endl; + + std::vector out_data; + auto output_names = predictor->GetOutputNames(); + auto output_t = predictor->GetOutputHandle(output_names[0]); + std::vector output_shape = output_t->shape(); + int out_num = std::accumulate( + output_shape.begin(), output_shape.end(), 1, std::multiplies()); + out_data.resize(out_num); + output_t->CopyToCpu(out_data.data()); + std::cout << "Output data." << std::endl; + delete[] input; +} + +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/operators/collective/c_broadcast_op_xpu.cc b/paddle/fluid/operators/collective/c_broadcast_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..3f12adcea3f8fce0cf3a8e506853624715699ebe --- /dev/null +++ b/paddle/fluid/operators/collective/c_broadcast_op_xpu.cc @@ -0,0 +1,118 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_broadcast_op.h" + +#ifdef PADDLE_WITH_XPU_BKCL +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/device/xpu/bkcl_helper.h" +#endif + +namespace paddle { +namespace operators { + +template +class CBroadcastOpXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { +#if defined(PADDLE_WITH_XPU_BKCL) + auto x = ctx.Input("X"); + auto out = ctx.Output("Out"); + size_t numel = x->numel(); + + BKCLDataType dtype = + platform::ToBKCLDataType(framework::TransToProtoVarType(x->dtype())); + int ring_id = ctx.Attr("ring_id"); + auto place = ctx.GetPlace(); + auto comm = + paddle::platform::BKCLCommContext::Instance().Get(ring_id, place); + + XPUStream stream = nullptr; + auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + if (ctx.Attr("use_calc_stream")) { + stream = static_cast(dev_ctx) + ->x_context() + ->xpu_stream; + } else { + stream = comm->stream(); + } + + int root = ctx.Attr("root"); + VLOG(3) << "begin bkcl broadcast, parameter is: " + << "root " << root << ", comm: " << comm->comm() + << ", stream: " << stream; + void* send_recv_buffer = nullptr; + if (root == comm->rank()) { + // API: BKCLResult_t bkcl_broadcast(const BKCLContext_t ctx, + // const void* sendbuf, + // void* recvbuf, + // size_t count, BKCLDataType datatype, + // int root, + // XPUStream stream); + send_recv_buffer = reinterpret_cast(const_cast(x->data())); + auto ret = bkcl_broadcast(comm->comm(), + send_recv_buffer, + send_recv_buffer, + numel, + dtype, + root, + stream); + PADDLE_ENFORCE_EQ(ret, + BKCL_SUCCESS, + platform::errors::PreconditionNotMet( + "XPU BKCL c_broadcast execute failed")); + if (out != x) { + framework::TensorCopy( + *static_cast(x), + place, + *platform::DeviceContextPool::Instance().Get(place), + static_cast(out)); + } + } else { + auto& dev_ctx = ctx.template device_context(); + dev_ctx.template Alloc(out); + send_recv_buffer = out->data(); + auto ret = bkcl_broadcast(comm->comm(), + send_recv_buffer, + send_recv_buffer, + numel, + dtype, + root, + stream); + PADDLE_ENFORCE_EQ(ret, + BKCL_SUCCESS, + platform::errors::PreconditionNotMet( + "XPU BKCL c_broadcast execute failed")); + } + + VLOG(3) << "rank " << comm->rank() << " invoke Bcast. received " + << phi::product(out->dims()); + out->Resize(x->dims()); + out->set_lod(x->lod()); +#else + PADDLE_THROW(platform::errors::PreconditionNotMet( + "PaddlePaddle should be compiled with XPU and BKCL.")); +#endif + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_XPU_KERNEL(c_broadcast, + ops::CBroadcastOpXPUKernel, + ops::CBroadcastOpXPUKernel); diff --git a/paddle/fluid/operators/collective/c_comm_init_op.cc b/paddle/fluid/operators/collective/c_comm_init_op.cc index 26e700262f8dda4555836dca2685eb182e08206b..e3aadbf344ac9975f190820e54620354bab3e7c8 100644 --- a/paddle/fluid/operators/collective/c_comm_init_op.cc +++ b/paddle/fluid/operators/collective/c_comm_init_op.cc @@ -84,15 +84,6 @@ class CCommInitOp : public framework::OperatorBase { int nranks = Attr("nranks"); int rid = Attr("ring_id"); -#if defined(PADDLE_WITH_XPU_BKCL) - PADDLE_ENFORCE_EQ( - rid, - 0, - platform::errors::OutOfRange( - "Ring id must equal 0 in multi Kunlun cards training, but got %d", - rid)); -#endif - int device_id = place.device; if (Attr("device_id") >= 0) { device_id = Attr("device_id"); diff --git a/paddle/fluid/platform/collective_helper.cc b/paddle/fluid/platform/collective_helper.cc index 41cb9ed1b700d2f0ac13e90b29878ba819736227..948fa300c1f475727ad757cb9da0121e45f4a4de 100644 --- a/paddle/fluid/platform/collective_helper.cc +++ b/paddle/fluid/platform/collective_helper.cc @@ -340,9 +340,7 @@ BKCLComm* BKCLCommContext::CreateComm( BKCLContext_t comm = nullptr; platform::SetXPUDeviceId(dev_id); PADDLE_ENFORCE_XPU_SUCCESS(bkcl_init_rank(&comm, rank, nranks, bkcl_id)); - auto* comm_wrapper = AssignBKCLComm(comm, nranks, rank, dev_id, ring_id); - VLOG(1) << "bkcl communicator of rank " << rank << " in ring " << ring_id << " has been created on device " << dev_id; @@ -372,30 +370,27 @@ BKCLComm* BKCLCommContext::AssignBKCLComm( paddle::memory::allocation::AllocatorFacade::Instance() .GetZeroAllocator(paddle::platform::CPUPlace()) .get()); - BKCLCommImpl* c = new BKCLCommImpl; c->set_ring_id(ring_id); c->set_nranks(nranks); c->set_rank(rank); c->set_comm(comm); c->set_dev_ctx(std::move(dev_ctx)); - comm_map_mutex_.lock(); if (comm_map_.count(ring_id) == 0) { comm_map_.emplace(ring_id, std::map>()); } auto& dev2comm = comm_map_[ring_id]; - dev2comm.emplace(dev_id, std::unique_ptr(c)); comm_map_mutex_.unlock(); - if (ring_id == 0) { auto* dev_ctx = static_cast( platform::DeviceContextPool::Instance().Get( platform::XPUPlace(dev_id))); dev_ctx->SetBkclContext(comm); } - + VLOG(3) << "add bkcl comm: " << comm_map_[ring_id][dev_id].get() + << ", ring_id:" << ring_id << ", dev_id:" << dev_id; return comm_map_[ring_id][dev_id].get(); } diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index a5935c07b94792e5ec98552fa1e1f51cf5267451..bc883bc6e32363ba6ddba3607c6e10ec990323ee 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -82,6 +82,8 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32, phi::DataType::INT32})}, + {"c_broadcast", + XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})}, {"c_concat", XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})}, {"c_embedding", XPUKernelSet({phi::DataType::FLOAT32})}, diff --git a/paddle/phi/kernels/xpu/pool_kernel.cc b/paddle/phi/kernels/xpu/pool_kernel.cc index d876bbbeb8fbfe08a2ae0222850f116b8833f0e9..ad09a5ed37188e19f23ec71ad276a7d3b10b389e 100644 --- a/paddle/phi/kernels/xpu/pool_kernel.cc +++ b/paddle/phi/kernels/xpu/pool_kernel.cc @@ -44,11 +44,12 @@ void Pool2dKernel(const Context& ctx, phi::errors::InvalidArgument( "The Pool2d XPU OP only support 2 dimension pooling!")); - PADDLE_ENFORCE_EQ( + // old model's data_format maybe AnyLayout + PADDLE_ENFORCE_NE( data_format, - "NCHW", - phi::errors::InvalidArgument("The Pool2d XPU OP only support " - "data_format is 'NCHW', but received %s", + "NHWC", + phi::errors::InvalidArgument("The Pool2d XPU OP does not support " + "data_format is 'NHWC', but received %s", data_format)); if (global_pooling) { diff --git a/python/paddle/distributed/auto_parallel/process_group.py b/python/paddle/distributed/auto_parallel/process_group.py index 8ad0172ea3d42ec27acaf54b76f15089e648c3d5..1debbaa325638886da874eadc1900c5954f3e4da 100644 --- a/python/paddle/distributed/auto_parallel/process_group.py +++ b/python/paddle/distributed/auto_parallel/process_group.py @@ -143,15 +143,25 @@ class ProcessGroup: core.NCCLParallelContext(strategy, place).init_with_ring_id( ring_id ) + elif core.is_compiled_with_xpu(): + place = core.XPUPlace(genv.device_id) + core.BKCLParallelContext(strategy, place).init_with_ring_id( + ring_id + ) else: assert False, "No CUDA device found" # TODO(shenliang03): This is a temporary solution to solve the problem of # hang caused by cross-creation of new_group paddle.disable_static() - paddle.set_device( - 'gpu:%d' % paddle.distributed.ParallelEnv().dev_id - ) + if core.is_compiled_with_cuda(): + paddle.set_device( + 'gpu:%d' % paddle.distributed.ParallelEnv().dev_id + ) + elif core.is_compiled_with_xpu(): + paddle.set_device( + 'xpu:%d' % paddle.distributed.ParallelEnv().dev_id + ) tmp = ( paddle.to_tensor([1], dtype="int32") if in_dygraph_mode() diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 19c7147da2d5095173915d36a3434d18a2e33973..dedcc6f5ac70c4235a73f6114185d352a0c6e3ea 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -1143,6 +1143,7 @@ class ShardingOptimizer(MetaOptimizerBase): "c_sync_comm_stream", "c_calc_comm_stream", "c_gen_nccl_id", + "c_gen_bkcl_id", "c_comm_init", 'send_v2', 'recv_v2', diff --git a/python/paddle/distributed/ps/utils/collective_transpiler.py b/python/paddle/distributed/ps/utils/collective_transpiler.py index 5e0fe0d19381d25a929499686e2280cdf886f4b1..ea6f23de48d9736d7908f0638d9989ca88d8651c 100644 --- a/python/paddle/distributed/ps/utils/collective_transpiler.py +++ b/python/paddle/distributed/ps/utils/collective_transpiler.py @@ -163,7 +163,36 @@ class Collective: self.op_role_key: OpRole.Forward, }, ) - else: + elif core.is_compiled_with_xpu(): + bkcl_id_var = block.create_var( + name=unique_name.generate('bkcl_id'), + persistable=True, + type=core.VarDesc.VarType.RAW, + ) + endpoint_to_index_map = {e: idx for idx, e in enumerate(endpoints)} + block.append_op( + type='c_gen_bkcl_id', + inputs={}, + outputs={'Out': bkcl_id_var}, + attrs={ + 'rank': rank, + 'endpoint': current_endpoint, + 'other_endpoints': other_endpoints, + self.op_role_key: OpRole.Forward, + }, + ) + block.append_op( + type='c_comm_init', + inputs={'X': bkcl_id_var}, + outputs={}, + attrs={ + 'nranks': nranks, + 'rank': rank, + 'ring_id': ring_id, + self.op_role_key: OpRole.Forward, + }, + ) + elif core.is_compiled_with_cuda(): nccl_id_var = block.create_var( name=unique_name.generate('nccl_id'), persistable=True, diff --git a/python/paddle/distributed/transpiler/collective.py b/python/paddle/distributed/transpiler/collective.py index db07a1887d4bb77fa27825fd1485b2892eb59f0e..3e6b9b6885fda0e23c56c6d9ab4711a52b2602bb 100644 --- a/python/paddle/distributed/transpiler/collective.py +++ b/python/paddle/distributed/transpiler/collective.py @@ -161,7 +161,7 @@ class Collective: self.op_role_key: OpRole.Forward, }, ) - else: + elif core.is_compiled_with_cuda(): nccl_id_var = block.create_var( name=unique_name.generate('nccl_id'), persistable=True, @@ -202,6 +202,34 @@ class Collective: self.op_role_key: OpRole.Forward, }, ) + elif core.is_compiled_with_xpu(): + bkcl_id_var = block.create_var( + name=unique_name.generate('bkcl_id'), + persistable=True, + type=core.VarDesc.VarType.RAW, + ) + block.append_op( + type='c_gen_bkcl_id', + inputs={}, + outputs={'Out': bkcl_id_var}, + attrs={ + 'rank': rank, + 'endpoint': current_endpoint, + 'other_endpoints': other_endpoints, + self.op_role_key: OpRole.Forward, + }, + ) + block.append_op( + type='c_comm_init', + inputs={'X': bkcl_id_var}, + outputs={}, + attrs={ + 'nranks': nranks, + 'rank': rank, + 'ring_id': ring_id, + self.op_role_key: OpRole.Forward, + }, + ) def _broadcast_params(self): block = self.startup_program.global_block() diff --git a/python/paddle/fluid/tests/unittests/xpu/collective_allgather_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/collective_allgather_op_xpu.py index 5cd5a92f4b2c4862d1d334eb6923fb08bd352d2f..421155e3b806b22145197ac26cef91fac593c388 100644 --- a/python/paddle/fluid/tests/unittests/xpu/collective_allgather_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/collective_allgather_op_xpu.py @@ -30,7 +30,7 @@ class TestCollectiveAllGather(TestCollectiveRunnerBase): nranks = 2 with fluid.program_guard(main_prog, startup_program): tindata = paddle.static.data( - name="tindata", shape=[-1, 10, 1000], dtype='float32' + name="tindata", shape=[10, 1000], dtype='float32' ) toutdata = main_prog.current_block().create_var( name="outofgather", diff --git a/python/paddle/fluid/tests/unittests/xpu/collective_allreduce_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/collective_allreduce_op_xpu.py index 54b60f76665e3740b818d1bc50f03f17097b01e8..00a5579253c4ca43229635a1902bb8bddc74137d 100644 --- a/python/paddle/fluid/tests/unittests/xpu/collective_allreduce_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/collective_allreduce_op_xpu.py @@ -31,7 +31,7 @@ class TestCollectiveAllReduce(TestCollectiveRunnerBase): ring_id = 0 with fluid.program_guard(main_prog, startup_program): tindata = paddle.static.data( - name="tindata", shape=[-1, 10, 1000], dtype='float32' + name="tindata", shape=[10, 1000], dtype='float32' ) toutdata = main_prog.current_block().create_var( name="outofreduce", diff --git a/python/paddle/fluid/tests/unittests/xpu/collective_broadcast_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/collective_broadcast_op_xpu.py new file mode 100755 index 0000000000000000000000000000000000000000..f35a9c90aa9606b2ae002e30834e880271845e6e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/collective_broadcast_op_xpu.py @@ -0,0 +1,62 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from test_collective_base_xpu import TestCollectiveRunnerBase, runtime_main + +import paddle +import paddle.fluid as fluid +from paddle.fluid import core + +paddle.enable_static() + + +class TestCollectiveBroadcast(TestCollectiveRunnerBase): + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program): + ring_id = 0 + rootid = 1 + with fluid.program_guard(main_prog, startup_program): + tindata = paddle.static.data( + name="tindata", shape=[10, 1000], dtype='float32' + ) + + toutdata = main_prog.current_block().create_var( + name="outofbroadcast", + dtype='float32', + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=False, + ) + main_prog.global_block().append_op( + type="c_broadcast", + inputs={'X': tindata}, + attrs={'ring_id': ring_id, 'root': rootid}, + outputs={'Out': toutdata}, + ) + main_prog.global_block().append_op( + type="c_sync_comm_stream", + inputs={'X': toutdata}, + outputs={'Out': toutdata}, + attrs={'ring_id': ring_id}, + ) + return toutdata + + +if __name__ == "__main__": + os.environ["BKCL_PCIE_RING"] = "1" + runtime_main(TestCollectiveBroadcast, "broadcast", 0) diff --git a/python/paddle/fluid/tests/unittests/xpu/collective_identity_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/collective_identity_op_xpu.py index 8fea9d7a4ac0e5ff4831949376fbb3e2c2ff333a..71c8be2ad10df365bb6c0e63edebd0d710608e96 100644 --- a/python/paddle/fluid/tests/unittests/xpu/collective_identity_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/collective_identity_op_xpu.py @@ -30,7 +30,7 @@ class TestCollectiveIdentity(TestCollectiveRunnerBase): nranks = 2 with fluid.program_guard(main_prog, startup_program): tindata = paddle.static.data( - name="tindata", shape=[-1, 10, 1000], dtype='float32' + name="tindata", shape=[10, 1000], dtype='float32' ) toutdata = main_prog.current_block().create_var( name="outofgather", diff --git a/python/paddle/fluid/tests/unittests/xpu/test_collective_broadcast_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_collective_broadcast_xpu.py new file mode 100644 index 0000000000000000000000000000000000000000..e015d0f92b11471c3771e0230bb9da94d1256968 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_collective_broadcast_xpu.py @@ -0,0 +1,55 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +from test_collective_base_xpu import TestDistBase + +import paddle +from paddle.fluid import core + +sys.path.append("..") + +from xpu.get_test_cover_info import XPUOpTestWrapper, create_test_class + +paddle.enable_static() + + +class XPUTestCBroadcastOP(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'c_broadcast' + self.use_dynamic_create_class = False + + class TestCBroadcastOp(TestDistBase): + def _setup_config(self): + pass + + def test_broadcast(self): + self.check_with_place( + "collective_broadcast_op_xpu.py", "broadcast", self.in_type_str + ) + + +support_types = ["float32"] +for stype in support_types: + create_test_class( + globals(), + XPUTestCBroadcastOP, + stype, + ignore_device_version=[core.XPUVersion.XPU1], + ) + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/hapi/model.py b/python/paddle/hapi/model.py index d5bad6a977639daebd11eda9b667a3d1478f951a..be29ce73be70c036a76fdecccf449fa3eb125eb4 100644 --- a/python/paddle/hapi/model.py +++ b/python/paddle/hapi/model.py @@ -184,6 +184,34 @@ def init_communicator( 'rank_ids': nranks, }, ) + elif core.is_compiled_with_xpu(): + bkcl_id_var = block.create_var( + name=fluid.unique_name.generate('bkcl_id'), + persistable=True, + type=fluid.core.VarDesc.VarType.RAW, + ) + + block.append_op( + type='c_gen_bkcl_id', + inputs={}, + outputs={'Out': bkcl_id_var}, + attrs={ + 'rank': rank, + 'endpoint': current_endpoint, + 'other_endpoints': other_endpoints, + }, + ) + + block.append_op( + type='c_comm_init', + inputs={'X': bkcl_id_var}, + outputs={}, + attrs={ + 'nranks': nranks, + 'rank': rank, + 'ring_id': 0, + }, + ) def prepare_distributed_context(place=None): diff --git a/python/paddle/incubate/optimizer/distributed_fused_lamb.py b/python/paddle/incubate/optimizer/distributed_fused_lamb.py index 9aa51cd8122e68114e610714672980ba132f9629..3964f431ac5540c7a576b0ef58f13488886570c7 100644 --- a/python/paddle/incubate/optimizer/distributed_fused_lamb.py +++ b/python/paddle/incubate/optimizer/distributed_fused_lamb.py @@ -34,17 +34,30 @@ def init_communicator(block, rank, ranks, ring_id): comm_id_var = block.create_var( name=comm_var_name, persistable=True, type=core.VarDesc.VarType.RAW ) - block.append_op( - type='c_gen_nccl_id', - inputs={}, - outputs={'Out': comm_id_var}, - attrs={ - 'rank': local_rank, - 'endpoint': cur_ep, - 'other_endpoints': other_eps, - 'ring_id': ring_id, - }, - ) + if core.is_compiled_with_cuda(): + block.append_op( + type='c_gen_nccl_id', + inputs={}, + outputs={'Out': comm_id_var}, + attrs={ + 'rank': local_rank, + 'endpoint': cur_ep, + 'other_endpoints': other_eps, + 'ring_id': ring_id, + }, + ) + elif core.is_compiled_with_xpu(): + block.append_op( + type='c_gen_bkcl_id', + inputs={}, + outputs={'Out': comm_id_var}, + attrs={ + 'rank': local_rank, + 'endpoint': cur_ep, + 'other_endpoints': other_eps, + 'ring_id': ring_id, + }, + ) block.append_op( type='c_comm_init', inputs={'X': comm_id_var},