set_value_sig.cc 13.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21

// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/core/compat/op_utils.h"

namespace phi {

KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) {
  if (ctx.IsDenseTensorInput("Input")) {
22 23 24
    if (ctx.InputSize("StartsTensorList") > 0) {
      if (ctx.InputSize("EndsTensorList") > 0) {
        if (ctx.InputSize("StepsTensorList") > 0) {
25 26 27 28 29 30 31 32 33 34
          if (ctx.HasInput("ValueTensor")) {
            return KernelSignature("set_value_with_tensor",
                                   {"Input", "ValueTensor"},
                                   {"StartsTensorList",
                                    "EndsTensorList",
                                    "StepsTensorList",
                                    "axes",
                                    "decrease_axes",
                                    "none_axes"},
                                   {"Out"});
35
          } else {
36 37 38 39 40 41 42 43 44
            return KernelSignature("set_value",
                                   {"Input"},
                                   {"StartsTensorList",
                                    "EndsTensorList",
                                    "StepsTensorList",
                                    "axes",
                                    "decrease_axes",
                                    "none_axes",
                                    "shape",
45
                                    "values"},
46 47 48 49 50 51 52 53 54 55 56 57 58
                                   {"Out"});
          }
        } else {
          if (ctx.HasInput("ValueTensor")) {
            return KernelSignature("set_value_with_tensor",
                                   {"Input", "ValueTensor"},
                                   {"StartsTensorList",
                                    "EndsTensorList",
                                    "steps",
                                    "axes",
                                    "decrease_axes",
                                    "none_axes"},
                                   {"Out"});
59
          } else {
傅剑寒 已提交
60 61 62 63 64 65 66 67 68
            return KernelSignature("set_value",
                                   {"Input"},
                                   {"StartsTensorList",
                                    "EndsTensorList",
                                    "steps",
                                    "axes",
                                    "decrease_axes",
                                    "none_axes",
                                    "shape",
69
                                    "values"},
傅剑寒 已提交
70
                                   {"Out"});
71 72 73
          }
        }
      } else {
74
        if (ctx.InputSize("StepsTensorList") > 0) {
75 76 77 78 79 80 81 82 83 84
          if (ctx.HasInput("ValueTensor")) {
            return KernelSignature("set_value_with_tensor",
                                   {"Input", "ValueTensor"},
                                   {"StartsTensorList",
                                    "ends",
                                    "StepsTensorList",
                                    "axes",
                                    "decrease_axes",
                                    "none_axes"},
                                   {"Out"});
85
          } else {
86 87 88 89 90 91 92 93 94
            return KernelSignature("set_value",
                                   {"Input"},
                                   {"StartsTensorList",
                                    "ends",
                                    "StepsTensorList",
                                    "axes",
                                    "decrease_axes",
                                    "none_axes",
                                    "shape",
95
                                    "values"},
96 97 98 99 100 101 102 103 104 105 106 107 108
                                   {"Out"});
          }
        } else {
          if (ctx.HasInput("ValueTensor")) {
            return KernelSignature("set_value_with_tensor",
                                   {"Input", "ValueTensor"},
                                   {"StartsTensorList",
                                    "ends",
                                    "steps",
                                    "axes",
                                    "decrease_axes",
                                    "none_axes"},
                                   {"Out"});
109
          } else {
110 111 112 113 114 115 116 117 118
            return KernelSignature("set_value",
                                   {"Input"},
                                   {"StartsTensorList",
                                    "ends",
                                    "steps",
                                    "axes",
                                    "decrease_axes",
                                    "none_axes",
                                    "shape",
119
                                    "values"},
120 121 122 123 124
                                   {"Out"});
          }
        }
      }
    } else {
125 126
      if (ctx.InputSize("EndsTensorList") > 0) {
        if (ctx.InputSize("StepsTensorList") > 0) {
127 128 129 130 131 132 133 134 135 136
          if (ctx.HasInput("ValueTensor")) {
            return KernelSignature("set_value_with_tensor",
                                   {"Input", "ValueTensor"},
                                   {"starts",
                                    "EndsTensorList",
                                    "StepsTensorList",
                                    "axes",
                                    "decrease_axes",
                                    "none_axes"},
                                   {"Out"});
137
          } else {
138 139 140 141 142 143 144 145 146
            return KernelSignature("set_value",
                                   {"Input"},
                                   {"starts",
                                    "EndsTensorList",
                                    "StepsTensorList",
                                    "axes",
                                    "decrease_axes",
                                    "none_axes",
                                    "shape",
147
                                    "values"},
148 149 150 151 152 153 154 155 156 157 158 159 160
                                   {"Out"});
          }
        } else {
          if (ctx.HasInput("ValueTensor")) {
            return KernelSignature("set_value_with_tensor",
                                   {"Input", "ValueTensor"},
                                   {"starts",
                                    "EndsTensorList",
                                    "steps",
                                    "axes",
                                    "decrease_axes",
                                    "none_axes"},
                                   {"Out"});
161
          } else {
162 163 164 165 166 167 168 169 170
            return KernelSignature("set_value",
                                   {"Input"},
                                   {"starts",
                                    "EndsTensorList",
                                    "steps",
                                    "axes",
                                    "decrease_axes",
                                    "none_axes",
                                    "shape",
171
                                    "values"},
172 173 174 175
                                   {"Out"});
          }
        }
      } else {
176
        if (ctx.InputSize("StepsTensorList") > 0) {
177 178 179 180 181 182 183 184 185 186
          if (ctx.HasInput("ValueTensor")) {
            return KernelSignature("set_value_with_tensor",
                                   {"Input", "ValueTensor"},
                                   {"starts",
                                    "ends",
                                    "StepsTensorList",
                                    "axes",
                                    "decrease_axes",
                                    "none_axes"},
                                   {"Out"});
187
          } else {
188 189 190 191 192 193 194 195 196
            return KernelSignature("set_value",
                                   {"Input"},
                                   {"starts",
                                    "ends",
                                    "StepsTensorList",
                                    "axes",
                                    "decrease_axes",
                                    "none_axes",
                                    "shape",
197
                                    "values"},
198 199 200 201 202 203 204 205 206 207 208 209 210
                                   {"Out"});
          }
        } else {
          if (ctx.HasInput("ValueTensor")) {
            return KernelSignature("set_value_with_tensor",
                                   {"Input", "ValueTensor"},
                                   {"starts",
                                    "ends",
                                    "steps",
                                    "axes",
                                    "decrease_axes",
                                    "none_axes"},
                                   {"Out"});
211
          } else {
212 213 214 215 216 217 218 219 220
            return KernelSignature("set_value",
                                   {"Input"},
                                   {"starts",
                                    "ends",
                                    "steps",
                                    "axes",
                                    "decrease_axes",
                                    "none_axes",
                                    "shape",
221
                                    "values"},
222
                                   {"Out"});
223 224 225 226 227 228 229
          }
        }
      }
    }
  }
  return KernelSignature("unregistered", {}, {}, {});
}
230 231 232

