strided_slice_sig.cc 14.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <string>

#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/utils/small_vector.h"

namespace phi {

KernelSignature StridedSliceOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  const auto& starts = paddle::any_cast<std::vector<int>>(ctx.Attr("starts"));
  const auto& ends = paddle::any_cast<std::vector<int>>(ctx.Attr("ends"));
  const auto& strides = paddle::any_cast<std::vector<int>>(ctx.Attr("strides"));

  bool use_attr_starts = !ctx.IsRuntime() && !starts.empty();
  bool use_attr_ends = !ctx.IsRuntime() && !ends.empty();
  bool use_attr_strides = !ctx.IsRuntime() && !strides.empty();

  std::string starts_key =
      ctx.HasInput("StartsTensor")
          ? "StartsTensor"
          : (ctx.InputSize("StartsTensorList") > 0
                 ? (use_attr_starts ? "starts" : "StartsTensorList")
                 : "starts");
  std::string ends_key =
      ctx.HasInput("EndsTensor")
          ? "EndsTensor"
          : (ctx.InputSize("EndsTensorList") > 0
                 ? (use_attr_ends ? "ends" : "EndsTensorList")
                 : "ends");
  std::string strides_key =
      ctx.HasInput("StridesTensor")
          ? "StridesTensor"
          : (ctx.InputSize("StridesTensorList") > 0
                 ? (use_attr_strides ? "strides" : "StridesTensorList")
                 : "strides");

  paddle::SmallVector<std::string> inputs = {"Input"};
  paddle::SmallVector<std::string> attrs = {"axes",
                                            starts_key,
                                            ends_key,
                                            strides_key,
                                            "infer_flags",
                                            "decrease_axis"};
  paddle::SmallVector<std::string> outputs = {"Out"};

60
  std::string kernel_name;
61
  if (ctx.IsDenseTensorVectorInput("Input")) {
62
    kernel_name = "strided_slice_array";
63
  } else {
64
    kernel_name = "strided_slice_raw";
65 66
  }
  // NOTE(dev): Use this to avoid regularization.
67
  KernelSignature sig(kernel_name, inputs, attrs, outputs);
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
  return sig;
}

KernelSignature StridedSliceGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  const auto& starts = paddle::any_cast<std::vector<int>>(ctx.Attr("starts"));
  const auto& ends = paddle::any_cast<std::vector<int>>(ctx.Attr("ends"));
  const auto& strides = paddle::any_cast<std::vector<int>>(ctx.Attr("strides"));

  bool use_attr_starts = !ctx.IsRuntime() && !starts.empty();
  bool use_attr_ends = !ctx.IsRuntime() && !ends.empty();
  bool use_attr_strides = !ctx.IsRuntime() && !strides.empty();

  std::string starts_key =
      ctx.HasInput("StartsTensor")
          ? "StartsTensor"
          : (ctx.InputSize("StartsTensorList") > 0
                 ? (use_attr_starts ? "starts" : "StartsTensorList")
                 : "starts");
  std::string ends_key =
      ctx.HasInput("EndsTensor")
          ? "EndsTensor"
          : (ctx.InputSize("EndsTensorList") > 0
                 ? (use_attr_ends ? "ends" : "EndsTensorList")
                 : "ends");
  std::string strides_key =
      ctx.HasInput("StridesTensor")
          ? "StridesTensor"
          : (ctx.InputSize("StridesTensorList") > 0
                 ? (use_attr_strides ? "strides" : "StridesTensorList")
                 : "strides");

  paddle::SmallVector<std::string> inputs = {"Input", GradVarName("Out")};
  paddle::SmallVector<std::string> attrs = {"axes",
                                            starts_key,
                                            ends_key,
                                            strides_key,
                                            "infer_flags",
                                            "decrease_axis"};
  paddle::SmallVector<std::string> outputs = {GradVarName("Input")};

109
  std::string kernel_name;
110
  if (ctx.IsDenseTensorVectorInput("Input")) {
111
    kernel_name = "strided_slice_array_grad";
112
  } else {
113
    kernel_name = "strided_slice_raw_grad";
114 115 116
  }

  // NOTE(dev): Use this to avoid regularization.
