未验证 提交 d65209b6 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] optimize SplitDenseTensor by calling split_with_num kernel (#55330)

上级 7a705727
......@@ -174,6 +174,15 @@ void ProcessGroupCustom::CreateCustomManagerCache(
ccl_comms[i] = CustomCCLCommManager::Create(
device_type, GetSize(), GetRank(), &ccl_id, new phi::ccl::CCLComm);
dev_ctx[i].reset(new CustomDeviceContext(places[i]));
dev_ctx[i]->SetAllocator(
&(phi::DeviceContextPool::Instance().Get(places[i])->GetAllocator()));
dev_ctx[i]->SetHostAllocator(&(
phi::DeviceContextPool::Instance().Get(places[i])->GetHostAllocator()));
dev_ctx[i]->SetZeroAllocator(&(
phi::DeviceContextPool::Instance().Get(places[i])->GetZeroAllocator()));
dev_ctx[i]->SetHostZeroAllocator(&(phi::DeviceContextPool::Instance()
.Get(places[i])
->GetHostZeroAllocator()));
}
std::vector<CustomEventManager> events;
......
......@@ -57,21 +57,22 @@ struct ConcatDenseTensor<platform::CustomDeviceContext, T> {
const std::vector<phi::DenseTensor> &in,
phi::DenseTensor *out,
int axis UNUSED = 0) {
auto *out_data = out->data<T>();
auto *device = phi::DeviceManager::GetDeviceWithPlace(context.GetPlace());
size_t offset = 0;
phi::stream::Stream stream_wrapper(context.GetPlace(), context.stream());
for (const auto &tensor : in) {
const auto *in_data = tensor.data<T>();
if (out_data + offset != in_data) {
device->MemoryCopyD2D(out_data + offset,
in_data,
tensor.numel() * sizeof(T),
&stream_wrapper);
}
offset += tensor.numel();
}
VLOG(10) << "ConcatDenseTensor: " << in.size();
auto kernel_result =
phi::KernelFactory::Instance().SelectKernelOrThrowError(
"concat",
phi::KernelKey(phi::TransToPhiBackend(context.GetPlace()),
phi::DataLayout::ALL_LAYOUT,
phi::CppTypeToDataType<T>::Type()));
const auto &kernel = kernel_result.kernel;
using kernel_signature =
void (*)(const phi::DeviceContext &,
const std::vector<const phi::DenseTensor *> &,
const phi::Scalar &,
phi::DenseTensor *);
auto *kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
std::vector<const phi::DenseTensor *> inputs;
(*kernel_fn)(context, inputs, phi::Scalar(0), out);
}
};
......@@ -81,20 +82,27 @@ struct SplitDenseTensor<platform::CustomDeviceContext, T> {
const phi::DenseTensor &in,
std::vector<phi::DenseTensor *> *out,
int axis UNUSED = 0) {
auto *in_data = in.data<T>();
auto *device = phi::DeviceManager::GetDeviceWithPlace(context.GetPlace());
size_t offset = 0;
phi::stream::Stream stream_wrapper(context.GetPlace(), context.stream());
for (auto *p_tensor : *out) {
auto *out_data = p_tensor->data<T>();
if (out_data != in_data + offset) {
device->MemoryCopyD2D(out_data,
in_data + offset,
p_tensor->numel() * sizeof(T),
&stream_wrapper);
VLOG(10) << "SplitDenseTensor: " << out->size();
auto kernel_result =
phi::KernelFactory::Instance().SelectKernelOrThrowError(
"split_with_num",
phi::KernelKey(phi::TransToPhiBackend(context.GetPlace()),
phi::DataLayout::ALL_LAYOUT,
phi::CppTypeToDataType<T>::Type()));
const auto &kernel = kernel_result.kernel;
using kernel_signature = void (*)(const phi::DeviceContext &,
const phi::DenseTensor &,
int,
const phi::Scalar &,
std::vector<phi::DenseTensor *>);
auto *kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(context, in, out->size(), phi::Scalar(0), *out);
for (auto *tensor : *out) {
auto dim_vec = phi::vectorize<int>(tensor->dims());
if (dim_vec.size() > 0 && dim_vec[0] == 1) {
tensor->Resize(phi::make_ddim(
std::vector<int>(dim_vec.begin() + 1, dim_vec.end())));
}
offset += p_tensor->numel();
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册