elementwise_sig.cc 12.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/core/compat/op_utils.h"
16

17
namespace phi {
18 19 20 21

KernelSignature ElementwiseAddOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
Y
YuanRisheng 已提交
22 23
  if (axis == -1) {
    return KernelSignature("add", {"X", "Y"}, {}, {"Out"});
24
  }
Y
YuanRisheng 已提交
25
  return KernelSignature("add_raw", {"X", "Y"}, {"axis"}, {"Out"});
26 27
}

28
KernelSignature ElementwiseGradAddOpArgumentMapping(
G
Galaxy1458 已提交
29
    const ArgumentMappingContext& ctx UNUSED) {
30 31 32
  return KernelSignature("grad_add", {"X", "Y"}, {}, {"Out"});
}

33 34 35
KernelSignature ElementwiseSubOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
Y
YuanRisheng 已提交
36 37
  if (axis == -1) {
    return KernelSignature("subtract", {"X", "Y"}, {}, {"Out"});
38
  }
Y
YuanRisheng 已提交
39
  return KernelSignature("subtract_raw", {"X", "Y"}, {"axis"}, {"Out"});
40 41 42 43 44 45 46 47 48 49
}

KernelSignature ElementwiseMulOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (ctx.IsDenseTensorInput("X")) {
    if (axis == -1) {
      return KernelSignature("multiply", {"X", "Y"}, {}, {"Out"});
    }
    return KernelSignature("multiply_raw", {"X", "Y"}, {"axis"}, {"Out"});
50 51 52 53 54
  } else {
    if (axis == -1) {
      return KernelSignature("multiply_sr", {"X", "Y"}, {}, {"Out"});
    }
    return KernelSignature("multiply_raw_sr", {"X", "Y"}, {"axis"}, {"Out"});
55 56 57 58 59 60
  }
}

KernelSignature ElementwiseDivOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
Y
YuanRisheng 已提交
61 62
  if (axis == -1) {
    return KernelSignature("divide", {"X", "Y"}, {}, {"Out"});
63
  }
Y
YuanRisheng 已提交
64
  return KernelSignature("divide_raw", {"X", "Y"}, {"axis"}, {"Out"});
65 66
}

