backward.cc 24.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/infermeta/backward.h"
16

Z
zyfncg 已提交
17
#include "paddle/phi/common/type_traits.h"
18 19
#include "paddle/phi/kernels/funcs/axis_utils.h"

20
namespace phi {
21

W
WangZhen 已提交
22 23 24 25 26 27
void AngleGradInferMeta(const MetaTensor& x,
                        const MetaTensor& out_grad,
                        MetaTensor* x_grad) {
  UnchangedInferMeta(x, x_grad);
}

28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
void BilinearTensorProductGradInferMeta(const MetaTensor& x,
                                        const MetaTensor& y,
                                        const MetaTensor& weight,
                                        const MetaTensor& dout,
                                        MetaTensor* dx,
                                        MetaTensor* dy,
                                        MetaTensor* dweight,
                                        MetaTensor* dbias) {
  auto x_dims = x.dims();
  auto y_dims = y.dims();
  auto weight_dims = weight.dims();
  auto out_dims = dout.dims();

  PADDLE_ENFORCE_EQ(
      out_dims.size(),
      2UL,
      errors::InvalidArgument("The input(Out@GRAD) must be a 2D Tensor."));
  PADDLE_ENFORCE_EQ(
      x_dims[0],
      out_dims[0],
      errors::InvalidArgument(
          "The first dimension(batch_size) of input(Out@GRAD) must be "
          "equal to the first dimension of the Input(X)."));
  PADDLE_ENFORCE_EQ(
      weight_dims[0],
      out_dims[1],
      errors::InvalidArgument(
          "The second dimension of input(Out@GRAD) must be equal to "
          "the third dimension of the Input(Weight)."));

  if (dx) {
    dx->set_dims(x_dims);
    dx->set_dtype(x.dtype());
  }
  if (dy) {
    dy->set_dims(y_dims);
    dy->set_dtype(y.dtype());
  }
  if (dweight) {
    dweight->set_dims(weight_dims);
    dweight->set_dtype(weight.dtype());
  }
  if (dbias) {
    dbias->set_dims({1, out_dims[1]});
    dbias->set_dtype(dout.dtype());
  }
}

76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
void ChannelShuffleGradInferMeta(const MetaTensor& out_grad,
                                 int groups,
                                 const std::string& data_format,
                                 MetaTensor* x_grad) {
  auto do_dims = out_grad.dims();
  PADDLE_ENFORCE_EQ(do_dims.size(),
                    4,
                    phi::errors::InvalidArgument(
                        "Input should be a 4-D tensor of format [N, C, H, W] "
                        "or [N, H, W, C], but got %u.",
                        do_dims.size()));
  auto dx_dims = do_dims;
  x_grad->set_dims(dx_dims);
  x_grad->set_dtype(out_grad.dtype());
}

92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
void ComplexGradInferMeta(const MetaTensor& x,
                          const MetaTensor& y,
                          const MetaTensor& dout,
                          MetaTensor* dx,
                          MetaTensor* dy) {
  auto x_dims = x.dims();
  if (dx) {
    dx->set_dims(x_dims);
    dx->set_dtype(x.dtype());
  }
  auto y_dims = y.dims();
  if (dy) {
    dy->set_dims(y_dims);
    dy->set_dtype(y.dtype());
  }
}

F
From00 已提交
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
void ConvTransposeGradInferMeta(const MetaTensor& x,
                                const MetaTensor& filter,
                                const MetaTensor& dout,
                                const std::vector<int>& strides,
                                const std::vector<int>& paddings,
                                const std::vector<int>& output_padding,
                                const std::vector<int>& output_size,
                                const std::string& padding_algorithm,
                                int groups,
                                const std::vector<int>& dilations,
                                const std::string& data_format,
                                MetaTensor* dx,
                                MetaTensor* dfilter) {
  GeneralBinaryGradInferMeta(x, filter, dx, dfilter);
}

void Conv2dTransposeDoubleGradInferMeta(const MetaTensor& x,
                                        const MetaTensor& filter,
                                        const MetaTensor& dout,
                                        const MetaTensor& ddx,
                                        const MetaTensor& ddfilter,
                                        const std::vector<int>& strides,
                                        const std::vector<int>& paddings,
                                        const std::vector<int>& output_padding,
                                        const std::vector<int>& output_size,
                                        const std::string& padding_algorithm,
                                        int groups,
                                        const std::vector<int>& dilations,
                                        const std::string& data_format,
                                        MetaTensor* dx,
                                        MetaTensor* dfilter,
                                        MetaTensor* ddout) {
  GeneralBinaryGradInferMeta(x, filter, dx, dfilter);

  if (ddout) {
    ddout->share_meta(dout);
  }
}

148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
void CrossEntropyWithSoftmaxGradInferMeta(const MetaTensor& label,
                                          const MetaTensor& softmax,
                                          const MetaTensor& loss_grad,
                                          bool soft_label,
                                          bool use_softmax,
                                          bool numeric_stable_mode,
                                          int ignore_index,
                                          int axis,
                                          MetaTensor* logits_grad,
                                          MetaConfig config) {
  auto softmax_dims = softmax.dims();
  auto labels_dims = label.dims();
  auto softmax_rank = softmax_dims.size();
  PADDLE_ENFORCE_GE(axis,
                    -softmax_rank,
                    phi::errors::InvalidArgument(
                        "Attr(axis) value should be in range [-R, R-1], "
                        "R is the rank of Input(Logits)."));
  PADDLE_ENFORCE_LT(axis,
                    softmax_rank,
                    phi::errors::InvalidArgument(
                        "Attr(axis) value should be in range [-R, R-1], "
                        "R is the rank of Input(Logits)."));

  axis = phi::funcs::CanonicalAxis(axis, softmax_rank);
  for (int i = 0; i < softmax_rank; i++) {
    if (i != axis) {
      if (config.is_runtime || (softmax_dims[i] > 0 && labels_dims[i] > 0)) {
        PADDLE_ENFORCE_EQ(
            softmax_dims[i],
            labels_dims[i],
            phi::errors::InvalidArgument(
                "Input(Logits) and Input(Label) should in same shape in "
                "dimensions except axis."));
      }
    }
  }

  if (soft_label) {
    if (config.is_runtime ||
        (softmax_dims[axis] > 0 && labels_dims[axis] > 0)) {
      PADDLE_ENFORCE_EQ(softmax_dims[axis],
                        labels_dims[axis],
                        phi::errors::InvalidArgument(
                            "If Attr(soft_label) == true, "
                            "the axis dimension of "
                            "Input(X) and Input(Label) should be equal."));
    }
  } else {
    if (config.is_runtime || labels_dims[axis] > 0) {
      PADDLE_ENFORCE_EQ(
          labels_dims[axis],
          1UL,
          phi::errors::InvalidArgument("If Attr(soft_label) == false, "
                                       "the axis dimension of "
                                       "Input(Label) should be 1."));
    }
  }

  logits_grad->set_dims(softmax.dims());
  logits_grad->set_dtype(softmax.dtype());
}

211 212 213
void DeformableConvGradInferMeta(const MetaTensor& x,
                                 const MetaTensor& offset,
                                 const MetaTensor& filter,
214
                                 const MetaTensor& mask,
215 216 217 218 219 220 221 222 223 224 225 226 227
                                 const MetaTensor& out_grad,
                                 const std::vector<int>& strides,
                                 const std::vector<int>& paddings,
                                 const std::vector<int>& dilations,
                                 int deformable_groups,
                                 int groups,
                                 int im2col_step,
                                 MetaTensor* dx,
                                 MetaTensor* offset_grad,
                                 MetaTensor* filter_grad,
                                 MetaTensor* mask_grad) {
  GeneralTernaryGradInferMeta(x, offset, filter, dx, offset_grad, filter_grad);
  if (mask) {
228
    UnchangedInferMeta(mask, mask_grad);
229 230 231
  }
}

232 233 234 235 236 237 238 239 240 241 242 243
void EigGradInferMeta(const MetaTensor& out_w,
                      const MetaTensor& out_v,
                      const MetaTensor& dout_w,
                      const MetaTensor& dout_v,
                      MetaTensor* dx) {
  auto dims = out_v.dims();

  if (dx) {
    dx->set_dims(dims);
  }
}

244 245 246 247 248 249 250 251
void GatherNdGradInferMeta(const MetaTensor& x,
                           const MetaTensor& index,
                           const MetaTensor& out_grad,
                           MetaTensor* x_grad) {
  const auto& dtype = out_grad.dtype();
  x_grad->set_dims(x.dims());
  x_grad->share_lod(x);
  x_grad->set_dtype(dtype);
252 253
}

254 255 256 257
void GeneralBinaryGradInferMeta(const MetaTensor& x,
                                const MetaTensor& y,
                                MetaTensor* dx,
                                MetaTensor* dy) {
258 259 260 261 262 263
  if (dx) {
    dx->share_meta(x);
  }
  if (dy) {
    dy->share_meta(y);
  }
264 265
}

266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
void GeneralTernaryGradInferMeta(const MetaTensor& x,
                                 const MetaTensor& y,
                                 const MetaTensor& z,
                                 MetaTensor* dx,
                                 MetaTensor* dy,
                                 MetaTensor* dz) {
  if (dx) {
    dx->share_meta(x);
  }
  if (dy) {
    dy->share_meta(y);
  }
  if (dz) {
    dz->share_meta(z);
  }
}
282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329
void GeneralQuaternaryGradInferMeta(const MetaTensor& x,
                                    const MetaTensor& y,
                                    const MetaTensor& z,
                                    const MetaTensor& k,
                                    MetaTensor* dx,
                                    MetaTensor* dy,
                                    MetaTensor* dz,
                                    MetaTensor* dk) {
  if (dx) {
    dx->share_meta(x);
  }
  if (dy) {
    dy->share_meta(y);
  }
  if (dz) {
    dz->share_meta(z);
  }
  if (dk) {
    dk->share_meta(k);
  }
}

void GeneralQuinaryGradInferMeta(const MetaTensor& x,
                                 const MetaTensor& y,
                                 const MetaTensor& z,
                                 const MetaTensor& k,
                                 const MetaTensor& l,
                                 MetaTensor* dx,
                                 MetaTensor* dy,
                                 MetaTensor* dz,
                                 MetaTensor* dk,
                                 MetaTensor* dl) {
  if (dx) {
    dx->share_meta(x);
  }
  if (dy) {
    dy->share_meta(y);
  }
  if (dz) {
    dz->share_meta(z);
  }
  if (dk) {
    dk->share_meta(k);
  }
  if (dl) {
    dl->share_meta(l);
  }
}
330

331 332 333 334 335 336
void GeneralUnaryGradInferMeta(const MetaTensor& x, MetaTensor* dx) {
  if (dx) {
    dx->share_meta(x);
  }
}

F
From00 已提交
337 338 339 340 341 342 343 344 345 346
void GumbelSoftmaxGradInferMeta(const MetaTensor& out,
                                const MetaTensor& dout,
                                int axis,
                                MetaTensor* dx) {
  PADDLE_ENFORCE_EQ(
      out.dims(),
      dout.dims(),
      errors::InvalidArgument(
          "Input(Out) and its gradients should have the same shape."));

347
  dx->share_meta(dout);
348 349
}

350
void InstanceNormGradInferMeta(const MetaTensor& x,
351
                               const MetaTensor& scale,
352 353
                               const MetaTensor& saved_mean,
                               const MetaTensor& saved_variance,
354
                               const MetaTensor& y_grad,
355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375
                               float epsilon,
                               MetaTensor* x_grad,
                               MetaTensor* scale_grad,
                               MetaTensor* bias_grad) {
  PADDLE_ENFORCE_NE(
      x_grad,
      nullptr,
      phi::errors::InvalidArgument(
          "The X@GRAD in InstanceNormGradInferMeta can't be nullptr."));
  const auto x_dims = x.dims();
  const int C = x_dims[1];
  x_grad->set_dims(x_dims);
  x_grad->set_dtype(x.dtype());
  x_grad->set_layout(x.layout());
  if (scale_grad) {
    scale_grad->set_dims({C});
  }
  if (bias_grad) {
    bias_grad->set_dims({C});
  }
}
376 377 378 379 380 381 382 383 384 385 386 387
void InstanceNormDoubleGradInferMeta(const MetaTensor& x,
                                     const MetaTensor& scale,
                                     const MetaTensor& saved_mean,
                                     const MetaTensor& saved_variance,
                                     const MetaTensor& dy,
                                     const MetaTensor& ddx,
                                     const MetaTensor& ddscale,
                                     const MetaTensor& ddbias,
                                     float epsilon,
                                     MetaTensor* dx,
                                     MetaTensor* dscale,
                                     MetaTensor* ddy) {
388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405
  PADDLE_ENFORCE_NE(
      dx,
      nullptr,
      phi::errors::InvalidArgument(
          "The DX in InstanceNormDoubleGradInferMeta can't be nullptr."));
  const auto x_dims = x.dims();
  const int C = x_dims[1];
  dx->set_dims(x_dims);
  dx->set_dtype(x.dtype());
  dx->set_layout(x.layout());
  if (dscale) {
    dscale->set_dims({C});
  }
  if (ddy) {
    ddy->share_dims(x);
  }
}

406 407 408 409 410 411 412
void KernelWithXShapeInferMeta(const MetaTensor& xshape, MetaTensor* dx) {
  auto xshape_dims = xshape.dims();
  auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size());
  dx->set_dims(x_dims);
  dx->share_lod(xshape);
}

