elementwise_sig.cc 11.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/core/compat/op_utils.h"
16

17
namespace phi {
18 19 20 21

KernelSignature ElementwiseAddOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
Y
YuanRisheng 已提交
22 23
  if (axis == -1) {
    return KernelSignature("add", {"X", "Y"}, {}, {"Out"});
24
  }
Y
YuanRisheng 已提交
25
  return KernelSignature("add_raw", {"X", "Y"}, {"axis"}, {"Out"});
26 27
}

28
KernelSignature ElementwiseGradAddOpArgumentMapping(
G
Galaxy1458 已提交
29
    const ArgumentMappingContext& ctx UNUSED) {
30 31 32
  return KernelSignature("grad_add", {"X", "Y"}, {}, {"Out"});
}

33 34 35
KernelSignature ElementwiseSubOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
Y
YuanRisheng 已提交
36 37
  if (axis == -1) {
    return KernelSignature("subtract", {"X", "Y"}, {}, {"Out"});
38
  }
Y
YuanRisheng 已提交
39
  return KernelSignature("subtract_raw", {"X", "Y"}, {"axis"}, {"Out"});
40 41 42 43 44 45 46 47 48 49
}

KernelSignature ElementwiseMulOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (ctx.IsDenseTensorInput("X")) {
    if (axis == -1) {
      return KernelSignature("multiply", {"X", "Y"}, {}, {"Out"});
    }
    return KernelSignature("multiply_raw", {"X", "Y"}, {"axis"}, {"Out"});
50 51 52 53 54
  } else {
    if (axis == -1) {
      return KernelSignature("multiply_sr", {"X", "Y"}, {}, {"Out"});
    }
    return KernelSignature("multiply_raw_sr", {"X", "Y"}, {"axis"}, {"Out"});
55 56 57 58 59 60
  }
}

KernelSignature ElementwiseDivOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
Y
YuanRisheng 已提交
61 62
  if (axis == -1) {
    return KernelSignature("divide", {"X", "Y"}, {}, {"Out"});
63
  }
Y
YuanRisheng 已提交
64
  return KernelSignature("divide_raw", {"X", "Y"}, {"axis"}, {"Out"});
65 66
}

