strided_memcpy.h 6.8 KB
Newer Older
1
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
L
Luo Tao 已提交
2 3 4 5 6 7 8 9 10
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yu Yang 已提交
11 12

#pragma once
C
chengduo 已提交
13
#include <vector>
W
wanghuancoder 已提交
14

15
#include "paddle/phi/kernels/funcs/detail/strided_memcpy.h"
W
wanghuancoder 已提交
16

17 18 19 20 21
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/core/dense_tensor.h"

namespace phi {
namespace funcs {
Y
Yu Yang 已提交
22

23
// Strided memory copy from src to dst.
Y
Yu Yang 已提交
24
//
25 26 27 28
// The src and dst should be both on dev_ctx.GetPlace(), otherwise, there will
// be a segment fault.
//
// The stride of an array (also referred to as increment, pitch or step size) is
Y
Yu Yang 已提交
29 30 31 32 33 34 35 36 37
// the number of locations in memory between beginnings of successive array
// elements
//
// For example, for tensor like [1, 3, 300, 300]. If there is no padding, the
// stride is [270000, 90000, 300, 1].
//
// NOTE: When use GPU, the memcpy is async. To sync memcpy, please invoke
// `dev_ctx.Wait()`.
template <typename T>
38
inline void StridedMemcpy(const phi::DeviceContext& dev_ctx,
39
                          const T* src,
40 41 42
                          const phi::DDim& src_stride,
                          const phi::DDim& dst_dim,
                          const phi::DDim& dst_stride,
43
                          T* dst) {
44
  detail::StridedCopyDimVisitor<T> func(
45
      dev_ctx, src, src_stride, dst_stride, dst);
S
sneaxiy 已提交
46
  dst_dim.apply_visitor(func);
Y
Yu Yang 已提交
47
}
Y
Yancey1989 已提交
48 49 50 51 52 53 54 55 56

// Strided numel memory copy from src to dst by the specified axis
//
// For example, for a tensor dims [4, 20, 100], the strieded numel is
// [8000, 2000, 100]
//
// NOTE: The src and dst tensor should have the same elements
// except the specified axis.
template <typename T>
57
inline void StridedNumelCopyWithAxis(const phi::DeviceContext& ctx,
58 59
                                     int64_t axis,
                                     T* dst,
60
                                     const phi::DDim& dst_stride_numel,
Y
Yancey1989 已提交
61
                                     const T* src,
62
                                     const phi::DDim& src_stride_numel,
T
typhoonzero 已提交
63
                                     int64_t size) {
Y
Yancey1989 已提交
64 65 66 67 68
  int64_t before = dst_stride_numel[0] / dst_stride_numel[axis];
  int64_t src_after = src_stride_numel[axis];
  int64_t dst_after = dst_stride_numel[axis];
  auto place = ctx.GetPlace();

69 70
  PADDLE_ENFORCE_EQ(src_stride_numel.size(),
                    dst_stride_numel.size(),
71
                    phi::errors::InvalidArgument(
72 73 74
                        "Source and destination tensor should have the same "
                        "dimension size, but source tensor dimension size is "
                        "%u, destination tensor size is %u.",
75 76
                        src_stride_numel.size(),
                        dst_stride_numel.size()));
Y
Yancey1989 已提交
77 78 79

  for (int64_t i = 0; i < axis; ++i) {
    if (i < axis) {
80 81 82
      PADDLE_ENFORCE_EQ(
          src_stride_numel[i] / src_stride_numel[axis],
          dst_stride_numel[i] / dst_stride_numel[axis],
83
          phi::errors::InvalidArgument(
84 85 86 87 88
              "Source and destination tensor should have the same number of "
              "elements except the specified axis, but the source elements "
              "number is %d, destination elements number is %d.",
              src_stride_numel[i] / src_stride_numel[axis],
              dst_stride_numel[i] / dst_stride_numel[axis]));
Y
Yancey1989 已提交
89 90 91
    } else if (i == axis) {
      continue;
    } else {
92
      PADDLE_ENFORCE_EQ(
93 94
          src_stride_numel[i],
          dst_stride_numel[i],
95
          phi::errors::InvalidArgument(
96 97 98
              "Source and destination tensor should have the same number of "
              "elements except the specified axis, but the source elements "
              "number is %d, destination elements number is %d.",
99 100
              src_stride_numel[i],
              dst_stride_numel[i]));
Y
Yancey1989 已提交
101 102 103 104
    }
  }

  for (int64_t i = 0; i < before; ++i) {
105
    if (place.GetType() == phi::AllocationType::CPU) {
106
      auto& cpu_place = place;
107 108 109 110 111
      paddle::memory::Copy(cpu_place,
                           dst + i * dst_after,
                           cpu_place,
                           src + i * src_after,
                           sizeof(T) * size);
Y
Yancey1989 已提交
112
    } else {
113
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
114
      auto& gpu_place = place;
L
Leo Chen 已提交
115
      auto& cuda_ctx = reinterpret_cast<const phi::GPUContext&>(ctx);
116 117 118 119 120 121
      paddle::memory::Copy(gpu_place,
                           dst + i * dst_after,
                           gpu_place,
                           src + i * src_after,
                           sizeof(T) * size,
                           cuda_ctx.stream());
122
#elif defined(PADDLE_WITH_ASCEND_CL)
123
      auto& npu_place = place;
124
      auto& npu_ctx = reinterpret_cast<const platform::NPUDeviceContext&>(ctx);
125 126 127 128 129 130
      paddle::memory::Copy(npu_place,
                           dst + i * dst_after,
                           npu_place,
                           src + i * src_after,
                           sizeof(T) * size,
                           npu_ctx.stream());
Z
zn 已提交
131 132 133
#elif defined(PADDLE_WITH_MLU)
      auto& mlu_place = place;
      auto& mlu_ctx = reinterpret_cast<const platform::MLUDeviceContext&>(ctx);
134 135 136 137 138 139
      paddle::memory::Copy(mlu_place,
                           dst + i * dst_after,
                           mlu_place,
                           src + i * src_after,
                           sizeof(T) * size,
                           mlu_ctx.stream());
Y
Yancey1989 已提交
140
#else
141 142
      PADDLE_THROW(
          phi::errors::PreconditionNotMet("Paddle is not compiled with GPU."));
Y
Yancey1989 已提交
143 144 145 146 147
#endif
    }
  }
}

C
chengduo 已提交
148 149
template <typename T>
inline void StridedMemcpyWithAxis0(
150
    const phi::DeviceContext& dev_ctx,
151 152 153
    const phi::DenseTensor& input,
    const std::vector<const phi::DenseTensor*>& shape_refer,
    std::vector<phi::DenseTensor*>* outputs) {
154
  const phi::DDim in_stride = stride_numel(input.dims());
C
chengduo 已提交
155 156 157 158 159 160
  const int axis = 0;
  size_t input_offset = 0;

  for (size_t i = 0; i < outputs->size(); ++i) {
    auto out_stride = stride_numel(shape_refer[i]->dims());
    auto out = outputs->at(i);
161
    if (out != nullptr && out->initialized() && out->numel() > 0) {
162 163 164 165 166 167
      StridedNumelCopyWithAxis<T>(dev_ctx,
                                  axis,
                                  out->data<T>(),
                                  out_stride,
                                  input.data<T>() + input_offset,
                                  in_stride,
C
chengduo 已提交
168 169 170 171 172 173
                                  out_stride[axis]);
    }
    input_offset += out_stride[axis];
  }
}

174 175
}  // namespace funcs
}  // namespace phi