F
From00 已提交
413 414 415 416 417 418 419 420 421 422 423 424
void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
                                   const MetaTensor& mask,
                                   const MetaTensor& dout,
                                   const std::vector<int>& kernel_size,
                                   const std::vector<int>& strides,
                                   const std::vector<int>& paddings,
                                   bool global_pooling,
                                   bool adaptive,
                                   MetaTensor* dx) {
  dx->share_meta(x);
}

425 426
void MeshgridGradInferMeta(const std::vector<const MetaTensor*>& inputs,
                           const std::vector<const MetaTensor*>& outputs_grad,
Y
YuanRisheng 已提交
427 428 429 430 431 432 433 434 435 436 437 438
                           std::vector<MetaTensor*> inputs_grad) {
  PADDLE_ENFORCE_GT(outputs_grad.size(),
                    1,
                    errors::InvalidArgument(
                        "Number of Inputs(Out@Grad) should be larger than 1."
                        "But received Inputs(Out@Grad)' size = %d .",
                        outputs_grad.size()));
  for (size_t i = 0; i < inputs.size(); i++) {
    inputs_grad[i]->share_meta(*inputs[i]);
  }
}

439
void MultiDotGradInferMeta(const std::vector<const MetaTensor*>& x,
440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470
                           const MetaTensor& out_grad,
                           std::vector<MetaTensor*> x_grad) {
  PADDLE_ENFORCE_EQ(
      x.size(),
      x_grad.size(),
      errors::InvalidArgument(
          "Number of Inputs(X) should be equal with Outputs(X@Grad)."
          "But received Inputs(X)' size = %d , Outputs(X@Grad)' size = %d.",
          x.size(),
          x_grad.size()));
  for (size_t i = 0; i < x.size(); i++) {
    if (x_grad[i] != nullptr) {
      x_grad[i]->set_dims(x[i]->dims());
      x_grad[i]->share_lod(*x[i]);
    }
  }
}

