未验证 提交 8c844356 编写于 作者: H HongyuJia 提交者: GitHub

[Tensor operants] Polish tensor operants implementation (#50634)

* polish tensor operants implementation

* change year, 2021->2023
上级 6b3c48c1
......@@ -1191,6 +1191,7 @@ cc_library(
phi_tensor
op_meta_info
phi_api
tensor_api
phi_tensor_operants
operants_manager)
......
......@@ -35,13 +35,15 @@ cc_test_old(
phi_dygraph_api
static_global_utils
static_tensor_operants
tensor_api
operants_manager)
if(NOT (NOT WITH_PYTHON AND ON_INFER))
cc_library(
init_env_utils
SRCS init_env_utils.cc
DEPS operants_manager eager_tensor_operants static_tensor_operants)
DEPS operants_manager tensor_api eager_tensor_operants
static_tensor_operants)
cc_test_old(
test_comp_eager
......
......@@ -497,6 +497,7 @@ if(WITH_PYTHON)
list(APPEND PYBIND_DEPS python)
list(APPEND PYBIND_DEPS custom_operator)
list(APPEND PYBIND_DEPS custom_operator_node)
list(APPEND PYBIND_DEPS tensor_api)
list(APPEND PYBIND_DEPS operants_manager)
list(APPEND PYBIND_DEPS eager_tensor_operants)
list(APPEND PYBIND_DEPS static_tensor_operants)
......
......@@ -524,6 +524,20 @@ class PADDLE_API Tensor final {
*/
Tensor& operator=(Tensor&& x) &;
/**
* @brief Tensor operants
*
* @param other
* @return Tensor
*/
Tensor operator+(const Tensor& other) const;
Tensor operator-(const Tensor& other) const;
Tensor operator*(const Tensor& other) const;
Tensor operator/(const Tensor& other) const;
/* Part 8: Autograd methods */
/**
......@@ -633,13 +647,5 @@ class PADDLE_API Tensor final {
std::string name_{""};
};
PADDLE_API Tensor operator+(const Tensor& x, const Tensor& y);
PADDLE_API Tensor operator-(const Tensor& x, const Tensor& y);
PADDLE_API Tensor operator*(const Tensor& x, const Tensor& y);
PADDLE_API Tensor operator/(const Tensor& x, const Tensor& y);
} // namespace experimental
} // namespace paddle
......@@ -5,19 +5,19 @@ if(WITH_GPU)
phi_tensor_raw
SRCS tensor.cc
DEPS tensor_base dense_tensor phi_api_utils phi_enforce context_pool
operants_manager)
tensor_api)
elseif(WITH_ROCM)
hip_library(
phi_tensor_raw
SRCS tensor.cc
DEPS tensor_base dense_tensor phi_api_utils phi_enforce context_pool
operants_manager)
tensor_api)
else()
cc_library(
phi_tensor_raw
SRCS tensor.cc
DEPS tensor_base dense_tensor phi_api_utils phi_enforce context_pool
operants_manager)
tensor_api)
endif()
set(api_gen_base ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/generator/api_base.py)
......@@ -388,3 +388,7 @@ cc_library(
operants_manager
SRCS ${operants_manager_source_file}
DEPS phi_enforce)
cc_library(
tensor_api
SRCS tensor_api.cc
DEPS operants_manager)
......@@ -21,7 +21,6 @@ limitations under the License. */
#include "glog/logging.h"
#include "paddle/phi/api/include/context_pool.h"
#include "paddle/phi/api/include/operants_manager.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
......@@ -434,21 +433,5 @@ void Tensor::reset_inplace_version(bool set_to_zero) {
}
}
PADDLE_API Tensor operator+(const Tensor &x, const Tensor &y) {
return paddle::OperantsManager::Instance().add(x, y);
}
PADDLE_API Tensor operator-(const Tensor &x, const Tensor &y) {
return paddle::OperantsManager::Instance().subtract(x, y);
}
PADDLE_API Tensor operator*(const Tensor &x, const Tensor &y) {
return paddle::OperantsManager::Instance().multiply(x, y);
}
PADDLE_API Tensor operator/(const Tensor &x, const Tensor &y) {
return paddle::OperantsManager::Instance().divide(x, y);
}
} // namespace experimental
} // namespace paddle
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/api/include/operants_manager.h"
namespace paddle {
namespace experimental {
Tensor Tensor::operator+(const Tensor &other) const {
return paddle::OperantsManager::Instance().add(
static_cast<const Tensor &>(*this), other);
}
Tensor Tensor::operator-(const Tensor &other) const {
return paddle::OperantsManager::Instance().subtract(
static_cast<const Tensor &>(*this), other);
}
Tensor Tensor::operator*(const Tensor &other) const {
return paddle::OperantsManager::Instance().multiply(
static_cast<const Tensor &>(*this), other);
}
Tensor Tensor::operator/(const Tensor &other) const {
return paddle::OperantsManager::Instance().divide(
static_cast<const Tensor &>(*this), other);
}
} // namespace experimental
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册