67 68
KernelSignature ElementwiseMaxOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
69 70 71
  if (ctx.IsForInferShape()) {
    return KernelSignature("maximum_raw", {"X", "Y"}, {"axis"}, {"Out"});
  }
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("maximum", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("maximum_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

KernelSignature ElementwiseMinOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("minimum", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("minimum_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

KernelSignature ElementwiseModOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
90 91 92
  if (ctx.IsForInferShape()) {
    return KernelSignature("remainder_raw", {"X", "Y"}, {"axis"}, {"Out"});
  }
93 94
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
C
Chen Weihang 已提交
95
    return KernelSignature("remainder", {"X", "Y"}, {}, {"Out"});
96
  }
C
Chen Weihang 已提交
97
  return KernelSignature("remainder_raw", {"X", "Y"}, {"axis"}, {"Out"});
98 99
}

100 101
KernelSignature ElementwiseFloorDivOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
102 103 104
  if (ctx.IsForInferShape()) {
    return KernelSignature("floor_divide_raw", {"X", "Y"}, {"axis"}, {"Out"});
  }
105 106 107 108 109 110 111 112
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("floor_divide", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("floor_divide_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

KernelSignature ElementwisePowOpArgumentMapping(
G
Galaxy1458 已提交
113
    const ArgumentMappingContext& ctx UNUSED) {
114 115 116 117
  if (ctx.IsForInferShape()) {
    return KernelSignature(
        "elementwise_pow_raw", {"X", "Y"}, {"axis"}, {"Out"});
  }
118 119 120 121 122 123 124
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("elementwise_pow", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("elementwise_pow_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

125
KernelSignature ElementwiseAddGradOpArgumentMapping(
G
Galaxy1458 已提交
126
    const ArgumentMappingContext& ctx UNUSED) {
127 128
  return KernelSignature(
      "add_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
129 130
}

131
KernelSignature ElementwiseAddDoubleGradOpArgumentMapping(
G
Galaxy1458 已提交
132
    const ArgumentMappingContext& ctx UNUSED) {
133
  return KernelSignature(
134
      "add_double_grad", {"Y", "DOut", "DDX", "DDY"}, {"axis"}, {"DDOut"});
135 136 137
}

KernelSignature ElementwiseAddTripleGradOpArgumentMapping(
G
Galaxy1458 已提交
138
    const ArgumentMappingContext& ctx UNUSED) {
139 140 141 142 143 144 145
  return KernelSignature("add_triple_grad",
                         {"DDX", "DDY", "D_DDOut"},
                         {"axis"},
                         {"D_DDX", "D_DDY"});
}

KernelSignature ElementwiseSubGradOpArgumentMapping(
G
Galaxy1458 已提交
146
    const ArgumentMappingContext& ctx UNUSED) {
147 148
  return KernelSignature(
      "subtract_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
149 150
}

151
KernelSignature ElementwiseSubDoubleGradOpArgumentMapping(
G
Galaxy1458 已提交
152
    const ArgumentMappingContext& ctx UNUSED) {
153
  return KernelSignature(
154
      "subtract_double_grad", {"Y", "DOut", "DDX", "DDY"}, {"axis"}, {"DDOut"});
155 156
}

157
KernelSignature ElementwiseDivGradOpArgumentMapping(
G
Galaxy1458 已提交
158
    const ArgumentMappingContext& ctx UNUSED) {
159
  return KernelSignature("divide_grad",
160
                         {"X", "Y", "Out", "Out@GRAD"},
161
                         {"axis"},
162
                         {"X@GRAD", "Y@GRAD"});
163 164
}

165
KernelSignature ElementwiseFMinGradOpArgumentMapping(
G
Galaxy1458 已提交
166
    const ArgumentMappingContext& ctx UNUSED) {
167
  return KernelSignature(
Z
zyfncg 已提交
168
      "fmin_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
169 170
}

171
KernelSignature ElementwiseDivDoubleGradOpArgumentMapping(
G
Galaxy1458 已提交
172
    const ArgumentMappingContext& ctx UNUSED) {
173 174 175
  return KernelSignature("divide_double_grad",
                         {"Y", "Out", "DX", "DDX", "DDY"},
                         {"axis"},
176
                         {"Y@GRAD", "DOut", "DDOut"});
177 178
}

Y
YuanRisheng 已提交
179
KernelSignature ElementwiseMulGradOpArgumentMapping(
G
Galaxy1458 已提交
180
    const ArgumentMappingContext& ctx UNUSED) {
181 182
  return KernelSignature(
      "multiply_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
Y
YuanRisheng 已提交
183 184
}

Y
YuanRisheng 已提交
185
KernelSignature ElementwiseFMaxOpArgumentMapping(
G
Galaxy1458 已提交
186
    const ArgumentMappingContext& ctx UNUSED) {
Z
zhangyuqin1998 已提交
187
  return KernelSignature("fmax", {"X", "Y"}, {}, {"Out"});
Y
YuanRisheng 已提交
188 189 190
}

KernelSignature ElementwiseFMinOpArgumentMapping(
G
Galaxy1458 已提交
191
    const ArgumentMappingContext& ctx UNUSED) {
Z
zyfncg 已提交
192
  return KernelSignature("fmin", {"X", "Y"}, {}, {"Out"});
Y
YuanRisheng 已提交
193 194
}

Y
YuanRisheng 已提交
195
KernelSignature ElementwiseMulDoubleGradOpArgumentMapping(
G
Galaxy1458 已提交
196
    const ArgumentMappingContext& ctx UNUSED) {
Y
YuanRisheng 已提交
197 198 199
  return KernelSignature("multiply_double_grad",
                         {"X", "Y", "DOut", "DDX", "DDY"},
                         {"axis"},
200
                         {"X@GRAD", "Y@GRAD", "DDOut"});
Y
YuanRisheng 已提交
201 202 203
}

KernelSignature ElementwiseMulTripleGradOpArgumentMapping(
G
Galaxy1458 已提交
204
    const ArgumentMappingContext& ctx UNUSED) {
Y
YuanRisheng 已提交
205 206 207 208 209 210 211
  return KernelSignature(
      "multiply_triple_grad",
      {"X", "Y", "DOut", "DDX", "DDY", "D_DX", "D_DY", "D_DDOut"},
      {"axis"},
      {"D_X", "D_Y", "D_DOut", "D_DDX", "D_DDY"});
}

212
KernelSignature ElementwiseMinGradOpArgumentMapping(
G
Galaxy1458 已提交
213
    const ArgumentMappingContext& ctx UNUSED) {
214
  return KernelSignature(
215
      "minimum_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
216
}
217

218
}  // namespace phi
219

220 221 222 223
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add, add);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub, subtract);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul, multiply);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_div, divide);
224 225
PD_REGISTER_BASE_KERNEL_NAME(elementwise_max, maximum);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_min, minimum);
C
Chen Weihang 已提交
226
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mod, remainder);
227
PD_REGISTER_BASE_KERNEL_NAME(elementwise_floordiv, floor_divide);
228 229 230 231
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad, add_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad_grad, add_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_triple_grad, add_triple_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad, subtract_grad);
232
PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad_grad, subtract_double_grad);
233 234
PD_REGISTER_BASE_KERNEL_NAME(elementwise_div_grad, divide_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_div_grad_grad, divide_double_grad);
Y
YuanRisheng 已提交
235 236 237
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad, multiply_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad_grad, multiply_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_triple_grad, multiply_triple_grad);
Y
YuanRisheng 已提交
238 239 240
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmax, fmax);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmin, fmin);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmin_grad, fmin_grad);
241
PD_REGISTER_BASE_KERNEL_NAME(elementwise_min_grad, minimum_grad);
242 243

