elementwise_sig.cc 12.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/core/compat/op_utils.h"
16

17
namespace phi {
18 19 20 21

KernelSignature ElementwiseAddOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
Y
YuanRisheng 已提交
22 23
  if (axis == -1) {
    return KernelSignature("add", {"X", "Y"}, {}, {"Out"});
24
  }
Y
YuanRisheng 已提交
25
  return KernelSignature("add_raw", {"X", "Y"}, {"axis"}, {"Out"});
26 27 28 29 30
}

KernelSignature ElementwiseSubOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
Y
YuanRisheng 已提交
31 32
  if (axis == -1) {
    return KernelSignature("subtract", {"X", "Y"}, {}, {"Out"});
33
  }
Y
YuanRisheng 已提交
34
  return KernelSignature("subtract_raw", {"X", "Y"}, {"axis"}, {"Out"});
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
}

KernelSignature ElementwiseMulOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (ctx.IsDenseTensorInput("X")) {
    if (axis == -1) {
      return KernelSignature("multiply", {"X", "Y"}, {}, {"Out"});
    }
    return KernelSignature("multiply_raw", {"X", "Y"}, {"axis"}, {"Out"});
  }
  return KernelSignature("unregistered", {}, {}, {});
}

KernelSignature ElementwiseDivOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
Y
YuanRisheng 已提交
52 53
  if (axis == -1) {
    return KernelSignature("divide", {"X", "Y"}, {}, {"Out"});
54
  }
Y
YuanRisheng 已提交
55
  return KernelSignature("divide_raw", {"X", "Y"}, {"axis"}, {"Out"});
56 57
}

58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
KernelSignature ElementwiseMaxOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("maximum", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("maximum_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

KernelSignature ElementwiseMinOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("minimum", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("minimum_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

KernelSignature ElementwiseModOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("modulo", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("modulo_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
KernelSignature ElementwiseFloorDivOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("floor_divide", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("floor_divide_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

KernelSignature ElementwisePowOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("elementwise_pow", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("elementwise_pow_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

103 104
KernelSignature ElementwiseAddGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
Y
YuanRisheng 已提交
105 106 107 108
  return KernelSignature("add_grad",
                         {"X", "Y", GradVarName("Out")},
                         {"axis"},
                         {GradVarName("X"), GradVarName("Y")});
109 110
}

111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
KernelSignature ElementwiseAddDoubleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature(
      "add_double_grad", {"Y", "DDX", "DDY", "DOut"}, {"axis"}, {"DDOut"});
}

KernelSignature ElementwiseAddTripleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("add_triple_grad",
                         {"DDX", "DDY", "D_DDOut"},
                         {"axis"},
                         {"D_DDX", "D_DDY"});
}

KernelSignature ElementwiseSubGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
Y
YuanRisheng 已提交
127 128 129 130
  return KernelSignature("subtract_grad",
                         {"X", "Y", GradVarName("Out")},
                         {"axis"},
                         {GradVarName("X"), GradVarName("Y")});
131 132
}

133 134 135 136 137 138
KernelSignature ElementwiseSubDoubleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature(
      "subtract_double_grad", {"Y", "DDX", "DDY", "DOut"}, {"axis"}, {"DDOut"});
}

139 140 141 142 143 144 145 146
KernelSignature ElementwiseDivGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("divide_grad",
                         {"X", "Y", "Out", GradVarName("Out")},
                         {"axis"},
                         {GradVarName("X"), GradVarName("Y")});
}

147 148
KernelSignature ElementwiseFMinGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
Y
YuanRisheng 已提交
149
  return KernelSignature("fmin_grad",
150 151 152 153 154
                         {"X", "Y", GradVarName("Out")},
                         {"axis"},
                         {GradVarName("X"), GradVarName("Y")});
}

155 156 157 158 159 160 161 162
KernelSignature ElementwiseDivDoubleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("divide_double_grad",
                         {"Y", "Out", "DX", "DDX", "DDY"},
                         {"axis"},
                         {GradVarName("Y"), "DOut", "DDOut"});
}

Y
YuanRisheng 已提交
163 164 165 166 167 168 169 170
KernelSignature ElementwiseMulGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("multiply_grad",
                         {"X", "Y", GradVarName("Out")},
                         {"axis"},
                         {GradVarName("X"), GradVarName("Y")});
}

Y
YuanRisheng 已提交
171 172 173 174 175 176 177 178 179 180
KernelSignature ElementwiseFMaxOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("fmax", {"X", "Y"}, {"axis"}, {"Out"});
}

KernelSignature ElementwiseFMinOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("fmin", {"X", "Y"}, {"axis"}, {"Out"});
}

181 182
KernelSignature ElementwiseFMaxGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
Y
YuanRisheng 已提交
183
  return KernelSignature("fmax_grad",
184 185 186 187 188
                         {"X", "Y", GradVarName("Out")},
                         {"axis"},
                         {GradVarName("X"), GradVarName("Y")});
}

Y
YuanRisheng 已提交
189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205
KernelSignature ElementwiseMulDoubleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("multiply_double_grad",
                         {"X", "Y", "DOut", "DDX", "DDY"},
                         {"axis"},
                         {GradVarName("X"), GradVarName("Y"), "DDOut"});
}

