elementwise_sig.cc 12.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/core/compat/op_utils.h"
16

17
namespace phi {
18 19 20 21

KernelSignature ElementwiseAddOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
Y
YuanRisheng 已提交
22 23
  if (axis == -1) {
    return KernelSignature("add", {"X", "Y"}, {}, {"Out"});
24
  }
Y
YuanRisheng 已提交
25
  return KernelSignature("add_raw", {"X", "Y"}, {"axis"}, {"Out"});
26 27
}

28
KernelSignature ElementwiseGradAddOpArgumentMapping(
G
Galaxy1458 已提交
29
    const ArgumentMappingContext& ctx UNUSED) {
30 31 32
  return KernelSignature("grad_add", {"X", "Y"}, {}, {"Out"});
}

33 34 35
KernelSignature ElementwiseSubOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
Y
YuanRisheng 已提交
36 37
  if (axis == -1) {
    return KernelSignature("subtract", {"X", "Y"}, {}, {"Out"});
38
  }
Y
YuanRisheng 已提交
39
  return KernelSignature("subtract_raw", {"X", "Y"}, {"axis"}, {"Out"});
40 41 42 43 44 45 46 47 48 49
}

KernelSignature ElementwiseMulOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (ctx.IsDenseTensorInput("X")) {
    if (axis == -1) {
      return KernelSignature("multiply", {"X", "Y"}, {}, {"Out"});
    }
    return KernelSignature("multiply_raw", {"X", "Y"}, {"axis"}, {"Out"});
50 51 52 53 54
  } else {
    if (axis == -1) {
      return KernelSignature("multiply_sr", {"X", "Y"}, {}, {"Out"});
    }
    return KernelSignature("multiply_raw_sr", {"X", "Y"}, {"axis"}, {"Out"});
55 56 57 58 59 60
  }
}

KernelSignature ElementwiseDivOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
Y
YuanRisheng 已提交
61 62
  if (axis == -1) {
    return KernelSignature("divide", {"X", "Y"}, {}, {"Out"});
63
  }
Y
YuanRisheng 已提交
64
  return KernelSignature("divide_raw", {"X", "Y"}, {"axis"}, {"Out"});
65 66
}

