math_kernel.h 3.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17 18 19
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/binary.h"
namespace phi {
20

21 22 23 24 25 26 27
template <typename T, typename Context>
void AddRawKernel(const Context& dev_ctx,
                  const DenseTensor& x,
                  const DenseTensor& y,
                  int axis,
                  DenseTensor* out);

28 29 30 31 32 33
template <typename T, typename Context>
void AddKernel(const Context& dev_ctx,
               const DenseTensor& x,
               const DenseTensor& y,
               DenseTensor* out);

34 35 36 37 38 39 40
template <typename T, typename Context>
void SubtractRawKernel(const Context& dev_ctx,
                       const DenseTensor& x,
                       const DenseTensor& y,
                       int axis,
                       DenseTensor* out);

41 42 43 44 45 46
template <typename T, typename Context>
void SubtractKernel(const Context& dev_ctx,
                    const DenseTensor& x,
                    const DenseTensor& y,
                    DenseTensor* out);

47 48 49 50 51 52 53
template <typename T, typename Context>
void DivideRawKernel(const Context& dev_ctx,
                     const DenseTensor& x,
                     const DenseTensor& y,
                     int axis,
                     DenseTensor* out);

54 55 56 57 58 59
template <typename T, typename Context>
void DivideKernel(const Context& dev_ctx,
                  const DenseTensor& x,
                  const DenseTensor& y,
                  DenseTensor* out);

60 61 62 63 64 65 66
template <typename T, typename Context>
void MultiplyRawKernel(const Context& dev_ctx,
                       const DenseTensor& x,
                       const DenseTensor& y,
                       int axis,
                       DenseTensor* out);

67 68 69 70 71 72
template <typename T, typename Context>
void MultiplyKernel(const Context& dev_ctx,
                    const DenseTensor& x,
                    const DenseTensor& y,
                    DenseTensor* out);

C
Chen Weihang 已提交
73 74
template <typename T, typename Context>
DenseTensor Add(const Context& dev_ctx,
75
                const DenseTensor& x,
76
                const DenseTensor& y) {
77
  DenseTensor dense_out;
78 79
  MetaTensor meta_out(&dense_out);
  ElementwiseInferMeta(x, y, &meta_out);
80
  AddKernel<T, Context>(dev_ctx, x, y, &dense_out);
81 82 83
  return dense_out;
}

C
Chen Weihang 已提交
84 85
template <typename T, typename Context>
DenseTensor Subtract(const Context& dev_ctx,
86
                     const DenseTensor& x,
87
                     const DenseTensor& y) {
88
  DenseTensor dense_out;
89 90
  MetaTensor meta_out(&dense_out);
  ElementwiseInferMeta(x, y, &meta_out);
91
  SubtractKernel<T, Context>(dev_ctx, x, y, &dense_out);
92 93 94
  return dense_out;
}

C
Chen Weihang 已提交
95 96
template <typename T, typename Context>
DenseTensor Divide(const Context& dev_ctx,
97
                   const DenseTensor& x,
98
                   const DenseTensor& y) {
99
  DenseTensor dense_out;
100 101
  MetaTensor meta_out(&dense_out);
  ElementwiseInferMeta(x, y, &meta_out);
102
  DivideKernel<T, Context>(dev_ctx, x, y, &dense_out);
103 104 105
  return dense_out;
}

C
Chen Weihang 已提交
106 107
template <typename T, typename Context>
DenseTensor Multiply(const Context& dev_ctx,
108
                     const DenseTensor& x,
109
                     const DenseTensor& y) {
110
  DenseTensor dense_out;
111 112
  MetaTensor meta_out(&dense_out);
  ElementwiseInferMeta(x, y, &meta_out);
113
  MultiplyKernel<T, Context>(dev_ctx, x, y, &dense_out);
114 115 116
  return dense_out;
}

117
}  // namespace phi