未验证 提交 d65209b6 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] optimize SplitDenseTensor by calling split_with_num kernel (#55330)

上级 7a705727
...@@ -174,6 +174,15 @@ void ProcessGroupCustom::CreateCustomManagerCache( ...@@ -174,6 +174,15 @@ void ProcessGroupCustom::CreateCustomManagerCache(
ccl_comms[i] = CustomCCLCommManager::Create( ccl_comms[i] = CustomCCLCommManager::Create(
device_type, GetSize(), GetRank(), &ccl_id, new phi::ccl::CCLComm); device_type, GetSize(), GetRank(), &ccl_id, new phi::ccl::CCLComm);
dev_ctx[i].reset(new CustomDeviceContext(places[i])); dev_ctx[i].reset(new CustomDeviceContext(places[i]));
dev_ctx[i]->SetAllocator(
&(phi::DeviceContextPool::Instance().Get(places[i])->GetAllocator()));
dev_ctx[i]->SetHostAllocator(&(
phi::DeviceContextPool::Instance().Get(places[i])->GetHostAllocator()));
dev_ctx[i]->SetZeroAllocator(&(
phi::DeviceContextPool::Instance().Get(places[i])->GetZeroAllocator()));
dev_ctx[i]->SetHostZeroAllocator(&(phi::DeviceContextPool::Instance()
.Get(places[i])
->GetHostZeroAllocator()));
} }
std::vector<CustomEventManager> events; std::vector<CustomEventManager> events;
......
...@@ -57,21 +57,22 @@ struct ConcatDenseTensor<platform::CustomDeviceContext, T> { ...@@ -57,21 +57,22 @@ struct ConcatDenseTensor<platform::CustomDeviceContext, T> {
const std::vector<phi::DenseTensor> &in, const std::vector<phi::DenseTensor> &in,
phi::DenseTensor *out, phi::DenseTensor *out,
int axis UNUSED = 0) { int axis UNUSED = 0) {
auto *out_data = out->data<T>(); VLOG(10) << "ConcatDenseTensor: " << in.size();
auto *device = phi::DeviceManager::GetDeviceWithPlace(context.GetPlace()); auto kernel_result =
size_t offset = 0; phi::KernelFactory::Instance().SelectKernelOrThrowError(
phi::stream::Stream stream_wrapper(context.GetPlace(), context.stream()); "concat",
phi::KernelKey(phi::TransToPhiBackend(context.GetPlace()),
for (const auto &tensor : in) { phi::DataLayout::ALL_LAYOUT,
const auto *in_data = tensor.data<T>(); phi::CppTypeToDataType<T>::Type()));
if (out_data + offset != in_data) { const auto &kernel = kernel_result.kernel;
device->MemoryCopyD2D(out_data + offset, using kernel_signature =
in_data, void (*)(const phi::DeviceContext &,
tensor.numel() * sizeof(T), const std::vector<const phi::DenseTensor *> &,
&stream_wrapper); const phi::Scalar &,
} phi::DenseTensor *);
offset += tensor.numel(); auto *kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
} std::vector<const phi::DenseTensor *> inputs;
(*kernel_fn)(context, inputs, phi::Scalar(0), out);
} }
}; };
...@@ -81,20 +82,27 @@ struct SplitDenseTensor<platform::CustomDeviceContext, T> { ...@@ -81,20 +82,27 @@ struct SplitDenseTensor<platform::CustomDeviceContext, T> {
const phi::DenseTensor &in, const phi::DenseTensor &in,
std::vector<phi::DenseTensor *> *out, std::vector<phi::DenseTensor *> *out,
int axis UNUSED = 0) { int axis UNUSED = 0) {
auto *in_data = in.data<T>(); VLOG(10) << "SplitDenseTensor: " << out->size();
auto *device = phi::DeviceManager::GetDeviceWithPlace(context.GetPlace()); auto kernel_result =
size_t offset = 0; phi::KernelFactory::Instance().SelectKernelOrThrowError(
phi::stream::Stream stream_wrapper(context.GetPlace(), context.stream()); "split_with_num",
phi::KernelKey(phi::TransToPhiBackend(context.GetPlace()),
for (auto *p_tensor : *out) { phi::DataLayout::ALL_LAYOUT,
auto *out_data = p_tensor->data<T>(); phi::CppTypeToDataType<T>::Type()));
if (out_data != in_data + offset) { const auto &kernel = kernel_result.kernel;
device->MemoryCopyD2D(out_data, using kernel_signature = void (*)(const phi::DeviceContext &,
in_data + offset, const phi::DenseTensor &,
p_tensor->numel() * sizeof(T), int,
&stream_wrapper); const phi::Scalar &,
std::vector<phi::DenseTensor *>);
auto *kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(context, in, out->size(), phi::Scalar(0), *out);
for (auto *tensor : *out) {
auto dim_vec = phi::vectorize<int>(tensor->dims());
if (dim_vec.size() > 0 && dim_vec[0] == 1) {
tensor->Resize(phi::make_ddim(
std::vector<int>(dim_vec.begin() + 1, dim_vec.end())));
} }
offset += p_tensor->numel();
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册