diff --git a/paddle/fluid/prim/api/manual_prim/utils/utils.h b/paddle/fluid/prim/api/manual_prim/utils/utils.h index d72da6461da75d150a68441e184ced7c730f893b..d37a50c21a8e7bcb4c687c0f2ed08fe8019445d6 100644 --- a/paddle/fluid/prim/api/manual_prim/utils/utils.h +++ b/paddle/fluid/prim/api/manual_prim/utils/utils.h @@ -115,17 +115,24 @@ static std::vector unsafe_vector_cast(const std::vector& src) { } // This fucction compute unsqueeze dims for reshape to replace unsqueeze. -static std::vector get_unsqueeze_dims(const Tensor& origin, - const IntArray& axis) { +static std::vector get_unsqueeze_dims( + const Tensor& origin, const std::vector& axis) { auto origin_dims = origin.shape(); auto total_shape_size = origin_dims.size() + axis.size(); - std::vector result; - int j = 0, k = 0; + std::vector result; + size_t j = 0, k = 0; for (size_t i = 0; i < total_shape_size; ++i) { - if (axis[j] == int64_t(i)) { + if (j < axis.size() && axis[j] == int64_t(i)) { result.push_back(1); j++; } else { + PADDLE_ENFORCE_LT( + k, + origin_dims.size(), + platform::errors::OutOfRange("Your index [%lu] exceeds the number of " + "elements in origin_dims[%lu].", + k, + origin_dims.size())); result.push_back(origin_dims[k]); k++; } diff --git a/test/prim/prim/vjp/test_comp_high_grad.py b/test/prim/prim/vjp/test_comp_high_grad.py index 77fa38684d9068835a13fcd9bbf067c0220159b0..76283528e24043a0561d4d7fcdb432dd620a43c4 100644 --- a/test/prim/prim/vjp/test_comp_high_grad.py +++ b/test/prim/prim/vjp/test_comp_high_grad.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License.