67 68
KernelSignature ElementwiseMaxOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
69 70 71
  if (ctx.IsForInferShape()) {
    return KernelSignature("maximum_raw", {"X", "Y"}, {"axis"}, {"Out"});
  }
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("maximum", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("maximum_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

KernelSignature ElementwiseMinOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("minimum", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("minimum_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

KernelSignature ElementwiseModOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
C
Chen Weihang 已提交
92
    return KernelSignature("remainder", {"X", "Y"}, {}, {"Out"});
93
  }
C
Chen Weihang 已提交
94
  return KernelSignature("remainder_raw", {"X", "Y"}, {"axis"}, {"Out"});
95 96
}

97 98 99 100 101 102 103 104 105
KernelSignature ElementwiseFloorDivOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("floor_divide", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("floor_divide_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

106
KernelSignature ElementwiseHeavisideOpArgumentMapping(
G
Galaxy1458 已提交
107
    const ArgumentMappingContext& ctx UNUSED) {
108
  return KernelSignature("heaviside", {"X", "Y"}, {}, {"Out"});
109 110
}

111
KernelSignature ElementwisePowOpArgumentMapping(
G
Galaxy1458 已提交
112
    const ArgumentMappingContext& ctx UNUSED) {
113 114 115 116 117 118 119
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("elementwise_pow", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("elementwise_pow_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

120
KernelSignature ElementwiseAddGradOpArgumentMapping(
G
Galaxy1458 已提交
121
    const ArgumentMappingContext& ctx UNUSED) {
122 123
  return KernelSignature(
      "add_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
124 125
}

126
KernelSignature ElementwiseAddDoubleGradOpArgumentMapping(
G
Galaxy1458 已提交
127
    const ArgumentMappingContext& ctx UNUSED) {
128
  return KernelSignature(
129
      "add_double_grad", {"Y", "DOut", "DDX", "DDY"}, {"axis"}, {"DDOut"});
130 131 132
}

KernelSignature ElementwiseAddTripleGradOpArgumentMapping(
G
Galaxy1458 已提交
133
    const ArgumentMappingContext& ctx UNUSED) {
134 135 136 137 138 139 140
  return KernelSignature("add_triple_grad",
                         {"DDX", "DDY", "D_DDOut"},
                         {"axis"},
                         {"D_DDX", "D_DDY"});
}

KernelSignature ElementwiseSubGradOpArgumentMapping(
G
Galaxy1458 已提交
141
    const ArgumentMappingContext& ctx UNUSED) {
142 143
  return KernelSignature(
      "subtract_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
144 145
}

146
KernelSignature ElementwiseSubDoubleGradOpArgumentMapping(
G
Galaxy1458 已提交
147
    const ArgumentMappingContext& ctx UNUSED) {
148
  return KernelSignature(
149
      "subtract_double_grad", {"Y", "DOut", "DDX", "DDY"}, {"axis"}, {"DDOut"});
150 151
}

152
KernelSignature ElementwiseDivGradOpArgumentMapping(
G
Galaxy1458 已提交
153
    const ArgumentMappingContext& ctx UNUSED) {
154
  return KernelSignature("divide_grad",
155
                         {"X", "Y", "Out", "Out@GRAD"},
156
                         {"axis"},
157
                         {"X@GRAD", "Y@GRAD"});
158 159
}

160
KernelSignature ElementwiseFMinGradOpArgumentMapping(
G
Galaxy1458 已提交
161
    const ArgumentMappingContext& ctx UNUSED) {
162
  return KernelSignature(
Z
zyfncg 已提交
163
      "fmin_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
164 165
}

166
KernelSignature ElementwiseDivDoubleGradOpArgumentMapping(
G
Galaxy1458 已提交
167
    const ArgumentMappingContext& ctx UNUSED) {
168 169 170
  return KernelSignature("divide_double_grad",
                         {"Y", "Out", "DX", "DDX", "DDY"},
                         {"axis"},
171
                         {"Y@GRAD", "DOut", "DDOut"});
172 173
}

Y
YuanRisheng 已提交
174
KernelSignature ElementwiseMulGradOpArgumentMapping(
G
Galaxy1458 已提交
175
    const ArgumentMappingContext& ctx UNUSED) {
176 177
  return KernelSignature(
      "multiply_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
Y
YuanRisheng 已提交
178 179
}

Y
YuanRisheng 已提交
180
KernelSignature ElementwiseFMaxOpArgumentMapping(
G
Galaxy1458 已提交
181
    const ArgumentMappingContext& ctx UNUSED) {
Z
zhangyuqin1998 已提交
182
  return KernelSignature("fmax", {"X", "Y"}, {}, {"Out"});
Y
YuanRisheng 已提交
183 184 185
}

KernelSignature ElementwiseFMinOpArgumentMapping(
G
Galaxy1458 已提交
186
    const ArgumentMappingContext& ctx UNUSED) {
Z
zyfncg 已提交
187
  return KernelSignature("fmin", {"X", "Y"}, {}, {"Out"});
Y
YuanRisheng 已提交
188 189
}

Y
YuanRisheng 已提交
190
KernelSignature ElementwiseMulDoubleGradOpArgumentMapping(
G
Galaxy1458 已提交
191
    const ArgumentMappingContext& ctx UNUSED) {
Y
YuanRisheng 已提交
192 193 194
  return KernelSignature("multiply_double_grad",
                         {"X", "Y", "DOut", "DDX", "DDY"},
                         {"axis"},
195
                         {"X@GRAD", "Y@GRAD", "DDOut"});
Y
YuanRisheng 已提交
196 197 198
}

KernelSignature ElementwiseMulTripleGradOpArgumentMapping(
G
Galaxy1458 已提交
199
    const ArgumentMappingContext& ctx UNUSED) {
Y
YuanRisheng 已提交
200 201 202 203 204 205 206
  return KernelSignature(
      "multiply_triple_grad",
      {"X", "Y", "DOut", "DDX", "DDY", "D_DX", "D_DY", "D_DDOut"},
      {"axis"},
      {"D_X", "D_Y", "D_DOut", "D_DDX", "D_DDY"});
}

207
KernelSignature ElementwiseMinGradOpArgumentMapping(
G
Galaxy1458 已提交
208
    const ArgumentMappingContext& ctx UNUSED) {
209
  return KernelSignature(
210
      "minimum_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
211
}
212 213

KernelSignature ElementwiseHeavisideGradOpArgumentMapping(
G
Galaxy1458 已提交
214
    const ArgumentMappingContext& ctx UNUSED) {
215 216
  return KernelSignature(
      "heaviside_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
217 218
}

219
KernelSignature ElementwisePowGradOpArgumentMapping(
G
Galaxy1458 已提交
220
    const ArgumentMappingContext& ctx UNUSED) {
221 222
  return KernelSignature(
      "elementwise_pow_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
223
}
224
}  // namespace phi
225

226 227 228 229
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add, add);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub, subtract);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul, multiply);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_div, divide);
230 231
PD_REGISTER_BASE_KERNEL_NAME(elementwise_max, maximum);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_min, minimum);
C
Chen Weihang 已提交
232
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mod, remainder);
233
PD_REGISTER_BASE_KERNEL_NAME(elementwise_floordiv, floor_divide);
234
PD_REGISTER_BASE_KERNEL_NAME(elementwise_heaviside, heaviside);
235 236 237 238
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad, add_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad_grad, add_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_triple_grad, add_triple_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad, subtract_grad);
239
PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad_grad, subtract_double_grad);
240 241
PD_REGISTER_BASE_KERNEL_NAME(elementwise_div_grad, divide_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_div_grad_grad, divide_double_grad);
Y
YuanRisheng 已提交
242 243 244
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad, multiply_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad_grad, multiply_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_triple_grad, multiply_triple_grad);
Y
YuanRisheng 已提交
245 246 247
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmax, fmax);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmin, fmin);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmin_grad, fmin_grad);
248
PD_REGISTER_BASE_KERNEL_NAME(elementwise_min_grad, minimum_grad);
249
PD_REGISTER_BASE_KERNEL_NAME(elementwise_heaviside_grad, heaviside_grad);
250 251

PD_REGISTER_ARG_MAPPING_FN(elementwise_add,
252
                           phi::ElementwiseAddOpArgumentMapping);
253
PD_REGISTER_ARG_MAPPING_FN(elementwise_sub,
254
                           phi::ElementwiseSubOpArgumentMapping);
255
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul,
256
                           phi::ElementwiseMulOpArgumentMapping);
