backward.cc 40.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/infermeta/backward.h"
Z
zyfncg 已提交
16
#include "paddle/phi/common/type_traits.h"
F
Feiyu Chan 已提交
17
#include "paddle/phi/core/utils/data_type.h"
18 19
#include "paddle/phi/kernels/funcs/axis_utils.h"

20
namespace phi {
21

22 23 24 25 26 27 28 29 30 31
void AffineGridGradInferMeta(const MetaTensor& output_grad,
                             const IntArray& outputShape,
                             bool align_corners,
                             MetaTensor* input_grad) {
  if (input_grad) {
    auto output_dims = output_grad.dims();
    input_grad->set_dims(phi::make_ddim({output_dims[0], 2, 3}));
  }
}

W
WangZhen 已提交
32 33 34 35 36 37
void AngleGradInferMeta(const MetaTensor& x,
                        const MetaTensor& out_grad,
                        MetaTensor* x_grad) {
  UnchangedInferMeta(x, x_grad);
}

38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
void BilinearTensorProductGradInferMeta(const MetaTensor& x,
                                        const MetaTensor& y,
                                        const MetaTensor& weight,
                                        const MetaTensor& dout,
                                        MetaTensor* dx,
                                        MetaTensor* dy,
                                        MetaTensor* dweight,
                                        MetaTensor* dbias) {
  auto x_dims = x.dims();
  auto y_dims = y.dims();
  auto weight_dims = weight.dims();
  auto out_dims = dout.dims();

  PADDLE_ENFORCE_EQ(
      out_dims.size(),
      2UL,
      errors::InvalidArgument("The input(Out@GRAD) must be a 2D Tensor."));
  PADDLE_ENFORCE_EQ(
      x_dims[0],
      out_dims[0],
      errors::InvalidArgument(
          "The first dimension(batch_size) of input(Out@GRAD) must be "
          "equal to the first dimension of the Input(X)."));
  PADDLE_ENFORCE_EQ(
      weight_dims[0],
      out_dims[1],
      errors::InvalidArgument(
          "The second dimension of input(Out@GRAD) must be equal to "
          "the third dimension of the Input(Weight)."));

  if (dx) {
    dx->set_dims(x_dims);
    dx->set_dtype(x.dtype());
  }
  if (dy) {
    dy->set_dims(y_dims);
    dy->set_dtype(y.dtype());
  }
  if (dweight) {
    dweight->set_dims(weight_dims);
    dweight->set_dtype(weight.dtype());
  }
  if (dbias) {
    dbias->set_dims({1, out_dims[1]});
    dbias->set_dtype(dout.dtype());
  }
}

B
BiynXu 已提交
86 87 88 89 90
void BmmGradInferMeta(const MetaTensor& x,
                      const MetaTensor& y,
                      const MetaTensor& out_grad,
                      MetaTensor* x_grad,
                      MetaTensor* y_grad) {
91 92 93 94 95 96 97 98
  if (x_grad) {
    x_grad->set_dims(x.dims());
    x_grad->set_dtype(x.dtype());
  }
  if (y_grad) {
    y_grad->set_dims(y.dims());
    y_grad->set_dtype(y.dtype());
  }
B
BiynXu 已提交
99 100
}

101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
void ChannelShuffleGradInferMeta(const MetaTensor& out_grad,
                                 int groups,
                                 const std::string& data_format,
                                 MetaTensor* x_grad) {
  auto do_dims = out_grad.dims();
  PADDLE_ENFORCE_EQ(do_dims.size(),
                    4,
                    phi::errors::InvalidArgument(
                        "Input should be a 4-D tensor of format [N, C, H, W] "
                        "or [N, H, W, C], but got %u.",
                        do_dims.size()));
  auto dx_dims = do_dims;
  x_grad->set_dims(dx_dims);
  x_grad->set_dtype(out_grad.dtype());
}

117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
void ComplexGradInferMeta(const MetaTensor& x,
                          const MetaTensor& y,
                          const MetaTensor& dout,
                          MetaTensor* dx,
                          MetaTensor* dy) {
  auto x_dims = x.dims();
  if (dx) {
    dx->set_dims(x_dims);
    dx->set_dtype(x.dtype());
  }
  auto y_dims = y.dims();
  if (dy) {
    dy->set_dims(y_dims);
    dy->set_dtype(y.dtype());
  }
}

F
From00 已提交
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
void ConvTransposeGradInferMeta(const MetaTensor& x,
                                const MetaTensor& filter,
                                const MetaTensor& dout,
                                const std::vector<int>& strides,
                                const std::vector<int>& paddings,
                                const std::vector<int>& output_padding,
                                const std::vector<int>& output_size,
                                const std::string& padding_algorithm,
                                int groups,
                                const std::vector<int>& dilations,
                                const std::string& data_format,
                                MetaTensor* dx,
                                MetaTensor* dfilter) {
  GeneralBinaryGradInferMeta(x, filter, dx, dfilter);
}

150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
void Conv2dTransposeGradInferMeta(const MetaTensor& x,
                                  const MetaTensor& filter,
                                  const MetaTensor& dout,
                                  const std::vector<int>& strides,
                                  const std::vector<int>& paddings,
                                  const std::vector<int>& output_padding,
                                  const IntArray& output_size,
                                  const std::string& padding_algorithm,
                                  int groups,
                                  const std::vector<int>& dilations,
                                  const std::string& data_format,
                                  MetaTensor* dx,
                                  MetaTensor* dfilter) {
  GeneralBinaryGradInferMeta(x, filter, dx, dfilter);
}

F
From00 已提交
166 167 168 169 170 171 172 173
void Conv2dTransposeDoubleGradInferMeta(const MetaTensor& x,
                                        const MetaTensor& filter,
                                        const MetaTensor& dout,
                                        const MetaTensor& ddx,
                                        const MetaTensor& ddfilter,
                                        const std::vector<int>& strides,
                                        const std::vector<int>& paddings,
                                        const std::vector<int>& output_padding,
174
                                        const IntArray& output_size,
F
From00 已提交
175 176 177 178 179 180 181 182 183 184 185 186 187 188
                                        const std::string& padding_algorithm,
                                        int groups,
                                        const std::vector<int>& dilations,
                                        const std::string& data_format,
                                        MetaTensor* dx,
                                        MetaTensor* dfilter,
                                        MetaTensor* ddout) {
  GeneralBinaryGradInferMeta(x, filter, dx, dfilter);

  if (ddout) {
    ddout->share_meta(dout);
  }
}

189 190 191 192
void CropGradInferMeta(const MetaTensor& out_grad,
                       const MetaTensor& x,
                       const IntArray& offsets,
                       MetaTensor* x_grad) {
193 194 195 196 197 198 199 200
  auto x_dims = x.dims();

  if (x_grad != nullptr) {
    x_grad->set_dims(x_dims);
    x_grad->set_dtype(x.dtype());
  }
}

201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217
void FlashAttnGradInferMeta(const MetaTensor& q,
                            const MetaTensor& k,
                            const MetaTensor& v,
                            MetaTensor* dq,
                            MetaTensor* dk,
                            MetaTensor* dv) {
  if (dq) {
    dq->share_meta(q);
  }
  if (dk && k) {
    dk->share_meta(k);
  }
  if (dv && v) {
    dv->share_meta(v);
  }
}

218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280
void CrossEntropyWithSoftmaxGradInferMeta(const MetaTensor& label,
                                          const MetaTensor& softmax,
                                          const MetaTensor& loss_grad,
                                          bool soft_label,
                                          bool use_softmax,
                                          bool numeric_stable_mode,
                                          int ignore_index,
                                          int axis,
                                          MetaTensor* logits_grad,
                                          MetaConfig config) {
  auto softmax_dims = softmax.dims();
  auto labels_dims = label.dims();
  auto softmax_rank = softmax_dims.size();
  PADDLE_ENFORCE_GE(axis,
                    -softmax_rank,
                    phi::errors::InvalidArgument(
                        "Attr(axis) value should be in range [-R, R-1], "
                        "R is the rank of Input(Logits)."));
  PADDLE_ENFORCE_LT(axis,
                    softmax_rank,
                    phi::errors::InvalidArgument(
                        "Attr(axis) value should be in range [-R, R-1], "
                        "R is the rank of Input(Logits)."));

  axis = phi::funcs::CanonicalAxis(axis, softmax_rank);
  for (int i = 0; i < softmax_rank; i++) {
    if (i != axis) {
      if (config.is_runtime || (softmax_dims[i] > 0 && labels_dims[i] > 0)) {
        PADDLE_ENFORCE_EQ(
            softmax_dims[i],
            labels_dims[i],
            phi::errors::InvalidArgument(
                "Input(Logits) and Input(Label) should in same shape in "
                "dimensions except axis."));
      }
    }
  }

  if (soft_label) {
    if (config.is_runtime ||
        (softmax_dims[axis] > 0 && labels_dims[axis] > 0)) {
      PADDLE_ENFORCE_EQ(softmax_dims[axis],
                        labels_dims[axis],
                        phi::errors::InvalidArgument(
                            "If Attr(soft_label) == true, "
                            "the axis dimension of "
                            "Input(X) and Input(Label) should be equal."));
    }
  } else {
    if (config.is_runtime || labels_dims[axis] > 0) {
      PADDLE_ENFORCE_EQ(
          labels_dims[axis],
          1UL,
          phi::errors::InvalidArgument("If Attr(soft_label) == false, "
                                       "the axis dimension of "
                                       "Input(Label) should be 1."));
    }
  }

  logits_grad->set_dims(softmax.dims());
  logits_grad->set_dtype(softmax.dtype());
}

281 282 283
void DeformableConvGradInferMeta(const MetaTensor& x,
                                 const MetaTensor& offset,
                                 const MetaTensor& filter,
284
                                 const MetaTensor& mask,
285 286 287 288 289 290 291 292 293 294 295 296 297
                                 const MetaTensor& out_grad,
                                 const std::vector<int>& strides,
                                 const std::vector<int>& paddings,
                                 const std::vector<int>& dilations,
                                 int deformable_groups,
                                 int groups,
                                 int im2col_step,
                                 MetaTensor* dx,
                                 MetaTensor* offset_grad,
                                 MetaTensor* filter_grad,
                                 MetaTensor* mask_grad) {
  GeneralTernaryGradInferMeta(x, offset, filter, dx, offset_grad, filter_grad);
  if (mask) {
298
    UnchangedInferMeta(mask, mask_grad);
299 300 301
  }
}

302 303 304 305 306 307 308 309 310 311 312 313
void EigGradInferMeta(const MetaTensor& out_w,
                      const MetaTensor& out_v,
                      const MetaTensor& dout_w,
                      const MetaTensor& dout_v,
                      MetaTensor* dx) {
  auto dims = out_v.dims();

  if (dx) {
    dx->set_dims(dims);
  }
}

314 315 316 317 318 319 320 321 322 323 324 325
void EigvalshGradInferMeta(const MetaTensor& out_v,
                           const MetaTensor& out_w_grad,
                           const std::string& uplo,
                           bool is_test,
                           MetaTensor* x_grad) {
  auto dims = out_v.dims();
  if (x_grad != nullptr) {
    x_grad->set_dims(dims);
    x_grad->set_dtype(out_v.dtype());
  }
}

326 327 328 329 330 331 332 333 334
void EmbeddingGradInferMeta(const MetaTensor& x,
                            const MetaTensor& weight,
                            MetaTensor* out) {
  (void)x;
  if (weight) {
    out->share_dims(weight);
  }
}

F
Feiyu Chan 已提交
335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375
void FFTC2RGradInferMeta(const MetaTensor& x,
                         const std::vector<int64_t>& axes,
                         const std::string& normalization,
                         bool forward,
                         int64_t last_dim_size,
                         MetaTensor* out,
                         MetaConfig config) {
  PADDLE_ENFORCE_NOT_NULL(out,
                          phi::errors::InvalidArgument(
                              "Output of fft_c2r _grad should not be null."));
  const phi::DDim x_dim = x.dims();

  // only ensure that fft axes' size greater than zero at runtime
  // they might be -1 to indicate unknown size ar compile time
  if (config.is_runtime) {
    for (size_t i = 0; i < axes.size(); i++) {
      PADDLE_ENFORCE_GT(x_dim[axes[i]],
                        0,
                        phi::errors::InvalidArgument(
                            "Invalid fft n-point (%d).", x_dim[axes[i]]));
    }
  }

  out->set_layout(x.layout());
  out->set_dtype(ToComplexType(x.dtype()));

  phi::DDim out_dim = x.dims();
  const int64_t last_fft_axis = axes.back();
  if (last_dim_size > 0) {
    out_dim.at(last_fft_axis) = last_dim_size / 2 + 1;
  } else if (config.is_runtime) {
    const int64_t last_fft_dim_size = x_dim[last_fft_axis];
    out_dim.at(last_fft_axis) = last_fft_dim_size / 2 + 1;
  } else {
    const int64_t last_fft_dim_size = x_dim[last_fft_axis];
    out_dim.at(last_fft_axis) =
        last_fft_dim_size == -1 ? -1 : last_fft_dim_size / 2 + 1;
  }
  out->set_dims(out_dim);
}

Z
zhiboniu 已提交
376 377 378 379 380 381 382 383 384 385 386 387
void FillDiagonalGradInferMeta(const MetaTensor& dout,
                               float value,
                               int offset,
                               bool wrap,
                               MetaTensor* dx) {
  auto x_dims = dout.dims();
  if (dx) {
    dx->set_dims(x_dims);
    dx->set_dtype(dout.dtype());
  }
}

Z
zhiboniu 已提交
388 389 390 391 392 393 394 395 396 397 398
void FillDiagonalTensorGradInferMeta(const MetaTensor& out_grad,
                                     int64_t offset,
                                     int dim1,
                                     int dim2,
                                     MetaTensor* x_grad) {
  if (x_grad != nullptr) {
    x_grad->set_dims(out_grad.dims());
    x_grad->set_dtype(out_grad.dtype());
  }
}

399 400 401 402 403 404 405 406
void GatherNdGradInferMeta(const MetaTensor& x,
                           const MetaTensor& index,
                           const MetaTensor& out_grad,
                           MetaTensor* x_grad) {
  const auto& dtype = out_grad.dtype();
  x_grad->set_dims(x.dims());
  x_grad->share_lod(x);
  x_grad->set_dtype(dtype);
407 408
}

409 410 411 412
void GeneralBinaryGradInferMeta(const MetaTensor& x,
                                const MetaTensor& y,
                                MetaTensor* dx,
                                MetaTensor* dy) {
413 414 415 416 417 418
  if (dx) {
    dx->share_meta(x);
  }
  if (dy) {
    dy->share_meta(y);
  }
419 420
}

421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436
void GeneralTernaryGradInferMeta(const MetaTensor& x,
                                 const MetaTensor& y,
                                 const MetaTensor& z,
                                 MetaTensor* dx,
                                 MetaTensor* dy,
                                 MetaTensor* dz) {
  if (dx) {
    dx->share_meta(x);
  }
  if (dy) {
    dy->share_meta(y);
  }
  if (dz) {
    dz->share_meta(z);
  }
}
437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484
void GeneralQuaternaryGradInferMeta(const MetaTensor& x,
                                    const MetaTensor& y,
                                    const MetaTensor& z,
                                    const MetaTensor& k,
                                    MetaTensor* dx,
                                    MetaTensor* dy,
                                    MetaTensor* dz,
                                    MetaTensor* dk) {
  if (dx) {
    dx->share_meta(x);
  }
  if (dy) {
    dy->share_meta(y);
  }
  if (dz) {
    dz->share_meta(z);
  }
  if (dk) {
    dk->share_meta(k);
  }
}

void GeneralQuinaryGradInferMeta(const MetaTensor& x,
                                 const MetaTensor& y,
                                 const MetaTensor& z,
                                 const MetaTensor& k,
                                 const MetaTensor& l,
                                 MetaTensor* dx,
                                 MetaTensor* dy,
                                 MetaTensor* dz,
                                 MetaTensor* dk,
                                 MetaTensor* dl) {
  if (dx) {
    dx->share_meta(x);
  }
  if (dy) {
    dy->share_meta(y);
  }
  if (dz) {
    dz->share_meta(z);
  }
  if (dk) {
    dk->share_meta(k);
  }
  if (dl) {
    dl->share_meta(l);
  }
}
485

486 487 488 489 490 491
void GeneralUnaryGradInferMeta(const MetaTensor& x, MetaTensor* dx) {
  if (dx) {
    dx->share_meta(x);
  }
}

F
From00 已提交
492 493 494 495 496 497 498 499 500 501
void GumbelSoftmaxGradInferMeta(const MetaTensor& out,
                                const MetaTensor& dout,
                                int axis,
                                MetaTensor* dx) {
  PADDLE_ENFORCE_EQ(
      out.dims(),
      dout.dims(),
      errors::InvalidArgument(
          "Input(Out) and its gradients should have the same shape."));

502
  dx->share_meta(dout);
503 504
}

505
void InstanceNormGradInferMeta(const MetaTensor& x,
506
                               const MetaTensor& scale,
507 508
                               const MetaTensor& saved_mean,
                               const MetaTensor& saved_variance,
509
                               const MetaTensor& y_grad,
510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530
                               float epsilon,
                               MetaTensor* x_grad,
                               MetaTensor* scale_grad,
                               MetaTensor* bias_grad) {
  PADDLE_ENFORCE_NE(
      x_grad,
      nullptr,
      phi::errors::InvalidArgument(
          "The X@GRAD in InstanceNormGradInferMeta can't be nullptr."));
  const auto x_dims = x.dims();
  const int C = x_dims[1];
  x_grad->set_dims(x_dims);
  x_grad->set_dtype(x.dtype());
  x_grad->set_layout(x.layout());
  if (scale_grad) {
    scale_grad->set_dims({C});
  }
  if (bias_grad) {
    bias_grad->set_dims({C});
  }
}
531 532 533 534 535 536 537 538 539 540 541 542
void InstanceNormDoubleGradInferMeta(const MetaTensor& x,
                                     const MetaTensor& scale,
                                     const MetaTensor& saved_mean,
                                     const MetaTensor& saved_variance,
                                     const MetaTensor& dy,
                                     const MetaTensor& ddx,
                                     const MetaTensor& ddscale,
                                     const MetaTensor& ddbias,
                                     float epsilon,
                                     MetaTensor* dx,
                                     MetaTensor* dscale,
                                     MetaTensor* ddy) {
543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560
  PADDLE_ENFORCE_NE(
      dx,
      nullptr,
      phi::errors::InvalidArgument(
          "The DX in InstanceNormDoubleGradInferMeta can't be nullptr."));
  const auto x_dims = x.dims();
  const int C = x_dims[1];
  dx->set_dims(x_dims);
  dx->set_dtype(x.dtype());
  dx->set_layout(x.layout());
  if (dscale) {
    dscale->set_dims({C});
  }
  if (ddy) {
    ddy->share_dims(x);
  }
}

561 562 563 564 565
void InverseGradInferMeta(const MetaTensor& out,
                          const MetaTensor& dout,
                          MetaTensor* dx) {
  if (dx) {
    dx->set_dims(dout.dims());
566
    dx->set_dtype(out.dtype());
567 568 569
  }
}

570 571 572 573 574 575 576
void KernelWithXShapeInferMeta(const MetaTensor& xshape, MetaTensor* dx) {
  auto xshape_dims = xshape.dims();
  auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size());
  dx->set_dims(x_dims);
  dx->share_lod(xshape);
}

