utils.h 3.2 KB
Newer Older
J
Jiabin Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_proto_maker.h"
19
#include "paddle/fluid/operators/common_infer_shape_functions.h"
J
Jiabin Yang 已提交
20 21 22
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/place.h"
23
#include "paddle/phi/core/ddim.h"
24

J
Jiabin Yang 已提交
25 26 27 28 29 30 31 32 33 34 35 36
namespace paddle {
namespace prim {
// We put some api like utils here
template <typename T>
paddle::experimental::Tensor empty(const paddle::experimental::IntArray& shape,
                                   paddle::experimental::DataType dype,
                                   const paddle::Place& place);

template <typename T>
paddle::experimental::Tensor empty_like(const paddle::experimental::Tensor& x,
                                        paddle::experimental::DataType dtype,
                                        const paddle::Place& place);
37 38 39
template <typename T>
void by_pass(const paddle::experimental::Tensor& x,
             paddle::experimental::Tensor* out);
40

41 42 43 44
template <typename T>
void set_output(const paddle::experimental::Tensor& x_tmp,
                paddle::experimental::Tensor* x);

45
// These method don't need to be specified
46 47
static phi::DDim get_reduce_dims_from_out(const phi::DDim& dout_dims,
                                          const phi::DDim& in_dims) {
48
  std::vector<int64_t> result;
49
  int bat = dout_dims.size() - in_dims.size();
50 51 52
  for (int i = 0; i < bat; ++i) {
    result.push_back(i);
  }
53 54
  for (int i = 0; i < in_dims.size(); ++i) {
    if (in_dims[i] == 1) {
55 56 57
      result.push_back(i + bat);
    } else {
      PADDLE_ENFORCE_EQ(
58 59
          in_dims[i],
          dout_dims[i + bat],
60 61
          platform::errors::InvalidArgument(
              "ReduceDims dimension mismatch. Operands could "
62 63
              "not be broadcast together with the shape of dout = [%s] and "
              "the shape of in_dims = [%s]. Received [%d] in X is not equal to "
64
              "[%d] in Y at i:%d.",
65 66 67 68
              dout_dims,
              in_dims,
              dout_dims[i + bat],
              in_dims[i],
69 70 71
              i));
    }
  }
72
  return phi::make_ddim(result);
73
}
74 75 76 77 78 79 80

static phi::DDim get_reduce_dims(const phi::DDim& x_dims,
                                 const phi::DDim& y_dims) {
  auto out_dims = paddle::operators::details::BroadcastTwoDims(x_dims, y_dims);
  return get_reduce_dims_from_out(out_dims, x_dims);
}

81 82 83 84 85 86
// TODO(cxxly): Check and throws InvalidCastException when overflow.
template <typename SRC_T, typename DST_T>
static std::vector<DST_T> unsafe_vector_cast(const std::vector<SRC_T>& src) {
  std::vector<DST_T> dst(src.begin(), src.end());
  return dst;
}
J
Jiabin Yang 已提交
87 88
}  // namespace prim
}  // namespace paddle