未验证 提交 2e6e188a 编写于 作者: H HongyuJia 提交者: GitHub

[Tensor Operants & Prim-Relevant] Tensor API support default value (#50928)

上级 539293e2
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/fluid/prim/utils/static/static_global_utils.h" #include "paddle/fluid/prim/utils/static/static_global_utils.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/flags.h" #include "paddle/phi/core/flags.h"
......
...@@ -31,8 +31,10 @@ using gpuStream_t = hipStream_t; ...@@ -31,8 +31,10 @@ using gpuStream_t = hipStream_t;
#include "paddle/phi/api/include/dll_decl.h" #include "paddle/phi/api/include/dll_decl.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/layout.h" #include "paddle/phi/common/layout.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/common/scalar.h"
namespace phi { namespace phi {
class DenseTensor; class DenseTensor;
...@@ -47,16 +49,6 @@ namespace paddle { ...@@ -47,16 +49,6 @@ namespace paddle {
namespace experimental { namespace experimental {
class Tensor;
template <typename T>
class ScalarBase;
using Scalar = paddle::experimental::ScalarBase<Tensor>;
template <typename T>
class IntArrayBase;
using IntArray = paddle::experimental::IntArrayBase<Tensor>;
class AbstractAutogradMeta { class AbstractAutogradMeta {
public: public:
// No AbstractAutogradMeta should be created // No AbstractAutogradMeta should be created
...@@ -684,22 +676,29 @@ class PADDLE_API Tensor final { ...@@ -684,22 +676,29 @@ class PADDLE_API Tensor final {
Tensor floor() const; Tensor floor() const;
Tensor gather_nd(const Tensor& index) const; Tensor gather_nd(const Tensor& index) const;
Tensor log() const; Tensor log() const;
Tensor roll(const IntArray& shifts, const std::vector<int64_t>& axis) const; Tensor roll(const IntArray& shifts = {},
const std::vector<int64_t>& axis = {}) const;
Tensor scatter(const Tensor& index, Tensor scatter(const Tensor& index,
const Tensor& updates, const Tensor& updates,
bool overwrite) const; bool overwrite = true) const;
Tensor scatter_nd_add(const Tensor& index, const Tensor& updates) const; Tensor scatter_nd_add(const Tensor& index, const Tensor& updates) const;
Tensor abs() const; Tensor abs() const;
Tensor assign() const; Tensor assign() const;
Tensor elementwise_pow(const Tensor& y) const; Tensor elementwise_pow(const Tensor& y) const;
Tensor expand(const IntArray& shape) const; Tensor expand(const IntArray& shape) const;
Tensor matmul(const Tensor& y, bool transpose_x, bool transpose_y) const; Tensor matmul(const Tensor& y,
Tensor max(const IntArray& axis, bool keepdim) const; bool transpose_x = false,
bool transpose_y = false) const;
Tensor max(const IntArray& axis = {}, bool keepdim = false) const;
Tensor maximum(const Tensor& y) const; Tensor maximum(const Tensor& y) const;
Tensor minimum(const Tensor& y) const; Tensor minimum(const Tensor& y) const;
Tensor scale(const Scalar& scale, float bias, bool bias_after_scale) const; Tensor scale(const Scalar& scale = 1.0,
Tensor sum(const IntArray& axis, DataType dtype, bool keepdim) const; float bias = 0.0,
Tensor tile(const IntArray& repeat_times) const; bool bias_after_scale = true) const;
Tensor sum(const IntArray& axis = {},
DataType dtype = DataType::UNDEFINED,
bool keepdim = false) const;
Tensor tile(const IntArray& repeat_times = {}) const;
}; };
PADDLE_API Tensor operator+(const Scalar& x, const Tensor& y); PADDLE_API Tensor operator+(const Scalar& x, const Tensor& y);
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/phi/common/int_array.h" #include "paddle/phi/common/int_array.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/api/lib/tensor_copy.h" #include "paddle/phi/api/lib/tensor_copy.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/phi/common/scalar.h" #include "paddle/phi/common/scalar.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/api/lib/tensor_copy.h" #include "paddle/phi/api/lib/tensor_copy.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
......
...@@ -14,8 +14,10 @@ limitations under the License. */ ...@@ -14,8 +14,10 @@ limitations under the License. */
#pragma once #pragma once
#include <vector>
#include "paddle/phi/api/ext/exception.h" #include "paddle/phi/api/ext/exception.h"
#include "paddle/phi/api/include/tensor.h" #include "paddle/phi/common/data_type.h"
namespace paddle { namespace paddle {
namespace experimental { namespace experimental {
...@@ -103,8 +105,8 @@ class IntArrayBase { ...@@ -103,8 +105,8 @@ class IntArrayBase {
bool is_from_tensor_{false}; bool is_from_tensor_{false};
}; };
using IntArray = class Tensor;
paddle::experimental::IntArrayBase<paddle::experimental::Tensor>; using IntArray = paddle::experimental::IntArrayBase<Tensor>;
} // namespace experimental } // namespace experimental
} // namespace paddle } // namespace paddle
......
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
#include <limits> #include <limits>
#include "paddle/phi/api/ext/exception.h" #include "paddle/phi/api/ext/exception.h"
#include "paddle/phi/api/include/tensor.h" #include "paddle/phi/common/data_type.h"
namespace paddle { namespace paddle {
namespace experimental { namespace experimental {
...@@ -228,7 +228,8 @@ void CopyScalar(const ScalarBase<T1>& src, ScalarBase<T2>* dst) { ...@@ -228,7 +228,8 @@ void CopyScalar(const ScalarBase<T1>& src, ScalarBase<T2>* dst) {
dst->data_.c128 = src.data_.c128; dst->data_.c128 = src.data_.c128;
} }
using Scalar = paddle::experimental::ScalarBase<paddle::experimental::Tensor>; class Tensor;
using Scalar = paddle::experimental::ScalarBase<Tensor>;
} // namespace experimental } // namespace experimental
} // namespace paddle } // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册