未验证 提交 90dad8b2 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Add cast method for Tensor and rename to method to copy_to (#37423)

* rename to api to copy_to

* support cast method for tensor

* fix compile failed
上级 73f4601d
...@@ -86,16 +86,13 @@ class AbstractAutogradMeta { ...@@ -86,16 +86,13 @@ class AbstractAutogradMeta {
class PD_DLL_DECL Tensor final { class PD_DLL_DECL Tensor final {
public: public:
/* Part 1: Construction and destruction methods */
/** /**
* @brief Construct a new Tensor object * @brief Construct a new Tensor object
*/ */
Tensor() = default; Tensor() = default;
/**
* @brief Construct a new Tensor object with name
* */
explicit Tensor(const std::string& name) { name_ = name; }
/** /**
* @brief Construct a new Tensor object by copy * @brief Construct a new Tensor object by copy
*/ */
...@@ -132,18 +129,14 @@ class PD_DLL_DECL Tensor final { ...@@ -132,18 +129,14 @@ class PD_DLL_DECL Tensor final {
Tensor(const PlaceType& place, const std::vector<int64_t>& shape); Tensor(const PlaceType& place, const std::vector<int64_t>& shape);
/** /**
* @brief Return the name of Tensor. * @brief Construct a new Tensor object with name
* *
* @return const std::string& * @note Used to adapt original execution mechanism and debug analysis
*/ * in the development of new dygraph. It may be removed in the future.
const std::string& name() const { return name_; } * */
explicit Tensor(const std::string& name) : name_(name) {}
/** /* Part 2: Dimension, DataType and DataLayout methods */
* @brief Set name of Tensor.
*
* @param const std::string& name
*/
void set_name(const std::string& name) { name_ = name; }
/** /**
* @brief Return the number of elements of Tensor. * @brief Return the number of elements of Tensor.
...@@ -179,7 +172,7 @@ class PD_DLL_DECL Tensor final { ...@@ -179,7 +172,7 @@ class PD_DLL_DECL Tensor final {
/** /**
* @brief Reset the shape of the tensor. * @brief Reset the shape of the tensor.
* Note: This method means Reset the shape of the tensor, * @note: This method means Reset the shape of the tensor,
* and must be called before calling mutable_data() or * and must be called before calling mutable_data() or
* copy_to(const PlaceType& place), this is not a standard definition of * copy_to(const PlaceType& place), this is not a standard definition of
* reshape behavior, so we will deprecated this feature in the future. * reshape behavior, so we will deprecated this feature in the future.
...@@ -329,14 +322,33 @@ class PD_DLL_DECL Tensor final { ...@@ -329,14 +322,33 @@ class PD_DLL_DECL Tensor final {
gpuStream_t stream() const; gpuStream_t stream() const;
#endif #endif
/**
* @brief Return the name of Tensor.
* @note Used to adapt original execution mechanism and debug analysis
* in the development of new dygraph. It may be removed in the future.
*
* @return const std::string&
*/
const std::string& name() const { return name_; }
/**
* @brief Set name of Tensor.
* @note Used to adapt original execution mechanism and debug analysis
* in the development of new dygraph. It may be removed in the future.
*
* @param const std::string& name
*/
void set_name(const std::string& name) { name_ = name; }
/* Part 5: Data Transform methods */ /* Part 5: Data Transform methods */
/** /**
* @brief Copy the current Tensor data to the specified device * @brief Copy the current Tensor data to the specified device
* and return the new Tensor. It's usually used to set the input tensor data. * and return the new Tensor. It's usually used to set the input tensor data.
* Note: The Tensor's `copy_to` method is deprecated since version 2.3, and * @note The Tensor's `copy_to` method is deprecated since version 2.3, and
* will be removed in version 2.4, please use `to` method instead. reason: * will be removed in version 2.4, please use `copy_to` method without
* copying a Tensor to another device does not need to specify the * template argument instead.
* reason: copying a Tensor to another device does not need to specify the
* data type template argument * data type template argument
* *
* @tparam T * @tparam T
...@@ -352,9 +364,8 @@ class PD_DLL_DECL Tensor final { ...@@ -352,9 +364,8 @@ class PD_DLL_DECL Tensor final {
* @param place, the target place of which the tensor will copy to. * @param place, the target place of which the tensor will copy to.
* @return Tensor * @return Tensor
*/ */
// TODO(chenweihang): replace Backend by new Place, may be append dtype and // TODO(chenweihang): replace Backend by new Place
// layout arguments in the future Tensor copy_to(Backend backend, bool blocking) const;
Tensor to(Backend backend, bool blocking) const;
/** /**
* @brief Cast datatype from one to another * @brief Cast datatype from one to another
...@@ -362,7 +373,7 @@ class PD_DLL_DECL Tensor final { ...@@ -362,7 +373,7 @@ class PD_DLL_DECL Tensor final {
* @param target_type * @param target_type
* @return Tensor * @return Tensor
*/ */
Tensor cast(const DataType& target_type) const; Tensor cast(DataType target_type) const;
/* Part 6: Status utils methods */ /* Part 6: Status utils methods */
...@@ -470,7 +481,7 @@ class PD_DLL_DECL Tensor final { ...@@ -470,7 +481,7 @@ class PD_DLL_DECL Tensor final {
std::shared_ptr<AbstractAutogradMeta> autograd_meta_{nullptr}; std::shared_ptr<AbstractAutogradMeta> autograd_meta_{nullptr};
/** /**
* Tensor name: used for adapt original execution mechanism and debug analysis * Tensor name: used to adapt original execution mechanism and debug analysis
* in the development of new dygraph. It may be removed in the future. * in the development of new dygraph. It may be removed in the future.
*/ */
std::string name_; std::string name_;
......
...@@ -21,8 +21,7 @@ namespace paddle { ...@@ -21,8 +21,7 @@ namespace paddle {
namespace experimental { namespace experimental {
// TODO(chenweihang): Replace backend by place when place is ready // TODO(chenweihang): Replace backend by place when place is ready
// TODO(chenweihang): Add layout and dtype argument if needed PD_DLL_DECL Tensor copy_to(const Tensor& x, Backend backend, bool blocking);
PD_DLL_DECL Tensor to(const Tensor& x, Backend backend, bool blocking);
} // namespace experimental } // namespace experimental
} // namespace paddle } // namespace paddle
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/pten/api/include/manipulation.h"
#include "paddle/pten/api/include/utils.h" #include "paddle/pten/api/include/utils.h"
#include "paddle/pten/api/lib/ext_compat_utils.h" #include "paddle/pten/api/lib/ext_compat_utils.h"
#include "paddle/pten/api/lib/utils/allocator.h" #include "paddle/pten/api/lib/utils/allocator.h"
...@@ -281,11 +282,11 @@ gpuStream_t Tensor::stream() const { ...@@ -281,11 +282,11 @@ gpuStream_t Tensor::stream() const {
template <typename T> template <typename T>
Tensor Tensor::copy_to(const PlaceType &target_place) const { Tensor Tensor::copy_to(const PlaceType &target_place) const {
LOG(WARNING) << "The Tensor's `copy_to` method is deprecated since version " LOG(WARNING) << "The Tensor's `copy_to` method is deprecated since version "
"2.3, and will be removed in version 2.4, please use `to` " "2.3, and will be removed in version 2.4, please use "
"method instead. " "`copy_to` method without template argumentinstead. "
"reason: copying a Tensor to another device does not need " "reason: copying a Tensor to another device does not need "
"to specify the data type template argument."; "to specify the data type template argument.";
return to(ConvertExtPlaceToBackend(target_place), /*blocking=*/false); return copy_to(ConvertExtPlaceToBackend(target_place), /*blocking=*/false);
} }
template PD_DLL_DECL Tensor template PD_DLL_DECL Tensor
...@@ -311,15 +312,12 @@ template PD_DLL_DECL Tensor Tensor::copy_to<paddle::platform::complex<double>>( ...@@ -311,15 +312,12 @@ template PD_DLL_DECL Tensor Tensor::copy_to<paddle::platform::complex<double>>(
template PD_DLL_DECL Tensor template PD_DLL_DECL Tensor
Tensor::copy_to<paddle::platform::float16>(const PlaceType &target_place) const; Tensor::copy_to<paddle::platform::float16>(const PlaceType &target_place) const;
Tensor Tensor::to(Backend backend, bool blocking) const { Tensor Tensor::copy_to(Backend backend, bool blocking) const {
return experimental::to(*this, backend, blocking); return experimental::copy_to(*this, backend, blocking);
} }
Tensor Tensor::cast(const DataType &target_type) const { Tensor Tensor::cast(DataType target_type) const {
PADDLE_THROW(platform::errors::Unimplemented( return experimental::cast(*this, target_type);
"The cast operation is not supported now, "
"and it will be implemented by calling the cast kernel later."));
return Tensor();
} }
/* Part 6: Status utils methods */ /* Part 6: Status utils methods */
......
...@@ -34,7 +34,7 @@ PT_DECLARE_MODULE(UtilsCUDA); ...@@ -34,7 +34,7 @@ PT_DECLARE_MODULE(UtilsCUDA);
namespace paddle { namespace paddle {
namespace experimental { namespace experimental {
PD_DLL_DECL Tensor to(const Tensor& x, Backend backend, bool blocking) { PD_DLL_DECL Tensor copy_to(const Tensor& x, Backend backend, bool blocking) {
// 1. Get kernel signature and kernel // 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x); auto kernel_key_set = ParseKernelKeyByInputArgs(x);
kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(backend); kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(backend);
......
if(WITH_ROCM) if(WITH_ROCM)
hip_test(test_pten_tensor SRCS test_pten_tensor.cc DEPS pten_tensor utils_api glog) hip_test(test_pten_tensor SRCS test_pten_tensor.cc DEPS pten_tensor utils_api manipulation_api glog)
else() else()
cc_test(test_pten_tensor SRCS test_pten_tensor.cc DEPS pten_tensor utils_api glog) cc_test(test_pten_tensor SRCS test_pten_tensor.cc DEPS pten_tensor utils_api manipulation_api glog)
endif() endif()
cc_test(test_pten_exception SRCS test_pten_exception.cc DEPS gtest) cc_test(test_pten_exception SRCS test_pten_exception.cc DEPS gtest)
......
...@@ -15,17 +15,15 @@ limitations under the License. */ ...@@ -15,17 +15,15 @@ limitations under the License. */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
#include "paddle/pten/api/include/creation.h"
#include "paddle/pten/api/include/manipulation.h" #include "paddle/pten/api/include/manipulation.h"
#include "paddle/pten/api/lib/utils/allocator.h" #include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/core/kernel_registry.h"
PT_DECLARE_MODULE(ManipulationCPU); namespace pten {
namespace tests {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_MODULE(ManipulationCUDA);
#endif
namespace framework = paddle::framework; namespace framework = paddle::framework;
using DDim = paddle::framework::DDim; using DDim = paddle::framework::DDim;
...@@ -67,3 +65,24 @@ TEST(API, cast) { ...@@ -67,3 +65,24 @@ TEST(API, cast) {
ASSERT_NEAR(dense_out_data[i], static_cast<double>(dense_x_data[i]), 1e-6f); ASSERT_NEAR(dense_out_data[i], static_cast<double>(dense_x_data[i]), 1e-6f);
} }
} }
TEST(Tensor, cast) {
auto x = paddle::experimental::full({3, 4}, 1.0, pten::DataType::FLOAT32);
auto y = x.cast(pten::DataType::INT32);
// check slice result
ASSERT_EQ(y.dims().size(), 2);
ASSERT_EQ(y.dims()[0], 3);
ASSERT_EQ(y.dims()[1], 4);
ASSERT_EQ(y.numel(), 12);
ASSERT_EQ(y.is_cpu(), true);
ASSERT_EQ(y.type(), pten::DataType::INT32);
ASSERT_EQ(y.layout(), pten::DataLayout::NCHW);
ASSERT_EQ(y.initialized(), true);
for (int64_t i = 0; i < y.numel(); ++i) {
ASSERT_EQ(y.mutable_data<int>()[i], 1);
}
}
} // namespace tests
} // namespace pten
...@@ -58,39 +58,39 @@ void CheckOutputResult(const paddle::experimental::Tensor& out) { ...@@ -58,39 +58,39 @@ void CheckOutputResult(const paddle::experimental::Tensor& out) {
} }
} }
TEST(API, to) { TEST(API, copy_to) {
// 1. create tensor // 1. create tensor
auto x = CreateInputTensor(); auto x = CreateInputTensor();
// 2. test API // 2. test API
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto tmp = paddle::experimental::to(x, pten::Backend::CUDA, false); auto tmp = paddle::experimental::copy_to(x, pten::Backend::CUDA, false);
auto out = paddle::experimental::to(tmp, pten::Backend::CPU, true); auto out = paddle::experimental::copy_to(tmp, pten::Backend::CPU, true);
#else #else
auto out = paddle::experimental::to(x, pten::Backend::CPU, false); auto out = paddle::experimental::copy_to(x, pten::Backend::CPU, false);
#endif #endif
// 3. check result // 3. check result
CheckOutputResult(out); CheckOutputResult(out);
} }
TEST(Tensor, to) { TEST(Tensor, copy_to) {
// 1. create tensor // 1. create tensor
auto x = CreateInputTensor(); auto x = CreateInputTensor();
// 2. test API // 2. test API
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto tmp = x.to(pten::Backend::CUDA, false); auto tmp = x.copy_to(pten::Backend::CUDA, false);
auto out = tmp.to(pten::Backend::CPU, true); auto out = tmp.copy_to(pten::Backend::CPU, true);
#else #else
auto out = x.to(pten::Backend::CPU, false); auto out = x.copy_to(pten::Backend::CPU, false);
#endif #endif
// 3. check result // 3. check result
CheckOutputResult(out); CheckOutputResult(out);
} }
TEST(Tensor, copy_to) { TEST(Tensor, old_copy_to) {
// 1. create tensor // 1. create tensor
auto x = CreateInputTensor(); auto x = CreateInputTensor();
......
...@@ -23,12 +23,6 @@ limitations under the License. */ ...@@ -23,12 +23,6 @@ limitations under the License. */
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/core/kernel_registry.h"
PT_DECLARE_MODULE(ManipulationCPU);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_MODULE(ManipulationCUDA);
#endif
namespace framework = paddle::framework; namespace framework = paddle::framework;
using DDim = paddle::framework::DDim; using DDim = paddle::framework::DDim;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册