void MultiplexGradInferMeta(const MetaTensor& ids,
                            const MetaTensor& out_grad,
                            std::vector<MetaTensor*> ins_grad) {
  PADDLE_ENFORCE_NE(
      ins_grad.empty(),
      true,
      errors::InvalidArgument("Output(X@Grad) should not be null."));
  auto dout_dim = out_grad.dims();
  for (auto in_grad : ins_grad) {
    in_grad->set_dims(dout_dim);
  }
}

471 472 473 474 475 476 477 478 479 480 481
void NanmedianGradInferMeta(const MetaTensor& x,
                            const MetaTensor& median_index,
                            const MetaTensor& out_grad,
                            const IntArray& axes,
                            bool keep_dim,
                            MetaTensor* x_grad) {
  auto x_dims = x.dims();
  x_grad->set_dims(x_dims);
  x_grad->set_dtype(x.dtype());
}

Z
zyfncg 已提交
482 483
void NllLossGradInferMeta(const MetaTensor& x,
                          const MetaTensor& label,
484
                          const MetaTensor& weight,
Z
zyfncg 已提交
485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547
                          const MetaTensor& total_weight,
                          const MetaTensor& out_grad,
                          int64_t ignore_index,
                          const std::string& reduction,
                          MetaTensor* dx,
                          MetaConfig config) {
  const auto& x_dims = x.dims();
  const auto& label_dims = label.dims();
  const auto& dout_dims = out_grad.dims();
  bool contain_unknown_dim =
      phi::contain_unknown_dim(x_dims) || phi::contain_unknown_dim(dout_dims);
  bool check = config.is_runtime || !contain_unknown_dim;

  if (check) {
    auto batch_size = x_dims[0];
    if (x_dims.size() == 2) {
      PADDLE_ENFORCE_EQ(dout_dims.size(),
                        1,
                        phi::errors::InvalidArgument(
                            "The dimensions of Input(Out@Grad) must be 1"));
      if (reduction == "none") {
        PADDLE_ENFORCE_EQ(
            dout_dims[0],
            batch_size,
            phi::errors::InvalidArgument(
                "The unreduced size ofInput(Out@Grad) must be the "
                "same as batch_size."));
      } else {
        PADDLE_ENFORCE_EQ(dout_dims[0],
                          1,
                          phi::errors::InvalidArgument(
                              "The reduced size of Input(Out@Grad) must be 1"));
      }
    } else if (x_dims.size() == 4) {
      if (reduction == "none") {
        PADDLE_ENFORCE_EQ(
            dout_dims.size(),
            3,
            phi::errors::InvalidArgument(
                "The dimensions of Input(Out@Grad) must be 3,But got [%s].",
                dout_dims.size()));
        PADDLE_ENFORCE_EQ(dout_dims[0] == label_dims[0] &&
                              dout_dims[1] == label_dims[1] &&
                              dout_dims[2] == label_dims[2],
                          true,
                          phi::errors::InvalidArgument(
                              "The dimensions of Input(Out@Grad) must be match "
                              "to Input(Label) dimensions."));
      } else {
        PADDLE_ENFORCE_EQ(dout_dims[0],
                          1,
                          phi::errors::InvalidArgument(
                              "The reduced size of Input(Out@Grad) must be 1"));
      }
    }
  }

  if (dx) {
    dx->set_dims(x_dims);
    dx->set_dtype(x.dtype());
  }
}