117
  KernelSignature sig(kernel_name, inputs, attrs, outputs);
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
  return sig;
}

}  // namespace phi

PD_REGISTER_ARG_MAPPING_FN(strided_slice, phi::StridedSliceOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(strided_slice_grad,
                           phi::StridedSliceGradOpArgumentMapping);

/*
******************************************************************
NOTE: The following codes are for 'get_compat_kernel_signature.py'
      DO NOT EDIT IT if you don't know the mechanism.
******************************************************************

############################  Forward ############################

135
return KernelSignature("strided_slice_raw", {"Input"},
136 137 138 139
              {"axes", "StartsTensor", "EndsTensor",
"StartsTensor","infer_flags", "decrease_axis"},
              {"Out"});

140
return KernelSignature("strided_slice_raw", {"Input"},
141 142 143 144
              {"axes", "StartsTensor", "EndsTensor",
"StartsTensorList","infer_flags", "decrease_axis"},
              {"Out"});

145
return KernelSignature("strided_slice_raw", {"Input"},
146 147 148 149
              {"axes", "StartsTensor", "EndsTensor", "starts","infer_flags",
"decrease_axis"},
              {"Out"});

150
return KernelSignature("strided_slice_raw", {"Input"},
151 152 153 154
              {"axes", "StartsTensor", "EndsTensorList",
"StartsTensor","infer_flags", "decrease_axis"},
              {"Out"});

155
return KernelSignature("strided_slice_raw", {"Input"},
156 157 158 159
              {"axes", "StartsTensor", "EndsTensorList",
"StartsTensorList","infer_flags", "decrease_axis"},
              {"Out"});

160
return KernelSignature("strided_slice_raw", {"Input"},
161 162 163 164
              {"axes", "StartsTensor", "EndsTensorList", "starts","infer_flags",
"decrease_axis"},
              {"Out"});

165
return KernelSignature("strided_slice_raw", {"Input"},
166 167 168 169
              {"axes", "StartsTensor", "ends", "StartsTensor","infer_flags",
"decrease_axis"},
              {"Out"});

170
return KernelSignature("strided_slice_raw", {"Input"},
171 172 173 174
              {"axes", "StartsTensor", "ends", "StartsTensorList","infer_flags",
"decrease_axis"},
              {"Out"});

175
return KernelSignature("strided_slice_raw", {"Input"},
176 177 178 179
              {"axes", "StartsTensor", "ends", "starts","infer_flags",
"decrease_axis"},
              {"Out"});

180
return KernelSignature("strided_slice_raw", {"Input"},
181 182 183 184
              {"axes", "StartsTensorList", "EndsTensor",
"StartsTensor","infer_flags", "decrease_axis"},
              {"Out"});

185
return KernelSignature("strided_slice_raw", {"Input"},
186 187 188 189
              {"axes", "StartsTensorList", "EndsTensor",
"StartsTensorList","infer_flags", "decrease_axis"},
              {"Out"});

190
return KernelSignature("strided_slice_raw", {"Input"},
191 192 193 194
              {"axes", "StartsTensorList", "EndsTensor", "starts","infer_flags",
"decrease_axis"},
              {"Out"});

195
return KernelSignature("strided_slice_raw", {"Input"},
196 197 198 199
              {"axes", "StartsTensorList", "EndsTensorList",
"StartsTensor","infer_flags", "decrease_axis"},
              {"Out"});

200
return KernelSignature("strided_slice_raw", {"Input"},
201 202 203 204
              {"axes", "StartsTensorList", "EndsTensorList",
"StartsTensorList","infer_flags", "decrease_axis"},
              {"Out"});

205
return KernelSignature("strided_slice_raw", {"Input"},
206 207 208 209
              {"axes", "StartsTensorList", "EndsTensorList",
"starts","infer_flags", "decrease_axis"},
              {"Out"});

210
return KernelSignature("strided_slice_raw", {"Input"},
211 212 213 214
              {"axes", "StartsTensorList", "ends", "StartsTensor","infer_flags",
"decrease_axis"},
              {"Out"});

215
return KernelSignature("strided_slice_raw", {"Input"},
216 217 218 219
              {"axes", "StartsTensorList", "ends",
"StartsTensorList","infer_flags", "decrease_axis"},
              {"Out"});

220
return KernelSignature("strided_slice_raw", {"Input"},
221 222 223 224
              {"axes", "StartsTensorList", "ends", "starts","infer_flags",
"decrease_axis"},
              {"Out"});

225
return KernelSignature("strided_slice_raw", {"Input"},
226 227 228 229
              {"axes", "starts", "EndsTensor", "StartsTensor","infer_flags",
"decrease_axis"},
              {"Out"});

230
return KernelSignature("strided_slice_raw", {"Input"},
231 232 233 234
              {"axes", "starts", "EndsTensor", "StartsTensorList","infer_flags",
"decrease_axis"},
              {"Out"});

235
return KernelSignature("strided_slice_raw", {"Input"},
236 237 238 239
              {"axes", "starts", "EndsTensor", "starts","infer_flags",
"decrease_axis"},
              {"Out"});

240
return KernelSignature("strided_slice_raw", {"Input"},
241 242 243 244
              {"axes", "starts", "EndsTensorList", "StartsTensor","infer_flags",
"decrease_axis"},
              {"Out"});

245
return KernelSignature("strided_slice_raw", {"Input"},
246 247 248 249
              {"axes", "starts", "EndsTensorList",
"StartsTensorList","infer_flags", "decrease_axis"},
              {"Out"});

250
return KernelSignature("strided_slice_raw", {"Input"},
251 252 253 254
              {"axes", "starts", "EndsTensorList", "starts","infer_flags",
"decrease_axis"},
              {"Out"});

255
return KernelSignature("strided_slice_raw", {"Input"},
256 257 258 259
              {"axes", "starts", "ends", "StartsTensor","infer_flags",
"decrease_axis"},
              {"Out"});

260
return KernelSignature("strided_slice_raw", {"Input"},
261 262 263 264
              {"axes", "starts", "ends", "StartsTensorList","infer_flags",
"decrease_axis"},
              {"Out"});

265
return KernelSignature("strided_slice_raw", {"Input"},
266 267 268 269
              {"axes", "starts", "ends", "starts","infer_flags",
"decrease_axis"},
              {"Out"});

270
return KernelSignature("strided_slice_array", {"Input"},
271 272 273 274
              {"axes", "StartsTensor", "EndsTensor",
"StartsTensor","infer_flags", "decrease_axis"},
              {"Out"});

275
return KernelSignature("strided_slice_array", {"Input"},
276 277 278 279
              {"axes", "StartsTensor", "EndsTensor",
"StartsTensorList","infer_flags", "decrease_axis"},
              {"Out"});

280
return KernelSignature("strided_slice_array", {"Input"},
281 282 283 284
              {"axes", "StartsTensor", "EndsTensor", "starts","infer_flags",
"decrease_axis"},
              {"Out"});

285
return KernelSignature("strided_slice_array", {"Input"},
286 287 288 289
              {"axes", "StartsTensor", "EndsTensorList",
"StartsTensor","infer_flags", "decrease_axis"},
              {"Out"});

290
return KernelSignature("strided_slice_array", {"Input"},
291 292 293 294
              {"axes", "StartsTensor", "EndsTensorList",
"StartsTensorList","infer_flags", "decrease_axis"},
              {"Out"});

295
return KernelSignature("strided_slice_array", {"Input"},
296 297 298 299
              {"axes", "StartsTensor", "EndsTensorList", "starts","infer_flags",
"decrease_axis"},
              {"Out"});

300
return KernelSignature("strided_slice_array", {"Input"},
301 302 303 304
              {"axes", "StartsTensor", "ends", "StartsTensor","infer_flags",
"decrease_axis"},
              {"Out"});

305
return KernelSignature("strided_slice_array", {"Input"},
306 307 308 309
              {"axes", "StartsTensor", "ends", "StartsTensorList","infer_flags",
"decrease_axis"},
              {"Out"});

310
return KernelSignature("strided_slice_array", {"Input"},
311 312 313 314
              {"axes", "StartsTensor", "ends", "starts","infer_flags",
"decrease_axis"},
              {"Out"});

315
return KernelSignature("strided_slice_array", {"Input"},
316 317 318 319
              {"axes", "StartsTensorList", "EndsTensor",
"StartsTensor","infer_flags", "decrease_axis"},
              {"Out"});

320
return KernelSignature("strided_slice_array", {"Input"},
321 322 323 324
              {"axes", "StartsTensorList", "EndsTensor",
"StartsTensorList","infer_flags", "decrease_axis"},
              {"Out"});

325
return KernelSignature("strided_slice_array", {"Input"},
326 327 328 329
              {"axes", "StartsTensorList", "EndsTensor", "starts","infer_flags",
"decrease_axis"},
              {"Out"});

330
return KernelSignature("strided_slice_array", {"Input"},
331 332 333 334
              {"axes", "StartsTensorList", "EndsTensorList",
"StartsTensor","infer_flags", "decrease_axis"},
              {"Out"});

335
return KernelSignature("strided_slice_array", {"Input"},
336 337 338 339
              {"axes", "StartsTensorList", "EndsTensorList",
"StartsTensorList","infer_flags", "decrease_axis"},
              {"Out"});

340
return KernelSignature("strided_slice_array", {"Input"},
341 342 343 344
              {"axes", "StartsTensorList", "EndsTensorList",
"starts","infer_flags", "decrease_axis"},
              {"Out"});

345
return KernelSignature("strided_slice_array", {"Input"},
346 347 348 349
              {"axes", "StartsTensorList", "ends", "StartsTensor","infer_flags",
"decrease_axis"},
              {"Out"});

350
return KernelSignature("strided_slice_array", {"Input"},
351 352 353 354
              {"axes", "StartsTensorList", "ends",
"StartsTensorList","infer_flags", "decrease_axis"},
              {"Out"});

355
return KernelSignature("strided_slice_array", {"Input"},
356 357 358 359
              {"axes", "StartsTensorList", "ends", "starts","infer_flags",
"decrease_axis"},
              {"Out"});

360
return KernelSignature("strided_slice_array", {"Input"},
361 362 363 364
              {"axes", "starts", "EndsTensor", "StartsTensor","infer_flags",
"decrease_axis"},
              {"Out"});

365
return KernelSignature("strided_slice_array", {"Input"},
366 367 368 369
              {"axes", "starts", "EndsTensor", "StartsTensorList","infer_flags",
"decrease_axis"},
              {"Out"});

370
return KernelSignature("strided_slice_array", {"Input"},
371 372 373 374
              {"axes", "starts", "EndsTensor", "starts","infer_flags",
"decrease_axis"},
              {"Out"});

375
return KernelSignature("strided_slice_array", {"Input"},
376 377 378 379
              {"axes", "starts", "EndsTensorList", "StartsTensor","infer_flags",
"decrease_axis"},
              {"Out"});

380
return KernelSignature("strided_slice_array", {"Input"},
381 382 383 384
              {"axes", "starts", "EndsTensorList",
"StartsTensorList","infer_flags", "decrease_axis"},
              {"Out"});

385
return KernelSignature("strided_slice_array", {"Input"},
386 387 388 389
              {"axes", "starts", "EndsTensorList", "starts","infer_flags",
"decrease_axis"},
              {"Out"});

390
return KernelSignature("strided_slice_array", {"Input"},
391 392 393 394
              {"axes", "starts", "ends", "StartsTensor","infer_flags",
"decrease_axis"},
              {"Out"});

395
return KernelSignature("strided_slice_array", {"Input"},
396 397 398 399
              {"axes", "starts", "ends", "StartsTensorList","infer_flags",
"decrease_axis"},
              {"Out"});

400
return KernelSignature("strided_slice_array", {"Input"},
401 402 403 404
              {"axes", "starts", "ends", "starts","infer_flags",
"decrease_axis"},
              {"Out"});
*/