elementwise_sig.cc 13.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/core/compat/op_utils.h"
16

17
namespace phi {
18 19 20 21

KernelSignature ElementwiseAddOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
Y
YuanRisheng 已提交
22 23
  if (axis == -1) {
    return KernelSignature("add", {"X", "Y"}, {}, {"Out"});
24
  }
Y
YuanRisheng 已提交
25
  return KernelSignature("add_raw", {"X", "Y"}, {"axis"}, {"Out"});
26 27
}

28 29 30 31 32
KernelSignature ElementwiseGradAddOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("grad_add", {"X", "Y"}, {}, {"Out"});
}

33 34 35
KernelSignature ElementwiseSubOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
Y
YuanRisheng 已提交
36 37
  if (axis == -1) {
    return KernelSignature("subtract", {"X", "Y"}, {}, {"Out"});
38
  }
Y
YuanRisheng 已提交
39
  return KernelSignature("subtract_raw", {"X", "Y"}, {"axis"}, {"Out"});
40 41 42 43 44 45 46 47 48 49
}

KernelSignature ElementwiseMulOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (ctx.IsDenseTensorInput("X")) {
    if (axis == -1) {
      return KernelSignature("multiply", {"X", "Y"}, {}, {"Out"});
    }
    return KernelSignature("multiply_raw", {"X", "Y"}, {"axis"}, {"Out"});
50 51 52 53 54
  } else {
    if (axis == -1) {
      return KernelSignature("multiply_sr", {"X", "Y"}, {}, {"Out"});
    }
    return KernelSignature("multiply_raw_sr", {"X", "Y"}, {"axis"}, {"Out"});
55 56 57 58 59 60
  }
}

KernelSignature ElementwiseDivOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
Y
YuanRisheng 已提交
61 62
  if (axis == -1) {
    return KernelSignature("divide", {"X", "Y"}, {}, {"Out"});
63
  }
Y
YuanRisheng 已提交
64
  return KernelSignature("divide_raw", {"X", "Y"}, {"axis"}, {"Out"});
65 66
}

67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
KernelSignature ElementwiseMaxOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("maximum", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("maximum_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

KernelSignature ElementwiseMinOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("minimum", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("minimum_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

KernelSignature ElementwiseModOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
C
Chen Weihang 已提交
89
    return KernelSignature("remainder", {"X", "Y"}, {}, {"Out"});
90
  }
C
Chen Weihang 已提交
91
  return KernelSignature("remainder_raw", {"X", "Y"}, {"axis"}, {"Out"});
92 93
}

94 95 96 97 98 99 100 101 102
KernelSignature ElementwiseFloorDivOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("floor_divide", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("floor_divide_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

103 104 105 106 107 108 109 110 111 112
KernelSignature ElementwiseHeavisideOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("elementwise_heaviside", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature(
      "elementwise_heaviside_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

113 114 115 116 117 118 119 120 121
KernelSignature ElementwisePowOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("elementwise_pow", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("elementwise_pow_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

122 123
KernelSignature ElementwiseAddGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
124 125
  return KernelSignature(
      "add_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
126 127
}

128 129 130
KernelSignature ElementwiseAddDoubleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature(
131
      "add_double_grad", {"Y", "DOut", "DDX", "DDY"}, {"axis"}, {"DDOut"});
132 133 134 135 136 137 138 139 140 141 142 143
}

KernelSignature ElementwiseAddTripleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("add_triple_grad",
                         {"DDX", "DDY", "D_DDOut"},
                         {"axis"},
                         {"D_DDX", "D_DDY"});
}

KernelSignature ElementwiseSubGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
144 145
  return KernelSignature(
      "subtract_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
146 147
}

148 149 150
KernelSignature ElementwiseSubDoubleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature(
151
      "subtract_double_grad", {"Y", "DOut", "DDX", "DDY"}, {"axis"}, {"DDOut"});
152 153
}

154 155 156
KernelSignature ElementwiseDivGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("divide_grad",
157
                         {"X", "Y", "Out", "Out@GRAD"},
158
                         {"axis"},
159
                         {"X@GRAD", "Y@GRAD"});
160 161
}

162 163
KernelSignature ElementwiseFMinGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
164 165
  return KernelSignature(
      "fmin_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
166 167
}

168 169 170 171 172
KernelSignature ElementwiseDivDoubleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("divide_double_grad",
                         {"Y", "Out", "DX", "DDX", "DDY"},
                         {"axis"},
173
                         {"Y@GRAD", "DOut", "DDOut"});
174 175
}

Y
YuanRisheng 已提交
176 177
KernelSignature ElementwiseMulGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
178 179
  return KernelSignature(
      "multiply_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
Y
YuanRisheng 已提交
180 181
}

Y
YuanRisheng 已提交
182 183 184 185 186 187 188 189 190 191
KernelSignature ElementwiseFMaxOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("fmax", {"X", "Y"}, {"axis"}, {"Out"});
}

KernelSignature ElementwiseFMinOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("fmin", {"X", "Y"}, {"axis"}, {"Out"});
}

192 193
KernelSignature ElementwiseFMaxGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
194 195
  return KernelSignature(
      "fmax_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
196 197
}

Y
YuanRisheng 已提交
198 199 200 201 202
KernelSignature ElementwiseMulDoubleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("multiply_double_grad",
                         {"X", "Y", "DOut", "DDX", "DDY"},
                         {"axis"},
203
                         {"X@GRAD", "Y@GRAD", "DDOut"});
Y
YuanRisheng 已提交
204 205 206 207 208 209 210 211 212 213 214
}

