elementwise_sig.cc 11.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/core/compat/op_utils.h"
16

17
namespace phi {
18 19 20 21

KernelSignature ElementwiseAddOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
Y
YuanRisheng 已提交
22 23
  if (axis == -1) {
    return KernelSignature("add", {"X", "Y"}, {}, {"Out"});
24
  }
Y
YuanRisheng 已提交
25
  return KernelSignature("add_raw", {"X", "Y"}, {"axis"}, {"Out"});
26 27
}

28
KernelSignature ElementwiseGradAddOpArgumentMapping(
G
Galaxy1458 已提交
29
    const ArgumentMappingContext& ctx UNUSED) {
30 31 32
  return KernelSignature("grad_add", {"X", "Y"}, {}, {"Out"});
}

33 34 35
KernelSignature ElementwiseSubOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
Y
YuanRisheng 已提交
36 37
  if (axis == -1) {
    return KernelSignature("subtract", {"X", "Y"}, {}, {"Out"});
38
  }
Y
YuanRisheng 已提交
39
  return KernelSignature("subtract_raw", {"X", "Y"}, {"axis"}, {"Out"});
40 41 42 43 44 45 46 47 48 49
}

KernelSignature ElementwiseMulOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (ctx.IsDenseTensorInput("X")) {
    if (axis == -1) {
      return KernelSignature("multiply", {"X", "Y"}, {}, {"Out"});
    }
    return KernelSignature("multiply_raw", {"X", "Y"}, {"axis"}, {"Out"});
50 51 52 53 54
  } else {
    if (axis == -1) {
      return KernelSignature("multiply_sr", {"X", "Y"}, {}, {"Out"});
    }
    return KernelSignature("multiply_raw_sr", {"X", "Y"}, {"axis"}, {"Out"});
55 56 57 58 59 60
  }
}

KernelSignature ElementwiseDivOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
Y
YuanRisheng 已提交
61 62
  if (axis == -1) {
    return KernelSignature("divide", {"X", "Y"}, {}, {"Out"});
63
  }
Y
YuanRisheng 已提交
64
  return KernelSignature("divide_raw", {"X", "Y"}, {"axis"}, {"Out"});
65 66
}

