未验证 提交 114a5d21 编写于 作者: Z zyfncg 提交者: GitHub

Make data transform inplaced when tensor is on GPUPinned (#43055)

* make data transform inplace when tensor is on gpupinned in new dygraph

* fix unittest
上级 4fd334f5
...@@ -174,20 +174,6 @@ inline phi::DenseTensor TransDataPlace(const phi::DenseTensor& tensor, ...@@ -174,20 +174,6 @@ inline phi::DenseTensor TransDataPlace(const phi::DenseTensor& tensor,
if (!platform::is_cuda_pinned_place(tensor.place())) { if (!platform::is_cuda_pinned_place(tensor.place())) {
pool.Get(tensor.place())->Wait(); pool.Get(tensor.place())->Wait();
pool.Get(dst_place)->Wait(); pool.Get(dst_place)->Wait();
} else if (platform::is_gpu_place(dst_place)) {
auto* dev_ctx = static_cast<phi::GPUContext*>(pool.Get(dst_place));
phi::Copy(*dev_ctx, tensor, dst_place, false, &out);
// Note: This is an empty callback, the only way is to "reference"
// tensor, so it will not be destructed until the kernels launched at
// current
// stream of given place is finished.
auto callback = [tensor, dst_place]() {
VLOG(4) << "Run callback of tensor:" << &tensor << " at place "
<< dst_place;
};
dev_ctx->AddStreamCallback(callback);
return out;
} }
#endif #endif
...@@ -204,23 +190,31 @@ inline phi::DenseTensor TransDataPlace(const phi::DenseTensor& tensor, ...@@ -204,23 +190,31 @@ inline phi::DenseTensor TransDataPlace(const phi::DenseTensor& tensor,
return out; return out;
} }
phi::DenseTensor TransformData(const phi::DenseTensor& tensor, phi::DenseTensor TransformData(phi::DenseTensor* tensor,
const phi::TensorArgDef& target_args_def, const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag) { const TransformFlag& transform_flag) {
phi::DenseTensor out = tensor; phi::DenseTensor out = *tensor;
bool trans_layout = false;
bool trans_dtype = false;
if (NeedTransformLayout( if (NeedTransformLayout(
tensor.layout(), target_args_def.layout, transform_flag)) { tensor->layout(), target_args_def.layout, transform_flag)) {
out = TransDataLayout(out, target_args_def.layout); out = TransDataLayout(out, target_args_def.layout);
trans_layout = true;
} }
if (NeedTransformDataType( if (NeedTransformDataType(
tensor.dtype(), target_args_def.dtype, transform_flag)) { tensor->dtype(), target_args_def.dtype, transform_flag)) {
out = TransDataType(out, target_args_def.dtype); out = TransDataType(out, target_args_def.dtype);
trans_dtype = true;
} }
if (NeedTransformPlace( if (NeedTransformPlace(
out.place(), target_args_def.backend, transform_flag)) { out.place(), target_args_def.backend, transform_flag)) {
out = TransDataPlace(out, phi::TransToPhiPlace(target_args_def.backend)); out = TransDataPlace(out, phi::TransToPhiPlace(target_args_def.backend));
if (!trans_layout && !trans_dtype &&
tensor->place().GetType() == AllocationType::GPUPINNED) {
tensor->ShareBufferWith(out);
}
} }
return out; return out;
} }
...@@ -243,7 +237,7 @@ std::shared_ptr<phi::DenseTensor> PrepareData( ...@@ -243,7 +237,7 @@ std::shared_ptr<phi::DenseTensor> PrepareData(
return std::static_pointer_cast<phi::DenseTensor>(tensor_in); return std::static_pointer_cast<phi::DenseTensor>(tensor_in);
} }
phi::DenseTensor out = phi::DenseTensor out =
TransformData(dense_tensor, target_args_def, transform_flag); TransformData(&dense_tensor, target_args_def, transform_flag);
return std::make_shared<phi::DenseTensor>(std::move(out)); return std::make_shared<phi::DenseTensor>(std::move(out));
} }
return nullptr; return nullptr;
...@@ -279,7 +273,7 @@ std::unique_ptr<std::vector<phi::DenseTensor>> PrepareData( ...@@ -279,7 +273,7 @@ std::unique_ptr<std::vector<phi::DenseTensor>> PrepareData(
*std::dynamic_pointer_cast<phi::DenseTensor>(tensor_in)); *std::dynamic_pointer_cast<phi::DenseTensor>(tensor_in));
} else { } else {
pt_tensors->emplace_back( pt_tensors->emplace_back(
TransformData(*(static_cast<phi::DenseTensor*>(tensor_in.get())), TransformData((static_cast<phi::DenseTensor*>(tensor_in.get())),
target_args_def, target_args_def,
transform_flag)); transform_flag));
} }
......
...@@ -25,8 +25,10 @@ limitations under the License. */ ...@@ -25,8 +25,10 @@ limitations under the License. */
#include "gtest/gtest.h" #include "gtest/gtest.h"
PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_DECLARE_KERNEL(full, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(full, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(copy, GPU, ALL_LAYOUT);
#endif #endif
namespace phi { namespace phi {
......
...@@ -96,7 +96,9 @@ class TestAsyncRead(unittest.TestCase): ...@@ -96,7 +96,9 @@ class TestAsyncRead(unittest.TestCase):
with _test_eager_guard(): with _test_eager_guard():
self.func_setUp() self.func_setUp()
self.func_test_async_read_empty_offset_and_count() self.func_test_async_read_empty_offset_and_count()
self.func_setUp()
self.func_test_async_read_success() self.func_test_async_read_success()
self.func_setUp()
self.func_test_async_read_only_1dim() self.func_test_async_read_only_1dim()
self.func_setUp() self.func_setUp()
self.func_test_async_read_empty_offset_and_count() self.func_test_async_read_empty_offset_and_count()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册