diff --git a/paddle/fluid/operators/custom_device_common_op_registry.cc b/paddle/fluid/operators/custom_device_common_op_registry.cc index 729e840e89d3e9b437e1654074687bb64d2df450..6fded8e113751465dab492c7a07683a38b62261b 100644 --- a/paddle/fluid/operators/custom_device_common_op_registry.cc +++ b/paddle/fluid/operators/custom_device_common_op_registry.cc @@ -279,7 +279,7 @@ class CEmbeddingGradOpCustomDeviceKernel : public framework::OpKernel { x_tensor, start_index + N, x_tensor.dtype(), x_tensor.place()); auto ids_mask_tensor = paddle::experimental::logical_and( x_tensor.greater_equal(start_index_tensor), - x_tensor.less_equal(end_index_tensor)); + x_tensor.less_than(end_index_tensor)); auto real_ids_tensor = (x_tensor - start_index_tensor) .multiply(paddle::experimental::cast( ids_mask_tensor, x_tensor.dtype())); @@ -668,6 +668,594 @@ class BarrierOpCustomDeviceKernel : public framework::OpKernel { } }; +template +class NumberCountOpCustomDeviceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto numbers = context.Input("numbers"); + auto upper_range = context.Attr("upper_range"); + auto number_count = context.Output("Out"); + const auto& dev_ctx = context.template device_context(); + number_count->Resize({upper_range}); + dev_ctx.template Alloc(number_count); + phi::DenseTensor cpu_tensor; + framework::TensorCopySync(*numbers, platform::CPUPlace(), &cpu_tensor); + std::vector count(upper_range); + for (auto i = 0; i < cpu_tensor.numel(); ++i) { + auto idx = static_cast(cpu_tensor.data()[i]); + if (idx >= 0 && idx < upper_range) { + count[idx] += 1; + } + } + framework::TensorFromVector(count, dev_ctx, number_count); + number_count->Resize({upper_range}); + } +}; + +template +class LimitByCapacityOpCustomDeviceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto expert_count = context.Input("expert_count"); + auto capacity = context.Input("capacity"); + auto out = context.Output("Out"); + auto n_worker = context.Attr("n_worker"); + auto n_expert = expert_count->numel() / n_worker; + + const auto& dev_ctx = context.template device_context(); + + dev_ctx.template Alloc(out); + std::vector out_data(out->numel()); + phi::DenseTensor expert_count_cpu, capacity_cpu; + framework::TensorCopySync( + *expert_count, platform::CPUPlace(), &expert_count_cpu); + framework::TensorCopySync(*capacity, platform::CPUPlace(), &capacity_cpu); + + auto* ec_data = expert_count_cpu.data(); + auto* capacity_data = capacity_cpu.data(); + int eid, wid; + for (int64_t i = 0; i < expert_count->numel(); ++i) { + wid = i / n_expert; + eid = i % n_expert; + auto proposal = ec_data[i]; + auto cap_left = capacity_data[eid]; + capacity_data[eid] -= proposal; + if (cap_left >= proposal) { + out_data[wid * n_expert + eid] = proposal; + } else if (cap_left >= 0) { + out_data[wid * n_expert + eid] = cap_left; + } else { + out_data[wid * n_expert + eid] = 0; + } + } + + auto out_dims = out->dims(); + framework::TensorFromVector(out_data, dev_ctx, out); + out->Resize(out_dims); + } +}; + +template +class PruneGateByCapacityCustomDeviceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* gate_idx = context.Input("GateIdx"); + auto* expert_count = context.Input("ExpertCount"); + auto* new_gate_idx = context.Output("NewGateIdx"); + const auto& dev_ctx = context.template device_context(); + dev_ctx.template Alloc(new_gate_idx); + + phi::DenseTensor expert_count_cpu, gate_idx_cpu; + framework::TensorCopySync( + *expert_count, platform::CPUPlace(), &expert_count_cpu); + framework::TensorCopySync(*gate_idx, platform::CPUPlace(), &gate_idx_cpu); + auto expert_count_data = expert_count_cpu.data(); + auto gate_idx_data = gate_idx_cpu.data(); + std::vector new_gate_idx_data(gate_idx->numel()); + for (auto i = 0; i < gate_idx->numel(); ++i) { + auto orig_cap = expert_count_data[gate_idx_data[i]]--; + if (orig_cap <= 0) { + new_gate_idx_data[i] = -1; + } else { + new_gate_idx_data[i] = gate_idx_data[i]; + } + } + } +}; + +template +class RandomRoutingOpCustomDeviceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto topk_idx = context.Input("TopK_Idx"); + auto topk_value = context.Input("TopK_Value"); + auto prob = context.Input("Prob"); + auto out = context.Output("Out"); + + const auto& dev_ctx = context.template device_context(); + size_t D = topk_idx->dims()[1]; + + phi::DenseTensor topk_value_cpu, prob_cpu; + framework::TensorCopySync( + *topk_value, platform::CPUPlace(), &topk_value_cpu); + framework::TensorCopySync(*prob, platform::CPUPlace(), &prob_cpu); + auto* topk_value_data = topk_value_cpu.data(); + auto* prob_data = prob_cpu.data(); + std::vector out_data(topk_idx->numel()); + + for (int64_t idx = 0; idx < topk_idx->numel(); ++idx) { + size_t row = idx / D; + size_t col = idx % D; + if (col == 1 && + static_cast(2) * topk_value_data[idx] < prob_data[row]) { + out_data[idx] = static_cast(-1); + } + } + auto out_dims = out->dims(); + framework::TensorFromVector(out_data, dev_ctx, out); + out->Resize(out_dims); + } +}; + +template +class AssignPosCustomDeviceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + // assign pos decides which tokens should be fetched belong to specially + // counter orderingly. + auto cum_count = context.Input( + "cum_count"); // (counter number) int32 | int64 + auto numbers = context.Input( + "X"); // (batch_size * seq_len, topk) int32 + auto eff_num_len = + context.Input("eff_num_len"); // (sum(cum_count)) + auto out = + context.Output("Out"); // (cum_count) value ranges + // from 0 to batch_size * + // seq_len * topk + const auto& dev_ctx = context.template device_context(); + + phi::DenseTensor cpu_eff_num_len; + int64_t cpu_eff_num_len_data = 0; + if (platform::is_cpu_place(eff_num_len->place())) { + cpu_eff_num_len_data = eff_num_len->data()[0]; + } else { + framework::TensorCopySync( + *eff_num_len, platform::CPUPlace(), &cpu_eff_num_len); + cpu_eff_num_len_data = cpu_eff_num_len.data()[0]; + } + + out->Resize({cpu_eff_num_len_data}); + dev_ctx.template Alloc(out); + + phi::DenseTensor numbers_cpu, cum_count_cpu; + framework::TensorCopySync(*numbers, platform::CPUPlace(), &numbers_cpu); + framework::TensorCopySync(*cum_count, platform::CPUPlace(), &cum_count_cpu); + auto* numbers_data = numbers_cpu.data(); + auto* cum_count_data = cum_count_cpu.data(); + + std::vector out_data(cpu_eff_num_len_data); + for (int64_t i = 0; i < numbers->numel(); ++i) { + int number_idx = numbers_data[i]; + if (number_idx > -1) { + cum_count_data[number_idx] -= 1; + int p = cum_count_data[number_idx]; + out_data[p] = i; + } + } + framework::TensorFromVector(out_data, dev_ctx, out); + } +}; + +template +class GlobalScatterOpCustomDeviceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto x = ctx.Input("X"); + auto local_count = ctx.Input("local_count"); + auto global_count = ctx.Input("global_count"); + auto out = ctx.Output("Out"); + const int rid = ctx.Attr("ring_id"); + const auto& dev_ctx = ctx.template device_context(); + auto place = ctx.GetPlace(); + + PADDLE_ENFORCE_EQ(local_count->dtype(), + phi::DataType::INT64, + platform::errors::InvalidArgument( + "Please use int64 type in local_count.")); + PADDLE_ENFORCE_EQ(global_count->dtype(), + phi::DataType::INT64, + platform::errors::InvalidArgument( + "Please use int64 type in global_count.")); + + auto map = distributed::ProcessGroupMapFromGid::getInstance(); + const int64_t* cpu_local_count_data; + const int64_t* cpu_global_count_data; + phi::DenseTensor cpu_local_count; + if (platform::is_cpu_place(local_count->place())) { + cpu_local_count_data = local_count->data(); + } else { + framework::TensorCopySync( + *local_count, platform::CPUPlace(), &cpu_local_count); + cpu_local_count_data = cpu_local_count.data(); + } + auto global_count_len = 0; + phi::DenseTensor cpu_global_count; + if (platform::is_cpu_place(global_count->place())) { + cpu_global_count_data = global_count->data(); + global_count_len = global_count->numel(); + } else { + framework::TensorCopySync( + *global_count, platform::CPUPlace(), &cpu_global_count); + cpu_global_count_data = cpu_global_count.data(); + global_count_len = cpu_global_count.numel(); + } + + if (map->has(rid)) { + distributed::ProcessGroup* pg = map->get(rid); + auto stream = + reinterpret_cast(pg->GetDeviceContext(place)) + ->GetStream(); + int nranks = pg->GetSize(); + int rank = pg->GetRank(); + auto in_feat = x->dims()[1]; + auto n_expert = local_count->dims()[0] / nranks; + int64_t fwd_count = 0; + + for (auto i = 0; i < global_count_len; ++i) { + fwd_count += cpu_global_count_data[i]; + } + framework::DDim out_dims = phi::make_ddim({fwd_count, in_feat}); + int64_t* expert_ptr = new int64_t[n_expert * nranks]; + expert_ptr[0] = 0; + auto tot_experts = n_expert * nranks; + for (auto i = 1; i < tot_experts; ++i) { + expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1]; + } + + auto recv_ptr = 0; + out->Resize(out_dims); + dev_ctx.template Alloc(out); + + for (auto i = 0; i < n_expert; ++i) { + for (auto j = 0; j < rank; ++j) { + int idx = i + j * n_expert; + if (cpu_global_count_data[idx]) { + pg->Recv(out, + j, + recv_ptr * in_feat, + cpu_global_count_data[idx] * in_feat, + /*sync_op*/ true); + recv_ptr += cpu_global_count_data[idx]; + } + } + for (auto j = 0; j < nranks; ++j) { + if (j != rank) { + int idx = i + j * n_expert; + if (cpu_local_count_data[idx]) { + phi::DenseTensor tmp = *x; + pg->Send(tmp, + j, + expert_ptr[idx] * in_feat, + cpu_local_count_data[idx] * in_feat, + /*sync_op*/ true); + } + } + } + if (cpu_local_count_data[i + rank * n_expert]) { + phi::DeviceManager::GetDeviceWithPlace(place)->MemoryCopyD2D( + reinterpret_cast(out->data() + recv_ptr * in_feat), + reinterpret_cast(x->data() + + expert_ptr[rank] * in_feat), + (cpu_local_count_data[rank] * in_feat) * phi::SizeOf(x->dtype()), + stream.get()); + recv_ptr += cpu_global_count_data[rank]; + } + for (auto j = rank + 1; j < nranks; ++j) { + int idx = i + j * n_expert; + if (cpu_global_count_data[idx]) { + pg->Recv(out, + j, + recv_ptr * in_feat, + cpu_global_count_data[idx] * in_feat, + /*sync_op*/ true); + recv_ptr += cpu_global_count_data[idx]; + } + } + } + } else { + auto comm = platform::XCCLCommContext::Instance(place.GetDeviceType()) + .Get(rid, place); + std::shared_ptr stream; + if (ctx.Attr("use_calc_stream")) { + stream = dev_ctx.GetStream(); + } else { + stream = comm->stream(); + } + int nranks = comm->nranks(); + int rank = comm->rank(); + auto in_feat = x->dims()[1]; + auto n_expert = local_count->dims()[0] / nranks; + int64_t fwd_count = 0; + + for (auto i = 0; i < global_count_len; ++i) { + fwd_count += cpu_global_count_data[i]; + } + framework::DDim out_dims = phi::make_ddim({fwd_count, in_feat}); + int64_t* expert_ptr = new int64_t[n_expert * nranks]; + expert_ptr[0] = 0; + auto tot_experts = n_expert * nranks; + for (auto i = 1; i < tot_experts; ++i) { + expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1]; + } + + auto recv_ptr = 0; + auto send_buf = x->data(); + out->Resize(out_dims); + auto recv_buf = dev_ctx.template Alloc(out); + + for (auto i = 0; i < n_expert; ++i) { + for (auto j = 0; j < rank; ++j) { + int idx = i + j * n_expert; + if (cpu_global_count_data[idx]) { + phi::DeviceManager::CCLRecv( + place.GetDeviceType(), + reinterpret_cast(recv_buf + recv_ptr * in_feat), + cpu_global_count_data[idx] * in_feat, + phi::ccl::ToCCLDataType(x->dtype()), + j, + comm->comm(), + *stream); + recv_ptr += cpu_global_count_data[idx]; + } + } + for (auto j = 0; j < nranks; ++j) { + if (j != rank) { + int idx = i + j * n_expert; + if (cpu_local_count_data[idx]) { + phi::DeviceManager::CCLSend( + place.GetDeviceType(), + const_cast(reinterpret_cast( + send_buf + expert_ptr[idx] * in_feat)), + cpu_local_count_data[idx] * in_feat, + phi::ccl::ToCCLDataType(x->dtype()), + j, + comm->comm(), + *stream); + } + } + } + if (cpu_local_count_data[i + rank * n_expert]) { + phi::DeviceManager::GetDeviceWithPlace(place)->MemoryCopyD2D( + reinterpret_cast(recv_buf + recv_ptr * in_feat), + reinterpret_cast(send_buf + + expert_ptr[rank] * in_feat), + (cpu_local_count_data[rank] * in_feat) * phi::SizeOf(x->dtype()), + stream.get()); + recv_ptr += cpu_global_count_data[rank]; + } + for (auto j = rank + 1; j < nranks; ++j) { + int idx = i + j * n_expert; + if (cpu_global_count_data[idx]) { + phi::DeviceManager::CCLRecv( + place.GetDeviceType(), + reinterpret_cast(recv_buf + recv_ptr * in_feat), + cpu_global_count_data[idx] * in_feat, + phi::ccl::ToCCLDataType(x->dtype()), + j, + comm->comm(), + *stream); + recv_ptr += cpu_global_count_data[idx]; + } + } + } + } + + phi::DeviceManager::SynchronizeDevice(ctx.GetPlace()); + } +}; + +template +class GlobalGatherOpCustomDeviceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto x = ctx.Input("X"); + auto local_count = ctx.Input("local_count"); + auto global_count = ctx.Input("global_count"); + const int rid = ctx.Attr("ring_id"); + const auto& dev_ctx = ctx.template device_context(); + auto place = ctx.GetPlace(); + auto out = ctx.Output("Out"); + + PADDLE_ENFORCE_EQ(local_count->dtype(), + phi::DataType::INT64, + platform::errors::InvalidArgument( + "Please use int64 type in local_count.")); + PADDLE_ENFORCE_EQ(global_count->dtype(), + phi::DataType::INT64, + platform::errors::InvalidArgument( + "Please use int64 type in global_count.")); + + const int64_t* cpu_local_count_data; + const int64_t* cpu_global_count_data; + auto local_count_len = 0; + phi::DenseTensor cpu_local_count; + if (platform::is_cpu_place(local_count->place())) { + cpu_local_count_data = local_count->data(); + local_count_len = local_count->numel(); + } else { + framework::TensorCopySync( + *local_count, platform::CPUPlace(), &cpu_local_count); + cpu_local_count_data = cpu_local_count.data(); + local_count_len = cpu_local_count.numel(); + } + phi::DenseTensor cpu_global_count; + if (platform::is_cpu_place(global_count->place())) { + cpu_global_count_data = global_count->data(); + } else { + framework::TensorCopySync( + *global_count, platform::CPUPlace(), &cpu_global_count); + cpu_global_count_data = cpu_global_count.data(); + } + + auto map = distributed::ProcessGroupMapFromGid::getInstance(); + + if (map->has(rid)) { + distributed::ProcessGroup* pg = map->get(rid); + auto stream = + reinterpret_cast(pg->GetDeviceContext(place)) + ->GetStream(); + int nranks = pg->GetSize(); + int rank = pg->GetRank(); + auto in_feat = x->dims()[1]; + auto n_expert = local_count->dims()[0] / nranks; + auto fwd_count = 0; + for (auto i = 0; i < local_count_len; ++i) { + fwd_count += cpu_local_count_data[i]; + } + framework::DDim out_dims = phi::make_ddim({fwd_count, in_feat}); + int64_t* expert_ptr = new int64_t[n_expert * nranks]; + expert_ptr[0] = 0; + auto tot_experts = n_expert * nranks; + for (auto i = 1; i < tot_experts; ++i) { + expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1]; + } + auto send_ptr = 0; + out->Resize(out_dims); + dev_ctx.template Alloc(out); + + for (auto i = 0; i < n_expert; ++i) { + for (auto j = 0; j < rank; ++j) { + int idx = i + j * n_expert; + if (cpu_local_count_data[idx]) { + pg->Recv(out, + j, + expert_ptr[idx] * in_feat, + cpu_local_count_data[idx] * in_feat, + /*sync_op*/ true); + } + } + for (auto j = 0; j < nranks; ++j) { + int idx = i + j * n_expert; + if (cpu_global_count_data[idx]) { + if (j != rank) { + phi::DenseTensor tmp = *x; + pg->Send(tmp, + j, + send_ptr * in_feat, + cpu_global_count_data[idx] * in_feat, + /*sync_op*/ true); + } else { + phi::DeviceManager::GetDeviceWithPlace(place)->MemoryCopyD2D( + reinterpret_cast(out->data() + + expert_ptr[idx] * in_feat), + reinterpret_cast(x->data() + + send_ptr * in_feat), + (cpu_global_count_data[idx] * in_feat) * + phi::SizeOf(x->dtype()), + stream.get()); + } + send_ptr += cpu_global_count_data[idx]; + } + } + for (auto j = rank + 1; j < nranks; ++j) { + int idx = i + j * n_expert; + if (cpu_local_count_data[idx]) { + pg->Recv(out, + j, + expert_ptr[idx] * in_feat, + cpu_local_count_data[idx] * in_feat, + /*sync_op*/ true); + } + } + } + } else { + auto comm = platform::XCCLCommContext::Instance(place.GetDeviceType()) + .Get(rid, place); + std::shared_ptr stream; + if (ctx.Attr("use_calc_stream")) { + stream = dev_ctx.GetStream(); + } else { + stream = comm->stream(); + } + int nranks = comm->nranks(); + int rank = comm->rank(); + auto in_feat = x->dims()[1]; + auto n_expert = local_count->dims()[0] / nranks; + + auto fwd_count = 0; + + for (auto i = 0; i < local_count_len; ++i) { + fwd_count += cpu_local_count_data[i]; + } + framework::DDim out_dims = phi::make_ddim({fwd_count, in_feat}); + int64_t* expert_ptr = new int64_t[n_expert * nranks]; + expert_ptr[0] = 0; + auto tot_experts = n_expert * nranks; + for (auto i = 1; i < tot_experts; ++i) { + expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1]; + } + auto send_ptr = 0; + auto send_buf = x->data(); + out->Resize(out_dims); + auto recv_buf = dev_ctx.template Alloc(out); + + for (auto i = 0; i < n_expert; ++i) { + for (auto j = 0; j < rank + 1; ++j) { + int idx = i + j * n_expert; + if (cpu_local_count_data[idx]) { + phi::DeviceManager::CCLRecv(place.GetDeviceType(), + recv_buf + expert_ptr[idx] * in_feat, + cpu_local_count_data[idx] * in_feat, + phi::ccl::ToCCLDataType(x->dtype()), + j, + comm->comm(), + *stream); + } + } + for (auto j = 0; j < nranks; ++j) { + int idx = i + j * n_expert; + if (cpu_global_count_data[idx]) { + if (j != rank) { + phi::DeviceManager::CCLSend( + place.GetDeviceType(), + const_cast(reinterpret_cast( + send_buf + send_ptr * in_feat)), + cpu_global_count_data[idx] * in_feat, + phi::ccl::ToCCLDataType(x->dtype()), + j, + comm->comm(), + *stream); + } else { + phi::DeviceManager::GetDeviceWithPlace(place)->MemoryCopyD2D( + reinterpret_cast(recv_buf + expert_ptr[idx] * in_feat), + reinterpret_cast(send_buf + send_ptr * in_feat), + (cpu_global_count_data[idx] * in_feat) * + phi::SizeOf(x->dtype()), + stream.get()); + } + send_ptr += cpu_global_count_data[idx]; + } + } + for (auto j = rank + 1; j < nranks; ++j) { + int idx = i + j * n_expert; + if (cpu_local_count_data[idx]) { + phi::DeviceManager::CCLRecv(place.GetDeviceType(), + recv_buf + expert_ptr[idx] * in_feat, + cpu_local_count_data[idx] * in_feat, + phi::ccl::ToCCLDataType(x->dtype()), + j, + comm->comm(), + *stream); + } + } + } + } + + phi::DeviceManager::SynchronizeDevice(ctx.GetPlace()); + } +}; + template void FeedDenseTensorKernel(const Context& dev_ctx, const phi::ExtendedTensor& x, @@ -918,6 +1506,48 @@ void RegisterCustomDeviceCommonKernel(const std::string& dev_type) { barrier, device_type, paddle::operators::BarrierOpCustomDeviceKernel) {} + REGISTER_OP_CUSTOM_DEVICE_KERNEL( + number_count, + device_type, + paddle::operators::NumberCountOpCustomDeviceKernel) {} + REGISTER_OP_CUSTOM_DEVICE_KERNEL( + limit_by_capacity, + device_type, + paddle::operators::LimitByCapacityOpCustomDeviceKernel) {} + REGISTER_OP_CUSTOM_DEVICE_KERNEL( + prune_gate_by_capacity, + device_type, + paddle::operators::PruneGateByCapacityCustomDeviceKernel) {} + REGISTER_OP_CUSTOM_DEVICE_KERNEL( + random_routing, + device_type, + paddle::operators::RandomRoutingOpCustomDeviceKernel, + paddle::operators::RandomRoutingOpCustomDeviceKernel, + paddle::operators::RandomRoutingOpCustomDeviceKernel< + paddle::platform::float16>) {} + REGISTER_OP_CUSTOM_DEVICE_KERNEL( + assign_pos, + device_type, + paddle::operators::AssignPosCustomDeviceKernel) {} + + REGISTER_OP_CUSTOM_DEVICE_KERNEL( + global_scatter, + device_type, + paddle::operators::GlobalScatterOpCustomDeviceKernel, + paddle::operators::GlobalScatterOpCustomDeviceKernel, + paddle::operators::GlobalScatterOpCustomDeviceKernel, + paddle::operators::GlobalScatterOpCustomDeviceKernel, + paddle::operators::GlobalScatterOpCustomDeviceKernel< + paddle::platform::float16>) {} + REGISTER_OP_CUSTOM_DEVICE_KERNEL( + global_gather, + device_type, + paddle::operators::GlobalGatherOpCustomDeviceKernel, + paddle::operators::GlobalGatherOpCustomDeviceKernel, + paddle::operators::GlobalGatherOpCustomDeviceKernel, + paddle::operators::GlobalGatherOpCustomDeviceKernel, + paddle::operators::GlobalGatherOpCustomDeviceKernel< + paddle::platform::float16>) {} #endif } diff --git a/python/paddle/incubate/distributed/models/moe/moe_layer.py b/python/paddle/incubate/distributed/models/moe/moe_layer.py index 73d64d1f5243b01c7cf422cf891c67780c9821e9..7c63c431bb95058b5ba494fe281cdbd1410383b9 100644 --- a/python/paddle/incubate/distributed/models/moe/moe_layer.py +++ b/python/paddle/incubate/distributed/models/moe/moe_layer.py @@ -19,6 +19,8 @@ # Copyright 2021, Jiaao He. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"). +import os + import numpy as np import paddle @@ -352,7 +354,10 @@ class MoELayer(nn.Layer): assert experts is not None self.experts = experts - if self.world_size > 1: + if ( + self.world_size > 1 + and os.getenv("PADDLE_DISTRI_BACKEND", None) != "xccl" + ): check_nccl_version_for_p2p() self.mp_group = mp_group diff --git a/python/paddle/nn/layer/layers.py b/python/paddle/nn/layer/layers.py index cbd96b97be2b6b26f34fe93ca387cafb8521d9ec..6155fdae7eee1c6f95d3e7c07f5bf971d5548f34 100644 --- a/python/paddle/nn/layer/layers.py +++ b/python/paddle/nn/layer/layers.py @@ -1913,6 +1913,13 @@ class Layer: p = core.Place() p.set_place(t._place()) place = core.XPUPlace(p.xpu_device_id()) + elif p.is_custom_place(): + p = core.Place() + p.set_place(t._place()) + place = core.CustomPlace( + paddle.device.get_device().split(':')[0], + p.custom_device_id(), + ) else: p = core.Place() p.set_place(t._place()) diff --git a/python/paddle/static/io.py b/python/paddle/static/io.py index b0bb6eabbe652a05b478beb4ddbd75b7a90692c0..eef8cacc0e60584895c9baa18835aa474bd2fd2c 100644 --- a/python/paddle/static/io.py +++ b/python/paddle/static/io.py @@ -1540,7 +1540,12 @@ def load(program, model_path, executor=None, var_list=None): p = paddle.fluid.core.Place() p.set_place(t._place()) place = paddle.fluid.XPUPlace(p.xpu_device_id()) - + elif p.is_custom_place(): + p = paddle.fluid.core.Place() + p.set_place(t._place()) + place = paddle.fluid.CustomPlace( + paddle.device.get_device().split(':')[0], p.custom_device_id() + ) else: p = paddle.fluid.core.Place() p.set_place(t._place())