math_kernel.h 6.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17 18 19 20
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/empty_kernel.h"
21

22
namespace phi {
23

24 25 26 27 28 29 30 31
template <typename T, typename Context>
void MeanRawKernel(const Context& dev_ctx,
                   const DenseTensor& x,
                   const std::vector<int64_t>& dims,
                   bool keep_dim,
                   bool reduce_all,
                   DenseTensor* out);

32
template <typename T, typename Context>
33 34 35 36 37
void MeanKernel(const Context& dev_ctx,
                const DenseTensor& x,
                const std::vector<int64_t>& dims,
                bool keep_dim,
                DenseTensor* out);
38

39 40 41 42 43 44 45 46 47 48 49 50 51 52
template <typename T, typename Context>
void SumRawKernel(const Context& dev_ctx,
                  const DenseTensor& x,
                  const std::vector<int64_t>& dims,
                  bool keep_dim,
                  bool reduce_all,
                  DataType out_dtype,
                  DenseTensor* out);

template <typename T, typename Context>
void SumKernel(const Context& dev_ctx,
               const DenseTensor& x,
               const std::vector<int64_t>& dims,
               DataType out_dtype,
53
               bool keep_dim,
54 55 56 57 58 59 60 61 62
               DenseTensor* out);

template <typename T, typename Context>
void AddRawKernel(const Context& dev_ctx,
                  const DenseTensor& x,
                  const DenseTensor& y,
                  int axis,
                  DenseTensor* out);

63 64 65 66 67 68
template <typename T, typename Context>
void AddKernel(const Context& dev_ctx,
               const DenseTensor& x,
               const DenseTensor& y,
               DenseTensor* out);

69 70 71 72 73 74 75
template <typename T, typename Context>
void SubtractRawKernel(const Context& dev_ctx,
                       const DenseTensor& x,
                       const DenseTensor& y,
                       int axis,
                       DenseTensor* out);

76 77 78 79 80 81
template <typename T, typename Context>
void SubtractKernel(const Context& dev_ctx,
                    const DenseTensor& x,
                    const DenseTensor& y,
                    DenseTensor* out);

82 83 84 85 86 87 88
template <typename T, typename Context>
void DivideRawKernel(const Context& dev_ctx,
                     const DenseTensor& x,
                     const DenseTensor& y,
                     int axis,
                     DenseTensor* out);

89 90 91 92 93 94
template <typename T, typename Context>
void DivideKernel(const Context& dev_ctx,
                  const DenseTensor& x,
                  const DenseTensor& y,
                  DenseTensor* out);

95 96 97 98 99 100 101
template <typename T, typename Context>
void MultiplyRawKernel(const Context& dev_ctx,
                       const DenseTensor& x,
                       const DenseTensor& y,
                       int axis,
                       DenseTensor* out);

102 103 104 105 106 107
template <typename T, typename Context>
void MultiplyKernel(const Context& dev_ctx,
                    const DenseTensor& x,
                    const DenseTensor& y,
                    DenseTensor* out);

C
Chen Weihang 已提交
108 109
template <typename T, typename Context>
DenseTensor Add(const Context& dev_ctx,
110
                const DenseTensor& x,
111
                const DenseTensor& y) {
112
  auto dense_out = phi::Empty<T, Context>(dev_ctx);
113 114
  MetaTensor meta_out(&dense_out);
  ElementwiseInferMeta(x, y, &meta_out);
115
  AddKernel<T, Context>(dev_ctx, x, y, &dense_out);
116 117 118
  return dense_out;
}

C
Chen Weihang 已提交
119 120
template <typename T, typename Context>
DenseTensor Subtract(const Context& dev_ctx,
121
                     const DenseTensor& x,
122
                     const DenseTensor& y) {
123
  auto dense_out = phi::Empty<T, Context>(dev_ctx);
124 125
  MetaTensor meta_out(&dense_out);
  ElementwiseInferMeta(x, y, &meta_out);
126
  SubtractKernel<T, Context>(dev_ctx, x, y, &dense_out);
127 128 129
  return dense_out;
}

C
Chen Weihang 已提交
130 131
template <typename T, typename Context>
DenseTensor Divide(const Context& dev_ctx,
132
                   const DenseTensor& x,
133
                   const DenseTensor& y) {
134
  auto dense_out = phi::Empty<T, Context>(dev_ctx);
135 136
  MetaTensor meta_out(&dense_out);
  ElementwiseInferMeta(x, y, &meta_out);
137
  DivideKernel<T, Context>(dev_ctx, x, y, &dense_out);
138 139 140
  return dense_out;
}

C
Chen Weihang 已提交
141 142
template <typename T, typename Context>
DenseTensor Multiply(const Context& dev_ctx,
143
                     const DenseTensor& x,
144
                     const DenseTensor& y) {
145
  auto dense_out = phi::Empty<T, Context>(dev_ctx);
146 147
  MetaTensor meta_out(&dense_out);
  ElementwiseInferMeta(x, y, &meta_out);
148
  MultiplyKernel<T, Context>(dev_ctx, x, y, &dense_out);
149 150 151
  return dense_out;
}

152 153 154 155 156
template <typename T, typename Context>
DenseTensor Mean(const Context& dev_ctx,
                 const DenseTensor& x,
                 const std::vector<int64_t>& axis,
                 bool keep_dim) {
157
  auto dense_out = phi::Empty<T, Context>(dev_ctx);
158
  MetaTensor meta_out(&dense_out);
159
  ReduceInferMetaBase(x, axis, keep_dim, false, x.dtype(), &meta_out);
160
  MeanKernel<T, Context>(dev_ctx, x, axis, keep_dim, &dense_out);
161 162 163 164 165 166 167 168 169
  return dense_out;
}

template <typename T, typename Context>
DenseTensor Sum(const Context& dev_ctx,
                const DenseTensor& x,
                const std::vector<int64_t>& axis,
                DataType dtype,
                bool keep_dim) {
170
  auto dense_out = phi::Empty<T, Context>(dev_ctx);
171 172
  MetaTensor meta_out(&dense_out);
  SumInferMeta(x, axis, dtype, keep_dim, &meta_out);
173
  SumKernel<T, Context>(dev_ctx, x, axis, dtype, keep_dim, &dense_out);
174 175 176
  return dense_out;
}

177
}  // namespace phi