未验证 提交 cd450c0a 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add data transform support (#56627)

上级 3945439b
...@@ -157,6 +157,26 @@ inline phi::DenseTensor TransDataType(const phi::DenseTensor& tensor, ...@@ -157,6 +157,26 @@ inline phi::DenseTensor TransDataType(const phi::DenseTensor& tensor,
} else if (tensor.place().GetType() == phi::AllocationType::GPU) { } else if (tensor.place().GetType() == phi::AllocationType::GPU) {
auto* dev_ctx = static_cast<phi::GPUContext*>(pool.Get(tensor.place())); auto* dev_ctx = static_cast<phi::GPUContext*>(pool.Get(tensor.place()));
return CastDataType(*dev_ctx, tensor, dtype); return CastDataType(*dev_ctx, tensor, dtype);
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
} else if (tensor.place().GetType() == phi::AllocationType::CUSTOM) {
phi::DenseTensor out;
out.Resize(tensor.dims());
auto* dev_ctx = static_cast<phi::CustomContext*>(pool.Get(tensor.place()));
auto kernel_result =
phi::KernelFactory::Instance().SelectKernelOrThrowError(
"cast",
{phi::TransToPhiBackend(tensor.place()),
phi::DataLayout::ALL_LAYOUT,
tensor.dtype()});
using kernel_signature = void (*)(const phi::DeviceContext&,
const phi::DenseTensor&,
phi::DataType,
phi::DenseTensor*);
const auto& kernel = kernel_result.kernel;
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, tensor, dtype, &out);
return out;
#endif #endif
} else { } else {
PADDLE_THROW(phi::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册