257
PD_REGISTER_ARG_MAPPING_FN(elementwise_div,
258
                           phi::ElementwiseDivOpArgumentMapping);
259 260 261 262 263 264
PD_REGISTER_ARG_MAPPING_FN(elementwise_max,
                           phi::ElementwiseMaxOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_min,
                           phi::ElementwiseMinOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_mod,
                           phi::ElementwiseModOpArgumentMapping);
265 266
PD_REGISTER_ARG_MAPPING_FN(elementwise_floordiv,
                           phi::ElementwiseFloorDivOpArgumentMapping);
267 268
PD_REGISTER_ARG_MAPPING_FN(elementwise_heaviside,
                           phi::ElementwiseHeavisideOpArgumentMapping);
269 270
PD_REGISTER_ARG_MAPPING_FN(elementwise_pow,
                           phi::ElementwisePowOpArgumentMapping);
271
PD_REGISTER_ARG_MAPPING_FN(elementwise_add_grad,
272
                           phi::ElementwiseAddGradOpArgumentMapping);
273
PD_REGISTER_ARG_MAPPING_FN(elementwise_add_grad_grad,
274
                           phi::ElementwiseAddDoubleGradOpArgumentMapping);
275
PD_REGISTER_ARG_MAPPING_FN(elementwise_add_triple_grad,
276
                           phi::ElementwiseAddTripleGradOpArgumentMapping);
277
PD_REGISTER_ARG_MAPPING_FN(elementwise_sub_grad,
278
                           phi::ElementwiseSubGradOpArgumentMapping);
279 280
PD_REGISTER_ARG_MAPPING_FN(elementwise_sub_grad_grad,
                           phi::ElementwiseSubDoubleGradOpArgumentMapping);
281 282 283 284
PD_REGISTER_ARG_MAPPING_FN(elementwise_div_grad,
                           phi::ElementwiseDivGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_div_grad_grad,
                           phi::ElementwiseDivDoubleGradOpArgumentMapping);
Y
YuanRisheng 已提交
285 286 287 288 289 290
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_grad,
                           phi::ElementwiseMulGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_grad_grad,
                           phi::ElementwiseMulDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_triple_grad,
                           phi::ElementwiseMulTripleGradOpArgumentMapping);
Y
YuanRisheng 已提交
291 292 293 294
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmax,
                           phi::ElementwiseFMaxOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmin,
                           phi::ElementwiseFMinOpArgumentMapping);
295 296
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmin_grad,
                           phi::ElementwiseFMinGradOpArgumentMapping);
297 298
PD_REGISTER_ARG_MAPPING_FN(elementwise_min_grad,
                           phi::ElementwiseMinGradOpArgumentMapping);
299 300
PD_REGISTER_ARG_MAPPING_FN(elementwise_heaviside_grad,
                           phi::ElementwiseHeavisideGradOpArgumentMapping);
301 302
PD_REGISTER_ARG_MAPPING_FN(elementwise_pow_grad,
                           phi::ElementwisePowGradOpArgumentMapping);
303
PD_REGISTER_ARG_MAPPING_FN(grad_add, phi::ElementwiseGradAddOpArgumentMapping);