未验证 提交 584ae4d7 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add MOE support, PART3 (#54676)

上级 ff806111
......@@ -279,7 +279,7 @@ class CEmbeddingGradOpCustomDeviceKernel : public framework::OpKernel<T> {
x_tensor, start_index + N, x_tensor.dtype(), x_tensor.place());
auto ids_mask_tensor = paddle::experimental::logical_and(
x_tensor.greater_equal(start_index_tensor),
x_tensor.less_equal(end_index_tensor));
x_tensor.less_than(end_index_tensor));
auto real_ids_tensor = (x_tensor - start_index_tensor)
.multiply(paddle::experimental::cast(
ids_mask_tensor, x_tensor.dtype()));
......@@ -668,6 +668,594 @@ class BarrierOpCustomDeviceKernel : public framework::OpKernel<T> {
}
};
template <typename T>
class NumberCountOpCustomDeviceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto numbers = context.Input<phi::DenseTensor>("numbers");
auto upper_range = context.Attr<int>("upper_range");
auto number_count = context.Output<phi::DenseTensor>("Out");
const auto& dev_ctx = context.template device_context<phi::CustomContext>();
number_count->Resize({upper_range});
dev_ctx.template Alloc<T>(number_count);
phi::DenseTensor cpu_tensor;
framework::TensorCopySync(*numbers, platform::CPUPlace(), &cpu_tensor);
std::vector<T> count(upper_range);
for (auto i = 0; i < cpu_tensor.numel(); ++i) {
auto idx = static_cast<int64_t>(cpu_tensor.data<T>()[i]);
if (idx >= 0 && idx < upper_range) {
count[idx] += 1;
}
}
framework::TensorFromVector<T>(count, dev_ctx, number_count);
number_count->Resize({upper_range});
}
};
template <typename T>
class LimitByCapacityOpCustomDeviceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto expert_count = context.Input<phi::DenseTensor>("expert_count");
auto capacity = context.Input<phi::DenseTensor>("capacity");
auto out = context.Output<phi::DenseTensor>("Out");
auto n_worker = context.Attr<int>("n_worker");
auto n_expert = expert_count->numel() / n_worker;
const auto& dev_ctx = context.template device_context<phi::CustomContext>();
dev_ctx.template Alloc<T>(out);
std::vector<T> out_data(out->numel());
phi::DenseTensor expert_count_cpu, capacity_cpu;
framework::TensorCopySync(
*expert_count, platform::CPUPlace(), &expert_count_cpu);
framework::TensorCopySync(*capacity, platform::CPUPlace(), &capacity_cpu);
auto* ec_data = expert_count_cpu.data<T>();
auto* capacity_data = capacity_cpu.data<T>();
int eid, wid;
for (int64_t i = 0; i < expert_count->numel(); ++i) {
wid = i / n_expert;
eid = i % n_expert;
auto proposal = ec_data[i];
auto cap_left = capacity_data[eid];
capacity_data[eid] -= proposal;
if (cap_left >= proposal) {
out_data[wid * n_expert + eid] = proposal;
} else if (cap_left >= 0) {
out_data[wid * n_expert + eid] = cap_left;
} else {
out_data[wid * n_expert + eid] = 0;
}
}
auto out_dims = out->dims();
framework::TensorFromVector<T>(out_data, dev_ctx, out);
out->Resize(out_dims);
}
};
template <typename T>
class PruneGateByCapacityCustomDeviceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* gate_idx = context.Input<phi::DenseTensor>("GateIdx");
auto* expert_count = context.Input<phi::DenseTensor>("ExpertCount");
auto* new_gate_idx = context.Output<phi::DenseTensor>("NewGateIdx");
const auto& dev_ctx = context.template device_context<phi::CustomContext>();
dev_ctx.template Alloc<T>(new_gate_idx);
phi::DenseTensor expert_count_cpu, gate_idx_cpu;
framework::TensorCopySync(
*expert_count, platform::CPUPlace(), &expert_count_cpu);
framework::TensorCopySync(*gate_idx, platform::CPUPlace(), &gate_idx_cpu);
auto expert_count_data = expert_count_cpu.data<T>();
auto gate_idx_data = gate_idx_cpu.data<T>();
std::vector<T> new_gate_idx_data(gate_idx->numel());
for (auto i = 0; i < gate_idx->numel(); ++i) {
auto orig_cap = expert_count_data[gate_idx_data[i]]--;
if (orig_cap <= 0) {
new_gate_idx_data[i] = -1;
} else {
new_gate_idx_data[i] = gate_idx_data[i];
}
}
}
};
template <typename T>
class RandomRoutingOpCustomDeviceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto topk_idx = context.Input<phi::DenseTensor>("TopK_Idx");
auto topk_value = context.Input<phi::DenseTensor>("TopK_Value");
auto prob = context.Input<phi::DenseTensor>("Prob");
auto out = context.Output<phi::DenseTensor>("Out");
const auto& dev_ctx = context.template device_context<phi::CustomContext>();
size_t D = topk_idx->dims()[1];
phi::DenseTensor topk_value_cpu, prob_cpu;
framework::TensorCopySync(
*topk_value, platform::CPUPlace(), &topk_value_cpu);
framework::TensorCopySync(*prob, platform::CPUPlace(), &prob_cpu);
auto* topk_value_data = topk_value_cpu.data<T>();
auto* prob_data = prob_cpu.data<T>();
std::vector<int64_t> out_data(topk_idx->numel());
for (int64_t idx = 0; idx < topk_idx->numel(); ++idx) {
size_t row = idx / D;
size_t col = idx % D;
if (col == 1 &&
static_cast<T>(2) * topk_value_data[idx] < prob_data[row]) {
out_data[idx] = static_cast<int64_t>(-1);
}
}
auto out_dims = out->dims();
framework::TensorFromVector<int64_t>(out_data, dev_ctx, out);
out->Resize(out_dims);
}
};
template <typename T>
class AssignPosCustomDeviceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
// assign pos decides which tokens should be fetched belong to specially
// counter orderingly.
auto cum_count = context.Input<phi::DenseTensor>(
"cum_count"); // (counter number) int32 | int64
auto numbers = context.Input<phi::DenseTensor>(
"X"); // (batch_size * seq_len, topk) int32
auto eff_num_len =
context.Input<phi::DenseTensor>("eff_num_len"); // (sum(cum_count))
auto out =
context.Output<phi::DenseTensor>("Out"); // (cum_count) value ranges
// from 0 to batch_size *
// seq_len * topk
const auto& dev_ctx = context.template device_context<phi::CustomContext>();
phi::DenseTensor cpu_eff_num_len;
int64_t cpu_eff_num_len_data = 0;
if (platform::is_cpu_place(eff_num_len->place())) {
cpu_eff_num_len_data = eff_num_len->data<T>()[0];
} else {
framework::TensorCopySync(
*eff_num_len, platform::CPUPlace(), &cpu_eff_num_len);
cpu_eff_num_len_data = cpu_eff_num_len.data<T>()[0];
}
out->Resize({cpu_eff_num_len_data});
dev_ctx.template Alloc<T>(out);
phi::DenseTensor numbers_cpu, cum_count_cpu;
framework::TensorCopySync(*numbers, platform::CPUPlace(), &numbers_cpu);
framework::TensorCopySync(*cum_count, platform::CPUPlace(), &cum_count_cpu);
auto* numbers_data = numbers_cpu.data<T>();
auto* cum_count_data = cum_count_cpu.data<T>();
std::vector<T> out_data(cpu_eff_num_len_data);
for (int64_t i = 0; i < numbers->numel(); ++i) {
int number_idx = numbers_data[i];
if (number_idx > -1) {
cum_count_data[number_idx] -= 1;
int p = cum_count_data[number_idx];
out_data[p] = i;
}
}
framework::TensorFromVector<int64_t>(out_data, dev_ctx, out);
}
};
template <typename T>
class GlobalScatterOpCustomDeviceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto x = ctx.Input<phi::DenseTensor>("X");
auto local_count = ctx.Input<phi::DenseTensor>("local_count");
auto global_count = ctx.Input<phi::DenseTensor>("global_count");
auto out = ctx.Output<phi::DenseTensor>("Out");
const int rid = ctx.Attr<int>("ring_id");
const auto& dev_ctx = ctx.template device_context<phi::CustomContext>();
auto place = ctx.GetPlace();
PADDLE_ENFORCE_EQ(local_count->dtype(),
phi::DataType::INT64,
platform::errors::InvalidArgument(
"Please use int64 type in local_count."));
PADDLE_ENFORCE_EQ(global_count->dtype(),
phi::DataType::INT64,
platform::errors::InvalidArgument(
"Please use int64 type in global_count."));
auto map = distributed::ProcessGroupMapFromGid::getInstance();
const int64_t* cpu_local_count_data;
const int64_t* cpu_global_count_data;
phi::DenseTensor cpu_local_count;
if (platform::is_cpu_place(local_count->place())) {
cpu_local_count_data = local_count->data<int64_t>();
} else {
framework::TensorCopySync(
*local_count, platform::CPUPlace(), &cpu_local_count);
cpu_local_count_data = cpu_local_count.data<int64_t>();
}
auto global_count_len = 0;
phi::DenseTensor cpu_global_count;
if (platform::is_cpu_place(global_count->place())) {
cpu_global_count_data = global_count->data<int64_t>();
global_count_len = global_count->numel();
} else {
framework::TensorCopySync(
*global_count, platform::CPUPlace(), &cpu_global_count);
cpu_global_count_data = cpu_global_count.data<int64_t>();
global_count_len = cpu_global_count.numel();
}
if (map->has(rid)) {
distributed::ProcessGroup* pg = map->get(rid);
auto stream =
reinterpret_cast<phi::CustomContext*>(pg->GetDeviceContext(place))
->GetStream();
int nranks = pg->GetSize();
int rank = pg->GetRank();
auto in_feat = x->dims()[1];
auto n_expert = local_count->dims()[0] / nranks;
int64_t fwd_count = 0;
for (auto i = 0; i < global_count_len; ++i) {
fwd_count += cpu_global_count_data[i];
}
framework::DDim out_dims = phi::make_ddim({fwd_count, in_feat});
int64_t* expert_ptr = new int64_t[n_expert * nranks];
expert_ptr[0] = 0;
auto tot_experts = n_expert * nranks;
for (auto i = 1; i < tot_experts; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1];
}
auto recv_ptr = 0;
out->Resize(out_dims);
dev_ctx.template Alloc<T>(out);
for (auto i = 0; i < n_expert; ++i) {
for (auto j = 0; j < rank; ++j) {
int idx = i + j * n_expert;
if (cpu_global_count_data[idx]) {
pg->Recv(out,
j,
recv_ptr * in_feat,
cpu_global_count_data[idx] * in_feat,
/*sync_op*/ true);
recv_ptr += cpu_global_count_data[idx];
}
}
for (auto j = 0; j < nranks; ++j) {
if (j != rank) {
int idx = i + j * n_expert;
if (cpu_local_count_data[idx]) {
phi::DenseTensor tmp = *x;
pg->Send(tmp,
j,
expert_ptr[idx] * in_feat,
cpu_local_count_data[idx] * in_feat,
/*sync_op*/ true);
}
}
}
if (cpu_local_count_data[i + rank * n_expert]) {
phi::DeviceManager::GetDeviceWithPlace(place)->MemoryCopyD2D(
reinterpret_cast<void*>(out->data<T>() + recv_ptr * in_feat),
reinterpret_cast<const void*>(x->data<T>() +
expert_ptr[rank] * in_feat),
(cpu_local_count_data[rank] * in_feat) * phi::SizeOf(x->dtype()),
stream.get());
recv_ptr += cpu_global_count_data[rank];
}
for (auto j = rank + 1; j < nranks; ++j) {
int idx = i + j * n_expert;
if (cpu_global_count_data[idx]) {
pg->Recv(out,
j,
recv_ptr * in_feat,
cpu_global_count_data[idx] * in_feat,
/*sync_op*/ true);
recv_ptr += cpu_global_count_data[idx];
}
}
}
} else {
auto comm = platform::XCCLCommContext::Instance(place.GetDeviceType())
.Get(rid, place);
std::shared_ptr<phi::stream::Stream> stream;
if (ctx.Attr<bool>("use_calc_stream")) {
stream = dev_ctx.GetStream();
} else {
stream = comm->stream();
}
int nranks = comm->nranks();
int rank = comm->rank();
auto in_feat = x->dims()[1];
auto n_expert = local_count->dims()[0] / nranks;
int64_t fwd_count = 0;
for (auto i = 0; i < global_count_len; ++i) {
fwd_count += cpu_global_count_data[i];
}
framework::DDim out_dims = phi::make_ddim({fwd_count, in_feat});
int64_t* expert_ptr = new int64_t[n_expert * nranks];
expert_ptr[0] = 0;
auto tot_experts = n_expert * nranks;
for (auto i = 1; i < tot_experts; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1];
}
auto recv_ptr = 0;
auto send_buf = x->data<T>();
out->Resize(out_dims);
auto recv_buf = dev_ctx.template Alloc<T>(out);
for (auto i = 0; i < n_expert; ++i) {
for (auto j = 0; j < rank; ++j) {
int idx = i + j * n_expert;
if (cpu_global_count_data[idx]) {
phi::DeviceManager::CCLRecv(
place.GetDeviceType(),
reinterpret_cast<void*>(recv_buf + recv_ptr * in_feat),
cpu_global_count_data[idx] * in_feat,
phi::ccl::ToCCLDataType(x->dtype()),
j,
comm->comm(),
*stream);
recv_ptr += cpu_global_count_data[idx];
}
}
for (auto j = 0; j < nranks; ++j) {
if (j != rank) {
int idx = i + j * n_expert;
if (cpu_local_count_data[idx]) {
phi::DeviceManager::CCLSend(
place.GetDeviceType(),
const_cast<void*>(reinterpret_cast<const void*>(
send_buf + expert_ptr[idx] * in_feat)),
cpu_local_count_data[idx] * in_feat,
phi::ccl::ToCCLDataType(x->dtype()),
j,
comm->comm(),
*stream);
}
}
}
if (cpu_local_count_data[i + rank * n_expert]) {
phi::DeviceManager::GetDeviceWithPlace(place)->MemoryCopyD2D(
reinterpret_cast<void*>(recv_buf + recv_ptr * in_feat),
reinterpret_cast<const void*>(send_buf +
expert_ptr[rank] * in_feat),
(cpu_local_count_data[rank] * in_feat) * phi::SizeOf(x->dtype()),
stream.get());
recv_ptr += cpu_global_count_data[rank];
}
for (auto j = rank + 1; j < nranks; ++j) {
int idx = i + j * n_expert;
if (cpu_global_count_data[idx]) {
phi::DeviceManager::CCLRecv(
place.GetDeviceType(),
reinterpret_cast<void*>(recv_buf + recv_ptr * in_feat),
cpu_global_count_data[idx] * in_feat,
phi::ccl::ToCCLDataType(x->dtype()),
j,
comm->comm(),
*stream);
recv_ptr += cpu_global_count_data[idx];
}
}
}
}
phi::DeviceManager::SynchronizeDevice(ctx.GetPlace());
}
};
template <typename T>
class GlobalGatherOpCustomDeviceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto x = ctx.Input<phi::DenseTensor>("X");
auto local_count = ctx.Input<phi::DenseTensor>("local_count");
auto global_count = ctx.Input<phi::DenseTensor>("global_count");
const int rid = ctx.Attr<int>("ring_id");
const auto& dev_ctx = ctx.template device_context<phi::CustomContext>();
auto place = ctx.GetPlace();
auto out = ctx.Output<phi::DenseTensor>("Out");
PADDLE_ENFORCE_EQ(local_count->dtype(),
phi::DataType::INT64,
platform::errors::InvalidArgument(
"Please use int64 type in local_count."));
PADDLE_ENFORCE_EQ(global_count->dtype(),
phi::DataType::INT64,
platform::errors::InvalidArgument(
"Please use int64 type in global_count."));
const int64_t* cpu_local_count_data;
const int64_t* cpu_global_count_data;
auto local_count_len = 0;
phi::DenseTensor cpu_local_count;
if (platform::is_cpu_place(local_count->place())) {
cpu_local_count_data = local_count->data<int64_t>();
local_count_len = local_count->numel();
} else {
framework::TensorCopySync(
*local_count, platform::CPUPlace(), &cpu_local_count);
cpu_local_count_data = cpu_local_count.data<int64_t>();
local_count_len = cpu_local_count.numel();
}
phi::DenseTensor cpu_global_count;
if (platform::is_cpu_place(global_count->place())) {
cpu_global_count_data = global_count->data<int64_t>();
} else {
framework::TensorCopySync(
*global_count, platform::CPUPlace(), &cpu_global_count);
cpu_global_count_data = cpu_global_count.data<int64_t>();
}
auto map = distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(rid)) {
distributed::ProcessGroup* pg = map->get(rid);
auto stream =
reinterpret_cast<phi::CustomContext*>(pg->GetDeviceContext(place))
->GetStream();
int nranks = pg->GetSize();
int rank = pg->GetRank();
auto in_feat = x->dims()[1];
auto n_expert = local_count->dims()[0] / nranks;
auto fwd_count = 0;
for (auto i = 0; i < local_count_len; ++i) {
fwd_count += cpu_local_count_data[i];
}
framework::DDim out_dims = phi::make_ddim({fwd_count, in_feat});
int64_t* expert_ptr = new int64_t[n_expert * nranks];
expert_ptr[0] = 0;
auto tot_experts = n_expert * nranks;
for (auto i = 1; i < tot_experts; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1];
}
auto send_ptr = 0;
out->Resize(out_dims);
dev_ctx.template Alloc<T>(out);
for (auto i = 0; i < n_expert; ++i) {
for (auto j = 0; j < rank; ++j) {
int idx = i + j * n_expert;
if (cpu_local_count_data[idx]) {
pg->Recv(out,
j,
expert_ptr[idx] * in_feat,
cpu_local_count_data[idx] * in_feat,
/*sync_op*/ true);
}
}
for (auto j = 0; j < nranks; ++j) {
int idx = i + j * n_expert;
if (cpu_global_count_data[idx]) {
if (j != rank) {
phi::DenseTensor tmp = *x;
pg->Send(tmp,
j,
send_ptr * in_feat,
cpu_global_count_data[idx] * in_feat,
/*sync_op*/ true);
} else {
phi::DeviceManager::GetDeviceWithPlace(place)->MemoryCopyD2D(
reinterpret_cast<void*>(out->data<T>() +
expert_ptr[idx] * in_feat),
reinterpret_cast<const void*>(x->data<T>() +
send_ptr * in_feat),
(cpu_global_count_data[idx] * in_feat) *
phi::SizeOf(x->dtype()),
stream.get());
}
send_ptr += cpu_global_count_data[idx];
}
}
for (auto j = rank + 1; j < nranks; ++j) {
int idx = i + j * n_expert;
if (cpu_local_count_data[idx]) {
pg->Recv(out,
j,
expert_ptr[idx] * in_feat,
cpu_local_count_data[idx] * in_feat,
/*sync_op*/ true);
}
}
}
} else {
auto comm = platform::XCCLCommContext::Instance(place.GetDeviceType())
.Get(rid, place);
std::shared_ptr<phi::stream::Stream> stream;
if (ctx.Attr<bool>("use_calc_stream")) {
stream = dev_ctx.GetStream();
} else {
stream = comm->stream();
}
int nranks = comm->nranks();
int rank = comm->rank();
auto in_feat = x->dims()[1];
auto n_expert = local_count->dims()[0] / nranks;
auto fwd_count = 0;
for (auto i = 0; i < local_count_len; ++i) {
fwd_count += cpu_local_count_data[i];
}
framework::DDim out_dims = phi::make_ddim({fwd_count, in_feat});
int64_t* expert_ptr = new int64_t[n_expert * nranks];
expert_ptr[0] = 0;
auto tot_experts = n_expert * nranks;
for (auto i = 1; i < tot_experts; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1];
}
auto send_ptr = 0;
auto send_buf = x->data<T>();
out->Resize(out_dims);
auto recv_buf = dev_ctx.template Alloc<T>(out);
for (auto i = 0; i < n_expert; ++i) {
for (auto j = 0; j < rank + 1; ++j) {
int idx = i + j * n_expert;
if (cpu_local_count_data[idx]) {
phi::DeviceManager::CCLRecv(place.GetDeviceType(),
recv_buf + expert_ptr[idx] * in_feat,
cpu_local_count_data[idx] * in_feat,
phi::ccl::ToCCLDataType(x->dtype()),
j,
comm->comm(),
*stream);
}
}
for (auto j = 0; j < nranks; ++j) {
int idx = i + j * n_expert;
if (cpu_global_count_data[idx]) {
if (j != rank) {
phi::DeviceManager::CCLSend(
place.GetDeviceType(),
const_cast<void*>(reinterpret_cast<const void*>(
send_buf + send_ptr * in_feat)),
cpu_global_count_data[idx] * in_feat,
phi::ccl::ToCCLDataType(x->dtype()),
j,
comm->comm(),
*stream);
} else {
phi::DeviceManager::GetDeviceWithPlace(place)->MemoryCopyD2D(
reinterpret_cast<void*>(recv_buf + expert_ptr[idx] * in_feat),
reinterpret_cast<const void*>(send_buf + send_ptr * in_feat),
(cpu_global_count_data[idx] * in_feat) *
phi::SizeOf(x->dtype()),
stream.get());
}
send_ptr += cpu_global_count_data[idx];
}
}
for (auto j = rank + 1; j < nranks; ++j) {
int idx = i + j * n_expert;
if (cpu_local_count_data[idx]) {
phi::DeviceManager::CCLRecv(place.GetDeviceType(),
recv_buf + expert_ptr[idx] * in_feat,
cpu_local_count_data[idx] * in_feat,
phi::ccl::ToCCLDataType(x->dtype()),
j,
comm->comm(),
*stream);
}
}
}
}
phi::DeviceManager::SynchronizeDevice(ctx.GetPlace());
}
};
template <typename Context>
void FeedDenseTensorKernel(const Context& dev_ctx,
const phi::ExtendedTensor& x,
......@@ -918,6 +1506,48 @@ void RegisterCustomDeviceCommonKernel(const std::string& dev_type) {
barrier,
device_type,
paddle::operators::BarrierOpCustomDeviceKernel<int>) {}
REGISTER_OP_CUSTOM_DEVICE_KERNEL(
number_count,
device_type,
paddle::operators::NumberCountOpCustomDeviceKernel<int64_t>) {}
REGISTER_OP_CUSTOM_DEVICE_KERNEL(
limit_by_capacity,
device_type,
paddle::operators::LimitByCapacityOpCustomDeviceKernel<int64_t>) {}
REGISTER_OP_CUSTOM_DEVICE_KERNEL(
prune_gate_by_capacity,
device_type,
paddle::operators::PruneGateByCapacityCustomDeviceKernel<int64_t>) {}
REGISTER_OP_CUSTOM_DEVICE_KERNEL(
random_routing,
device_type,
paddle::operators::RandomRoutingOpCustomDeviceKernel<float>,
paddle::operators::RandomRoutingOpCustomDeviceKernel<double>,
paddle::operators::RandomRoutingOpCustomDeviceKernel<
paddle::platform::float16>) {}
REGISTER_OP_CUSTOM_DEVICE_KERNEL(
assign_pos,
device_type,
paddle::operators::AssignPosCustomDeviceKernel<int64_t>) {}
REGISTER_OP_CUSTOM_DEVICE_KERNEL(
global_scatter,
device_type,
paddle::operators::GlobalScatterOpCustomDeviceKernel<float>,
paddle::operators::GlobalScatterOpCustomDeviceKernel<double>,
paddle::operators::GlobalScatterOpCustomDeviceKernel<int32_t>,
paddle::operators::GlobalScatterOpCustomDeviceKernel<int64_t>,
paddle::operators::GlobalScatterOpCustomDeviceKernel<
paddle::platform::float16>) {}
REGISTER_OP_CUSTOM_DEVICE_KERNEL(
global_gather,
device_type,
paddle::operators::GlobalGatherOpCustomDeviceKernel<float>,
paddle::operators::GlobalGatherOpCustomDeviceKernel<double>,
paddle::operators::GlobalGatherOpCustomDeviceKernel<int32_t>,
paddle::operators::GlobalGatherOpCustomDeviceKernel<int64_t>,
paddle::operators::GlobalGatherOpCustomDeviceKernel<
paddle::platform::float16>) {}
#endif
}
......
......@@ -19,6 +19,8 @@
# Copyright 2021, Jiaao He. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License").
import os
import numpy as np
import paddle
......@@ -352,7 +354,10 @@ class MoELayer(nn.Layer):
assert experts is not None
self.experts = experts
if self.world_size > 1:
if (
self.world_size > 1
and os.getenv("PADDLE_DISTRI_BACKEND", None) != "xccl"
):
check_nccl_version_for_p2p()
self.mp_group = mp_group
......
......@@ -1913,6 +1913,13 @@ class Layer:
p = core.Place()
p.set_place(t._place())
place = core.XPUPlace(p.xpu_device_id())
elif p.is_custom_place():
p = core.Place()
p.set_place(t._place())
place = core.CustomPlace(
paddle.device.get_device().split(':')[0],
p.custom_device_id(),
)
else:
p = core.Place()
p.set_place(t._place())
......
......@@ -1540,7 +1540,12 @@ def load(program, model_path, executor=None, var_list=None):
p = paddle.fluid.core.Place()
p.set_place(t._place())
place = paddle.fluid.XPUPlace(p.xpu_device_id())
elif p.is_custom_place():
p = paddle.fluid.core.Place()
p.set_place(t._place())
place = paddle.fluid.CustomPlace(
paddle.device.get_device().split(':')[0], p.custom_device_id()
)
else:
p = paddle.fluid.core.Place()
p.set_place(t._place())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册