elementwise_sig.cc 12.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/core/compat/op_utils.h"
16

17
namespace phi {
18 19 20 21

KernelSignature ElementwiseAddOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
Y
YuanRisheng 已提交
22 23
  if (axis == -1) {
    return KernelSignature("add", {"X", "Y"}, {}, {"Out"});
24
  }
Y
YuanRisheng 已提交
25
  return KernelSignature("add_raw", {"X", "Y"}, {"axis"}, {"Out"});
26 27 28 29 30
}

KernelSignature ElementwiseSubOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
Y
YuanRisheng 已提交
31 32
  if (axis == -1) {
    return KernelSignature("subtract", {"X", "Y"}, {}, {"Out"});
33
  }
Y
YuanRisheng 已提交
34
  return KernelSignature("subtract_raw", {"X", "Y"}, {"axis"}, {"Out"});
35 36 37 38 39 40 41 42 43 44
}

KernelSignature ElementwiseMulOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (ctx.IsDenseTensorInput("X")) {
    if (axis == -1) {
      return KernelSignature("multiply", {"X", "Y"}, {}, {"Out"});
    }
    return KernelSignature("multiply_raw", {"X", "Y"}, {"axis"}, {"Out"});
45 46 47 48 49
  } else {
    if (axis == -1) {
      return KernelSignature("multiply_sr", {"X", "Y"}, {}, {"Out"});
    }
    return KernelSignature("multiply_raw_sr", {"X", "Y"}, {"axis"}, {"Out"});
50 51 52 53 54 55
  }
}

KernelSignature ElementwiseDivOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
Y
YuanRisheng 已提交
56 57
  if (axis == -1) {
    return KernelSignature("divide", {"X", "Y"}, {}, {"Out"});
58
  }
Y
YuanRisheng 已提交
59
  return KernelSignature("divide_raw", {"X", "Y"}, {"axis"}, {"Out"});
60 61
}

62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
KernelSignature ElementwiseMaxOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("maximum", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("maximum_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

KernelSignature ElementwiseMinOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("minimum", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("minimum_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

KernelSignature ElementwiseModOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("modulo", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("modulo_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
KernelSignature ElementwiseFloorDivOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("floor_divide", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("floor_divide_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

KernelSignature ElementwisePowOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("elementwise_pow", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("elementwise_pow_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

107 108
KernelSignature ElementwiseAddGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
Y
YuanRisheng 已提交
109 110 111 112
  return KernelSignature("add_grad",
                         {"X", "Y", GradVarName("Out")},
                         {"axis"},
                         {GradVarName("X"), GradVarName("Y")});
113 114
}

115 116 117
KernelSignature ElementwiseAddDoubleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature(
118
      "add_double_grad", {"Y", "DOut", "DDX", "DDY"}, {"axis"}, {"DDOut"});
119 120 121 122 123 124 125 126 127 128 129 130
}

KernelSignature ElementwiseAddTripleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("add_triple_grad",
                         {"DDX", "DDY", "D_DDOut"},
                         {"axis"},
                         {"D_DDX", "D_DDY"});
}

KernelSignature ElementwiseSubGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
Y
YuanRisheng 已提交
131 132 133 134
  return KernelSignature("subtract_grad",
                         {"X", "Y", GradVarName("Out")},
                         {"axis"},
                         {GradVarName("X"), GradVarName("Y")});
135 136
}

137 138 139 140 141 142
KernelSignature ElementwiseSubDoubleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature(
      "subtract_double_grad", {"Y", "DDX", "DDY", "DOut"}, {"axis"}, {"DDOut"});
}

143 144 145 146 147 148 149 150
KernelSignature ElementwiseDivGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("divide_grad",
                         {"X", "Y", "Out", GradVarName("Out")},
                         {"axis"},
                         {GradVarName("X"), GradVarName("Y")});
}

151 152
KernelSignature ElementwiseFMinGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
Y
YuanRisheng 已提交
153
  return KernelSignature("fmin_grad",
154 155 156 157 158
                         {"X", "Y", GradVarName("Out")},
                         {"axis"},
                         {GradVarName("X"), GradVarName("Y")});
}

159 160 161 162 163 164 165 166
KernelSignature ElementwiseDivDoubleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("divide_double_grad",
                         {"Y", "Out", "DX", "DDX", "DDY"},
                         {"axis"},
                         {GradVarName("Y"), "DOut", "DDOut"});
}

Y
YuanRisheng 已提交
167 168 169 170 171 172 173 174
KernelSignature ElementwiseMulGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("multiply_grad",
                         {"X", "Y", GradVarName("Out")},
                         {"axis"},
                         {GradVarName("X"), GradVarName("Y")});
}

Y
YuanRisheng 已提交
175 176 177 178 179 180 181 182 183 184
KernelSignature ElementwiseFMaxOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("fmax", {"X", "Y"}, {"axis"}, {"Out"});
}

KernelSignature ElementwiseFMinOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("fmin", {"X", "Y"}, {"axis"}, {"Out"});
}

185 186
KernelSignature ElementwiseFMaxGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
Y
YuanRisheng 已提交
187
  return KernelSignature("fmax_grad",
188 189 190 191 192
                         {"X", "Y", GradVarName("Out")},
                         {"axis"},
                         {GradVarName("X"), GradVarName("Y")});
}