L
Lin Manhui 已提交
577 578 579 580 581 582 583 584 585 586 587 588 589 590
void LUGradInferMeta(const MetaTensor& x,
                     const MetaTensor& out,
                     const MetaTensor& pivots,
                     const MetaTensor& out_grad,
                     bool pivot,
                     MetaTensor* x_grad) {
  auto x_dims = x.dims();

  if (x_grad) {
    x_grad->set_dims(x_dims);
    x_grad->set_dtype(x.dtype());
  }
}

591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608
void LUUnpackGradInferMeta(const MetaTensor& x,
                           const MetaTensor& pivots,
                           const MetaTensor& l,
                           const MetaTensor& u,
                           const MetaTensor& pmat,
                           const MetaTensor& l_grad,
                           const MetaTensor& u_grad,
                           bool unpack_ludata,
                           bool unpack_pivots,
                           MetaTensor* x_grad) {
  auto x_dims = x.dims();

  if (x_grad) {
    x_grad->set_dims(x_dims);
    x_grad->set_dtype(x.dtype());
  }
}

609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632
void MarginCrossEntropyGradInferMeta(const MetaTensor& logits,
                                     const MetaTensor& label,
                                     const MetaTensor& softmax,
                                     const MetaTensor& loss_grad,
                                     bool return_softmax,
                                     int ring_id,
                                     int rank,
                                     int nranks,
                                     float margin1,
                                     float margin2,
                                     float margin3,
                                     float scale,
                                     MetaTensor* logits_grad) {
  PADDLE_ENFORCE_NE(
      logits_grad,
      nullptr,
      phi::errors::InvalidArgument(
          "The Logits@GRAD in MarginCrossEntropy can't be nullptr."));
  auto softmax_dims = softmax.dims();

  logits_grad->set_dims(softmax_dims);
  logits_grad->set_dtype(softmax.dtype());
}