548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577
void PixelUnshuffleGradInferMeta(const MetaTensor& out_grad,
                                 int downscale_factor,
                                 const std::string& data_format,
                                 MetaTensor* x_grad) {
  auto do_dims = out_grad.dims();
  PADDLE_ENFORCE_EQ(do_dims.size(),
                    4,
                    phi::errors::InvalidArgument(
                        "Input should be a 4-D tensor of format [N, C, H, W] "
                        "or [N, H, W, C], but got %u.",
                        do_dims.size()));

  const bool channel_last = (data_format == "NHWC");

  auto dx_dims = do_dims;
  dx_dims[0] = do_dims[0];

  if (!channel_last) {
    dx_dims[1] = do_dims[1] / (downscale_factor * downscale_factor);
    dx_dims[2] = do_dims[2] * downscale_factor;
    dx_dims[3] = do_dims[3] * downscale_factor;
  } else {
    dx_dims[1] = do_dims[1] * downscale_factor;
    dx_dims[2] = do_dims[2] * downscale_factor;
    dx_dims[3] = do_dims[3] / (downscale_factor * downscale_factor);
  }
  x_grad->set_dims(dx_dims);
  x_grad->set_dtype(out_grad.dtype());
}

F
From00 已提交
578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594
void PoolGradInferMeta(const MetaTensor& x,
                       const MetaTensor& out,
                       const MetaTensor& dout,
                       const std::vector<int>& kernel_size,
                       const std::vector<int>& strides,
                       const std::vector<int>& paddings,
                       bool ceil_mode,
                       bool exclusive,
                       const std::string& data_format,
                       const std::string& pooling_type,
                       bool global_pooling,
                       bool adaptive,
                       const std::string& padding_algorithm,
                       MetaTensor* dx) {
  dx->share_meta(x);
}