67 68
KernelSignature ElementwiseMaxOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
69 70 71
  if (ctx.IsForInferShape()) {
    return KernelSignature("maximum_raw", {"X", "Y"}, {"axis"}, {"Out"});
  }
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("maximum", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("maximum_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

KernelSignature ElementwiseMinOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("minimum", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("minimum_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

KernelSignature ElementwiseModOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
90 91 92
  if (ctx.IsForInferShape()) {
    return KernelSignature("remainder_raw", {"X", "Y"}, {"axis"}, {"Out"});
  }
93 94
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
C
Chen Weihang 已提交
95
    return KernelSignature("remainder", {"X", "Y"}, {}, {"Out"});
96
  }
C
Chen Weihang 已提交
97
  return KernelSignature("remainder_raw", {"X", "Y"}, {"axis"}, {"Out"});
98 99
}

100 101
KernelSignature ElementwiseFloorDivOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
102 103 104
  if (ctx.IsForInferShape()) {
    return KernelSignature("floor_divide_raw", {"X", "Y"}, {"axis"}, {"Out"});
  }
105 106 107 108 109 110 111 112
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("floor_divide", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("floor_divide_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

KernelSignature ElementwisePowOpArgumentMapping(
G
Galaxy1458 已提交
113
    const ArgumentMappingContext& ctx UNUSED) {
114 115 116 117 118 119 120
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("elementwise_pow", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("elementwise_pow_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

121
KernelSignature ElementwiseAddGradOpArgumentMapping(
G
Galaxy1458 已提交
122
    const ArgumentMappingContext& ctx UNUSED) {
123 124
  return KernelSignature(
      "add_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
125 126
}

127
KernelSignature ElementwiseAddDoubleGradOpArgumentMapping(
G
Galaxy1458 已提交
128
    const ArgumentMappingContext& ctx UNUSED) {
129
  return KernelSignature(
130
      "add_double_grad", {"Y", "DOut", "DDX", "DDY"}, {"axis"}, {"DDOut"});
131 132 133
}

KernelSignature ElementwiseAddTripleGradOpArgumentMapping(
G
Galaxy1458 已提交
134
    const ArgumentMappingContext& ctx UNUSED) {
135 136 137 138 139 140 141
  return KernelSignature("add_triple_grad",
                         {"DDX", "DDY", "D_DDOut"},
                         {"axis"},
                         {"D_DDX", "D_DDY"});
}

KernelSignature ElementwiseSubGradOpArgumentMapping(
G
Galaxy1458 已提交
142
    const ArgumentMappingContext& ctx UNUSED) {
143 144
  return KernelSignature(
      "subtract_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
145 146
}

147
KernelSignature ElementwiseSubDoubleGradOpArgumentMapping(
G
Galaxy1458 已提交
148
    const ArgumentMappingContext& ctx UNUSED) {
149
  return KernelSignature(
150
      "subtract_double_grad", {"Y", "DOut", "DDX", "DDY"}, {"axis"}, {"DDOut"});
151 152
}

153
KernelSignature ElementwiseDivGradOpArgumentMapping(
G
Galaxy1458 已提交
154
    const ArgumentMappingContext& ctx UNUSED) {
155
  return KernelSignature("divide_grad",
156
                         {"X", "Y", "Out", "Out@GRAD"},
157
                         {"axis"},
158
                         {"X@GRAD", "Y@GRAD"});
159 160
}

161
KernelSignature ElementwiseFMinGradOpArgumentMapping(
G
Galaxy1458 已提交
162
    const ArgumentMappingContext& ctx UNUSED) {
163
  return KernelSignature(
Z
zyfncg 已提交
164
      "fmin_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
165 166
}

167
KernelSignature ElementwiseDivDoubleGradOpArgumentMapping(
G
Galaxy1458 已提交
168
    const ArgumentMappingContext& ctx UNUSED) {
169 170 171
  return KernelSignature("divide_double_grad",
                         {"Y", "Out", "DX", "DDX", "DDY"},
                         {"axis"},
172
                         {"Y@GRAD", "DOut", "DDOut"});
173 174
}

Y
YuanRisheng 已提交
175
KernelSignature ElementwiseMulGradOpArgumentMapping(
G
Galaxy1458 已提交
176
    const ArgumentMappingContext& ctx UNUSED) {
177 178
  return KernelSignature(
      "multiply_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
Y
YuanRisheng 已提交
179 180
}

Y
YuanRisheng 已提交
181
KernelSignature ElementwiseFMaxOpArgumentMapping(
G
Galaxy1458 已提交
182
    const ArgumentMappingContext& ctx UNUSED) {
Z
zhangyuqin1998 已提交
183
  return KernelSignature("fmax", {"X", "Y"}, {}, {"Out"});
Y
YuanRisheng 已提交
184 185 186
}

KernelSignature ElementwiseFMinOpArgumentMapping(
G
Galaxy1458 已提交
187
    const ArgumentMappingContext& ctx UNUSED) {
Z
zyfncg 已提交
188
  return KernelSignature("fmin", {"X", "Y"}, {}, {"Out"});
Y
YuanRisheng 已提交
189 190
}

Y
YuanRisheng 已提交
191
KernelSignature ElementwiseMulDoubleGradOpArgumentMapping(
G
Galaxy1458 已提交
192
    const ArgumentMappingContext& ctx UNUSED) {
Y
YuanRisheng 已提交
193 194 195
  return KernelSignature("multiply_double_grad",
                         {"X", "Y", "DOut", "DDX", "DDY"},
                         {"axis"},
196
                         {"X@GRAD", "Y@GRAD", "DDOut"});
Y
YuanRisheng 已提交
197 198 199
}

KernelSignature ElementwiseMulTripleGradOpArgumentMapping(
G
Galaxy1458 已提交
200
    const ArgumentMappingContext& ctx UNUSED) {
Y
YuanRisheng 已提交
201 202 203 204 205 206 207
  return KernelSignature(
      "multiply_triple_grad",
      {"X", "Y", "DOut", "DDX", "DDY", "D_DX", "D_DY", "D_DDOut"},
      {"axis"},
      {"D_X", "D_Y", "D_DOut", "D_DDX", "D_DDY"});
}

208
KernelSignature ElementwiseMinGradOpArgumentMapping(
G
Galaxy1458 已提交
209
    const ArgumentMappingContext& ctx UNUSED) {
210
  return KernelSignature(
211
      "minimum_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
212
}
213

214
KernelSignature ElementwisePowGradOpArgumentMapping(
G
Galaxy1458 已提交
215
    const ArgumentMappingContext& ctx UNUSED) {
216 217
  return KernelSignature(
      "elementwise_pow_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
218
}
219
}  // namespace phi
220

221 222 223 224
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add, add);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub, subtract);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul, multiply);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_div, divide);
225 226
PD_REGISTER_BASE_KERNEL_NAME(elementwise_max, maximum);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_min, minimum);
C
Chen Weihang 已提交
227
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mod, remainder);
228
PD_REGISTER_BASE_KERNEL_NAME(elementwise_floordiv, floor_divide);
229 230 231 232
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad, add_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad_grad, add_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_triple_grad, add_triple_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad, subtract_grad);
233
PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad_grad, subtract_double_grad);
234 235
PD_REGISTER_BASE_KERNEL_NAME(elementwise_div_grad, divide_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_div_grad_grad, divide_double_grad);
Y
YuanRisheng 已提交
236 237 238
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad, multiply_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad_grad, multiply_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_triple_grad, multiply_triple_grad);
Y
YuanRisheng 已提交
239 240 241
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmax, fmax);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmin, fmin);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmin_grad, fmin_grad);
242
PD_REGISTER_BASE_KERNEL_NAME(elementwise_min_grad, minimum_grad);
243 244