F
From00 已提交
633 634 635 636 637 638 639 640 641 642 643 644
void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
                                   const MetaTensor& mask,
                                   const MetaTensor& dout,
                                   const std::vector<int>& kernel_size,
                                   const std::vector<int>& strides,
                                   const std::vector<int>& paddings,
                                   bool global_pooling,
                                   bool adaptive,
                                   MetaTensor* dx) {
  dx->share_meta(x);
}

Z
ZhangDY-6483 已提交
645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729
void MemoryEfficientAttentionGradInferMeta(const MetaTensor& query,
                                           const MetaTensor& key,
                                           const MetaTensor& value,
                                           const MetaTensor& bias,
                                           const MetaTensor& cu_seqlens_q,
                                           const MetaTensor& cu_seqlens_k,
                                           const MetaTensor& output,
                                           const MetaTensor& logsumexp,
                                           const MetaTensor& seed_and_offset,
                                           const MetaTensor& output_grad,
                                           const Scalar& max_seqlen_q,
                                           const Scalar& max_seqlen_k,
                                           const bool causal,
                                           const double dropout_p,
                                           const float scale,
                                           MetaTensor* query_grad,
                                           MetaTensor* key_grad,
                                           MetaTensor* value_grad,
                                           MetaTensor* bias_grad) {
  PADDLE_ENFORCE_EQ(
      output_grad.dims().size(),
      4,
      phi::errors::InvalidArgument("Key should be a 4-D tensor"
                                   "But received Key dimension(%s)",
                                   output_grad.dims().size()));
  PADDLE_ENFORCE_EQ(
      output.dims().size(),
      4,
      phi::errors::InvalidArgument("Key should be a 4-D tensor"
                                   "But received Key dimension(%s)",
                                   output_grad.dims().size()));

  const int64_t query_batch_size = query.dims()[0];
  const int64_t query_seq_length = query.dims()[1];
  const int64_t query_num_head = query.dims()[2];
  const int64_t query_head_size = query.dims()[3];

  const int64_t key_batch_size = key.dims()[0];
  const int64_t key_seq_length = key.dims()[1];
  const int64_t key_num_head = key.dims()[2];
  const int64_t key_head_size = key.dims()[3];

  const int64_t value_batch_size = value.dims()[0];
  const int64_t value_seq_length = value.dims()[1];
  const int64_t value_num_head = value.dims()[2];
  const int64_t value_head_size = value.dims()[3];

  std::vector<int64_t> query_grad_dims(
      {query_batch_size, query_seq_length, query_num_head, query_head_size});
  std::vector<int64_t> key_grad_dims(
      {key_batch_size, key_seq_length, key_num_head, key_head_size});
  std::vector<int64_t> value_grad_dims(
      {value_batch_size, value_seq_length, value_num_head, value_head_size});

  query_grad->set_dims(phi::make_ddim(query_grad_dims));
  query_grad->share_lod(query);
  query_grad->set_dtype(query.dtype());
  query_grad->set_layout(query.layout());

  key_grad->set_dims(phi::make_ddim(key_grad_dims));
  key_grad->share_lod(key);
  key_grad->set_dtype(key.dtype());
  key_grad->set_layout(key.layout());

  value_grad->set_dims(phi::make_ddim(value_grad_dims));
  value_grad->share_lod(value);
  value_grad->set_dtype(value.dtype());
  value_grad->set_layout(value.layout());

  if (bias) {
    const int64_t bias_batch_size = bias.dims()[0];
    const int64_t bias_seq_length = bias.dims()[1];
    const int64_t bias_num_head = bias.dims()[2];
    const int64_t bias_head_size = bias.dims()[3];

    std::vector<int64_t> bias_grad_dims(
        {bias_batch_size, bias_seq_length, bias_num_head, bias_head_size});

    bias_grad->set_dims(phi::make_ddim(bias_grad_dims));
    bias_grad->share_lod(bias);
    bias_grad->set_dtype(bias.dtype());
    bias_grad->set_layout(bias.layout());
  }
}

