未验证 提交 2e6e188a 编写于 作者: H HongyuJia 提交者: GitHub

[Tensor Operants & Prim-Relevant] Tensor API support default value (#50928)

上级 539293e2
......@@ -28,6 +28,7 @@
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/fluid/prim/utils/static/static_global_utils.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/flags.h"
......
......@@ -31,8 +31,10 @@ using gpuStream_t = hipStream_t;
#include "paddle/phi/api/include/dll_decl.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/scalar.h"
namespace phi {
class DenseTensor;
......@@ -47,16 +49,6 @@ namespace paddle {
namespace experimental {
class Tensor;
template <typename T>
class ScalarBase;
using Scalar = paddle::experimental::ScalarBase<Tensor>;
template <typename T>
class IntArrayBase;
using IntArray = paddle::experimental::IntArrayBase<Tensor>;
class AbstractAutogradMeta {
public:
// No AbstractAutogradMeta should be created
......@@ -684,22 +676,29 @@ class PADDLE_API Tensor final {
Tensor floor() const;
Tensor gather_nd(const Tensor& index) const;
Tensor log() const;
Tensor roll(const IntArray& shifts, const std::vector<int64_t>& axis) const;
Tensor roll(const IntArray& shifts = {},
const std::vector<int64_t>& axis = {}) const;
Tensor scatter(const Tensor& index,
const Tensor& updates,
bool overwrite) const;
bool overwrite = true) const;
Tensor scatter_nd_add(const Tensor& index, const Tensor& updates) const;
Tensor abs() const;
Tensor assign() const;
Tensor elementwise_pow(const Tensor& y) const;
Tensor expand(const IntArray& shape) const;
Tensor matmul(const Tensor& y, bool transpose_x, bool transpose_y) const;
Tensor max(const IntArray& axis, bool keepdim) const;
Tensor matmul(const Tensor& y,
bool transpose_x = false,
bool transpose_y = false) const;
Tensor max(const IntArray& axis = {}, bool keepdim = false) const;
Tensor maximum(const Tensor& y) const;
Tensor minimum(const Tensor& y) const;
Tensor scale(const Scalar& scale, float bias, bool bias_after_scale) const;
Tensor sum(const IntArray& axis, DataType dtype, bool keepdim) const;
Tensor tile(const IntArray& repeat_times) const;
Tensor scale(const Scalar& scale = 1.0,
float bias = 0.0,
bool bias_after_scale = true) const;
Tensor sum(const IntArray& axis = {},
DataType dtype = DataType::UNDEFINED,
bool keepdim = false) const;
Tensor tile(const IntArray& repeat_times = {}) const;
};
PADDLE_API Tensor operator+(const Scalar& x, const Tensor& y);
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/api/lib/tensor_copy.h"
#include "paddle/phi/common/place.h"
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/api/lib/tensor_copy.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/enforce.h"
......
......@@ -14,8 +14,10 @@ limitations under the License. */
#pragma once
#include <vector>
#include "paddle/phi/api/ext/exception.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/data_type.h"
namespace paddle {
namespace experimental {
......@@ -103,8 +105,8 @@ class IntArrayBase {
bool is_from_tensor_{false};
};
using IntArray =
paddle::experimental::IntArrayBase<paddle::experimental::Tensor>;
class Tensor;
using IntArray = paddle::experimental::IntArrayBase<Tensor>;
} // namespace experimental
} // namespace paddle
......
......@@ -18,7 +18,7 @@ limitations under the License. */
#include <limits>
#include "paddle/phi/api/ext/exception.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/data_type.h"
namespace paddle {
namespace experimental {
......@@ -228,7 +228,8 @@ void CopyScalar(const ScalarBase<T1>& src, ScalarBase<T2>* dst) {
dst->data_.c128 = src.data_.c128;
}
using Scalar = paddle::experimental::ScalarBase<paddle::experimental::Tensor>;
class Tensor;
using Scalar = paddle::experimental::ScalarBase<Tensor>;
} // namespace experimental
} // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册