elementwise_sig.cc 11.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/core/compat/op_utils.h"
16

17
namespace phi {
18 19 20 21

KernelSignature ElementwiseAddOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
Y
YuanRisheng 已提交
22 23
  if (axis == -1) {
    return KernelSignature("add", {"X", "Y"}, {}, {"Out"});
24
  }
Y
YuanRisheng 已提交
25
  return KernelSignature("add_raw", {"X", "Y"}, {"axis"}, {"Out"});
26 27 28 29 30
}

KernelSignature ElementwiseSubOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
Y
YuanRisheng 已提交
31 32
  if (axis == -1) {
    return KernelSignature("subtract", {"X", "Y"}, {}, {"Out"});
33
  }
Y
YuanRisheng 已提交
34
  return KernelSignature("subtract_raw", {"X", "Y"}, {"axis"}, {"Out"});
35 36 37 38 39 40 41 42 43 44
}

KernelSignature ElementwiseMulOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (ctx.IsDenseTensorInput("X")) {
    if (axis == -1) {
      return KernelSignature("multiply", {"X", "Y"}, {}, {"Out"});
    }
    return KernelSignature("multiply_raw", {"X", "Y"}, {"axis"}, {"Out"});
45 46 47 48 49
  } else {
    if (axis == -1) {
      return KernelSignature("multiply_sr", {"X", "Y"}, {}, {"Out"});
    }
    return KernelSignature("multiply_raw_sr", {"X", "Y"}, {"axis"}, {"Out"});
50 51 52 53 54 55
  }
}

KernelSignature ElementwiseDivOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
Y
YuanRisheng 已提交
56 57
  if (axis == -1) {
    return KernelSignature("divide", {"X", "Y"}, {}, {"Out"});
58
  }
Y
YuanRisheng 已提交
59
  return KernelSignature("divide_raw", {"X", "Y"}, {"axis"}, {"Out"});
60 61
}

62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
KernelSignature ElementwiseMaxOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("maximum", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("maximum_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

KernelSignature ElementwiseMinOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("minimum", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("minimum_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

KernelSignature ElementwiseModOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  int axis = paddle::any_cast<int>(ctx.Attr("axis"));
  if (axis == -1) {
    return KernelSignature("modulo", {"X", "Y"}, {}, {"Out"});
  }
  return KernelSignature("modulo_raw", {"X", "Y"}, {"axis"}, {"Out"});
}

89 90
KernelSignature ElementwiseAddGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
Y
YuanRisheng 已提交
91 92 93 94
  return KernelSignature("add_grad",
                         {"X", "Y", GradVarName("Out")},
                         {"axis"},
                         {GradVarName("X"), GradVarName("Y")});
95 96
}

97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
KernelSignature ElementwiseAddDoubleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature(
      "add_double_grad", {"Y", "DDX", "DDY", "DOut"}, {"axis"}, {"DDOut"});
}

KernelSignature ElementwiseAddTripleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("add_triple_grad",
                         {"DDX", "DDY", "D_DDOut"},
                         {"axis"},
                         {"D_DDX", "D_DDY"});
}

KernelSignature ElementwiseSubGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
Y
YuanRisheng 已提交
113 114 115 116
  return KernelSignature("subtract_grad",
                         {"X", "Y", GradVarName("Out")},
                         {"axis"},
                         {GradVarName("X"), GradVarName("Y")});
117 118
}

119 120 121 122 123 124
KernelSignature ElementwiseSubDoubleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature(
      "subtract_double_grad", {"Y", "DDX", "DDY", "DOut"}, {"axis"}, {"DDOut"});
}

125 126 127 128 129 130 131 132
KernelSignature ElementwiseDivGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("divide_grad",
                         {"X", "Y", "Out", GradVarName("Out")},
                         {"axis"},
                         {GradVarName("X"), GradVarName("Y")});
}

133 134
KernelSignature ElementwiseFMinGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
Y
YuanRisheng 已提交
135
  return KernelSignature("fmin_grad",
136 137 138 139 140
                         {"X", "Y", GradVarName("Out")},
                         {"axis"},
                         {GradVarName("X"), GradVarName("Y")});
}

141 142 143 144 145 146 147 148
KernelSignature ElementwiseDivDoubleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("divide_double_grad",
                         {"Y", "Out", "DX", "DDX", "DDY"},
                         {"axis"},
                         {GradVarName("Y"), "DOut", "DDOut"});
}

Y
YuanRisheng 已提交
149 150 151 152 153 154 155 156
KernelSignature ElementwiseMulGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("multiply_grad",
                         {"X", "Y", GradVarName("Out")},
                         {"axis"},
                         {GradVarName("X"), GradVarName("Y")});
}

Y
YuanRisheng 已提交
157 158 159 160 161 162 163 164 165 166
KernelSignature ElementwiseFMaxOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("fmax", {"X", "Y"}, {"axis"}, {"Out"});
}

KernelSignature ElementwiseFMinOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("fmin", {"X", "Y"}, {"axis"}, {"Out"});
}