730 731
void MeshgridGradInferMeta(const std::vector<const MetaTensor*>& inputs,
                           const std::vector<const MetaTensor*>& outputs_grad,
Y
YuanRisheng 已提交
732 733 734 735 736 737 738 739 740 741 742 743
                           std::vector<MetaTensor*> inputs_grad) {
  PADDLE_ENFORCE_GT(outputs_grad.size(),
                    1,
                    errors::InvalidArgument(
                        "Number of Inputs(Out@Grad) should be larger than 1."
                        "But received Inputs(Out@Grad)' size = %d .",
                        outputs_grad.size()));
  for (size_t i = 0; i < inputs.size(); i++) {
    inputs_grad[i]->share_meta(*inputs[i]);
  }
}

744
void MultiDotGradInferMeta(const std::vector<const MetaTensor*>& x,
745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771
                           const MetaTensor& out_grad,
                           std::vector<MetaTensor*> x_grad) {
  PADDLE_ENFORCE_EQ(
      x.size(),
      x_grad.size(),
      errors::InvalidArgument(
          "Number of Inputs(X) should be equal with Outputs(X@Grad)."
          "But received Inputs(X)' size = %d , Outputs(X@Grad)' size = %d.",
          x.size(),
          x_grad.size()));
  for (size_t i = 0; i < x.size(); i++) {
    if (x_grad[i] != nullptr) {
      x_grad[i]->set_dims(x[i]->dims());
      x_grad[i]->share_lod(*x[i]);
    }
  }
}