F
From00 已提交
595 596
void PsroiPoolGradInferMeta(const MetaTensor& x,
                            const MetaTensor& rois,
597
                            const MetaTensor& rois_num,
F
From00 已提交
598 599 600 601 602 603 604 605 606
                            const MetaTensor& dout,
                            int pooled_height,
                            int pooled_width,
                            int output_channels,
                            float spatial_scale,
                            MetaTensor* dx) {
  dx->share_meta(x);
}

Z
zyfncg 已提交
607 608 609 610 611 612
void RealAndImagGradInferMeta(const MetaTensor& out_grad, MetaTensor* dx) {
  dx->set_dims(out_grad.dims());
  dx->set_dtype(dtype::ToComplex(out_grad.dtype()));
  dx->set_layout(out_grad.layout());
}

613 614 615 616 617 618 619 620
void ReshapeDoubleGradInferMeta(const MetaTensor& out_grad,
                                const MetaTensor& x_grad_grad,
                                MetaTensor* out_grad_grad) {
  if (out_grad_grad != nullptr) {
    out_grad_grad->share_dims(out_grad);
  }
}

621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655
void ScatterGradInferMeta(const MetaTensor& index,
                          const MetaTensor& updates,
                          const MetaTensor& out_grad,
                          bool overwrite,
                          MetaTensor* x_grad,
                          MetaTensor* updates_grad) {
  const auto& dtype = out_grad.dtype();
  if (updates_grad) {
    updates_grad->set_dims(updates.dims());
    updates_grad->set_dtype(dtype);
  }

  if (x_grad) {
    x_grad->set_dims(out_grad.dims());
    x_grad->set_dtype(dtype);
  }
}