67 68
KernelSignature ElementwiseMaxOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
69 70 71
  if (ctx.IsForInferShape()) {
    return KernelSignature("maximum_raw", {"X", "Y"}, {"axis"}, {"Out"});
  }
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("maximum", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("maximum_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

KernelSignature ElementwiseMinOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("minimum", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("minimum_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

KernelSignature ElementwiseModOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
C
Chen Weihang 已提交
92
    return KernelSignature("remainder", {"X", "Y"}, {}, {"Out"});
93
  }
C
Chen Weihang 已提交
94
  return KernelSignature("remainder_raw", {"X", "Y"}, {"axis"}, {"Out"});
95 96
}

97 98
KernelSignature ElementwiseFloorDivOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
99 100 101
  if (ctx.IsForInferShape()) {
    return KernelSignature("floor_divide_raw", {"X", "Y"}, {"axis"}, {"Out"});
  }
102 103 104 105 106 107 108
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("floor_divide", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("floor_divide_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

109
KernelSignature ElementwiseHeavisideOpArgumentMapping(
G
Galaxy1458 已提交
110
    const ArgumentMappingContext& ctx UNUSED) {
111
  return KernelSignature("heaviside", {"X", "Y"}, {}, {"Out"});
112 113
}

114
KernelSignature ElementwisePowOpArgumentMapping(
G
Galaxy1458 已提交
115
    const ArgumentMappingContext& ctx UNUSED) {
116 117 118 119 120 121 122
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("elementwise_pow", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("elementwise_pow_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

123
KernelSignature ElementwiseAddGradOpArgumentMapping(
G
Galaxy1458 已提交
124
    const ArgumentMappingContext& ctx UNUSED) {
125 126
  return KernelSignature(
      "add_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
127 128
}

129
KernelSignature ElementwiseAddDoubleGradOpArgumentMapping(
G
Galaxy1458 已提交
130
    const ArgumentMappingContext& ctx UNUSED) {
131
  return KernelSignature(
132
      "add_double_grad", {"Y", "DOut", "DDX", "DDY"}, {"axis"}, {"DDOut"});
133 134 135
}

KernelSignature ElementwiseAddTripleGradOpArgumentMapping(
G
Galaxy1458 已提交
136
    const ArgumentMappingContext& ctx UNUSED) {
137 138 139 140 141 142 143
  return KernelSignature("add_triple_grad",
                         {"DDX", "DDY", "D_DDOut"},
                         {"axis"},
                         {"D_DDX", "D_DDY"});
}

KernelSignature ElementwiseSubGradOpArgumentMapping(
G
Galaxy1458 已提交
144
    const ArgumentMappingContext& ctx UNUSED) {
145 146
  return KernelSignature(
      "subtract_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
147 148
}

149
KernelSignature ElementwiseSubDoubleGradOpArgumentMapping(
G
Galaxy1458 已提交
150
    const ArgumentMappingContext& ctx UNUSED) {
151
  return KernelSignature(
152
      "subtract_double_grad", {"Y", "DOut", "DDX", "DDY"}, {"axis"}, {"DDOut"});
153 154
}

155
KernelSignature ElementwiseDivGradOpArgumentMapping(
G
Galaxy1458 已提交
156
    const ArgumentMappingContext& ctx UNUSED) {
157
  return KernelSignature("divide_grad",
158
                         {"X", "Y", "Out", "Out@GRAD"},
159
                         {"axis"},
160
                         {"X@GRAD", "Y@GRAD"});
161 162
}

163
KernelSignature ElementwiseFMinGradOpArgumentMapping(
G
Galaxy1458 已提交
164
    const ArgumentMappingContext& ctx UNUSED) {
165
  return KernelSignature(
Z
zyfncg 已提交
166
      "fmin_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
167 168
}

169
KernelSignature ElementwiseDivDoubleGradOpArgumentMapping(
G
Galaxy1458 已提交
170
    const ArgumentMappingContext& ctx UNUSED) {
171 172 173
  return KernelSignature("divide_double_grad",
                         {"Y", "Out", "DX", "DDX", "DDY"},
                         {"axis"},
174
                         {"Y@GRAD", "DOut", "DDOut"});
175 176
}

Y
YuanRisheng 已提交
177
KernelSignature ElementwiseMulGradOpArgumentMapping(
G
Galaxy1458 已提交
178
    const ArgumentMappingContext& ctx UNUSED) {
179 180
  return KernelSignature(
      "multiply_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
Y
YuanRisheng 已提交
181 182
}

Y
YuanRisheng 已提交
183
KernelSignature ElementwiseFMaxOpArgumentMapping(
G
Galaxy1458 已提交
184
    const ArgumentMappingContext& ctx UNUSED) {
Z
zhangyuqin1998 已提交
185
  return KernelSignature("fmax", {"X", "Y"}, {}, {"Out"});
Y
YuanRisheng 已提交
186 187 188
}

KernelSignature ElementwiseFMinOpArgumentMapping(
G
Galaxy1458 已提交
189
    const ArgumentMappingContext& ctx UNUSED) {
Z
zyfncg 已提交
190
  return KernelSignature("fmin", {"X", "Y"}, {}, {"Out"});
Y
YuanRisheng 已提交
191 192
}

Y
YuanRisheng 已提交
193
KernelSignature ElementwiseMulDoubleGradOpArgumentMapping(
G
Galaxy1458 已提交
194
    const ArgumentMappingContext& ctx UNUSED) {
Y
YuanRisheng 已提交
195 196 197
  return KernelSignature("multiply_double_grad",
                         {"X", "Y", "DOut", "DDX", "DDY"},
                         {"axis"},
198
                         {"X@GRAD", "Y@GRAD", "DDOut"});
Y
YuanRisheng 已提交
199 200 201
}

KernelSignature ElementwiseMulTripleGradOpArgumentMapping(
G
Galaxy1458 已提交
202
    const ArgumentMappingContext& ctx UNUSED) {
Y
YuanRisheng 已提交
203 204 205 206 207 208 209
  return KernelSignature(
      "multiply_triple_grad",
      {"X", "Y", "DOut", "DDX", "DDY", "D_DX", "D_DY", "D_DDOut"},
      {"axis"},
      {"D_X", "D_Y", "D_DOut", "D_DDX", "D_DDY"});
}

210
KernelSignature ElementwiseMinGradOpArgumentMapping(
G
Galaxy1458 已提交
211
    const ArgumentMappingContext& ctx UNUSED) {
212
  return KernelSignature(
213
      "minimum_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
214
}
215 216

KernelSignature ElementwiseHeavisideGradOpArgumentMapping(
G
Galaxy1458 已提交
217
    const ArgumentMappingContext& ctx UNUSED) {
218 219
  return KernelSignature(
      "heaviside_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
220 221
}

222
KernelSignature ElementwisePowGradOpArgumentMapping(
G
Galaxy1458 已提交
223
    const ArgumentMappingContext& ctx UNUSED) {
224 225
  return KernelSignature(
      "elementwise_pow_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
226
}
227
}  // namespace phi
228

229 230 231 232
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add, add);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub, subtract);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul, multiply);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_div, divide);
233 234
PD_REGISTER_BASE_KERNEL_NAME(elementwise_max, maximum);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_min, minimum);
C
Chen Weihang 已提交
235
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mod, remainder);
236
PD_REGISTER_BASE_KERNEL_NAME(elementwise_floordiv, floor_divide);
237
PD_REGISTER_BASE_KERNEL_NAME(elementwise_heaviside, heaviside);
238 239 240 241
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad, add_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad_grad, add_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_triple_grad, add_triple_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad, subtract_grad);
242
PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad_grad, subtract_double_grad);
243 244
PD_REGISTER_BASE_KERNEL_NAME(elementwise_div_grad, divide_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_div_grad_grad, divide_double_grad);
Y
YuanRisheng 已提交
245 246 247
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad, multiply_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad_grad, multiply_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_triple_grad, multiply_triple_grad);
Y
YuanRisheng 已提交
248 249 250
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmax, fmax);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmin, fmin);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmin_grad, fmin_grad);
251
PD_REGISTER_BASE_KERNEL_NAME(elementwise_min_grad, minimum_grad);
252
PD_REGISTER_BASE_KERNEL_NAME(elementwise_heaviside_grad, heaviside_grad);
253 254

PD_REGISTER_ARG_MAPPING_FN(elementwise_add,
255
                           phi::ElementwiseAddOpArgumentMapping);
256
PD_REGISTER_ARG_MAPPING_FN(elementwise_sub,
257
                           phi::ElementwiseSubOpArgumentMapping);
258
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul,
259
                           phi::ElementwiseMulOpArgumentMapping);