void MultiplexGradInferMeta(const MetaTensor& ids,
                            const MetaTensor& out_grad,
                            std::vector<MetaTensor*> ins_grad) {
  PADDLE_ENFORCE_NE(
      ins_grad.empty(),
      true,
      errors::InvalidArgument("Output(X@Grad) should not be null."));
  auto dout_dim = out_grad.dims();
  for (auto in_grad : ins_grad) {
772 773 774
    if (in_grad != nullptr) {
      in_grad->set_dims(dout_dim);
    }
775 776 777
  }
}

778 779 780 781 782 783 784 785 786 787 788
void NanmedianGradInferMeta(const MetaTensor& x,
                            const MetaTensor& median_index,
                            const MetaTensor& out_grad,
                            const IntArray& axes,
                            bool keep_dim,
                            MetaTensor* x_grad) {
  auto x_dims = x.dims();
  x_grad->set_dims(x_dims);
  x_grad->set_dtype(x.dtype());
}

Z
zyfncg 已提交
789 790
void NllLossGradInferMeta(const MetaTensor& x,
                          const MetaTensor& label,
791
                          const MetaTensor& weight,
Z
zyfncg 已提交
792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854
                          const MetaTensor& total_weight,
                          const MetaTensor& out_grad,
                          int64_t ignore_index,
                          const std::string& reduction,
                          MetaTensor* dx,
                          MetaConfig config) {
  const auto& x_dims = x.dims();
  const auto& label_dims = label.dims();
  const auto& dout_dims = out_grad.dims();
  bool contain_unknown_dim =
      phi::contain_unknown_dim(x_dims) || phi::contain_unknown_dim(dout_dims);
  bool check = config.is_runtime || !contain_unknown_dim;

  if (check) {
    auto batch_size = x_dims[0];
    if (x_dims.size() == 2) {
      PADDLE_ENFORCE_EQ(dout_dims.size(),
                        1,
                        phi::errors::InvalidArgument(
                            "The dimensions of Input(Out@Grad) must be 1"));
      if (reduction == "none") {
        PADDLE_ENFORCE_EQ(
            dout_dims[0],
            batch_size,
            phi::errors::InvalidArgument(
                "The unreduced size ofInput(Out@Grad) must be the "
                "same as batch_size."));
      } else {
        PADDLE_ENFORCE_EQ(dout_dims[0],
                          1,
                          phi::errors::InvalidArgument(
                              "The reduced size of Input(Out@Grad) must be 1"));
      }
    } else if (x_dims.size() == 4) {
      if (reduction == "none") {
        PADDLE_ENFORCE_EQ(
            dout_dims.size(),
            3,
            phi::errors::InvalidArgument(
                "The dimensions of Input(Out@Grad) must be 3,But got [%s].",
                dout_dims.size()));
        PADDLE_ENFORCE_EQ(dout_dims[0] == label_dims[0] &&
                              dout_dims[1] == label_dims[1] &&
                              dout_dims[2] == label_dims[2],
                          true,
                          phi::errors::InvalidArgument(
                              "The dimensions of Input(Out@Grad) must be match "
                              "to Input(Label) dimensions."));
      } else {
        PADDLE_ENFORCE_EQ(dout_dims[0],
                          1,
                          phi::errors::InvalidArgument(
                              "The reduced size of Input(Out@Grad) must be 1"));
      }
    }
  }

  if (dx) {
    dx->set_dims(x_dims);
    dx->set_dtype(x.dtype());
  }
}

