From a5a0e8fef76aa436b2d7b1c80bf221f663707121 Mon Sep 17 00:00:00 2001 From: xiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com> Date: Sat, 6 May 2023 17:27:51 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90prim=E3=80=91Elementwise=20double=20gr?= =?UTF-8?q?ad=20(#53014)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add mul doubel grad * add sub_double_grad * add add sub high test * add mutiply test * modify other unsqueeze * delete api.yaml * only for make ci run * midify unsqueeze * modify unsqueeze * tmp * modify operants gen * review modify * modify review * debug * debug * modify ci cross boundary * delete log --- paddle/fluid/prim/api/manual_prim/utils/utils.h | 17 ++++++++++++----- test/prim/prim/vjp/test_comp_high_grad.py | 2 +- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/prim/api/manual_prim/utils/utils.h b/paddle/fluid/prim/api/manual_prim/utils/utils.h index d72da6461da..d37a50c21a8 100644 --- a/paddle/fluid/prim/api/manual_prim/utils/utils.h +++ b/paddle/fluid/prim/api/manual_prim/utils/utils.h @@ -115,17 +115,24 @@ static std::vector unsafe_vector_cast(const std::vector& src) { } // This fucction compute unsqueeze dims for reshape to replace unsqueeze. -static std::vector get_unsqueeze_dims(const Tensor& origin, - const IntArray& axis) { +static std::vector get_unsqueeze_dims( + const Tensor& origin, const std::vector& axis) { auto origin_dims = origin.shape(); auto total_shape_size = origin_dims.size() + axis.size(); - std::vector result; - int j = 0, k = 0; + std::vector result; + size_t j = 0, k = 0; for (size_t i = 0; i < total_shape_size; ++i) { - if (axis[j] == int64_t(i)) { + if (j < axis.size() && axis[j] == int64_t(i)) { result.push_back(1); j++; } else { + PADDLE_ENFORCE_LT( + k, + origin_dims.size(), + platform::errors::OutOfRange("Your index [%lu] exceeds the number of " + "elements in origin_dims[%lu].", + k, + origin_dims.size())); result.push_back(origin_dims[k]); k++; } diff --git a/test/prim/prim/vjp/test_comp_high_grad.py b/test/prim/prim/vjp/test_comp_high_grad.py index 77fa38684d9..76283528e24 100644 --- a/test/prim/prim/vjp/test_comp_high_grad.py +++ b/test/prim/prim/vjp/test_comp_high_grad.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. -- GitLab