diff --git a/paddle/fluid/operators/memcpy_d2h_op.cc b/paddle/fluid/operators/memcpy_d2h_op.cc index f3f15f000d50ab57fb50839a414dc02ed529477d..ed99fd5bf8783ada7fdd00f031d9e1de80248577 100644 --- a/paddle/fluid/operators/memcpy_d2h_op.cc +++ b/paddle/fluid/operators/memcpy_d2h_op.cc @@ -13,6 +13,10 @@ limitations under the License. */ #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" + namespace paddle { namespace framework { class OpDesc; @@ -32,17 +36,6 @@ class MemcpyD2HOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - auto type = ctx->GetInputsVarType("X")[0]; - if (type == framework::proto::VarType::SELECTED_ROWS || - type == framework::proto::VarType::LOD_TENSOR) { - ctx->SetOutputDim("Out", ctx->GetInputDim("X")); - if (type == framework::proto::VarType::LOD_TENSOR) { - ctx->ShareLoD("X", /*->*/ "Out"); - } - } - } - protected: framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, @@ -116,95 +109,19 @@ raise error if the type is not listed above. namespace ops = paddle::operators; namespace plat = paddle::platform; + +DECLARE_INFER_SHAPE_FUNCTOR(memcpy_d2h, + MemcpyD2HInferShapeFunctor, + PD_INFER_META(phi::UnchangedInferMeta)); + REGISTER_OPERATOR( memcpy_d2h, ops::MemcpyD2HOp, ops::MemcpyD2HOpProtoMaker, ops::MemcpyD2HInferVarType, paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); - -REGISTER_OP_CPU_KERNEL_FUNCTOR(memcpy_d2h, - float, - ops::MemcpyD2HKernel, - double, - ops::MemcpyD2HKernel, - int8_t, - ops::MemcpyD2HKernel, - uint8_t, - ops::MemcpyD2HKernel, - int, - ops::MemcpyD2HKernel, - int64_t, - ops::MemcpyD2HKernel, - bool, - ops::MemcpyD2HKernel, - paddle::platform::bfloat16, - ops::MemcpyD2HKernel, - paddle::platform::complex, - ops::MemcpyD2HKernel, - paddle::platform::complex, - ops::MemcpyD2HKernel, - plat::float16, - ops::MemcpyD2HKernel, - int16_t, - ops::MemcpyD2HKernel); - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -REGISTER_OP_CUDA_KERNEL_FUNCTOR(memcpy_d2h, - float, - ops::MemcpyD2HKernel, - double, - ops::MemcpyD2HKernel, - int8_t, - ops::MemcpyD2HKernel, - uint8_t, - ops::MemcpyD2HKernel, - int, - ops::MemcpyD2HKernel, - int64_t, - ops::MemcpyD2HKernel, - bool, - ops::MemcpyD2HKernel, - paddle::platform::bfloat16, - ops::MemcpyD2HKernel, - paddle::platform::complex, - ops::MemcpyD2HKernel, - paddle::platform::complex, - ops::MemcpyD2HKernel, - plat::float16, - ops::MemcpyD2HKernel, - int16_t, - ops::MemcpyD2HKernel); -#endif - -#ifdef PADDLE_WITH_XPU -REGISTER_OP_XPU_KERNEL_FUNCTOR(memcpy_d2h, - float, - ops::MemcpyD2HKernel, - double, - ops::MemcpyD2HKernel, - int8_t, - ops::MemcpyD2HKernel, - uint8_t, - ops::MemcpyD2HKernel, - int, - ops::MemcpyD2HKernel, - int64_t, - ops::MemcpyD2HKernel, - bool, - ops::MemcpyD2HKernel, - paddle::platform::bfloat16, - ops::MemcpyD2HKernel, - paddle::platform::complex, - ops::MemcpyD2HKernel, - paddle::platform::complex, - ops::MemcpyD2HKernel, - plat::float16, - ops::MemcpyD2HKernel, - int16_t, - ops::MemcpyD2HKernel); -#endif + paddle::framework::EmptyGradOpMaker, + MemcpyD2HInferShapeFunctor); #ifdef PADDLE_WITH_ASCEND_CL REGISTER_OP_NPU_KERNEL_FUNCTOR(memcpy_d2h, diff --git a/paddle/phi/core/tensor_utils.cc b/paddle/phi/core/tensor_utils.cc index dcd25180e299760ce748e239bf8268341c87238c..e2ec4a6f14cef59e95e023292418e240d460b16b 100644 --- a/paddle/phi/core/tensor_utils.cc +++ b/paddle/phi/core/tensor_utils.cc @@ -52,6 +52,7 @@ void Copy(const Context& dev_ctx, << dst_place; dst->Resize(src.dims()); + dst->mutable_data(dst_place); void* dst_ptr = nullptr; if (paddle::platform::is_cpu_place(dst_place)) { diff --git a/paddle/phi/kernels/memcpy_h2d_kernel.cc b/paddle/phi/kernels/memcpy_h2d_kernel.cc deleted file mode 100644 index 6d20475b9fa65e0d4a43342d7eae9a266cd2b854..0000000000000000000000000000000000000000 --- a/paddle/phi/kernels/memcpy_h2d_kernel.cc +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/kernels/memcpy_h2d_kernel.h" - -#include "paddle/phi/common/complex.h" -#include "paddle/phi/core/kernel_registry.h" - -namespace phi { - -template -void MemcpyH2DKernel(const Context& dev_ctx, - const DenseTensor& x, - int dst_place_type, - DenseTensor* out) { - PADDLE_ENFORCE_GE( - dst_place_type, - 0, - errors::OutOfRange("dst_place_type only support 0-3, but got: %d", - dst_place_type)); - PADDLE_ENFORCE_LE( - dst_place_type, - 3, - errors::OutOfRange("dst_place_type only support 0-3, but got: %d", - dst_place_type)); - - // Copy will set the stream of the tensor while setting blocking to false - Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); -} - -} // namespace phi - -PD_REGISTER_KERNEL(memcpy_h2d, - CPU, - ALL_LAYOUT, - phi::MemcpyH2DKernel, - float, - double, - int8_t, - uint8_t, - int, - int64_t, - bool, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex, - phi::dtype::float16, - int16_t) {} - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PD_REGISTER_KERNEL(memcpy_h2d, - GPU, - ALL_LAYOUT, - phi::MemcpyH2DKernel, - float, - double, - int8_t, - uint8_t, - int, - int64_t, - bool, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex, - phi::dtype::float16, - int16_t) {} -#endif - -#ifdef PADDLE_WITH_XPU -PD_REGISTER_KERNEL(memcpy_h2d, - XPU, - ALL_LAYOUT, - phi::MemcpyH2DKernel, - float, - double, - int8_t, - uint8_t, - int, - int64_t, - bool, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex, - phi::dtype::float16, - int16_t) {} -#endif diff --git a/paddle/phi/kernels/memcpy_kernel.cc b/paddle/phi/kernels/memcpy_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..a9de4a4dd6eb211a96f69c215e41ac979d85c3cf --- /dev/null +++ b/paddle/phi/kernels/memcpy_kernel.cc @@ -0,0 +1,161 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/memcpy_kernel.h" + +#include + +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +static constexpr size_t WAIT_THRESHOLD = 64 * 1024; + +template +void MemcpyH2DKernel(const Context& dev_ctx, + const DenseTensor& x, + int dst_place_type, + DenseTensor* out) { + PADDLE_ENFORCE_GE( + dst_place_type, + 0, + errors::OutOfRange("dst_place_type only support 0-3, but got: %d", + dst_place_type)); + PADDLE_ENFORCE_LE( + dst_place_type, + 3, + errors::OutOfRange("dst_place_type only support 0-3, but got: %d", + dst_place_type)); + + // Copy will set the stream of the tensor while setting blocking to false + Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); +} + +template +void MemcpyD2HKernel(const Context& dev_ctx, + const DenseTensor& x, + int dst_place_type, + DenseTensor* out) { + // Copy will set the stream of the tensor while setting blocking to false + switch (dst_place_type) { + case 0: + Copy(dev_ctx, x, CPUPlace(), false, out); + // NOTE(copy from Aurelius84): host <-> device memory copies of a memory + // block of 64 KB or less are asynchronous. See + // https://forums.developer.nvidia.com/t/host-device-memory-copies-up-to-64-kb-are-asynchronous/17907 + if (x.memory_size() <= WAIT_THRESHOLD) { + dev_ctx.Wait(); + } + break; + + case 1: + Copy(dev_ctx, x, GPUPinnedPlace(), false, out); + // paddle::memory::Copy use async copy for GPUPinnedPlace + dev_ctx.Wait(); + break; + + default: + PADDLE_THROW(errors::InvalidArgument( + "Arugment 'dst_place_type' only support 0-1, but got: %d", + dst_place_type)); + break; + } +} + +template +void MemcpyD2HMultiIOKernel(const Context& dev_ctx, + const std::vector& array, + int dst_place_type, + std::vector out_array) { + PADDLE_ENFORCE_EQ( + array.size(), + out_array.size(), + errors::PreconditionNotMet( + "input size %d != output size %d", array.size(), out_array.size())); + + for (size_t i = 0; i < array.size(); i++) { + PADDLE_ENFORCE_NOT_NULL( + array[i], + errors::PreconditionNotMet("input tesnor %d should not be nullptr", i)); + PADDLE_ENFORCE_NOT_NULL( + out_array[i], + errors::PreconditionNotMet("input tesnor %d should not be nullptr", i)); + + const auto& x = *(array[i]); + MemcpyD2HKernel(dev_ctx, x, dst_place_type, out_array[i]); + } +} + +} // namespace phi + +PD_REGISTER_GENERAL_KERNEL(memcpy_h2d, + CPU, + ALL_LAYOUT, + phi::MemcpyH2DKernel, + ALL_DTYPE) {} + +PD_REGISTER_GENERAL_KERNEL(memcpy_d2h, + CPU, + ALL_LAYOUT, + phi::MemcpyD2HKernel, + ALL_DTYPE) {} + +PD_REGISTER_GENERAL_KERNEL(memcpy_d2h_multi_io, + CPU, + ALL_LAYOUT, + phi::MemcpyD2HMultiIOKernel, + ALL_DTYPE) {} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PD_REGISTER_GENERAL_KERNEL(memcpy_h2d, + GPU, + ALL_LAYOUT, + phi::MemcpyH2DKernel, + ALL_DTYPE) {} + +PD_REGISTER_GENERAL_KERNEL(memcpy_d2h, + GPU, + ALL_LAYOUT, + phi::MemcpyD2HKernel, + ALL_DTYPE) {} + +PD_REGISTER_GENERAL_KERNEL(memcpy_d2h_multi_io, + GPU, + ALL_LAYOUT, + phi::MemcpyD2HMultiIOKernel, + ALL_DTYPE) {} + +#endif + +#ifdef PADDLE_WITH_XPU +PD_REGISTER_GENERAL_KERNEL(memcpy_h2d, + XPU, + ALL_LAYOUT, + phi::MemcpyH2DKernel, + ALL_DTYPE) {} + +PD_REGISTER_GENERAL_KERNEL(memcpy_d2h, + XPU, + ALL_LAYOUT, + phi::MemcpyD2HKernel, + ALL_DTYPE) {} + +PD_REGISTER_GENERAL_KERNEL(memcpy_d2h_multi_io, + XPU, + ALL_LAYOUT, + phi::MemcpyD2HMultiIOKernel, + ALL_DTYPE) {} + +#endif diff --git a/paddle/phi/kernels/memcpy_h2d_kernel.h b/paddle/phi/kernels/memcpy_kernel.h similarity index 62% rename from paddle/phi/kernels/memcpy_h2d_kernel.h rename to paddle/phi/kernels/memcpy_kernel.h index 18b642cabe822e84ca791e3b35067a968df73fe1..9f72946dd67d6b0e6f519a5688750b61f1da72b0 100644 --- a/paddle/phi/kernels/memcpy_h2d_kernel.h +++ b/paddle/phi/kernels/memcpy_kernel.h @@ -14,15 +14,30 @@ #pragma once +#include + #include "paddle/phi/core/dense_tensor.h" namespace phi { // used in new executor, for memory copy from host to device -template +template void MemcpyH2DKernel(const Context& dev_ctx, const DenseTensor& x, int dst_place_type, DenseTensor* out); +// used in new executor, for memory copy from device to host +template +void MemcpyD2HKernel(const Context& dev_ctx, + const DenseTensor& x, + int dst_place_type, + DenseTensor* out); + +template +void MemcpyD2HMultiIOKernel(const Context& dev_ctx, + const std::vector& array, + int dst_place_type, + std::vector out_array); + } // namespace phi diff --git a/paddle/phi/ops/compat/memcpy_d2h_sig.cc b/paddle/phi/ops/compat/memcpy_d2h_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..38b0f01082e757d0b4f6b821554788d0e672211c --- /dev/null +++ b/paddle/phi/ops/compat/memcpy_d2h_sig.cc @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +#include "glog/logging.h" + +namespace phi { + +KernelSignature MemcpyD2HOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.IsDenseTensorVectorInput("X")) { + return KernelSignature( + "memcpy_d2h_multi_io", {"X"}, {"dst_place_type"}, {"Out"}); + } + + return KernelSignature("memcpy_d2h", {"X"}, {"dst_place_type"}, {"Out"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(memcpy_d2h, phi::MemcpyD2HOpArgumentMapping); diff --git a/paddle/phi/tests/kernels/CMakeLists.txt b/paddle/phi/tests/kernels/CMakeLists.txt index ca466780da45037a1515b60aefc543c8b51a306f..152bc0dd0c060e9b49ce1b520a65a2cfda6af595 100644 --- a/paddle/phi/tests/kernels/CMakeLists.txt +++ b/paddle/phi/tests/kernels/CMakeLists.txt @@ -129,3 +129,8 @@ elseif(WITH_ROCM) SRCS test_strings_copy_dev_api.cu DEPS phi phi_api_utils) endif() + +cc_test( + test_memcpy_dev_api + SRCS test_memcpy_dev_api.cc + DEPS phi phi_api_utils) diff --git a/paddle/phi/tests/kernels/test_memcpy_dev_api.cc b/paddle/phi/tests/kernels/test_memcpy_dev_api.cc new file mode 100644 index 0000000000000000000000000000000000000000..732b51f4204bf8b05da083a94514921572279611 --- /dev/null +++ b/paddle/phi/tests/kernels/test_memcpy_dev_api.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include + +#include "paddle/phi/api/lib/utils/allocator.h" +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/kernels/memcpy_kernel.h" + +namespace phi { +namespace tests { + +namespace framework = paddle::framework; +using DDim = phi::DDim; + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +TEST(DEV_API, memcpy_d2h) { + // 1. create tensor + const auto cpu_alloc = + std::make_unique(phi::CPUPlace()); + phi::DenseTensor x_cpu(cpu_alloc.get(), + phi::DenseTensorMeta(phi::DataType::FLOAT32, + phi::make_ddim({3, 2, 2, 3}), + phi::DataLayout::NCHW)); + auto* x_cpu_data = x_cpu.mutable_data(paddle::platform::CPUPlace()); + + for (int i = 0; i < x_cpu.numel(); i++) { + x_cpu_data[i] = i; + } + + const auto alloc = + std::make_unique(phi::GPUPlace()); + phi::DenseTensor x; + + // 2. test API + auto& pool = phi::DeviceContextPool::Instance(); + auto place = phi::GPUPlace(); + auto* dev_ctx = static_cast(pool.GetByPlace(place)); + + phi::MemcpyH2DKernel(*dev_ctx, x_cpu, 1, &x); + phi::DenseTensor out; + phi::MemcpyD2HKernel(*dev_ctx, x, 1, &out); + + // 3. check result + std::vector expect_shape = {12, 3}; + ASSERT_EQ(out.dims(), x.dims()); + ASSERT_EQ(out.meta().dtype, phi::DataType::FLOAT32); + ASSERT_EQ(out.meta().layout, phi::DataLayout::NCHW); + + bool value_equal = true; + auto* dense_out_data = out.data(); + for (int i = 0; i < x_cpu.numel(); i++) { + if (x_cpu_data[i] != dense_out_data[i]) { + value_equal = false; + break; + } + } + ASSERT_EQ(value_equal, true); +} + +#endif +} // namespace tests +} // namespace phi