855 856 857 858 859 860 861 862 863 864 865 866
void OverlapAddGradInferMeta(const MetaTensor& x,
                             const MetaTensor& out_grad,
                             int hop_length,
                             int axis,
                             MetaTensor* x_grad) {
  const auto x_dims = x.dims();
  if (x_grad != nullptr) {
    x_grad->set_dims(x_dims);
    x_grad->set_dtype(x.dtype());
  }
}

867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896
void PixelUnshuffleGradInferMeta(const MetaTensor& out_grad,
                                 int downscale_factor,
                                 const std::string& data_format,
                                 MetaTensor* x_grad) {
  auto do_dims = out_grad.dims();
  PADDLE_ENFORCE_EQ(do_dims.size(),
                    4,
                    phi::errors::InvalidArgument(
                        "Input should be a 4-D tensor of format [N, C, H, W] "
                        "or [N, H, W, C], but got %u.",
                        do_dims.size()));

  const bool channel_last = (data_format == "NHWC");

  auto dx_dims = do_dims;
  dx_dims[0] = do_dims[0];

  if (!channel_last) {
    dx_dims[1] = do_dims[1] / (downscale_factor * downscale_factor);
    dx_dims[2] = do_dims[2] * downscale_factor;
    dx_dims[3] = do_dims[3] * downscale_factor;
  } else {
    dx_dims[1] = do_dims[1] * downscale_factor;
    dx_dims[2] = do_dims[2] * downscale_factor;
    dx_dims[3] = do_dims[3] / (downscale_factor * downscale_factor);
  }
  x_grad->set_dims(dx_dims);
  x_grad->set_dtype(out_grad.dtype());
}

