From 114a5d214977507c20c2b8f770301e3187f3ab04 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Mon, 30 May 2022 10:39:37 +0800 Subject: [PATCH] Make data transform inplaced when tensor is on GPUPinned (#43055) * make data transform inplace when tensor is on gpupinned in new dygraph * fix unittest --- paddle/phi/api/lib/data_transform.cc | 34 ++++++++------------ paddle/phi/tests/common/test_int_array.cc | 2 ++ python/paddle/tests/test_async_read_write.py | 2 ++ 3 files changed, 18 insertions(+), 20 deletions(-) diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index 598559cc4df..12f7b8bba58 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -174,20 +174,6 @@ inline phi::DenseTensor TransDataPlace(const phi::DenseTensor& tensor, if (!platform::is_cuda_pinned_place(tensor.place())) { pool.Get(tensor.place())->Wait(); pool.Get(dst_place)->Wait(); - } else if (platform::is_gpu_place(dst_place)) { - auto* dev_ctx = static_cast(pool.Get(dst_place)); - phi::Copy(*dev_ctx, tensor, dst_place, false, &out); - - // Note: This is an empty callback, the only way is to "reference" - // tensor, so it will not be destructed until the kernels launched at - // current - // stream of given place is finished. - auto callback = [tensor, dst_place]() { - VLOG(4) << "Run callback of tensor:" << &tensor << " at place " - << dst_place; - }; - dev_ctx->AddStreamCallback(callback); - return out; } #endif @@ -204,23 +190,31 @@ inline phi::DenseTensor TransDataPlace(const phi::DenseTensor& tensor, return out; } -phi::DenseTensor TransformData(const phi::DenseTensor& tensor, +phi::DenseTensor TransformData(phi::DenseTensor* tensor, const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag) { - phi::DenseTensor out = tensor; + phi::DenseTensor out = *tensor; + bool trans_layout = false; + bool trans_dtype = false; if (NeedTransformLayout( - tensor.layout(), target_args_def.layout, transform_flag)) { + tensor->layout(), target_args_def.layout, transform_flag)) { out = TransDataLayout(out, target_args_def.layout); + trans_layout = true; } if (NeedTransformDataType( - tensor.dtype(), target_args_def.dtype, transform_flag)) { + tensor->dtype(), target_args_def.dtype, transform_flag)) { out = TransDataType(out, target_args_def.dtype); + trans_dtype = true; } if (NeedTransformPlace( out.place(), target_args_def.backend, transform_flag)) { out = TransDataPlace(out, phi::TransToPhiPlace(target_args_def.backend)); + if (!trans_layout && !trans_dtype && + tensor->place().GetType() == AllocationType::GPUPINNED) { + tensor->ShareBufferWith(out); + } } return out; } @@ -243,7 +237,7 @@ std::shared_ptr PrepareData( return std::static_pointer_cast(tensor_in); } phi::DenseTensor out = - TransformData(dense_tensor, target_args_def, transform_flag); + TransformData(&dense_tensor, target_args_def, transform_flag); return std::make_shared(std::move(out)); } return nullptr; @@ -279,7 +273,7 @@ std::unique_ptr> PrepareData( *std::dynamic_pointer_cast(tensor_in)); } else { pt_tensors->emplace_back( - TransformData(*(static_cast(tensor_in.get())), + TransformData((static_cast(tensor_in.get())), target_args_def, transform_flag)); } diff --git a/paddle/phi/tests/common/test_int_array.cc b/paddle/phi/tests/common/test_int_array.cc index b6c4f2b1ea8..a6278ee4a34 100644 --- a/paddle/phi/tests/common/test_int_array.cc +++ b/paddle/phi/tests/common/test_int_array.cc @@ -25,8 +25,10 @@ limitations under the License. */ #include "gtest/gtest.h" PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PD_DECLARE_KERNEL(full, GPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(copy, GPU, ALL_LAYOUT); #endif namespace phi { diff --git a/python/paddle/tests/test_async_read_write.py b/python/paddle/tests/test_async_read_write.py index babdf43199d..14320634215 100644 --- a/python/paddle/tests/test_async_read_write.py +++ b/python/paddle/tests/test_async_read_write.py @@ -96,7 +96,9 @@ class TestAsyncRead(unittest.TestCase): with _test_eager_guard(): self.func_setUp() self.func_test_async_read_empty_offset_and_count() + self.func_setUp() self.func_test_async_read_success() + self.func_setUp() self.func_test_async_read_only_1dim() self.func_setUp() self.func_test_async_read_empty_offset_and_count() -- GitLab