未验证 提交 a5a0e8fe 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【prim】Elementwise double grad (#53014)

* add mul doubel grad

* add sub_double_grad

* add add sub high test

* add mutiply test

* modify other unsqueeze

* delete api.yaml

* only for make ci run

* midify unsqueeze

* modify unsqueeze

* tmp

* modify operants gen

* review modify

* modify review

* debug

* debug

* modify ci cross boundary

* delete log
上级 f5476dad
......@@ -115,17 +115,24 @@ static std::vector<DST_T> unsafe_vector_cast(const std::vector<SRC_T>& src) {
}
// This fucction compute unsqueeze dims for reshape to replace unsqueeze.
static std::vector<int> get_unsqueeze_dims(const Tensor& origin,
const IntArray& axis) {
static std::vector<int64_t> get_unsqueeze_dims(
const Tensor& origin, const std::vector<int64_t>& axis) {
auto origin_dims = origin.shape();
auto total_shape_size = origin_dims.size() + axis.size();
std::vector<int> result;
int j = 0, k = 0;
std::vector<int64_t> result;
size_t j = 0, k = 0;
for (size_t i = 0; i < total_shape_size; ++i) {
if (axis[j] == int64_t(i)) {
if (j < axis.size() && axis[j] == int64_t(i)) {
result.push_back(1);
j++;
} else {
PADDLE_ENFORCE_LT(
k,
origin_dims.size(),
platform::errors::OutOfRange("Your index [%lu] exceeds the number of "
"elements in origin_dims[%lu].",
k,
origin_dims.size()));
result.push_back(origin_dims[k]);
k++;
}
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册