KernelSignature SetValueGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
233 234 235
  if (ctx.InputSize("StartsTensorList") > 0) {
    if (ctx.InputSize("EndsTensorList") > 0) {
      if (ctx.InputSize("StepsTensorList") > 0) {
236 237 238 239 240 241 242 243 244
        return KernelSignature("set_value_grad",
                               {"Out@GRAD"},
                               {"StartsTensorList",
                                "EndsTensorList",
                                "StepsTensorList",
                                "axes",
                                "decrease_axes",
                                "none_axes"},
                               {"Input@GRAD", "ValueTensor@GRAD"});
245
      } else {
246 247 248 249 250 251 252 253 254
        return KernelSignature("set_value_grad",
                               {"Out@GRAD"},
                               {"StartsTensorList",
                                "EndsTensorList",
                                "steps",
                                "axes",
                                "decrease_axes",
                                "none_axes"},
                               {"Input@GRAD", "ValueTensor@GRAD"});
255 256
      }
    } else {
257
      if (ctx.InputSize("StepsTensorList") > 0) {
258 259 260 261 262 263 264 265 266
        return KernelSignature("set_value_grad",
                               {"Out@GRAD"},
                               {"StartsTensorList",
                                "ends",
                                "StepsTensorList",
                                "axes",
                                "decrease_axes",
                                "none_axes"},
                               {"Input@GRAD", "ValueTensor@GRAD"});
267
      } else {
268 269 270 271 272 273 274 275 276
        return KernelSignature("set_value_grad",
                               {"Out@GRAD"},
                               {"StartsTensorList",
                                "ends",
                                "steps",
                                "axes",
                                "decrease_axes",
                                "none_axes"},
                               {"Input@GRAD", "ValueTensor@GRAD"});
277 278 279
      }
    }
  } else {
280 281
    if (ctx.InputSize("EndsTensorList") > 0) {
      if (ctx.InputSize("StepsTensorList") > 0) {
282 283 284 285 286 287 288 289 290
        return KernelSignature("set_value_grad",
                               {"Out@GRAD"},
                               {"starts",
                                "EndsTensorList",
                                "StepsTensorList",
                                "axes",
                                "decrease_axes",
                                "none_axes"},
                               {"Input@GRAD", "ValueTensor@GRAD"});
291
      } else {
292 293 294 295 296 297 298 299 300
        return KernelSignature("set_value_grad",
                               {"Out@GRAD"},
                               {"starts",
                                "EndsTensorList",
                                "steps",
                                "axes",
                                "decrease_axes",
                                "none_axes"},
                               {"Input@GRAD", "ValueTensor@GRAD"});
301 302
      }
    } else {
303
      if (ctx.InputSize("StepsTensorList") > 0) {
304 305 306 307 308 309 310 311 312
        return KernelSignature("set_value_grad",
                               {"Out@GRAD"},
                               {"starts",
                                "ends",
                                "StepsTensorList",
                                "axes",
                                "decrease_axes",
                                "none_axes"},
                               {"Input@GRAD", "ValueTensor@GRAD"});
313 314 315
      } else {
        return KernelSignature(
            "set_value_grad",
316
            {"Out@GRAD"},
317
            {"starts", "ends", "steps", "axes", "decrease_axes", "none_axes"},
318
            {"Input@GRAD", "ValueTensor@GRAD"});
319 320 321 322 323
      }
    }
  }
}

324 325 326
}  // namespace phi

PD_REGISTER_ARG_MAPPING_FN(set_value, phi::SetValueOpArgumentMapping);
327
PD_REGISTER_ARG_MAPPING_FN(set_value_grad, phi::SetValueGradOpArgumentMapping);