gather.h 8.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Z
Zhuoyuan 已提交
2 3 4 5 6

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

Z
zchen0211 已提交
7
   http://www.apache.org/licenses/LICENSE-2.0
Z
Zhuoyuan 已提交
8 9 10 11 12 13 14 15

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
Z
zchen0211 已提交
16
#include <memory.h>
17

Z
Zhuoyuan 已提交
18
#include <cstring>
19
#include <vector>
Z
zchen0211 已提交
20

21
#include "paddle/phi/common/place.h"
22
#include "paddle/phi/core/ddim.h"
23
#include "paddle/phi/core/dense_tensor.h"
24
#include "paddle/phi/core/macros.h"
25
#include "paddle/phi/kernels/funcs/math_function.h"
26 27
namespace phi {
namespace funcs {
28

Z
zchen0211 已提交
29
/**
Z
1 api  
zchen0211 已提交
30
 * A thin wrapper for gathering on cpu tensor
Z
zchen0211 已提交
31 32
 * Return a new tensor from source tensor, gathered according to index
 * input[src]: type-T source Tensor
33
 * input[index]: type-IndexT index Tensor (1-D)
Z
zchen0211 已提交
34 35
 * return: output tensor
 */
36
template <typename T, typename IndexT = int>
37
void CPUGather(const phi::CPUContext& ctx UNUSED,
38 39 40
               const DenseTensor& src,
               const DenseTensor& index,
               DenseTensor* output) {
41
  if (index.dims().size() == 2) {
42
    PADDLE_ENFORCE_EQ(
43 44 45
        index.dims()[1],
        1,
        phi::errors::InvalidArgument(
46 47 48
            "index.dims()[1] should be 1 when index.dims().size() = 2"
            "in gather_op, but received value is [%d].",
            index.dims()[1]));
49
  } else {
50 51 52 53 54 55
    PADDLE_ENFORCE_EQ(
        index.dims().size() == 1 || index.dims().size() == 0,
        true,
        phi::errors::InvalidArgument(
            "The index should be 0D or 1D, when it is not 2D, but we get %d",
            index.dims().size()));
56
  }
57 58

  int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0];
Z
zchen0211 已提交
59

60
  auto src_dims = src.dims();
Z
zchen0211 已提交
61

62
  const T* p_src = src.data<T>();
63
  const IndexT* p_index = index.data<IndexT>();
Z
1 api  
zchen0211 已提交
64 65
  T* p_output = output->data<T>();

Z
zchen0211 已提交
66
  // slice size
67
  int64_t slice_size = 1;
Z
zchen0211 已提交
68
  for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
69
  // input size
70
  int64_t input_size = src_dims[0] * slice_size;
Z
zchen0211 已提交
71

Z
1 api  
zchen0211 已提交
72 73
  const size_t slice_bytes = slice_size * sizeof(T);

74
  for (int64_t i = 0; i < index_size; ++i) {
75
    IndexT index_ = p_index[i];
76 77 78
    PADDLE_ENFORCE_LT(p_index[i],
                      input_size,
                      phi::errors::OutOfRange(
79 80 81
                          "The element of Index must be less than the size of "
                          "input dim size of axis which is %d, but received "
                          "index element which is %d in the %d index.",
82 83 84 85 86 87
                          input_size,
                          p_index[i],
                          i));
    PADDLE_ENFORCE_GE(p_index[i],
                      0,
                      phi::errors::OutOfRange(
88 89 90
                          "The element of Index must be greater than or equal "
                          "to 0, but received index element which is %d in the "
                          "%d index.",
91 92
                          p_index[i],
                          i));
Z
1 api  
zchen0211 已提交
93 94
    memcpy(p_output + i * slice_size, p_src + index_ * slice_size, slice_bytes);
  }
Z
Zhuoyuan 已提交
95
}
Z
zchen0211 已提交
96

97
template <typename T, typename IndexT = int>
98
void CPUGatherNd(const phi::CPUContext& ctx UNUSED,
99 100 101
                 const DenseTensor& input,
                 const DenseTensor& index,
                 DenseTensor* output) {
102 103 104 105 106 107 108 109 110 111 112 113
  auto index_dims = index.dims();
  auto index_dims_size = index_dims.size();
  auto input_dims = input.dims();
  auto input_dims_size = input_dims.size();

  const T* p_input = input.data<T>();
  const IndexT* p_index = index.data<IndexT>();
  T* p_output = output->data<T>();

  // final dim
  int64_t end_size = index_dims[index_dims_size - 1];
  // remain dim
114 115
  auto remain_ddim = phi::slice_ddim(index_dims, 0, index_dims_size - 1);
  int64_t remain_numel = phi::product(remain_ddim);
116 117 118 119 120 121 122 123 124 125 126 127
  // slice size
  int64_t slice_size = 1;
  for (int64_t i = end_size; i < input_dims_size; ++i) {
    slice_size *= input_dims[i];
  }
  const size_t slice_bytes = slice_size * sizeof(T);

  for (int64_t i = 0; i < remain_numel; ++i) {
    int64_t index_ = 0;
    int64_t temp = 1;
    for (int64_t j = end_size - 1; j >= 0; --j) {
      IndexT index_value = p_index[i * end_size + j];
128
      PADDLE_ENFORCE_LT(
129 130 131
          index_value,
          input_dims[j],
          phi::errors::InvalidArgument(
132 133
              "Input(index[-1)] has wrong value, it is [%d]", index_value));
      PADDLE_ENFORCE_GE(
134 135 136
          index_value,
          0,
          phi::errors::InvalidArgument(
137
              "The value of Input(index) must be no less than 0"));
138 139 140 141

      index_ += (index_value * temp);
      temp *= input_dims[j];
    }
142 143
    memcpy(
        p_output + i * slice_size, p_input + index_ * slice_size, slice_bytes);
144 145 146
  }
}

147
template <typename T, typename U>
148 149 150 151 152
void GatherV2Function(const phi::CPUContext& ctx,
                      const DenseTensor* input,
                      const DenseTensor* index,
                      int axis,
                      DenseTensor* out) {
153
  auto* index_data = index->data<U>();
154 155
  int64_t index_size = index->numel();
  int64_t input_size = input->numel();
156 157 158 159
  auto input_dim = input->dims();
  auto* input_data = input->data<T>();

  if (input->numel() == 0) return;
160
  int axis_index = axis;
161

162 163
  int64_t input_index_dim_size = input_dim[axis_index];
  for (int64_t i = 0; i < index_size; i++) {
164 165 166
    PADDLE_ENFORCE_LT(index_data[i],
                      input_index_dim_size,
                      phi::errors::OutOfRange(
167 168 169
                          "The element of Index must be less than the size of "
                          "input dim size of axis which is %d, but received "
                          "index element which is %d in the %d index.",
170 171 172 173 174 175
                          input_index_dim_size,
                          index_data[i],
                          i));
    PADDLE_ENFORCE_GE(index_data[i],
                      0,
                      phi::errors::OutOfRange(
176 177 178
                          "The element of Index must be greater than or equal "
                          "to 0, but received index element which is %d in the "
                          "%d index.",
179 180
                          index_data[i],
                          i));
181 182
  }

183 184 185
  int64_t inner_dim_size = 1;
  int64_t outer_dim_size = 1;
  std::vector<int64_t> out_dim_vec;
186 187 188 189 190

  for (int i = 0; i < axis_index; i++) {
    inner_dim_size *= input_dim[i];
    out_dim_vec.push_back(input_dim[i]);
  }
191 192 193
  if (index->dims().size() != 0) {
    out_dim_vec.push_back(index_size);
  }
194 195 196 197
  for (int i = axis_index + 1; i < input_dim.size(); i++) {
    outer_dim_size *= input_dim[i];
    out_dim_vec.push_back(input_dim[i]);
  }
198
  auto out_dim = phi::make_ddim(out_dim_vec);
199 200

  out->Resize(out_dim);
201
  auto* out_data = ctx.Alloc<T>(out);
202 203

  int out_index = 0;
204 205 206 207 208
  for (int64_t i = 0; i < inner_dim_size; i++) {
    for (int64_t j = 0; j < index_size; j++) {
      for (int64_t k = 0; k < outer_dim_size; k++) {
        int64_t index = k + index_data[j] * outer_dim_size +
                        (i * input_size / inner_dim_size);
209 210 211 212 213 214 215
        out_data[out_index] = input_data[index];
        out_index++;
      }
    }
  }
}

216
template <typename T, typename U>
217 218 219 220 221
void GatherV2GradFunction(const phi::CPUContext& ctx,
                          const DenseTensor* input,
                          const DenseTensor* index,
                          const int axis,
                          DenseTensor* out) {
222 223 224 225 226 227
  auto* index_data = index->data<U>();

  auto input_dim = input->dims();
  auto* input_data = input->data<T>();

  if (input->numel() == 0) return;
228
  int axis_index = axis;
229 230 231 232 233 234 235
  int64_t input_index_dim_size;
  if (input_dim.size() == out->dims().size()) {
    input_index_dim_size = input_dim[axis_index];
  } else {
    // 0d index
    input_index_dim_size = 1;
  }
236

237 238
  int64_t inner_dim_size = 1;
  int64_t outer_dim_size = 1;
239 240 241 242 243 244 245 246

  for (int i = 0; i < axis_index; i++) {
    inner_dim_size *= input_dim[i];
  }
  for (int i = axis_index + 1; i < input_dim.size(); i++) {
    outer_dim_size *= input_dim[i];
  }

247
  auto* out_data = ctx.Alloc<T>(out);
248
  auto out_dim = out->dims();
249
  int64_t out_index_dim_size = out_dim[axis_index];
250
  phi::funcs::set_constant(ctx, out, 0.0);
251

252 253 254 255 256
  for (int64_t i = 0; i < inner_dim_size; i++) {
    for (int64_t j = 0; j < input_index_dim_size; j++) {
      for (int64_t k = 0; k < outer_dim_size; k++) {
        int64_t index = k + index_data[j] * outer_dim_size +
                        i * outer_dim_size * out_index_dim_size;
257 258 259 260 261 262
        out_data[index] += input_data[j * outer_dim_size + k];
      }
    }
  }
}

263 264
}  // namespace funcs
}  // namespace phi