未验证 提交 2360406d 编写于 作者: C chentianyu03 提交者: GitHub

[PTen]fix pten::Copy use error (#37982)

* fix pten::Copy use error in redcue_impl

* remove in_dtype args in reduce kernel

* fix copy error

* fix copy error
上级 62b1f38c
......@@ -769,6 +769,23 @@ static void LaunchReduceKernel(const Tx* x_data,
}
}
void TensorCopy(const DenseTensor& src, DenseTensor* dst) {
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
const paddle::platform::CUDADeviceContext* dev_ctx;
if (paddle::platform::is_gpu_place(dst->place()) ||
paddle::platform::is_npu_place(dst->place())) {
dev_ctx = static_cast<paddle::platform::CUDADeviceContext*>(
pool.Get(dst->place()));
} else {
dev_ctx = static_cast<paddle::platform::CUDADeviceContext*>(
pool.Get(src.place()));
}
pten::Copy(*dev_ctx, src, false, dst);
}
template <typename Tx,
typename Ty,
template <typename, typename> class ReduceOp>
......@@ -800,7 +817,7 @@ void TensorReduceFunctorImpl(const pten::DenseTensor& x,
if (config.reduce_num == 1) {
auto out_dims = y->dims();
if (x.dtype() == y->dtype()) {
pten::Copy(*dev_ctx, x, true, y);
TensorCopy(x, y);
y->Resize(out_dims);
} else {
PD_VISIT_ALL_TYPES(y->dtype(), "CastKernelImpl", ([&] {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册