void ScatterNdAddGradInferMeta(const MetaTensor& index,
                               const MetaTensor& updates,
                               const MetaTensor& out_grad,
                               MetaTensor* x_grad,
                               MetaTensor* updates_grad) {
  const auto& dtype = out_grad.dtype();
  if (updates_grad) {
    updates_grad->set_dims(updates.dims());
    updates_grad->set_dtype(dtype);
  }

  if (x_grad) {
    x_grad->set_dims(out_grad.dims());
    x_grad->set_dtype(dtype);
  }
}

656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691
void StackGradInferMeta(const MetaTensor& out_grad,
                        int axis,
                        std::vector<MetaTensor*> x_grad) {
  auto dy_dim = out_grad.dims();
  int rank = dy_dim.size();
  PADDLE_ENFORCE_GE(
      axis,
      -rank,
      phi::errors::InvalidArgument(
          "Attr(axis) must be inside [-rank, rank), where rank = %d, "
          "but received axis is:%d.",
          rank,
          axis));
  PADDLE_ENFORCE_LT(
      axis,
      rank,
      phi::errors::InvalidArgument(
          "Attr(axis) must be inside [-rank, rank), where rank = %d, "
          "but received axis is:%d.",
          rank,
          axis));

  if (axis < 0) axis += rank;
  PADDLE_ENFORCE_LE(
      x_grad.size(),
      static_cast<size_t>(dy_dim[axis]),
      phi::errors::InvalidArgument(
          "Number of Outputs(X@Grad) should be less than or equal to dy dim "
          "at axis, but received outputs size is:%d, dy dims is:%d.",
          x_grad.size(),
          static_cast<size_t>(dy_dim[axis])));

  auto vec = phi::vectorize<int>(dy_dim);
  vec.erase(vec.begin() + axis);

  for (size_t i = 0; i < x_grad.size(); ++i) {
692 693 694 695
    if (x_grad[i]) {
      x_grad[i]->set_dims(phi::make_ddim(vec));
      x_grad[i]->set_dtype(out_grad.dtype());
    }
696 697 698
  }
}

699
}  // namespace phi