Y
YuanRisheng 已提交
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
KernelSignature ElementwiseMulDoubleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("multiply_double_grad",
                         {"X", "Y", "DOut", "DDX", "DDY"},
                         {"axis"},
                         {GradVarName("X"), GradVarName("Y"), "DDOut"});
}

KernelSignature ElementwiseMulTripleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature(
      "multiply_triple_grad",
      {"X", "Y", "DOut", "DDX", "DDY", "D_DX", "D_DY", "D_DDOut"},
      {"axis"},
      {"D_X", "D_Y", "D_DOut", "D_DDX", "D_DDY"});
}

210 211 212 213 214 215 216 217 218 219 220 221 222 223 224
KernelSignature ElementwiseMaxGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("maximum_grad",
                         {"X", "Y", GradVarName("Out")},
                         {"axis"},
                         {GradVarName("X"), GradVarName("Y")});
}

KernelSignature ElementwiseMinGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("minimum_grad",
                         {"X", "Y", GradVarName("Out")},
                         {"axis"},
                         {GradVarName("X"), GradVarName("Y")});
}
225 226 227 228 229 230 231
KernelSignature ElementwisePowGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("elementwise_pow_grad",
                         {"X", "Y", GradVarName("Out")},
                         {"axis"},
                         {GradVarName("X"), GradVarName("Y")});
}
232
}  // namespace phi
233

234 235 236 237
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add, add);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub, subtract);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul, multiply);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_div, divide);
238 239 240
PD_REGISTER_BASE_KERNEL_NAME(elementwise_max, maximum);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_min, minimum);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mod, modulo);
241
PD_REGISTER_BASE_KERNEL_NAME(elementwise_floordiv, floor_divide);
242 243 244 245
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad, add_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad_grad, add_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_triple_grad, add_triple_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad, subtract_grad);
246
PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad_grad, subtract_double_grad);
247 248
PD_REGISTER_BASE_KERNEL_NAME(elementwise_div_grad, divide_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_div_grad_grad, divide_double_grad);
Y
YuanRisheng 已提交
249 250 251
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad, multiply_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad_grad, multiply_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_triple_grad, multiply_triple_grad);
Y
YuanRisheng 已提交
252 253 254 255
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmax, fmax);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmin, fmin);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmax_grad, fmax_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmin_grad, fmin_grad);
256 257
PD_REGISTER_BASE_KERNEL_NAME(elementwise_max_grad, maximum_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_min_grad, minimum_grad);
258 259

PD_REGISTER_ARG_MAPPING_FN(elementwise_add,
260
                           phi::ElementwiseAddOpArgumentMapping);
261
PD_REGISTER_ARG_MAPPING_FN(elementwise_sub,
262
                           phi::ElementwiseSubOpArgumentMapping);
263
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul,
264
                           phi::ElementwiseMulOpArgumentMapping);
265
PD_REGISTER_ARG_MAPPING_FN(elementwise_div,
266
                           phi::ElementwiseDivOpArgumentMapping);
267 268 269 270 271 272
PD_REGISTER_ARG_MAPPING_FN(elementwise_max,
                           phi::ElementwiseMaxOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_min,
                           phi::ElementwiseMinOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_mod,
                           phi::ElementwiseModOpArgumentMapping);
273 274 275 276
PD_REGISTER_ARG_MAPPING_FN(elementwise_floordiv,
                           phi::ElementwiseFloorDivOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_pow,
                           phi::ElementwisePowOpArgumentMapping);
277
PD_REGISTER_ARG_MAPPING_FN(elementwise_add_grad,
278
                           phi::ElementwiseAddGradOpArgumentMapping);
279
PD_REGISTER_ARG_MAPPING_FN(elementwise_add_grad_grad,
280
                           phi::ElementwiseAddDoubleGradOpArgumentMapping);
281
PD_REGISTER_ARG_MAPPING_FN(elementwise_add_triple_grad,
282
                           phi::ElementwiseAddTripleGradOpArgumentMapping);
283
PD_REGISTER_ARG_MAPPING_FN(elementwise_sub_grad,
284
                           phi::ElementwiseSubGradOpArgumentMapping);
285 286
PD_REGISTER_ARG_MAPPING_FN(elementwise_sub_grad_grad,
                           phi::ElementwiseSubDoubleGradOpArgumentMapping);
287 288 289 290
PD_REGISTER_ARG_MAPPING_FN(elementwise_div_grad,
                           phi::ElementwiseDivGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_div_grad_grad,
                           phi::ElementwiseDivDoubleGradOpArgumentMapping);
Y
YuanRisheng 已提交
291 292 293 294 295 296
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_grad,
                           phi::ElementwiseMulGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_grad_grad,
                           phi::ElementwiseMulDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_triple_grad,
                           phi::ElementwiseMulTripleGradOpArgumentMapping);
Y
YuanRisheng 已提交
297 298 299 300
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmax,
                           phi::ElementwiseFMaxOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmin,
                           phi::ElementwiseFMinOpArgumentMapping);
301 302 303 304
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmax_grad,
                           phi::ElementwiseFMaxGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmin_grad,
                           phi::ElementwiseFMinGradOpArgumentMapping);
305 306 307 308
PD_REGISTER_ARG_MAPPING_FN(elementwise_max_grad,
                           phi::ElementwiseMaxGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_min_grad,
                           phi::ElementwiseMinGradOpArgumentMapping);
309 310
PD_REGISTER_ARG_MAPPING_FN(elementwise_pow_grad,
                           phi::ElementwisePowGradOpArgumentMapping);