utils.h 4.7 KB
Newer Older
J
Jiabin Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_proto_maker.h"
19
#include "paddle/fluid/operators/common_infer_shape_functions.h"
20
#include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h"
J
Jiabin Yang 已提交
21 22 23
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/place.h"
24
#include "paddle/phi/core/ddim.h"
25
#include "paddle/phi/kernels/funcs/blas/blas.h"
26

J
Jiabin Yang 已提交
27 28 29 30
namespace paddle {
namespace prim {
// We put some api like utils here
template <typename T>
31
Tensor empty(const paddle::experimental::IntArray& shape,
32
             phi::DataType dype,
33
             const paddle::Place& place);
J
Jiabin Yang 已提交
34 35

template <typename T>
36
Tensor empty_like(const Tensor& x,
37
                  phi::DataType dtype,
38 39 40
                  const paddle::Place& place);

// copy tensor for output ptr, in static need use assigh op
41
template <typename T>
42
void by_pass(const Tensor& x, Tensor* out);
43

44
// set output ptr impl with tmp ptr impl,in dygraph OutGradMeta should be set
45
template <typename T>
46
void set_output(const Tensor& x_tmp, Tensor* x);
47

48
// These method don't need to be specified
49 50
static phi::DDim get_reduce_dims_from_out(const phi::DDim& dout_dims,
                                          const phi::DDim& in_dims) {
51
  std::vector<int64_t> result;
52
  int bat = dout_dims.size() - in_dims.size();
53 54 55
  for (int i = 0; i < bat; ++i) {
    result.push_back(i);
  }
56 57
  for (int i = 0; i < in_dims.size(); ++i) {
    if (in_dims[i] == 1) {
58 59 60
      result.push_back(i + bat);
    } else {
      PADDLE_ENFORCE_EQ(
61 62
          in_dims[i],
          dout_dims[i + bat],
63 64
          platform::errors::InvalidArgument(
              "ReduceDims dimension mismatch. Operands could "
65 66
              "not be broadcast together with the shape of dout = [%s] and "
              "the shape of in_dims = [%s]. Received [%d] in X is not equal to "
67
              "[%d] in Y at i:%d.",
68 69 70 71
              dout_dims,
              in_dims,
              dout_dims[i + bat],
              in_dims[i],
72 73 74
              i));
    }
  }
75
  return phi::make_ddim(result);
76
}
77 78 79 80 81 82 83

static phi::DDim get_reduce_dims(const phi::DDim& x_dims,
                                 const phi::DDim& y_dims) {
  auto out_dims = paddle::operators::details::BroadcastTwoDims(x_dims, y_dims);
  return get_reduce_dims_from_out(out_dims, x_dims);
}

84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
static std::vector<int> get_reduce_dims(const Tensor& dx,
                                        const int& dout_ndim,
                                        const int& x_ndim,
                                        std::vector<int64_t>* x_dims) {
  // this branch for broadcast with 1dim, we make 1dim to 2dim which make
  // ddout_ndim > dout_dim, but ddout_ndim just can be used when grad_out_grad
  // != nullptr
  if (dout_ndim < x_ndim) {
    return std::vector<int>({});
  }
  const std::vector<std::int64_t> dx_dims = phi::vectorize(dx.dims());
  std::vector<std::int64_t> broadcast_dims(dout_ndim);
  std::fill(
      broadcast_dims.data(), broadcast_dims.data() + dout_ndim - x_ndim, 1);
  std::copy(x_dims->data(),
            x_dims->data() + x_ndim,
            broadcast_dims.data() + dout_ndim - x_ndim);
  std::vector<int> reduce_dims;
  for (int i = 0; i <= dout_ndim - 3; i++) {
    if (dx_dims[i] != 1 && broadcast_dims[i] == 1) {
      reduce_dims.push_back(i);
    }
  }
  return reduce_dims;
}

110 111 112 113 114 115
// TODO(cxxly): Check and throws InvalidCastException when overflow.
template <typename SRC_T, typename DST_T>
static std::vector<DST_T> unsafe_vector_cast(const std::vector<SRC_T>& src) {
  std::vector<DST_T> dst(src.begin(), src.end());
  return dst;
}
116

117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
// This fucction compute unsqueeze dims for reshape to replace unsqueeze.
static std::vector<int> get_unsqueeze_dims(const Tensor& origin,
                                           const IntArray& axis) {
  auto origin_dims = origin.shape();
  auto total_shape_size = origin_dims.size() + axis.size();
  std::vector<int> result;
  int j = 0, k = 0;
  for (size_t i = 0; i < total_shape_size; ++i) {
    if (axis[j] == int64_t(i)) {
      result.push_back(1);
      j++;
    } else {
      result.push_back(origin_dims[k]);
      k++;
    }
  }
  return result;
}
J
Jiabin Yang 已提交
135 136
}  // namespace prim
}  // namespace paddle