KernelSignature ElementwiseMulTripleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature(
      "multiply_triple_grad",
      {"X", "Y", "DOut", "DDX", "DDY", "D_DX", "D_DY", "D_DDOut"},
      {"axis"},
      {"D_X", "D_Y", "D_DOut", "D_DDX", "D_DDY"});
}

215 216
KernelSignature ElementwiseMaxGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
217 218
  return KernelSignature(
      "maximum_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
219 220 221 222
}

KernelSignature ElementwiseMinGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
223 224
  return KernelSignature(
      "minimum_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
225
}
226 227 228 229 230 231 232 233 234

KernelSignature ElementwiseHeavisideGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("elementwise_heaviside_grad",
                         {"X", "Y", "Out@GRAD"},
                         {"axis"},
                         {"X@GRAD", "Y@GRAD"});
}

235 236 237
KernelSignature ElementwisePowGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("elementwise_pow_grad",
238
                         {"X", "Y", "Out@GRAD"},
239
                         {"axis"},
240
                         {"X@GRAD", "Y@GRAD"});
241
}
242
}  // namespace phi
243

244 245 246 247
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add, add);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub, subtract);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul, multiply);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_div, divide);
248 249
PD_REGISTER_BASE_KERNEL_NAME(elementwise_max, maximum);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_min, minimum);
C
Chen Weihang 已提交
250
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mod, remainder);
251
PD_REGISTER_BASE_KERNEL_NAME(elementwise_floordiv, floor_divide);
252 253 254 255
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad, add_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad_grad, add_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_triple_grad, add_triple_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad, subtract_grad);
256
PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad_grad, subtract_double_grad);
257 258
PD_REGISTER_BASE_KERNEL_NAME(elementwise_div_grad, divide_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_div_grad_grad, divide_double_grad);
Y
YuanRisheng 已提交
259 260 261
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad, multiply_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad_grad, multiply_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_triple_grad, multiply_triple_grad);
Y
YuanRisheng 已提交
262 263 264 265
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmax, fmax);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmin, fmin);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmax_grad, fmax_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmin_grad, fmin_grad);
266 267
PD_REGISTER_BASE_KERNEL_NAME(elementwise_max_grad, maximum_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_min_grad, minimum_grad);
268 269

PD_REGISTER_ARG_MAPPING_FN(elementwise_add,
270
                           phi::ElementwiseAddOpArgumentMapping);
271
PD_REGISTER_ARG_MAPPING_FN(elementwise_sub,
272
                           phi::ElementwiseSubOpArgumentMapping);
273
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul,
274
                           phi::ElementwiseMulOpArgumentMapping);
275
PD_REGISTER_ARG_MAPPING_FN(elementwise_div,
276
                           phi::ElementwiseDivOpArgumentMapping);
277 278 279 280 281 282
PD_REGISTER_ARG_MAPPING_FN(elementwise_max,
                           phi::ElementwiseMaxOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_min,
                           phi::ElementwiseMinOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_mod,
                           phi::ElementwiseModOpArgumentMapping);
283 284
PD_REGISTER_ARG_MAPPING_FN(elementwise_floordiv,
                           phi::ElementwiseFloorDivOpArgumentMapping);
285 286
PD_REGISTER_ARG_MAPPING_FN(elementwise_heaviside,
                           phi::ElementwiseHeavisideOpArgumentMapping);
287 288
PD_REGISTER_ARG_MAPPING_FN(elementwise_pow,
                           phi::ElementwisePowOpArgumentMapping);
289
PD_REGISTER_ARG_MAPPING_FN(elementwise_add_grad,
290
                           phi::ElementwiseAddGradOpArgumentMapping);
291
PD_REGISTER_ARG_MAPPING_FN(elementwise_add_grad_grad,
292
                           phi::ElementwiseAddDoubleGradOpArgumentMapping);
293
PD_REGISTER_ARG_MAPPING_FN(elementwise_add_triple_grad,
294
                           phi::ElementwiseAddTripleGradOpArgumentMapping);
295
PD_REGISTER_ARG_MAPPING_FN(elementwise_sub_grad,
296
                           phi::ElementwiseSubGradOpArgumentMapping);
297 298
PD_REGISTER_ARG_MAPPING_FN(elementwise_sub_grad_grad,
                           phi::ElementwiseSubDoubleGradOpArgumentMapping);
299 300 301 302
PD_REGISTER_ARG_MAPPING_FN(elementwise_div_grad,
                           phi::ElementwiseDivGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_div_grad_grad,
                           phi::ElementwiseDivDoubleGradOpArgumentMapping);
Y
YuanRisheng 已提交
303 304 305 306 307 308
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_grad,
                           phi::ElementwiseMulGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_grad_grad,
                           phi::ElementwiseMulDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_triple_grad,
                           phi::ElementwiseMulTripleGradOpArgumentMapping);
Y
YuanRisheng 已提交
309 310 311 312
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmax,
                           phi::ElementwiseFMaxOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmin,
                           phi::ElementwiseFMinOpArgumentMapping);
313 314 315 316
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmax_grad,
                           phi::ElementwiseFMaxGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmin_grad,
                           phi::ElementwiseFMinGradOpArgumentMapping);
317 318 319 320
PD_REGISTER_ARG_MAPPING_FN(elementwise_max_grad,
                           phi::ElementwiseMaxGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_min_grad,
                           phi::ElementwiseMinGradOpArgumentMapping);
321 322
PD_REGISTER_ARG_MAPPING_FN(elementwise_heaviside_grad,
                           phi::ElementwiseHeavisideGradOpArgumentMapping);
323 324
PD_REGISTER_ARG_MAPPING_FN(elementwise_pow_grad,
                           phi::ElementwisePowGradOpArgumentMapping);
325
PD_REGISTER_ARG_MAPPING_FN(grad_add, phi::ElementwiseGradAddOpArgumentMapping);