未验证 提交 822a2d1f 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Fix assign kernel bug (#40927)

* fix assign kernel bug

* fix xpu kernel select error

* add cudn pinned place

* fix copy error

* fix infrt error
上级 cb183762
......@@ -394,6 +394,12 @@ function(op_library TARGET)
if(NOT ${op_name} EQUAL "")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, XPU);\n")
set(pybind_flag 1)
else()
find_register(${xpu_src} "REGISTER_OP_XPU_KERNEL_FUNCTOR" op_name)
if(NOT ${op_name} EQUAL "")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, XPU);\n")
set(pybind_flag 1)
endif()
endif()
endforeach()
endif()
......
......@@ -26,11 +26,13 @@ template <typename Context>
void AssignKernel(const Context& dev_ctx,
paddle::optional<const DenseTensor&> x,
DenseTensor* out) {
if (!x.is_initialized()) {
return;
if (x.get_ptr()) {
if (!x.is_initialized()) {
return;
}
auto& x_tensor = *x.get_ptr();
Copy<Context>(dev_ctx, x_tensor, x_tensor.place(), false, out);
}
auto& x_tensor = *x.get_ptr();
Copy<Context>(dev_ctx, x_tensor, x_tensor.place(), false, out);
}
// Note: use `const paddle::optional<std::vector<const DenseTensor*>&> x`
......@@ -103,7 +105,9 @@ void AssignValueKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_GENERAL_KERNEL(
assign, CPU, ALL_LAYOUT, phi::AssignKernel<phi::CPUContext>, ALL_DTYPE) {}
assign, CPU, ALL_LAYOUT, phi::AssignKernel<phi::CPUContext>, ALL_DTYPE) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_GENERAL_KERNEL(assign_array,
CPU,
ALL_LAYOUT,
......@@ -120,7 +124,9 @@ PD_REGISTER_KERNEL(assign_value,
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_GENERAL_KERNEL(
assign, GPU, ALL_LAYOUT, phi::AssignKernel<phi::GPUContext>, ALL_DTYPE) {}
assign, GPU, ALL_LAYOUT, phi::AssignKernel<phi::GPUContext>, ALL_DTYPE) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_GENERAL_KERNEL(assign_array,
GPU,
ALL_LAYOUT,
......
......@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/device_context.h"
......@@ -34,12 +35,6 @@ void Copy(const Context& dev_ctx,
auto* src_ptr = src.data();
const auto& src_place = src.place();
if (src_place == dst_place && paddle::platform::is_cpu_place(src_place)) {
PADDLE_THROW(phi::errors::InvalidArgument(
"The src and dst tensor are all CPU tensor, you should call copy "
"function in CPU mode."));
}
VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to "
<< dst_place;
......@@ -48,6 +43,10 @@ void Copy(const Context& dev_ctx,
void* dst_ptr = nullptr;
if (paddle::platform::is_cpu_place(dst_place)) {
dst_ptr = dev_ctx.HostAlloc(dst, src.dtype());
} else if (paddle::platform::is_cuda_pinned_place(dst_place)) {
// now we only can use mutable_data to Alloc pinned memory here,
// dev_ctx can not alloc pinned memory now
dst_ptr = dst->mutable_data(dst_place, src.dtype());
} else {
dst_ptr = dev_ctx.Alloc(dst, src.dtype());
}
......@@ -63,8 +62,13 @@ void Copy(const Context& dev_ctx,
auto size = src.numel() * paddle::experimental::SizeOf(src.dtype());
if (paddle::platform::is_gpu_place(src_place) && // NOLINT
paddle::platform::is_cpu_place(dst_place)) {
if ((paddle::platform::is_cpu_place(src_place) ||
paddle::platform::is_cuda_pinned_place(src_place)) && // NOLINT
(paddle::platform::is_cpu_place(dst_place) ||
paddle::platform::is_cuda_pinned_place(dst_place))) {
paddle::memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size, nullptr);
} else if (paddle::platform::is_gpu_place(src_place) && // NOLINT
paddle::platform::is_cpu_place(dst_place)) {
auto src_gpu_place = src_place;
auto dst_cpu_place = dst_place;
auto ctx_place = dev_ctx.GetPlace();
......
......@@ -27,7 +27,7 @@ attr_type_converter = {
"St6vectorIiSaIiEE": 'I32ArrayAttr'
}
target_type_converter = {"CPU": "CPU", "GPU": "GPU"}
target_type_converter = {"CPU": "CPU", "GPU": "GPU", "Undefined": "UNK"}
layout_type_converter = {
"NCHW": "NCHW",
"NHWC": "NHWC",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册