copy_kernel.cu 7.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/kernels/copy_kernel.h"
16

17 18 19 20
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/kernel_registry.h"
21 22

// See Note [ Why still include the fluid headers? ]
23
#include "paddle/fluid/memory/malloc.h"
24
#include "paddle/fluid/memory/memcpy.h"
25
#include "paddle/fluid/platform/device_context.h"
26

27
namespace phi {
28

29 30
template <typename Context>
void Copy(const Context& dev_ctx,
31
          const DenseTensor& src,
32
          Place dst_place,
33
          bool blocking,
34 35 36 37
          DenseTensor* dst) {
  auto* src_ptr = src.data();
  const auto& src_place = src.place();

38 39 40
  VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to "
          << dst_place;

41 42 43 44 45
  dst->Resize(src.dims());

  void* dst_ptr = nullptr;
  if (paddle::platform::is_cpu_place(dst_place)) {
    dst_ptr = dev_ctx.HostAlloc(dst, src.dtype());
46 47 48 49
  } else if (paddle::platform::is_cuda_pinned_place(dst_place)) {
    // now we only can use mutable_data to Alloc pinned memory here,
    // dev_ctx can not alloc pinned memory now
    dst_ptr = dst->mutable_data(dst_place, src.dtype());
50
  } else {
W
wanghuancoder 已提交
51 52
    dst_ptr = dev_ctx.Alloc(
        dst, src.dtype(), 0, paddle::platform::is_cuda_pinned_place(dst_place));
53
  }
54

55 56 57 58 59 60
  if (src_ptr == dst_ptr && src_place == dst_place) {
    VLOG(3) << "Skip copy the same data async from " << src_place << " to "
            << dst_place;
    return;
  }
  VLOG(4) << "src:" << src_ptr << ", dst:" << dst_ptr;
61

62
  CHECK(dst->layout() == src.layout());
63

64
  auto size = src.numel() * paddle::experimental::SizeOf(src.dtype());
65

66 67 68 69 70 71 72
  if ((paddle::platform::is_cpu_place(src_place) ||
       paddle::platform::is_cuda_pinned_place(src_place)) &&  // NOLINT
      (paddle::platform::is_cpu_place(dst_place) ||
       paddle::platform::is_cuda_pinned_place(dst_place))) {
    paddle::memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size, nullptr);
  } else if (paddle::platform::is_gpu_place(src_place) &&  // NOLINT
             paddle::platform::is_cpu_place(dst_place)) {
73 74
    auto src_gpu_place = src_place;
    auto dst_cpu_place = dst_place;
75 76 77 78
    auto ctx_place = dev_ctx.GetPlace();
    PADDLE_ENFORCE_EQ(
        paddle::platform::is_gpu_place(ctx_place),
        true,
79
        phi::errors::PreconditionNotMet(
80 81
            "Context place error, excepted GPUPlace, but actually %s.",
            ctx_place));
82
    auto ctx_gpu_place = ctx_place;
83 84
    PADDLE_ENFORCE_EQ(src_gpu_place,
                      ctx_gpu_place,
85
                      phi::errors::Unavailable(
86 87 88 89 90
                          "Source place and context place do not match, source "
                          "place is %s, context place is %s.",
                          src_gpu_place,
                          ctx_gpu_place));
    auto stream =
91
        blocking ? nullptr
92
                 : reinterpret_cast<const phi::GPUContext&>(dev_ctx).stream();
93 94
    paddle::memory::Copy(
        dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream);
95 96
  } else if ((paddle::platform::is_cpu_place(src_place) ||
              paddle::platform::is_cuda_pinned_place(src_place)) &&  // NOLINT
97
             paddle::platform::is_gpu_place(dst_place)) {
98 99
    auto src_cpu_place = src_place;
    auto dst_gpu_place = dst_place;
100 101 102 103
    auto ctx_place = dev_ctx.GetPlace();
    PADDLE_ENFORCE_EQ(
        paddle::platform::is_gpu_place(ctx_place),
        true,
104
        phi::errors::PreconditionNotMet(
105 106
            "Context place error, excepted GPUPlace, but actually %s.",
            ctx_place));
107
    auto ctx_gpu_place = ctx_place;
108 109
    PADDLE_ENFORCE_EQ(dst_gpu_place,
                      ctx_gpu_place,
110
                      phi::errors::Unavailable(
111 112 113 114 115
                          "Destination place and context place do not match, "
                          "destination place is %s, context place is %s.",
                          dst_gpu_place,
                          ctx_gpu_place));
    auto stream =
116
        blocking ? nullptr
117
                 : reinterpret_cast<const phi::GPUContext&>(dev_ctx).stream();
118 119 120 121
    paddle::memory::Copy(
        dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, stream);
  } else if (paddle::platform::is_gpu_place(src_place) &&  // NOLINT
             paddle::platform::is_gpu_place(dst_place)) {
122 123
    auto src_gpu_place = src_place;
    auto dst_gpu_place = dst_place;
124 125 126 127
    auto ctx_place = dev_ctx.GetPlace();
    PADDLE_ENFORCE_EQ(
        paddle::platform::is_gpu_place(ctx_place),
        true,
128
        phi::errors::PreconditionNotMet(
129 130 131
            "Context place error, excepted GPUPlace, but actually %s.",
            ctx_place));
    auto stream =
132
        blocking ? nullptr
133
                 : reinterpret_cast<const phi::GPUContext&>(dev_ctx).stream();
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
    if (paddle::platform::is_same_place(src_place, dst_place)) {
      paddle::memory::Copy(
          dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream);
    } else {
      if (paddle::platform::is_same_place(ctx_place, src_place)) {
        paddle::memory::Copy(
            dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream);
        paddle::platform::DeviceContextPool::Instance()
            .Get(src.place())
            ->Wait();
      } else if (paddle::platform::is_same_place(ctx_place, dst_place)) {
        paddle::platform::DeviceContextPool::Instance()
            .Get(src.place())
            ->Wait();
        paddle::memory::Copy(
            dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream);
      } else {
151
        PADDLE_THROW(phi::errors::Unavailable(
152 153 154
            "Context place dose not match the source and destination place."));
      }
    }
W
wanghuancoder 已提交
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
  } else if (paddle::platform::is_gpu_place(src_place) &&  // NOLINT
             paddle::platform::is_cuda_pinned_place(dst_place)) {
    auto src_gpu_place = src_place;
    auto dst_cuda_pinned_place = dst_place;
    auto ctx_place = dev_ctx.GetPlace();
    PADDLE_ENFORCE_EQ(
        paddle::platform::is_gpu_place(ctx_place),
        true,
        phi::errors::PreconditionNotMet(
            "Context place error, excepted GPUPlace, but actually %s.",
            ctx_place));
    auto ctx_gpu_place = ctx_place;
    PADDLE_ENFORCE_EQ(src_gpu_place,
                      ctx_gpu_place,
                      phi::errors::Unavailable(
                          "Source place and context place do not match, source "
                          "place is %s, context place is %s.",
                          src_gpu_place,
                          ctx_gpu_place));
    auto stream =
        blocking ? nullptr
                 : reinterpret_cast<const phi::GPUContext&>(dev_ctx).stream();
    paddle::memory::Copy(
        dst_cuda_pinned_place, dst_ptr, src_gpu_place, src_ptr, size, stream);
179
  } else {
180
    PADDLE_THROW(phi::errors::InvalidArgument(
181
        "Place type error. Please check the place of src and dst Tensor."));
182 183
  }
}
184

185
}  // namespace phi
186

187
PD_REGISTER_GENERAL_KERNEL(
188
    copy, GPU, ALL_LAYOUT, phi::Copy<phi::GPUContext>, ALL_DTYPE) {}