KernelSignature ElementwiseMulTripleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature(
      "multiply_triple_grad",
      {"X", "Y", "DOut", "DDX", "DDY", "D_DX", "D_DY", "D_DDOut"},
      {"axis"},
      {"D_X", "D_Y", "D_DOut", "D_DDX", "D_DDY"});
}

206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
KernelSignature ElementwiseMaxGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("maximum_grad",
                         {"X", "Y", GradVarName("Out")},
                         {"axis"},
                         {GradVarName("X"), GradVarName("Y")});
}

KernelSignature ElementwiseMinGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("minimum_grad",
                         {"X", "Y", GradVarName("Out")},
                         {"axis"},
                         {GradVarName("X"), GradVarName("Y")});
}
221 222 223 224 225 226 227
KernelSignature ElementwisePowGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("elementwise_pow_grad",
                         {"X", "Y", GradVarName("Out")},
                         {"axis"},
                         {GradVarName("X"), GradVarName("Y")});
}
228
}  // namespace phi
229

230 231 232 233
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add, add);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub, subtract);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul, multiply);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_div, divide);
234 235 236
PD_REGISTER_BASE_KERNEL_NAME(elementwise_max, maximum);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_min, minimum);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mod, modulo);
237
PD_REGISTER_BASE_KERNEL_NAME(elementwise_floordiv, floor_divide);
238 239 240 241
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad, add_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad_grad, add_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_triple_grad, add_triple_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad, subtract_grad);
242
PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad_grad, subtract_double_grad);
243 244
PD_REGISTER_BASE_KERNEL_NAME(elementwise_div_grad, divide_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_div_grad_grad, divide_double_grad);
Y
YuanRisheng 已提交
245 246 247
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad, multiply_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad_grad, multiply_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_triple_grad, multiply_triple_grad);
Y
YuanRisheng 已提交
248 249 250 251
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmax, fmax);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmin, fmin);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmax_grad, fmax_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmin_grad, fmin_grad);
252 253
PD_REGISTER_BASE_KERNEL_NAME(elementwise_max_grad, maximum_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_min_grad, minimum_grad);
254 255

PD_REGISTER_ARG_MAPPING_FN(elementwise_add,
256
                           phi::ElementwiseAddOpArgumentMapping);
257
PD_REGISTER_ARG_MAPPING_FN(elementwise_sub,
258
                           phi::ElementwiseSubOpArgumentMapping);
259
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul,
260
                           phi::ElementwiseMulOpArgumentMapping);
261
PD_REGISTER_ARG_MAPPING_FN(elementwise_div,
262
                           phi::ElementwiseDivOpArgumentMapping);
263 264 265 266 267 268
PD_REGISTER_ARG_MAPPING_FN(elementwise_max,
                           phi::ElementwiseMaxOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_min,
                           phi::ElementwiseMinOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_mod,
                           phi::ElementwiseModOpArgumentMapping);
269 270 271 272
PD_REGISTER_ARG_MAPPING_FN(elementwise_floordiv,
                           phi::ElementwiseFloorDivOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_pow,
                           phi::ElementwisePowOpArgumentMapping);
273
PD_REGISTER_ARG_MAPPING_FN(elementwise_add_grad,
274
                           phi::ElementwiseAddGradOpArgumentMapping);
275
PD_REGISTER_ARG_MAPPING_FN(elementwise_add_grad_grad,
276
                           phi::ElementwiseAddDoubleGradOpArgumentMapping);
277
PD_REGISTER_ARG_MAPPING_FN(elementwise_add_triple_grad,
278
                           phi::ElementwiseAddTripleGradOpArgumentMapping);
279
PD_REGISTER_ARG_MAPPING_FN(elementwise_sub_grad,
280
                           phi::ElementwiseSubGradOpArgumentMapping);
281 282
PD_REGISTER_ARG_MAPPING_FN(elementwise_sub_grad_grad,
                           phi::ElementwiseSubDoubleGradOpArgumentMapping);
283 284 285 286
PD_REGISTER_ARG_MAPPING_FN(elementwise_div_grad,
                           phi::ElementwiseDivGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_div_grad_grad,
                           phi::ElementwiseDivDoubleGradOpArgumentMapping);
Y
YuanRisheng 已提交
287 288 289 290 291 292
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_grad,
                           phi::ElementwiseMulGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_grad_grad,
                           phi::ElementwiseMulDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_triple_grad,
                           phi::ElementwiseMulTripleGradOpArgumentMapping);
Y
YuanRisheng 已提交
293 294 295 296
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmax,
                           phi::ElementwiseFMaxOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmin,
                           phi::ElementwiseFMinOpArgumentMapping);
297 298 299 300
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmax_grad,
                           phi::ElementwiseFMaxGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmin_grad,
                           phi::ElementwiseFMinGradOpArgumentMapping);
301 302 303 304
PD_REGISTER_ARG_MAPPING_FN(elementwise_max_grad,
                           phi::ElementwiseMaxGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_min_grad,
                           phi::ElementwiseMinGradOpArgumentMapping);
305 306
PD_REGISTER_ARG_MAPPING_FN(elementwise_pow_grad,
                           phi::ElementwisePowGradOpArgumentMapping);