260
PD_REGISTER_ARG_MAPPING_FN(elementwise_div,
261
                           phi::ElementwiseDivOpArgumentMapping);
262 263 264 265 266 267
PD_REGISTER_ARG_MAPPING_FN(elementwise_max,
                           phi::ElementwiseMaxOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_min,
                           phi::ElementwiseMinOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_mod,
                           phi::ElementwiseModOpArgumentMapping);
268 269
PD_REGISTER_ARG_MAPPING_FN(elementwise_floordiv,
                           phi::ElementwiseFloorDivOpArgumentMapping);
270 271
PD_REGISTER_ARG_MAPPING_FN(elementwise_heaviside,
                           phi::ElementwiseHeavisideOpArgumentMapping);
272 273
PD_REGISTER_ARG_MAPPING_FN(elementwise_pow,
                           phi::ElementwisePowOpArgumentMapping);
274
PD_REGISTER_ARG_MAPPING_FN(elementwise_add_grad,
275
                           phi::ElementwiseAddGradOpArgumentMapping);
276
PD_REGISTER_ARG_MAPPING_FN(elementwise_add_grad_grad,
277
                           phi::ElementwiseAddDoubleGradOpArgumentMapping);
278
PD_REGISTER_ARG_MAPPING_FN(elementwise_add_triple_grad,
279
                           phi::ElementwiseAddTripleGradOpArgumentMapping);
280
PD_REGISTER_ARG_MAPPING_FN(elementwise_sub_grad,
281
                           phi::ElementwiseSubGradOpArgumentMapping);
282 283
PD_REGISTER_ARG_MAPPING_FN(elementwise_sub_grad_grad,
                           phi::ElementwiseSubDoubleGradOpArgumentMapping);
284 285 286 287
PD_REGISTER_ARG_MAPPING_FN(elementwise_div_grad,
                           phi::ElementwiseDivGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_div_grad_grad,
                           phi::ElementwiseDivDoubleGradOpArgumentMapping);
Y
YuanRisheng 已提交
288 289 290 291 292 293
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_grad,
                           phi::ElementwiseMulGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_grad_grad,
                           phi::ElementwiseMulDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_triple_grad,
                           phi::ElementwiseMulTripleGradOpArgumentMapping);
Y
YuanRisheng 已提交
294 295 296 297
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmax,
                           phi::ElementwiseFMaxOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmin,
                           phi::ElementwiseFMinOpArgumentMapping);
298 299
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmin_grad,
                           phi::ElementwiseFMinGradOpArgumentMapping);
300 301
PD_REGISTER_ARG_MAPPING_FN(elementwise_min_grad,
                           phi::ElementwiseMinGradOpArgumentMapping);
302 303
PD_REGISTER_ARG_MAPPING_FN(elementwise_heaviside_grad,
                           phi::ElementwiseHeavisideGradOpArgumentMapping);
304 305
PD_REGISTER_ARG_MAPPING_FN(elementwise_pow_grad,
                           phi::ElementwisePowGradOpArgumentMapping);
306
PD_REGISTER_ARG_MAPPING_FN(grad_add, phi::ElementwiseGradAddOpArgumentMapping);