167 168
KernelSignature ElementwiseFMaxGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
Y
YuanRisheng 已提交
169
  return KernelSignature("fmax_grad",
170 171 172 173 174
                         {"X", "Y", GradVarName("Out")},
                         {"axis"},
                         {GradVarName("X"), GradVarName("Y")});
}

Y
YuanRisheng 已提交
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
KernelSignature ElementwiseMulDoubleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("multiply_double_grad",
                         {"X", "Y", "DOut", "DDX", "DDY"},
                         {"axis"},
                         {GradVarName("X"), GradVarName("Y"), "DDOut"});
}

KernelSignature ElementwiseMulTripleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature(
      "multiply_triple_grad",
      {"X", "Y", "DOut", "DDX", "DDY", "D_DX", "D_DY", "D_DDOut"},
      {"axis"},
      {"D_X", "D_Y", "D_DOut", "D_DDX", "D_DDY"});
}

192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
KernelSignature ElementwiseMaxGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("maximum_grad",
                         {"X", "Y", GradVarName("Out")},
                         {"axis"},
                         {GradVarName("X"), GradVarName("Y")});
}

KernelSignature ElementwiseMinGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("minimum_grad",
                         {"X", "Y", GradVarName("Out")},
                         {"axis"},
                         {GradVarName("X"), GradVarName("Y")});
}
207
}  // namespace phi
208

209 210 211 212
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add, add);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub, subtract);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul, multiply);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_div, divide);
213 214 215
PD_REGISTER_BASE_KERNEL_NAME(elementwise_max, maximum);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_min, minimum);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mod, modulo);
216 217 218 219
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad, add_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad_grad, add_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_triple_grad, add_triple_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad, subtract_grad);
220
PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad_grad, subtract_double_grad);
221 222
PD_REGISTER_BASE_KERNEL_NAME(elementwise_div_grad, divide_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_div_grad_grad, divide_double_grad);
Y
YuanRisheng 已提交
223 224 225
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad, multiply_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad_grad, multiply_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_triple_grad, multiply_triple_grad);
Y
YuanRisheng 已提交
226 227 228 229
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmax, fmax);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmin, fmin);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmax_grad, fmax_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmin_grad, fmin_grad);
230 231
PD_REGISTER_BASE_KERNEL_NAME(elementwise_max_grad, maximum_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_min_grad, minimum_grad);
232 233

PD_REGISTER_ARG_MAPPING_FN(elementwise_add,
234
                           phi::ElementwiseAddOpArgumentMapping);
235
PD_REGISTER_ARG_MAPPING_FN(elementwise_sub,
236
                           phi::ElementwiseSubOpArgumentMapping);
237
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul,
238
                           phi::ElementwiseMulOpArgumentMapping);
239
PD_REGISTER_ARG_MAPPING_FN(elementwise_div,
240
                           phi::ElementwiseDivOpArgumentMapping);
241 242 243 244 245 246
PD_REGISTER_ARG_MAPPING_FN(elementwise_max,
                           phi::ElementwiseMaxOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_min,
                           phi::ElementwiseMinOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_mod,
                           phi::ElementwiseModOpArgumentMapping);
247
PD_REGISTER_ARG_MAPPING_FN(elementwise_add_grad,
248
                           phi::ElementwiseAddGradOpArgumentMapping);
249
PD_REGISTER_ARG_MAPPING_FN(elementwise_add_grad_grad,
250
                           phi::ElementwiseAddDoubleGradOpArgumentMapping);
251
PD_REGISTER_ARG_MAPPING_FN(elementwise_add_triple_grad,
252
                           phi::ElementwiseAddTripleGradOpArgumentMapping);
253
PD_REGISTER_ARG_MAPPING_FN(elementwise_sub_grad,
254
                           phi::ElementwiseSubGradOpArgumentMapping);
255 256
PD_REGISTER_ARG_MAPPING_FN(elementwise_sub_grad_grad,
                           phi::ElementwiseSubDoubleGradOpArgumentMapping);
257 258 259 260
PD_REGISTER_ARG_MAPPING_FN(elementwise_div_grad,
                           phi::ElementwiseDivGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_div_grad_grad,
                           phi::ElementwiseDivDoubleGradOpArgumentMapping);
Y
YuanRisheng 已提交
261 262 263 264 265 266
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_grad,
                           phi::ElementwiseMulGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_grad_grad,
                           phi::ElementwiseMulDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_triple_grad,
                           phi::ElementwiseMulTripleGradOpArgumentMapping);
Y
YuanRisheng 已提交
267 268 269 270
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmax,
                           phi::ElementwiseFMaxOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmin,
                           phi::ElementwiseFMinOpArgumentMapping);
271 272 273 274
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmax_grad,
                           phi::ElementwiseFMaxGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmin_grad,
                           phi::ElementwiseFMinGradOpArgumentMapping);
275 276 277 278
PD_REGISTER_ARG_MAPPING_FN(elementwise_max_grad,
                           phi::ElementwiseMaxGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_min_grad,
                           phi::ElementwiseMinGradOpArgumentMapping);