897 898 899 900 901 902 903 904 905 906 907 908
void PreluGradInferMeta(const MetaTensor& x,
                        const MetaTensor& y,
                        MetaTensor* dx,
                        MetaTensor* dy) {
  if (dx) {
    dx->share_dims(x);
  }
  if (dy) {
    dy->share_dims(y);
  }
}

F
From00 已提交
909 910
void PsroiPoolGradInferMeta(const MetaTensor& x,
                            const MetaTensor& rois,
911
                            const MetaTensor& rois_num,
F
From00 已提交
912 913 914 915 916 917 918 919 920
                            const MetaTensor& dout,
                            int pooled_height,
                            int pooled_width,
                            int output_channels,
                            float spatial_scale,
                            MetaTensor* dx) {
  dx->share_meta(x);
}

Z
zyfncg 已提交
921 922 923 924 925 926
void RealAndImagGradInferMeta(const MetaTensor& out_grad, MetaTensor* dx) {
  dx->set_dims(out_grad.dims());
  dx->set_dtype(dtype::ToComplex(out_grad.dtype()));
  dx->set_layout(out_grad.layout());
}

927 928 929 930 931 932 933 934
void ReshapeDoubleGradInferMeta(const MetaTensor& out_grad,
                                const MetaTensor& x_grad_grad,
                                MetaTensor* out_grad_grad) {
  if (out_grad_grad != nullptr) {
    out_grad_grad->share_dims(out_grad);
  }
}

Y
YuanRisheng 已提交
935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961
void RnnGradInferMeta(const MetaTensor& x,
                      const std::vector<const MetaTensor*>& pre_state,
                      const std::vector<const MetaTensor*>& weight_list,
                      MetaTensor* x_grad,
                      std::vector<MetaTensor*> pre_state_grad,
                      std::vector<MetaTensor*> weight_grad_list) {
  PADDLE_ENFORCE_GT(
      pre_state.size(),
      0UL,
      phi::errors::InvalidArgument(
          "The input pre_state in RnnGradInferMeta can't be empty."));
  PADDLE_ENFORCE_GT(
      weight_grad_list.size(),
      0UL,
      phi::errors::InvalidArgument(
          "The input weight_grad_list in RnnGradInferMeta can't be empty."));
  if (x_grad) {
    UnchangedInferMeta(x, x_grad);
  }
  if (pre_state_grad.size()) {
    UnchangedMultiInferMeta(pre_state, pre_state_grad);
  }
  if (weight_grad_list.size()) {
    UnchangedMultiInferMeta(weight_list, weight_grad_list);
  }
}

962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996
void ScatterGradInferMeta(const MetaTensor& index,
                          const MetaTensor& updates,
                          const MetaTensor& out_grad,
                          bool overwrite,
                          MetaTensor* x_grad,
                          MetaTensor* updates_grad) {
  const auto& dtype = out_grad.dtype();
  if (updates_grad) {
    updates_grad->set_dims(updates.dims());
    updates_grad->set_dtype(dtype);
  }

  if (x_grad) {
    x_grad->set_dims(out_grad.dims());
    x_grad->set_dtype(dtype);
  }
}

void ScatterNdAddGradInferMeta(const MetaTensor& index,
                               const MetaTensor& updates,
                               const MetaTensor& out_grad,
                               MetaTensor* x_grad,
                               MetaTensor* updates_grad) {
  const auto& dtype = out_grad.dtype();
  if (updates_grad) {
    updates_grad->set_dims(updates.dims());
    updates_grad->set_dtype(dtype);
  }

  if (x_grad) {
    x_grad->set_dims(out_grad.dims());
    x_grad->set_dtype(dtype);
  }
}

997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011
void SpectralNormGradInferMeta(const MetaTensor& weight,
                               const MetaTensor& u,
                               const MetaTensor& v,
                               const MetaTensor& out_grad,
                               int dim,
                               int power_iters,
                               float eps,
                               MetaTensor* weight_grad) {
  auto dim_x = weight.dims();
  if (weight_grad) {
    weight_grad->set_dims(dim_x);
    weight_grad->set_dtype(out_grad.dtype());
  }
}