PD_REGISTER_ARG_MAPPING_FN(elementwise_add,
245
                           phi::ElementwiseAddOpArgumentMapping);
246
PD_REGISTER_ARG_MAPPING_FN(elementwise_sub,
247
                           phi::ElementwiseSubOpArgumentMapping);
248
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul,
249
                           phi::ElementwiseMulOpArgumentMapping);
250
PD_REGISTER_ARG_MAPPING_FN(elementwise_div,
251
                           phi::ElementwiseDivOpArgumentMapping);
252 253 254 255 256 257
PD_REGISTER_ARG_MAPPING_FN(elementwise_max,
                           phi::ElementwiseMaxOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_min,
                           phi::ElementwiseMinOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_mod,
                           phi::ElementwiseModOpArgumentMapping);
258 259 260 261
PD_REGISTER_ARG_MAPPING_FN(elementwise_floordiv,
                           phi::ElementwiseFloorDivOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_pow,
                           phi::ElementwisePowOpArgumentMapping);
262
PD_REGISTER_ARG_MAPPING_FN(elementwise_add_grad,
263
                           phi::ElementwiseAddGradOpArgumentMapping);
264
PD_REGISTER_ARG_MAPPING_FN(elementwise_add_grad_grad,
265
                           phi::ElementwiseAddDoubleGradOpArgumentMapping);
266
PD_REGISTER_ARG_MAPPING_FN(elementwise_add_triple_grad,
267
                           phi::ElementwiseAddTripleGradOpArgumentMapping);
268
PD_REGISTER_ARG_MAPPING_FN(elementwise_sub_grad,
269
                           phi::ElementwiseSubGradOpArgumentMapping);
270 271
PD_REGISTER_ARG_MAPPING_FN(elementwise_sub_grad_grad,
                           phi::ElementwiseSubDoubleGradOpArgumentMapping);
272 273 274 275
PD_REGISTER_ARG_MAPPING_FN(elementwise_div_grad,
                           phi::ElementwiseDivGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_div_grad_grad,
                           phi::ElementwiseDivDoubleGradOpArgumentMapping);
Y
YuanRisheng 已提交
276 277 278 279 280 281
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_grad,
                           phi::ElementwiseMulGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_grad_grad,
                           phi::ElementwiseMulDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_triple_grad,
                           phi::ElementwiseMulTripleGradOpArgumentMapping);
Y
YuanRisheng 已提交
282 283 284 285
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmax,
                           phi::ElementwiseFMaxOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmin,
                           phi::ElementwiseFMinOpArgumentMapping);
286 287
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmin_grad,
                           phi::ElementwiseFMinGradOpArgumentMapping);
288 289
PD_REGISTER_ARG_MAPPING_FN(elementwise_min_grad,
                           phi::ElementwiseMinGradOpArgumentMapping);
290 291
PD_REGISTER_ARG_MAPPING_FN(elementwise_pow_grad,
                           phi::ElementwisePowGradOpArgumentMapping);
292
PD_REGISTER_ARG_MAPPING_FN(grad_add, phi::ElementwiseGradAddOpArgumentMapping);