未验证 提交 0d14e74a 编写于 作者: K kangguangli 提交者: GitHub

Transfer memcpy d2h from fluid to phi (#45150)

* transfer memcpy_d2h from fluid to phi

* refine arg check and add comment

* fix cannot fallback to phi kernel

* fix gpu_context host alloc when tensor size = 0

* add kernel for std::vector<DenseTensor> args

* fix bugs in MemcpyD2HMultiIOKernel

* remove useless header file

* polish format

* fix typo

* add testcase for cudapinned place

* refine check condition in test

* polish error message

* polish error message

* remove header in fluid  directory

* merge memcpy_h2d and memcpy_d2h into one file, change register method to simplify implementation

* fix code style check
上级 64afa638
......@@ -13,6 +13,10 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace framework {
class OpDesc;
......@@ -32,17 +36,6 @@ class MemcpyD2HOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
auto type = ctx->GetInputsVarType("X")[0];
if (type == framework::proto::VarType::SELECTED_ROWS ||
type == framework::proto::VarType::LOD_TENSOR) {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
if (type == framework::proto::VarType::LOD_TENSOR) {
ctx->ShareLoD("X", /*->*/ "Out");
}
}
}
protected:
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name,
......@@ -116,95 +109,19 @@ raise error if the type is not listed above.
namespace ops = paddle::operators;
namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(memcpy_d2h,
MemcpyD2HInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR(
memcpy_d2h,
ops::MemcpyD2HOp,
ops::MemcpyD2HOpProtoMaker,
ops::MemcpyD2HInferVarType,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL_FUNCTOR(memcpy_d2h,
float,
ops::MemcpyD2HKernel,
double,
ops::MemcpyD2HKernel,
int8_t,
ops::MemcpyD2HKernel,
uint8_t,
ops::MemcpyD2HKernel,
int,
ops::MemcpyD2HKernel,
int64_t,
ops::MemcpyD2HKernel,
bool,
ops::MemcpyD2HKernel,
paddle::platform::bfloat16,
ops::MemcpyD2HKernel,
paddle::platform::complex<float>,
ops::MemcpyD2HKernel,
paddle::platform::complex<double>,
ops::MemcpyD2HKernel,
plat::float16,
ops::MemcpyD2HKernel,
int16_t,
ops::MemcpyD2HKernel);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
REGISTER_OP_CUDA_KERNEL_FUNCTOR(memcpy_d2h,
float,
ops::MemcpyD2HKernel,
double,
ops::MemcpyD2HKernel,
int8_t,
ops::MemcpyD2HKernel,
uint8_t,
ops::MemcpyD2HKernel,
int,
ops::MemcpyD2HKernel,
int64_t,
ops::MemcpyD2HKernel,
bool,
ops::MemcpyD2HKernel,
paddle::platform::bfloat16,
ops::MemcpyD2HKernel,
paddle::platform::complex<float>,
ops::MemcpyD2HKernel,
paddle::platform::complex<double>,
ops::MemcpyD2HKernel,
plat::float16,
ops::MemcpyD2HKernel,
int16_t,
ops::MemcpyD2HKernel);
#endif
#ifdef PADDLE_WITH_XPU
REGISTER_OP_XPU_KERNEL_FUNCTOR(memcpy_d2h,
float,
ops::MemcpyD2HKernel,
double,
ops::MemcpyD2HKernel,
int8_t,
ops::MemcpyD2HKernel,
uint8_t,
ops::MemcpyD2HKernel,
int,
ops::MemcpyD2HKernel,
int64_t,
ops::MemcpyD2HKernel,
bool,
ops::MemcpyD2HKernel,
paddle::platform::bfloat16,
ops::MemcpyD2HKernel,
paddle::platform::complex<float>,
ops::MemcpyD2HKernel,
paddle::platform::complex<double>,
ops::MemcpyD2HKernel,
plat::float16,
ops::MemcpyD2HKernel,
int16_t,
ops::MemcpyD2HKernel);
#endif
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
MemcpyD2HInferShapeFunctor);
#ifdef PADDLE_WITH_ASCEND_CL
REGISTER_OP_NPU_KERNEL_FUNCTOR(memcpy_d2h,
......
......@@ -52,6 +52,7 @@ void Copy(const Context& dev_ctx,
<< dst_place;
dst->Resize(src.dims());
dst->mutable_data(dst_place);
void* dst_ptr = nullptr;
if (paddle::platform::is_cpu_place(dst_place)) {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/memcpy_h2d_kernel.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void MemcpyH2DKernel(const Context& dev_ctx,
const DenseTensor& x,
int dst_place_type,
DenseTensor* out) {
PADDLE_ENFORCE_GE(
dst_place_type,
0,
errors::OutOfRange("dst_place_type only support 0-3, but got: %d",
dst_place_type));
PADDLE_ENFORCE_LE(
dst_place_type,
3,
errors::OutOfRange("dst_place_type only support 0-3, but got: %d",
dst_place_type));
// Copy will set the stream of the tensor while setting blocking to false
Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
}
} // namespace phi
PD_REGISTER_KERNEL(memcpy_h2d,
CPU,
ALL_LAYOUT,
phi::MemcpyH2DKernel,
float,
double,
int8_t,
uint8_t,
int,
int64_t,
bool,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
phi::dtype::float16,
int16_t) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(memcpy_h2d,
GPU,
ALL_LAYOUT,
phi::MemcpyH2DKernel,
float,
double,
int8_t,
uint8_t,
int,
int64_t,
bool,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
phi::dtype::float16,
int16_t) {}
#endif
#ifdef PADDLE_WITH_XPU
PD_REGISTER_KERNEL(memcpy_h2d,
XPU,
ALL_LAYOUT,
phi::MemcpyH2DKernel,
float,
double,
int8_t,
uint8_t,
int,
int64_t,
bool,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
phi::dtype::float16,
int16_t) {}
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/memcpy_kernel.h"
#include <vector>
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
static constexpr size_t WAIT_THRESHOLD = 64 * 1024;
template <typename Context>
void MemcpyH2DKernel(const Context& dev_ctx,
const DenseTensor& x,
int dst_place_type,
DenseTensor* out) {
PADDLE_ENFORCE_GE(
dst_place_type,
0,
errors::OutOfRange("dst_place_type only support 0-3, but got: %d",
dst_place_type));
PADDLE_ENFORCE_LE(
dst_place_type,
3,
errors::OutOfRange("dst_place_type only support 0-3, but got: %d",
dst_place_type));
// Copy will set the stream of the tensor while setting blocking to false
Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
}
template <typename Context>
void MemcpyD2HKernel(const Context& dev_ctx,
const DenseTensor& x,
int dst_place_type,
DenseTensor* out) {
// Copy will set the stream of the tensor while setting blocking to false
switch (dst_place_type) {
case 0:
Copy(dev_ctx, x, CPUPlace(), false, out);
// NOTE(copy from Aurelius84): host <-> device memory copies of a memory
// block of 64 KB or less are asynchronous. See
// https://forums.developer.nvidia.com/t/host-device-memory-copies-up-to-64-kb-are-asynchronous/17907
if (x.memory_size() <= WAIT_THRESHOLD) {
dev_ctx.Wait();
}
break;
case 1:
Copy(dev_ctx, x, GPUPinnedPlace(), false, out);
// paddle::memory::Copy use async copy for GPUPinnedPlace
dev_ctx.Wait();
break;
default:
PADDLE_THROW(errors::InvalidArgument(
"Arugment 'dst_place_type' only support 0-1, but got: %d",
dst_place_type));
break;
}
}
template <typename Context>
void MemcpyD2HMultiIOKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& array,
int dst_place_type,
std::vector<DenseTensor*> out_array) {
PADDLE_ENFORCE_EQ(
array.size(),
out_array.size(),
errors::PreconditionNotMet(
"input size %d != output size %d", array.size(), out_array.size()));
for (size_t i = 0; i < array.size(); i++) {
PADDLE_ENFORCE_NOT_NULL(
array[i],
errors::PreconditionNotMet("input tesnor %d should not be nullptr", i));
PADDLE_ENFORCE_NOT_NULL(
out_array[i],
errors::PreconditionNotMet("input tesnor %d should not be nullptr", i));
const auto& x = *(array[i]);
MemcpyD2HKernel<Context>(dev_ctx, x, dst_place_type, out_array[i]);
}
}
} // namespace phi
PD_REGISTER_GENERAL_KERNEL(memcpy_h2d,
CPU,
ALL_LAYOUT,
phi::MemcpyH2DKernel<phi::CPUContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(memcpy_d2h,
CPU,
ALL_LAYOUT,
phi::MemcpyD2HKernel<phi::CPUContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(memcpy_d2h_multi_io,
CPU,
ALL_LAYOUT,
phi::MemcpyD2HMultiIOKernel<phi::CPUContext>,
ALL_DTYPE) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_GENERAL_KERNEL(memcpy_h2d,
GPU,
ALL_LAYOUT,
phi::MemcpyH2DKernel<phi::GPUContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(memcpy_d2h,
GPU,
ALL_LAYOUT,
phi::MemcpyD2HKernel<phi::GPUContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(memcpy_d2h_multi_io,
GPU,
ALL_LAYOUT,
phi::MemcpyD2HMultiIOKernel<phi::GPUContext>,
ALL_DTYPE) {}
#endif
#ifdef PADDLE_WITH_XPU
PD_REGISTER_GENERAL_KERNEL(memcpy_h2d,
XPU,
ALL_LAYOUT,
phi::MemcpyH2DKernel<phi::XPUContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(memcpy_d2h,
XPU,
ALL_LAYOUT,
phi::MemcpyD2HKernel<phi::XPUContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(memcpy_d2h_multi_io,
XPU,
ALL_LAYOUT,
phi::MemcpyD2HMultiIOKernel<phi::XPUContext>,
ALL_DTYPE) {}
#endif
......@@ -14,15 +14,30 @@
#pragma once
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
// used in new executor, for memory copy from host to device
template <typename T, typename Context>
template <typename Context>
void MemcpyH2DKernel(const Context& dev_ctx,
const DenseTensor& x,
int dst_place_type,
DenseTensor* out);
// used in new executor, for memory copy from device to host
template <typename Context>
void MemcpyD2HKernel(const Context& dev_ctx,
const DenseTensor& x,
int dst_place_type,
DenseTensor* out);
template <typename Context>
void MemcpyD2HMultiIOKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& array,
int dst_place_type,
std::vector<DenseTensor*> out_array);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
#include "glog/logging.h"
namespace phi {
KernelSignature MemcpyD2HOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorVectorInput("X")) {
return KernelSignature(
"memcpy_d2h_multi_io", {"X"}, {"dst_place_type"}, {"Out"});
}
return KernelSignature("memcpy_d2h", {"X"}, {"dst_place_type"}, {"Out"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(memcpy_d2h, phi::MemcpyD2HOpArgumentMapping);
......@@ -129,3 +129,8 @@ elseif(WITH_ROCM)
SRCS test_strings_copy_dev_api.cu
DEPS phi phi_api_utils)
endif()
cc_test(
test_memcpy_dev_api
SRCS test_memcpy_dev_api.cc
DEPS phi phi_api_utils)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <memory>
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/kernels/memcpy_kernel.h"
namespace phi {
namespace tests {
namespace framework = paddle::framework;
using DDim = phi::DDim;
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TEST(DEV_API, memcpy_d2h) {
// 1. create tensor
const auto cpu_alloc =
std::make_unique<paddle::experimental::DefaultAllocator>(phi::CPUPlace());
phi::DenseTensor x_cpu(cpu_alloc.get(),
phi::DenseTensorMeta(phi::DataType::FLOAT32,
phi::make_ddim({3, 2, 2, 3}),
phi::DataLayout::NCHW));
auto* x_cpu_data = x_cpu.mutable_data<float>(paddle::platform::CPUPlace());
for (int i = 0; i < x_cpu.numel(); i++) {
x_cpu_data[i] = i;
}
const auto alloc =
std::make_unique<paddle::experimental::DefaultAllocator>(phi::GPUPlace());
phi::DenseTensor x;
// 2. test API
auto& pool = phi::DeviceContextPool::Instance();
auto place = phi::GPUPlace();
auto* dev_ctx = static_cast<const phi::GPUContext*>(pool.GetByPlace(place));
phi::MemcpyH2DKernel<phi::GPUContext>(*dev_ctx, x_cpu, 1, &x);
phi::DenseTensor out;
phi::MemcpyD2HKernel<phi::GPUContext>(*dev_ctx, x, 1, &out);
// 3. check result
std::vector<int64_t> expect_shape = {12, 3};
ASSERT_EQ(out.dims(), x.dims());
ASSERT_EQ(out.meta().dtype, phi::DataType::FLOAT32);
ASSERT_EQ(out.meta().layout, phi::DataLayout::NCHW);
bool value_equal = true;
auto* dense_out_data = out.data<float>();
for (int i = 0; i < x_cpu.numel(); i++) {
if (x_cpu_data[i] != dense_out_data[i]) {
value_equal = false;
break;
}
}
ASSERT_EQ(value_equal, true);
}
#endif
} // namespace tests
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册