PD_REGISTER_ARG_MAPPING_FN(elementwise_add,
244
                           phi::ElementwiseAddOpArgumentMapping);
245
PD_REGISTER_ARG_MAPPING_FN(elementwise_sub,
246
                           phi::ElementwiseSubOpArgumentMapping);
247
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul,
248
                           phi::ElementwiseMulOpArgumentMapping);
249
PD_REGISTER_ARG_MAPPING_FN(elementwise_div,
250
                           phi::ElementwiseDivOpArgumentMapping);
251 252 253 254 255 256
PD_REGISTER_ARG_MAPPING_FN(elementwise_max,
                           phi::ElementwiseMaxOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_min,
                           phi::ElementwiseMinOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_mod,
                           phi::ElementwiseModOpArgumentMapping);
257 258 259 260
PD_REGISTER_ARG_MAPPING_FN(elementwise_floordiv,
                           phi::ElementwiseFloorDivOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_pow,
                           phi::ElementwisePowOpArgumentMapping);
261
PD_REGISTER_ARG_MAPPING_FN(elementwise_add_grad,
262
                           phi::ElementwiseAddGradOpArgumentMapping);
263
PD_REGISTER_ARG_MAPPING_FN(elementwise_add_grad_grad,
264
                           phi::ElementwiseAddDoubleGradOpArgumentMapping);
265
PD_REGISTER_ARG_MAPPING_FN(elementwise_add_triple_grad,
266
                           phi::ElementwiseAddTripleGradOpArgumentMapping);
267
PD_REGISTER_ARG_MAPPING_FN(elementwise_sub_grad,
268
                           phi::ElementwiseSubGradOpArgumentMapping);
269 270
PD_REGISTER_ARG_MAPPING_FN(elementwise_sub_grad_grad,
                           phi::ElementwiseSubDoubleGradOpArgumentMapping);
271 272 273 274
PD_REGISTER_ARG_MAPPING_FN(elementwise_div_grad,
                           phi::ElementwiseDivGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_div_grad_grad,
                           phi::ElementwiseDivDoubleGradOpArgumentMapping);
Y
YuanRisheng 已提交
275 276 277 278 279 280
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_grad,
                           phi::ElementwiseMulGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_grad_grad,
                           phi::ElementwiseMulDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_triple_grad,
                           phi::ElementwiseMulTripleGradOpArgumentMapping);
Y
YuanRisheng 已提交
281 282 283 284
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmax,
                           phi::ElementwiseFMaxOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmin,
                           phi::ElementwiseFMinOpArgumentMapping);
285 286
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmin_grad,
                           phi::ElementwiseFMinGradOpArgumentMapping);
287 288
PD_REGISTER_ARG_MAPPING_FN(elementwise_min_grad,
                           phi::ElementwiseMinGradOpArgumentMapping);
289
PD_REGISTER_ARG_MAPPING_FN(grad_add, phi::ElementwiseGradAddOpArgumentMapping);