1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047
void StackGradInferMeta(const MetaTensor& out_grad,
                        int axis,
                        std::vector<MetaTensor*> x_grad) {
  auto dy_dim = out_grad.dims();
  int rank = dy_dim.size();
  PADDLE_ENFORCE_GE(
      axis,
      -rank,
      phi::errors::InvalidArgument(
          "Attr(axis) must be inside [-rank, rank), where rank = %d, "
          "but received axis is:%d.",
          rank,
          axis));
  PADDLE_ENFORCE_LT(
      axis,
      rank,
      phi::errors::InvalidArgument(
          "Attr(axis) must be inside [-rank, rank), where rank = %d, "
          "but received axis is:%d.",
          rank,
          axis));

  if (axis < 0) axis += rank;
  PADDLE_ENFORCE_LE(
      x_grad.size(),
      static_cast<size_t>(dy_dim[axis]),
      phi::errors::InvalidArgument(
          "Number of Outputs(X@Grad) should be less than or equal to dy dim "
          "at axis, but received outputs size is:%d, dy dims is:%d.",
          x_grad.size(),
          static_cast<size_t>(dy_dim[axis])));

  auto vec = phi::vectorize<int>(dy_dim);
  vec.erase(vec.begin() + axis);

  for (size_t i = 0; i < x_grad.size(); ++i) {
1048 1049 1050 1051
    if (x_grad[i]) {
      x_grad[i]->set_dims(phi::make_ddim(vec));
      x_grad[i]->set_dtype(out_grad.dtype());
    }
1052 1053 1054
  }
}

1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072
void UniformRandomInplaceGradInferMeta(const MetaTensor& out_grad,
                                       float min,
                                       float max,
                                       int seed,
                                       int diag_num,
                                       int diag_step,
                                       float diag_val,
                                       MetaTensor* x_grad) {
  PADDLE_ENFORCE_NE(
      x_grad,
      nullptr,
      phi::errors::InvalidArgument(
          "The X@GRAD in UniformRandomInplaceGradInferMeta can't be nullptr."));
  auto dims = out_grad.dims();
  x_grad->set_dims(dims);
  x_grad->set_dtype(out_grad.dtype());
}

1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114
void UnStackGradInferMeta(const std::vector<const MetaTensor*>& out_grad,
                          int axis,
                          MetaTensor* x_grad) {
  std::vector<phi::DDim> input_dims(out_grad.size());
  for (size_t i = 0; i < out_grad.size(); ++i) {
    input_dims[i] = out_grad[i]->dims();
  }
  for (size_t i = 1; i < input_dims.size(); ++i) {
    PADDLE_ENFORCE_EQ(
        input_dims[i],
        input_dims[0],
        phi::errors::InvalidArgument(
            "The dimensions of all Inputs(Y@Grad) must be the same,"
            "but received Inputs(Y@Grad)'s %d-th dimension is %d, "
            "Inputs(Y@Grad)'s 0-th to %d-th dimension is %d.",
            i,
            input_dims[i],
            i - 1,
            input_dims[0]));
  }

  int rank = input_dims[0].size();
  PADDLE_ENFORCE_GE(axis,
                    -(rank + 1),
                    phi::errors::InvalidArgument(
                        "The attribute axis is out of range, it must be "
                        "inside [-(rank+1), rank+1), where rank = %d",
                        rank));
  PADDLE_ENFORCE_LT(axis,
                    rank + 1,
                    phi::errors::InvalidArgument(
                        "The attribute axis is out of range, it must be "
                        "inside [-(rank+1), rank+1), where rank = %d",
                        rank));
  if (axis < 0) axis += (rank + 1);

  auto vec = phi::vectorize<int>(input_dims[0]);
  vec.insert(vec.begin() + axis, input_dims.size());
  x_grad->set_dims(phi::make_ddim(vec));
  x_grad->set_dtype(out_grad[0]->dtype());
}

1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132
void YoloLossGradInferMeta(const MetaTensor& x,
                           const MetaTensor& gt_box,
                           const MetaTensor& gt_label,
                           const MetaTensor& gt_score,
                           const MetaTensor& objectness_mask,
                           const MetaTensor& gt_match_mask,
                           const MetaTensor& loss_grad,
                           const std::vector<int>& anchors,
                           const std::vector<int>& anchor_mask,
                           int class_num,
                           float ignore_thresh,
                           int downsample_ratio,
                           bool use_label_smooth,
                           float scale_x_y,
                           MetaTensor* x_grad,
                           MetaTensor* gt_box_grad,
                           MetaTensor* gt_label_grad,
                           MetaTensor* gt_score_grad) {
1133 1134 1135 1136 1137 1138
  if (x_grad) {
    x_grad->set_dims(x.dims());
    x_grad->set_dtype(x.dtype());
  }
}

L
Li Min 已提交
1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160
void IndexAddGradInferMeta(const MetaTensor& index,
                           const MetaTensor& add_value,
                           const MetaTensor& out_grad,
                           int axis,
                           MetaTensor* x_grad,
                           MetaTensor* add_value_grad) {
  auto do_dims = out_grad.dims();
  auto add_value_dims = add_value.dims();
  if (x_grad) {
    x_grad->set_dims(do_dims);
    x_grad->set_dtype(out_grad.dtype());
    x_grad->set_layout(out_grad.layout());
    x_grad->share_lod(out_grad);
  }
  if (add_value_grad) {
    add_value_grad->set_dims(add_value_dims);
    add_value_grad->set_dtype(add_value.dtype());
    add_value_grad->set_layout(add_value.layout());
    add_value_grad->share_lod(add_value);
  }
}

1161
}  // namespace phi